Unverified Commit 452822a4 authored by q.yao's avatar q.yao Committed by GitHub
Browse files

Add flashattention2 (#196)



* first

* fix causal mask

* disable flash attention2 on sm70

* fix 2

* update readme

* clang-format

* disable ft2 on windows

* fix lint

* fix build

* fix build

* fix long kv seq

* fix lint

* sync copy output

---------
Co-authored-by: default avatargrimoire <yaoqian@pjlab.org.cn>
Co-authored-by: default avatarirexyc <irexyc@gmail.com>
parent d4d609bd
......@@ -47,7 +47,7 @@ include(FetchContent)
FetchContent_Declare(
repo-cutlass
GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git
GIT_TAG cc85b64cf676c45f98a17e3a47c0aafcf817f088
GIT_TAG 6f47420213f757831fae65c686aa471749fa8d60
)
set(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
......@@ -312,6 +312,7 @@ add_library(transformer-shared SHARED
$<TARGET_OBJECTS:BaseSamplingLayer>
$<TARGET_OBJECTS:DynamicDecodeLayer>
$<TARGET_OBJECTS:llama_fmha>
$<TARGET_OBJECTS:flash_attention2>
$<TARGET_OBJECTS:Llama>
$<TARGET_OBJECTS:LlamaTritonBackend>
$<TARGET_OBJECTS:gemm_s4_f16>
......
......@@ -20,6 +20,7 @@ ______________________________________________________________________
## News 🎉
- \[2023/08\] TurboMind supports flash-attention2.
- \[2023/08\] TurboMind supports Qwen-7B, dynamic NTK-RoPE scaling and dynamic logN scaling
- \[2023/08\] TurboMind supports Windows (tp=1)
- \[2023/08\] TurboMind supports 4-bit inference, 2.4x faster than FP16, the fastest open-source implementation🚀. Check [this](./docs/en/w4a16.md) guide for detailed info
......
......@@ -20,6 +20,7 @@ ______________________________________________________________________
## 更新 🎉
- \[2023/08\] TurboMind 支持 flash-attention2
- \[2023/08\] TurboMind 支持 Qwen-7B,动态NTK-RoPE缩放,动态logN缩放
- \[2023/08\] TurboMind 支持 Windows (tp=1)
- \[2023/08\] TurboMind 支持 4-bit 推理,速度是 FP16 的 2.4 倍,是目前最快的开源实现🚀。部署方式请看[这里](./docs/zh_cn/w4a16.md)
......
......@@ -41,5 +41,10 @@ target_link_libraries(Llama PUBLIC CUDA::cudart
logger
llama_fmha)
if (NOT MSVC)
add_subdirectory(flash_attention2)
target_link_libraries(Llama PUBLIC flash_attention2)
endif()
add_executable(llama_gemm llama_gemm.cc)
target_link_libraries(llama_gemm PUBLIC CUDA::cudart gpt_gemm_func memory_utils cuda_utils logger)
......@@ -264,6 +264,7 @@ void LlamaContextAttentionLayer<T>::fusedMultiHeadAttention(T** key_cache_ptr
{
//////////////////////////////////////////////
// flash attention
// flash attention 2 only support half inputs
using AttentionOp = FlashAttentionOp<T>;
using Layout = typename AttentionOp::AttentionLayout;
Layout layout_q{
......
cmake_minimum_required(VERSION 3.8)
project(flash_attention2)
add_library(${PROJECT_NAME} STATIC
flash_api.cpp
flash_fwd_hdim32_fp16_sm80.cu
flash_fwd_hdim64_fp16_sm80.cu
flash_fwd_hdim128_fp16_sm80.cu
flash_fwd_hdim256_fp16_sm80.cu
)
target_include_directories(${PROJECT_NAME} PRIVATE ${CUTLASS_DIR} / include)
target_link_libraries(${PROJECT_NAME} PRIVATE nvidia::cutlass::cutlass)
set_property(TARGET ${PROJECT_NAME} PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET ${PROJECT_NAME} PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
#Flash Attention 2
This is flash attention2 implementation modified from https://github.com/Dao-AILab/flash-attention
- remove dropout
- remove backward
- cutlass 3.1.0
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
namespace flash {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool Varlen = true>
struct BlockInfo {
template<typename Params>
__device__ BlockInfo(const Params& params, const int bidb):
sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]),
sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb]),
actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q :
params.cu_seqlens_q[bidb + 1] - sum_s_q),
actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k :
params.cu_seqlens_k[bidb + 1] - sum_s_k)
{
}
template<typename index_t>
inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const
{
return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
}
template<typename index_t>
inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const
{
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
}
const int sum_s_q;
const int sum_s_k;
const int actual_seqlen_q;
const int actual_seqlen_k;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace flash
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
// modify from: https://github.com/Dao-AILab/flash-attention
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1;
constexpr int D_DIM = 2;
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Qkv_params {
using index_t = uint32_t;
// The QKV matrices.
void* __restrict__ q_ptr;
void* __restrict__ k_ptr;
void* __restrict__ v_ptr;
// batched ptr inputs.
void** __restrict__ k_batched_ptr = nullptr;
void** __restrict__ v_batched_ptr = nullptr;
int k_batched_offset = 0;
int v_batched_offset = 0;
// The stride between rows of the Q, K and V matrices.
index_t q_batch_stride;
index_t k_batch_stride;
index_t v_batch_stride;
index_t q_row_stride;
index_t k_row_stride;
index_t v_row_stride;
index_t q_head_stride;
index_t k_head_stride;
index_t v_head_stride;
// The number of heads.
int h, h_k;
// In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
// different from nheads (query).
int h_h_k_ratio; // precompute h / h_k,
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Flash_fwd_params: public Qkv_params {
// The O matrix (output).
void* __restrict__ o_ptr;
// The stride between rows of O.
index_t o_batch_stride;
index_t o_row_stride;
index_t o_head_stride;
// The pointer to the P matrix.
void* __restrict__ p_ptr;
// The dimensions.
int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded;
// The scaling factors for the kernel.
float scale_softmax;
float scale_softmax_log2;
// array of length b+1 holding starting offset of each sequence.
int* __restrict__ cu_seqlens_q;
int* __restrict__ cu_seqlens_k;
void* __restrict__ blockmask;
bool is_bf16;
bool is_causal;
// enable output seqlen
bool q_enable_seqlen;
bool o_enable_seqlen;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int Headdim>
void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream);
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
// modify from: https://github.com/Dao-AILab/flash-attention
#include "flash.h"
#include "src/turbomind/models/llama/llama_kernels.h"
#include "static_switch.h"
#include <cuda_runtime.h>
#include <cutlass/numeric_types.h>
#include <math.h>
void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream)
{
FP16_SWITCH(true,
[&] { FWD_HEADDIM_SWITCH(params.d, [&] { run_mha_fwd_<elem_type, kHeadDim>(params, stream); }); });
}
namespace turbomind {
static constexpr int FMHA_VERSION = 2;
template<typename T>
class FlashAttentionOpImpl<T, FMHA_VERSION> {
public:
using AttentionLayout = BaseAttentionLayout<T>;
using Params = BaseAttentionParams<T>;
public:
FlashAttentionOpImpl(int batch_size, int head_num, int key_len, int seq_len, int size_per_head);
~FlashAttentionOpImpl();
int get_workspace_size() const;
void operator()(Params& params, cudaStream_t st) const;
private:
class impl;
std::unique_ptr<impl> pimpl;
};
template<typename T>
class FlashAttentionOpImpl<T, FMHA_VERSION>::impl {
private:
using scalar_t =
typename std::conditional_t<std::is_same<half, typename std::decay<T>::type>::value, cutlass::half_t, T>;
using Params = typename FlashAttentionOpImpl<T, FMHA_VERSION>::Params;
int batch_size_;
int head_num_;
int key_len_;
int seq_len_;
int size_per_head_;
public:
impl(int batch_size, int head_num, int key_len, int seq_len, int size_per_head):
batch_size_(batch_size),
head_num_(head_num),
key_len_(key_len),
seq_len_(seq_len),
size_per_head_(size_per_head)
{
}
~impl() {}
int get_workspace_size() const
{
return 0;
}
void operator()(Params& params, cudaStream_t st) const
{
const float qk_scale = static_cast<float>(1.f / sqrtf(size_per_head_ * 1.f));
Flash_fwd_params fwd_params;
memset(&fwd_params, 0, sizeof(fwd_params));
fwd_params.q_ptr = reinterpret_cast<void*>(params.query);
fwd_params.k_ptr = reinterpret_cast<void*>(params.key);
fwd_params.v_ptr = reinterpret_cast<void*>(params.val);
fwd_params.k_batched_ptr = reinterpret_cast<void**>(params.layout_k.batch_seqs);
fwd_params.v_batched_ptr = reinterpret_cast<void**>(params.layout_v.batch_seqs);
fwd_params.k_batched_offset = params.layout_k.batch_seqs_offset;
fwd_params.v_batched_offset = params.layout_v.batch_seqs_offset;
fwd_params.q_batch_stride = params.layout_q.stride_batch;
fwd_params.k_batch_stride = params.layout_k.stride_batch;
fwd_params.v_batch_stride = params.layout_v.stride_batch;
fwd_params.q_row_stride = params.layout_q.stride_seq;
fwd_params.k_row_stride = params.layout_k.stride_seq;
fwd_params.v_row_stride = params.layout_v.stride_seq;
fwd_params.q_head_stride = params.layout_q.stride_head;
fwd_params.v_head_stride = params.layout_v.stride_head;
fwd_params.k_head_stride = params.layout_k.stride_head;
fwd_params.h = head_num_;
fwd_params.h_k = head_num_ / params.group_size;
fwd_params.h_h_k_ratio = params.group_size;
fwd_params.o_ptr = reinterpret_cast<void*>(params.attn_out);
fwd_params.o_batch_stride = params.layout_o.stride_batch;
fwd_params.o_row_stride = params.layout_o.stride_seq;
fwd_params.o_head_stride = params.layout_o.stride_head;
fwd_params.p_ptr = nullptr;
fwd_params.b = batch_size_;
fwd_params.seqlen_q = seq_len_;
fwd_params.seqlen_k = key_len_;
fwd_params.d = size_per_head_;
fwd_params.seqlen_q_rounded = 0;
fwd_params.seqlen_k_rounded = 0;
fwd_params.scale_softmax = qk_scale;
fwd_params.scale_softmax_log2 = qk_scale * M_LOG2E;
fwd_params.cu_seqlens_q = params.cu_seqlens_q;
fwd_params.cu_seqlens_k = params.cu_seqlens_k;
fwd_params.blockmask = reinterpret_cast<void*>(params.mask);
fwd_params.is_bf16 = false;
fwd_params.is_causal = true;
fwd_params.q_enable_seqlen = params.layout_q.use_seqlens;
fwd_params.o_enable_seqlen = params.layout_o.use_seqlens;
run_mha_fwd(fwd_params, st);
}
};
template<typename T>
FlashAttentionOpImpl<T, FMHA_VERSION>::FlashAttentionOpImpl(
int batch_size, int head_num, int key_len, int seq_len, int size_per_head):
pimpl{std::make_unique<FlashAttentionOpImpl<T, FMHA_VERSION>::impl>(
batch_size, head_num, key_len, seq_len, size_per_head)}
{
}
template<typename T>
FlashAttentionOpImpl<T, FMHA_VERSION>::~FlashAttentionOpImpl()
{
}
template<typename T>
int FlashAttentionOpImpl<T, FMHA_VERSION>::get_workspace_size() const
{
return pimpl->get_workspace_size();
}
template<typename T>
void FlashAttentionOpImpl<T, FMHA_VERSION>::operator()(Params& params, cudaStream_t st) const
{
pimpl->operator()(params, st);
}
template class FlashAttentionOpImpl<float, FMHA_VERSION>;
template class FlashAttentionOpImpl<half, FMHA_VERSION>;
} // namespace turbomind
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params& params, cudaStream_t stream)
{
run_mha_fwd_hdim128<cutlass::half_t>(params, stream);
}
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::half_t, 256>(Flash_fwd_params& params, cudaStream_t stream)
{
run_mha_fwd_hdim256<cutlass::half_t>(params, stream);
}
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params& params, cudaStream_t stream)
{
run_mha_fwd_hdim32<cutlass::half_t>(params, stream);
}
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params& params, cudaStream_t stream)
{
run_mha_fwd_hdim64<cutlass::half_t>(params, stream);
}
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
// modify from: https://github.com/Dao-AILab/flash-attention
#pragma once
#include <cmath>
#include <cute/algorithm/copy.hpp>
#include <cute/algorithm/gemm.hpp>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
#include "block_info.h"
#include "kernel_traits.h"
#include "softmax.h"
#include "utils.h"
namespace flash {
using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool Is_first, bool Check_inf = false, typename Tensor0, typename Tensor1, typename Tensor2>
inline __device__ void
softmax_rescale_o(Tensor0& scores, Tensor1& scores_max, Tensor1& scores_sum, Tensor2& acc_o, float softmax_scale_log2)
{
if (Is_first) {
flash::template reduce_max</*zero_init=*/true>(scores, scores_max);
flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2);
flash::reduce_sum(scores, scores_sum);
}
else {
Tensor scores_max_prev = make_fragment_like(scores_max);
copy(scores_max, scores_max_prev);
flash::template reduce_max</*zero_init=*/false>(scores, scores_max);
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
#pragma unroll
for (int mi = 0; mi < size(scores_max); ++mi) {
float scores_max_cur = !Check_inf ? scores_max(mi) : (scores_max(mi) == -INFINITY ? 0.0f : scores_max(mi));
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
scores_sum(mi) *= scores_scale;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
acc_o_rowcol(mi, ni) *= scores_scale;
}
}
flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2);
Tensor scores_sum_cur = make_fragment_like(scores_sum);
flash::reduce_sum(scores, scores_sum_cur);
#pragma unroll
for (int mi = 0; mi < size(scores_sum); ++mi) {
scores_sum(mi) += scores_sum_cur(mi);
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename TiledCopy>
inline __device__ void
write_softmax_to_gmem(Tensor<Engine0, Layout0> const& tOrP, Tensor<Engine1, Layout1>& tPgP, TiledCopy gmem_thr_copy_P)
{
#if (__CUDA_ARCH__ >= 800)
// Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N)
Layout l = tOrP.layout();
Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l))));
CUTE_STATIC_ASSERT_V(size<2>(tPgP) == _1{});
CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP));
#pragma unroll
for (int mi = 0; mi < size<1>(tPrP); ++mi) {
copy(gmem_thr_copy_P, tPrP(_, mi), tPgP(_, mi, 0));
}
#else
// do not support return softmax on device < sm80
assert(false);
#endif
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits,
bool Is_dropout,
bool Is_causal,
bool Is_even_N,
bool Is_even_K,
bool Return_softmax,
typename Params>
inline __device__ void compute_attn_1rowblock(const Params& params, const int bidb, const int bidh, const int m_block)
{
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
// Shared memory.
extern __shared__ char smem_[];
// The thread index.
const int tidx = threadIdx.x;
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kNWarps = Kernel_traits::kNWarps;
constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value;
const BlockInfo</*Varlen=*/!Is_even_N> binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0)
return;
int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
if (Is_causal) {
int seq_diff = max(binfo.actual_seqlen_k - binfo.actual_seqlen_q, int(0));
n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM + seq_diff, kBlockN));
}
// We iterate over the blocks in reverse order. This is because the last block is the only one
// that needs masking when we read K and V from global memory. Moreover, iterating in reverse
// might save us 1 register (we just need n_block instead of both n_block and n_block_max).
index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)
+ m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
if (!params.q_enable_seqlen) {
row_offset_q =
bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
}
// We move K and V to the last block.
auto k_ptr = params.k_ptr;
auto k_batch_stride = params.k_batch_stride;
if (params.k_batched_ptr != nullptr) {
k_ptr = (reinterpret_cast<Element**>(params.k_batched_ptr))[bidb] + params.k_batched_offset;
k_batch_stride = 0;
}
const index_t row_offset_k = binfo.k_offset(k_batch_stride, params.k_row_stride, bidb)
+ (n_block_max - 1) * kBlockN * params.k_row_stride
+ (bidh / params.h_h_k_ratio) * params.k_head_stride;
auto v_ptr = params.v_ptr;
auto v_batch_stride = params.v_batch_stride;
if (params.v_batched_ptr != nullptr) {
v_ptr = (reinterpret_cast<Element**>(params.v_batched_ptr))[bidb] + params.v_batched_offset;
v_batch_stride = 0;
}
const index_t row_offset_v = binfo.k_offset(v_batch_stride, params.v_row_stride, bidb)
+ (n_block_max - 1) * kBlockN * params.v_row_stride
+ (bidh / params.h_h_k_ratio) * params.v_head_stride;
const index_t row_offset_p =
((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.seqlen_k_rounded
+ (n_block_max - 1) * kBlockN;
// gmem tensor
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.q_row_stride, _1{}));
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.k_row_stride, _1{}));
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(v_ptr) + row_offset_v),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.v_row_stride, _1{}));
Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.p_ptr) + row_offset_p),
Shape<Int<kBlockM>, Int<kBlockN>>{},
make_stride(params.seqlen_k_rounded, _1{}));
// smem tensor
Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element*>(smem_)), typename Kernel_traits::SmemLayoutQ{});
// Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem;
Tensor sK =
make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)), typename Kernel_traits::SmemLayoutKV{});
Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
// g->s thread copy
auto gmem_tiled_copy_QKV = typename Kernel_traits::GmemTiledCopyQKV{};
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
auto gmem_tiled_copy_P = typename Kernel_traits::GmemTiledCopyP{};
auto gmem_thr_copy_P = gmem_tiled_copy_P.get_thread_slice(tidx);
// tiles of g->s copy
Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
Tensor tPgP = gmem_thr_copy_P.partition_D(gP);
// tiles of mma
typename Kernel_traits::TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tidx);
Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N)
Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_M, MMA_K
//
// Copy Atom retiling
//
auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
Tensor tSsK = smem_thr_copy_K.partition_S(sK);
auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
// TODO: this might need to change if we change the mma instruction in SM70
Tensor scores_max = make_tensor<ElementAccum>(Shape<Int<2 * size<1>(acc_o)>>{});
Tensor scores_sum = make_fragment_like(scores_max);
//
// PREDICATES
//
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
// Repeat the partitioning with identity layouts
Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
// Allocate predicate tensors for k
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
// Set predicates for k bounds
if (!Is_even_K) {
#pragma unroll
for (int k = 0; k < size(tQpQ); ++k) {
tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d;
}
#pragma unroll
for (int k = 0; k < size(tKVpKV); ++k) {
tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d;
}
}
// Prologue
Tensor tQrQ = make_fragment_like(tQgQ);
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
flash::copy</*Is_even_MN=*/false, Is_even_K>(
gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM);
if (Kernel_traits::Is_Q_in_regs) {
cute::cp_async_fence();
}
if (Kernel_traits::Share_Q_K_smem) {
flash::cp_async_wait<0>();
__syncthreads();
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
__syncthreads();
}
// copy K
int n_block = n_block_max - 1;
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
flash::copy<Is_even_N, Is_even_K>(
gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN);
cute::cp_async_fence();
if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) {
flash::cp_async_wait<1>();
__syncthreads();
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
}
clear(acc_o);
// For performance reason, we separate out two kinds of iterations:
// those that need masking on S, and those that don't.
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
// We will have at least 1 "masking" iteration.
constexpr int n_masking_steps = Is_causal ? cute::ceil_div(kBlockM, kBlockN) + 1 : 2;
#pragma unroll
for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
clear(acc_s);
flash::cp_async_wait<0>();
__syncthreads();
// Advance gV
if (masking_step > 0) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
}
else {
// Clear the smem tiles to account for predicated off loads
flash::copy<Is_even_N, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN);
}
cute::cp_async_fence();
flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(acc_s,
tSrQ,
tSrK,
tSsQ,
tSsK,
tiled_mma,
smem_tiled_copy_Q,
smem_tiled_copy_K,
smem_thr_copy_Q,
smem_thr_copy_K);
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
// We don't put the masking before the matmul S = Q K^T because we don't clear sK
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
// can produce Inf / NaN.
if (!Is_causal) {
if (!Is_even_N) {
flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN);
}
}
else {
flash::apply_mask_causal(scores,
n_block * kBlockN,
binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4
+ max(binfo.actual_seqlen_k - binfo.actual_seqlen_q, 0),
kNWarps * 16);
}
flash::cp_async_wait<0>();
__syncthreads();
if (n_block > 0) {
// Advance gK
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
cute::cp_async_fence();
}
// softmax
// TODO: when we have key_padding_mask we'll need to Check_inf
masking_step == 0 ? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal>(
scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) :
softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal>(
scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
// Convert scores from fp32 to fp16/bf16
Tensor rP = flash::convert_type<Element>(scores);
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
if (Return_softmax) {
Tensor tOrP_copy = make_fragment_like(tOrP);
copy(tOrP, tOrP_copy);
flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P);
tPgP.data() = tPgP.data() + (-kBlockN);
}
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
// This check is at the end of the loop since we always have at least 1 iteration
if (n_masking_steps > 1 && n_block <= 0) {
--n_block;
break;
}
}
// These are the iterations where we don't need masking on S
for (; n_block >= 0; --n_block) {
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
clear(acc_s);
flash::cp_async_wait<0>();
__syncthreads();
// Advance gV
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
cute::cp_async_fence();
flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(acc_s,
tSrQ,
tSrK,
tSsQ,
tSsK,
tiled_mma,
smem_tiled_copy_Q,
smem_tiled_copy_K,
smem_thr_copy_Q,
smem_thr_copy_K);
flash::cp_async_wait<0>();
__syncthreads();
if (n_block > 0) {
// Advance gK
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
cute::cp_async_fence();
}
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
softmax_rescale_o</*Is_first=*/false>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
Tensor rP = flash::convert_type<Element>(scores);
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
if (Return_softmax) {
Tensor tOrP_copy = make_fragment_like(tOrP);
copy(tOrP, tOrP_copy);
flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P);
tPgP.data() = tPgP.data() + (-kBlockN);
}
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
}
// Epilogue
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
#pragma unroll
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
float sum = scores_sum(mi);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
float scale = inv_sum;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
acc_o_rowcol(mi, ni) *= scale;
}
}
// Convert acc_o from fp32 to fp16/bf16
Tensor rO = flash::convert_type<Element>(acc_o);
Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
// Partition sO to match the accumulator partitioning
auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma);
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx);
Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// sO has the same size as sQ, so we don't need to sync here.
if (Kernel_traits::Share_Q_K_smem) {
__syncthreads();
}
copy(smem_tiled_copy_O, taccOrO, taccOsO);
__syncthreads();
index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
if (!params.o_enable_seqlen) {
row_offset_o =
bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
}
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.o_ptr) + row_offset_o),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.o_row_stride, _1{}));
auto gmem_tiled_copy_O = typename Kernel_traits::GmemTiledCopyO{};
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
Tensor tOrO = make_tensor<Element>(shape(tOgO));
copy(gmem_tiled_copy_O, tOsO, tOrO);
// Construct identity layout for sO
Tensor cO = make_identity_tensor(make_shape(size<0>(sO), size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
if (!Is_even_K) {
#pragma unroll
for (int k = 0; k < size(tOpO); ++k) {
tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d;
}
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits,
bool Is_dropout,
bool Is_causal,
bool Is_even_N,
bool Is_even_K,
bool Return_softmax,
typename Params>
inline __device__ void compute_attn(const Params& params)
{
const int m_block = blockIdx.x;
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.z;
flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax>(
params, bidb, bidh, m_block);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace flash
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
// modify from: https://github.com/Dao-AILab/flash-attention
#pragma once
#include "flash.h"
#include "flash_fwd_kernel.h"
#include "kernel_traits.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "static_switch.h"
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax>
__global__ void flash_fwd_kernel(Flash_fwd_params params)
{
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax>(params);
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
void run_flash_fwd(Flash_fwd_params& params, cudaStream_t stream)
{
constexpr size_t smem_size = Kernel_traits::kSmemSize;
// printf("smem_size = %d\n", smem_size);
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
// https://github.com/kokkos/kokkos-kernels/issues/349
// https://github.com/HazyResearch/flash-attention/issues/21
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid(num_m_block, params.b, params.h);
// We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check
// for cu_seqlens_q as well.
const bool is_even_N = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr
&& params.seqlen_k % Kernel_traits::kBlockN == 0;
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
// const bool return_softmax = params.p_ptr != nullptr;
constexpr bool return_softmax = false;
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
// Will only return softmax if dropout, to reduce compilation time.
auto kernel = &flash_fwd_kernel < Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, IsEvenKConst,
ReturnSoftmaxConst && Is_dropout > ;
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, true,
// ReturnSoftmaxConst && Is_dropout>;
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
}
int ctas_per_sm;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
});
});
});
}
template<typename T>
void run_mha_fwd_hdim32(Flash_fwd_params& params, cudaStream_t stream)
{
constexpr int Headdim = 32;
static constexpr bool Is_dropout = false;
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params,
stream);
});
}
template<typename T>
void run_mha_fwd_hdim64(Flash_fwd_params& params, cudaStream_t stream)
{
constexpr int Headdim = 64;
static constexpr bool Is_dropout = false;
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if constexpr (!Is_dropout) {
// Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
// Using block size (64 x 256) is 27% slower for seqlen=2k
// Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(
params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout,
// Is_causal>(params, stream); run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>,
// Is_dropout, Is_causal>(params, stream);
}
else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params,
stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params,
// stream); run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout,
// Is_causal>(params, stream); run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>,
// Is_dropout, Is_causal>(params, stream);
}
});
}
template<typename T>
void run_mha_fwd_hdim128(Flash_fwd_params& params, cudaStream_t stream)
{
constexpr int Headdim = 128;
bool is_sm8x = (turbomind::getSMVersion() >= 80);
static constexpr bool Is_dropout = false;
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if constexpr (!Is_dropout) {
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
// and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM.
if (is_sm8x) {
if constexpr (!Is_causal) {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(
params, stream);
}
else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(
params, stream);
}
}
else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(
params, stream);
}
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout,
// Is_causal>(params, stream); run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>,
// Is_dropout, Is_causal>(params, stream); run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false,
// false, T>, Is_dropout, Is_causal>(params, stream); Using 8 warps (128 x 128 and 256 x 64) is 28% slower
// for seqlen=2k run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, false, T>, Is_dropout,
// Is_causal>(params, stream); run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>,
// Is_dropout, Is_causal>(params, stream); 1st ones are good for H100, A100 2nd one is good for A6000 bc we
// get slightly better occupancy
}
else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params,
stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout,
// Is_causal>(params, stream); run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, false, T>,
// Is_dropout, Is_causal>(params, stream); run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true,
// true, T>, Is_dropout, Is_causal>(params, stream);
}
});
}
template<typename T>
void run_mha_fwd_hdim256(Flash_fwd_params& params, cudaStream_t stream)
{
constexpr int Headdim = 256;
int device;
cudaGetDevice(&device);
int max_smem_per_sm, max_smem_per_block;
cudaDeviceGetAttribute(&max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device);
cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
// printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block);
static constexpr bool Is_dropout = false;
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
// For A100, we want to run with 128 x 64 (128KB smem).
// For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.
if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params,
stream);
}
else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params,
stream);
}
// 64 KB
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params,
// stream); 96 KB run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout,
// Is_causal>(params, stream);
});
}
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
#include "cute/algorithm/copy.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/layout/layout.h"
#include <cutlass/numeric_types.h>
using namespace cute;
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type = cutlass::half_t>
struct Flash_kernel_traits {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using Element = elem_type;
static constexpr bool Has_cp_async = true;
#else
using Element = cutlass::half_t;
static constexpr bool Has_cp_async = false;
#endif
using ElementAccum = float;
using index_t = uint32_t;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using MMA_Atom_Arch = std::conditional_t<std::is_same_v<elem_type, cutlass::half_t>,
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>>;
using ValLayoutMNK = Layout<Shape<_1, _2, _1>>;
#else
using MMA_Atom_Arch = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
using ValLayoutMNK = Layout<Shape<_1, _2, _2>>;
#endif
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, elem_type>;
using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x8_LDSM_T, elem_type>;
#else
using SmemCopyAtom = Copy_Atom<DefaultCopy, elem_type>;
using SmemCopyAtomTransposed = Copy_Atom<DefaultCopy, elem_type>;
#endif
};
// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true
template<int kHeadDim_,
int kBlockM_,
int kBlockN_,
int kNWarps_,
bool Is_Q_in_regs_ = false,
bool Share_Q_K_smem_ = false,
typename elem_type = cutlass::half_t,
typename Base = Flash_kernel_traits<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type>>
struct Flash_fwd_kernel_traits: public Base {
using Element = typename Base::Element;
using ElementAccum = typename Base::ElementAccum;
using index_t = typename Base::index_t;
static constexpr bool Has_cp_async = Base::Has_cp_async;
using SmemCopyAtom = typename Base::SmemCopyAtom;
using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;
static constexpr bool Share_Q_K_smem = Share_Q_K_smem_;
static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem;
// The number of threads.
static constexpr int kNWarps = kNWarps_;
static constexpr int kNThreads = kNWarps * 32;
static constexpr int kBlockM = kBlockM_;
static constexpr int kBlockN = kBlockN_;
static constexpr int kHeadDim = kHeadDim_;
static_assert(kHeadDim % 32 == 0);
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
using TiledMma = TiledMMA<typename Base::MMA_Atom_Arch,
Layout<Shape<Int<kNWarps>, _1, _1>>, // 4x1x1 or 8x1x1 thread group
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
using SmemLayoutAtomQ =
decltype(composition(Swizzle<kSwizzle, 3, 3>{},
// This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128
Layout<Shape<_8, Int<kBlockKSmem>>, Stride<Int<kBlockKSmem>, _1>>{}));
using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, Shape<Int<kBlockM>, Int<kHeadDim>>{}));
using SmemLayoutKV = decltype(tile_to_shape(SmemLayoutAtomQ{}, Shape<Int<kBlockN>, Int<kHeadDim>>{}));
// This has to be kBlockN and not 8, otherwise we get wrong results for d=128
using SmemLayoutAtomVtransposedNoSwizzle =
Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>, Stride<_1, Int<kBlockKSmem>>>;
using SmemLayoutAtomVtransposed =
decltype(composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomVtransposedNoSwizzle{}));
using SmemLayoutVtransposed =
decltype(tile_to_shape(SmemLayoutAtomVtransposed{}, Shape<Int<kHeadDim>, Int<kBlockN>>{}));
// Maybe the VtransposeNoSwizzle just needs to have the right shape
// And the strides don't matter?
using SmemLayoutVtransposedNoSwizzle =
decltype(tile_to_shape(SmemLayoutAtomVtransposedNoSwizzle{}, Shape<Int<kHeadDim>, Int<kBlockN>>{}));
// using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
using SmemLayoutAtomO = decltype(composition(
Swizzle<kSwizzle, 3, 3>{}, Layout<Shape<Int<8>, Int<kBlockKSmem>>, Stride<Int<kBlockKSmem>, _1>>{}));
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, Shape<Int<kBlockM>, Int<kHeadDim>>{}));
using SmemCopyAtomO = Copy_Atom<DefaultCopy, elem_type>;
static constexpr int kSmemQCount = size(SmemLayoutQ{});
static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2;
static constexpr int kSmemQSize = kSmemQCount * sizeof(Element);
static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize;
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
// Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts.
// For example, for d=128, smem is split into 2 "pages", each page takes care of columns
// 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem,
// thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page,
// to the same banks.
static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
using GmemLayoutAtom = Layout<Shape<Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
Stride<Int<kGmemThreadsPerRow>, _1>>;
// We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
// from the same address by the same threadblock. This is slightly faster.
using Gmem_copy_struct = std::conditional_t<Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, DefaultCopy>;
using GmemTiledCopyQKV = decltype(make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
using GmemTiledCopyO = decltype(make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad;
static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP");
using GmemLayoutAtomP = Layout<Shape<Int<kNThreads / kGmemThreadsPerRowP>, Int<kGmemThreadsPerRowP>>,
Stride<Int<kGmemThreadsPerRowP>, _1>>;
using GmemTiledCopyP = decltype(make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
GmemLayoutAtomP{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
};
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
#include <cmath>
#include <cute/tensor.hpp>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include "utils.h"
namespace flash {
template<bool zero_init = true,
typename Engine0,
typename Layout0,
typename Engine1,
typename Layout1,
typename Operator>
__device__ inline void
thread_reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1>& summary, Operator& op)
{
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); mi++) {
summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
#pragma unroll
for (int ni = 1; ni < size<1>(tensor); ni++) {
summary(mi) = op(summary(mi), tensor(mi, ni));
}
}
}
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ inline void quad_allreduce_(Tensor<Engine0, Layout0>& dst, Tensor<Engine1, Layout1>& src, Operator& op)
{
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
#pragma unroll
for (int i = 0; i < size(dst); i++) {
dst(i) = Allreduce<4>::run(src(i), op);
}
}
template<bool zero_init = true,
typename Engine0,
typename Layout0,
typename Engine1,
typename Layout1,
typename Operator>
__device__ inline void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1>& summary, Operator& op)
{
thread_reduce_<zero_init>(tensor, summary, op);
quad_allreduce_(summary, summary, op);
}
template<bool zero_init = true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ inline void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1>& max)
{
MaxOp<float> max_op;
reduce_<zero_init>(tensor, max, max_op);
}
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ inline void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1>& sum)
{
SumOp<float> sum_op;
reduce_(tensor, sum, sum_op);
}
// Apply the exp to all the elements.
template<bool Scale_max = true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
inline __device__ void
scale_apply_exp2(Tensor<Engine0, Layout0>& tensor, Tensor<Engine1, Layout1> const& max, const float scale)
{
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
// If we don't have float around M_LOG2E the multiplication is done in fp64.
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E));
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
}
}
}
using namespace cute;
template<typename Engine, typename Layout>
inline __device__ void apply_mask(Tensor<Engine, Layout>& tensor, const int max_seqlen_k, const int col_idx_offset_ = 0)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
const int lane_id = threadIdx.x % 32;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
if (col_idx >= max_seqlen_k) {
// Without the "make_coord" we get wrong results
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
tensor(mi, make_coord(j, nj)) = -INFINITY;
}
}
}
}
}
template<typename Engine, typename Layout>
inline __device__ void apply_mask_causal(Tensor<Engine, Layout>& tensor,
const uint32_t col_idx_offset_,
const uint32_t max_seqlen_k,
const uint32_t row_idx_offset_,
const uint32_t warp_row_stride)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
const uint32_t lane_id = threadIdx.x % 32;
// const uint32_t row_idx_offset = row_idx_offset_ + lane_id / 4;
const uint32_t row_idx_offset = row_idx_offset_;
const uint32_t col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
#pragma unroll
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
const uint32_t row_idx_base = row_idx_offset + mi * warp_row_stride;
#pragma unroll
for (int i = 0; i < size<0, 0>(tensor); ++i) {
const uint32_t row_idx = row_idx_base + i * 8;
const uint32_t col_idx_limit = std::min(max_seqlen_k, row_idx + 1);
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const uint32_t col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const uint32_t col_idx = col_idx_base + j;
if (col_idx >= col_idx_limit) {
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
}
}
}
}
}
}
} // namespace flash
// Inspired by
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
#pragma once
/// @param COND - a boolean expression to switch by
/// @param CONST_NAME - a name given for the constexpr bool variable.
/// @param ... - code to execute for true and false
///
/// Usage:
/// ```
/// BOOL_SWITCH(flag, BoolConst, [&] {
/// some_function<BoolConst>(...);
/// });
/// ```
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr static bool CONST_NAME = true; \
return __VA_ARGS__(); \
} \
else { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
#define FP16_SWITCH(COND, ...) \
[&] { \
if (COND) { \
using elem_type = cutlass::half_t; \
return __VA_ARGS__(); \
} \
else { \
using elem_type = cutlass::bfloat16_t; \
return __VA_ARGS__(); \
} \
}()
#define FWD_HEADDIM_SWITCH(HEADDIM, ...) \
[&] { \
if (HEADDIM <= 32) { \
constexpr static int kHeadDim = 32; \
return __VA_ARGS__(); \
} \
else if (HEADDIM <= 64) { \
constexpr static int kHeadDim = 64; \
return __VA_ARGS__(); \
} \
else if (HEADDIM <= 128) { \
constexpr static int kHeadDim = 128; \
return __VA_ARGS__(); \
} \
else if (HEADDIM <= 256) { \
constexpr static int kHeadDim = 256; \
return __VA_ARGS__(); \
} \
}()
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
#include <cuda_fp16.h>
#include <cute/algorithm/copy.hpp>
#include <cute/algorithm/gemm.hpp>
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace cute {
} // namespace cute
namespace flash {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool A_in_regs = false,
bool B_in_regs = false,
typename Tensor0,
typename Tensor1,
typename Tensor2,
typename Tensor3,
typename Tensor4,
typename TiledMma,
typename TiledCopyA,
typename TiledCopyB,
typename ThrCopyA,
typename ThrCopyB>
inline __device__ void gemm(Tensor0& acc,
Tensor1& tCrA,
Tensor2& tCrB,
Tensor3 const& tCsA,
Tensor4 const& tCsB,
TiledMma tiled_mma,
TiledCopyA smem_tiled_copy_A,
TiledCopyB smem_tiled_copy_B,
ThrCopyA smem_thr_copy_A,
ThrCopyB smem_thr_copy_B)
{
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA);
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
if (!A_in_regs) {
cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{}));
}
if (!B_in_regs) {
cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
}
#pragma unroll
for (int i = 0; i < size<2>(tCrA); ++i) {
if (i < size<2>(tCrA) - 1) {
if (!A_in_regs) {
cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1));
}
if (!B_in_regs) {
cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));
}
}
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Tensor0,
typename Tensor1,
typename Tensor2,
typename Tensor3,
typename TiledMma,
typename TiledCopy,
typename ThrCopy>
inline __device__ void gemm_A_in_regs(Tensor0& acc,
Tensor1& tCrA,
Tensor2& tCrB,
Tensor3 const& tCsB,
TiledMma tiled_mma,
TiledCopy smem_tiled_copy_B,
ThrCopy smem_thr_copy_B)
{
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
#pragma unroll
for (int i = 0; i < size<2>(tCrA); ++i) {
if (i < size<2>(tCrA) - 1) {
cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));
}
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Blocks until all but N previous cp.async.commit_group operations have committed.
// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all
// (which is equivalent to commit_group then wait_group 0).
// Instead we just call cp.async.wait_group 0, which is slightly faster.
// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113
template<int N>
CUTE_HOST_DEVICE void cp_async_wait()
{
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool Is_even_MN = true,
bool Is_even_K = true,
bool Clear_OOB_MN = false,
bool Clear_OOB_K = true,
typename TiledCopy,
typename Engine0,
typename Layout0,
typename Engine1,
typename Layout1,
typename Engine2,
typename Layout2,
typename Engine3,
typename Layout3>
inline __device__ void copy(TiledCopy thr_copy,
Tensor<Engine0, Layout0> const& S,
Tensor<Engine1, Layout1>& D,
Tensor<Engine2, Layout2> const& identity_MN,
Tensor<Engine3, Layout3> const& predicate_K,
int max_MN = 0)
{
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
// There's no case where !Clear_OOB_K && Clear_OOB_MN
static_assert(!(Clear_OOB_MN && !Clear_OOB_K));
#pragma unroll
for (int m = 0; m < size<1>(S); ++m) {
if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
#pragma unroll
for (int k = 0; k < size<2>(S); ++k) {
if (Is_even_K || predicate_K(k)) {
copy(thr_copy, S(_, m, k), D(_, m, k));
}
else if (Clear_OOB_K) {
clear(D(_, m, k));
}
}
}
else if (Clear_OOB_MN) {
clear(D(_, m, _));
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct MaxOp {
__device__ inline T operator()(T const& x, T const& y)
{
return x > y ? x : y;
}
};
template<>
struct MaxOp<float> {
// This is slightly faster
__device__ inline float operator()(float const& x, float const& y)
{
return max(x, y);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct SumOp {
__device__ inline T operator()(T const& x, T const& y)
{
return x + y;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int THREADS>
struct Allreduce {
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator& op)
{
constexpr int OFFSET = THREADS / 2;
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
return Allreduce<OFFSET>::run(x, op);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Allreduce<2> {
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator& op)
{
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
return x;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
template<typename Layout>
inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout)
{
static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3);
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
template<typename MMA_traits, typename Layout>
inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout)
{
using X = Underscore;
static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2);
static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2);
constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{});
static_assert(mma_shape_K == 8 || mma_shape_K == 16);
constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2;
auto l =
logical_divide(rowcol_layout, Shape<X, Shape<X, Int<MMA_N_divisor>>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2)))
return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), get<0, 1>(l), get<1, 1, 1>(l));
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename To_type, typename Engine, typename Layout>
inline __device__ auto convert_type(Tensor<Engine, Layout> const& tensor)
{
using From_type = typename Engine::value_type;
constexpr int numel = decltype(size(tensor))::value;
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
// HACK: this requires tensor to be "contiguous"
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel>*>(tensor.data()));
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace flash
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment