Unverified Commit 452822a4 authored by q.yao's avatar q.yao Committed by GitHub
Browse files

Add flashattention2 (#196)



* first

* fix causal mask

* disable flash attention2 on sm70

* fix 2

* update readme

* clang-format

* disable ft2 on windows

* fix lint

* fix build

* fix build

* fix long kv seq

* fix lint

* sync copy output

---------
Co-authored-by: default avatargrimoire <yaoqian@pjlab.org.cn>
Co-authored-by: default avatarirexyc <irexyc@gmail.com>
parent d4d609bd
...@@ -47,7 +47,7 @@ include(FetchContent) ...@@ -47,7 +47,7 @@ include(FetchContent)
FetchContent_Declare( FetchContent_Declare(
repo-cutlass repo-cutlass
GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git
GIT_TAG cc85b64cf676c45f98a17e3a47c0aafcf817f088 GIT_TAG 6f47420213f757831fae65c686aa471749fa8d60
) )
set(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") set(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
...@@ -312,6 +312,7 @@ add_library(transformer-shared SHARED ...@@ -312,6 +312,7 @@ add_library(transformer-shared SHARED
$<TARGET_OBJECTS:BaseSamplingLayer> $<TARGET_OBJECTS:BaseSamplingLayer>
$<TARGET_OBJECTS:DynamicDecodeLayer> $<TARGET_OBJECTS:DynamicDecodeLayer>
$<TARGET_OBJECTS:llama_fmha> $<TARGET_OBJECTS:llama_fmha>
$<TARGET_OBJECTS:flash_attention2>
$<TARGET_OBJECTS:Llama> $<TARGET_OBJECTS:Llama>
$<TARGET_OBJECTS:LlamaTritonBackend> $<TARGET_OBJECTS:LlamaTritonBackend>
$<TARGET_OBJECTS:gemm_s4_f16> $<TARGET_OBJECTS:gemm_s4_f16>
......
...@@ -20,6 +20,7 @@ ______________________________________________________________________ ...@@ -20,6 +20,7 @@ ______________________________________________________________________
## News 🎉 ## News 🎉
- \[2023/08\] TurboMind supports flash-attention2.
- \[2023/08\] TurboMind supports Qwen-7B, dynamic NTK-RoPE scaling and dynamic logN scaling - \[2023/08\] TurboMind supports Qwen-7B, dynamic NTK-RoPE scaling and dynamic logN scaling
- \[2023/08\] TurboMind supports Windows (tp=1) - \[2023/08\] TurboMind supports Windows (tp=1)
- \[2023/08\] TurboMind supports 4-bit inference, 2.4x faster than FP16, the fastest open-source implementation🚀. Check [this](./docs/en/w4a16.md) guide for detailed info - \[2023/08\] TurboMind supports 4-bit inference, 2.4x faster than FP16, the fastest open-source implementation🚀. Check [this](./docs/en/w4a16.md) guide for detailed info
......
...@@ -20,6 +20,7 @@ ______________________________________________________________________ ...@@ -20,6 +20,7 @@ ______________________________________________________________________
## 更新 🎉 ## 更新 🎉
- \[2023/08\] TurboMind 支持 flash-attention2
- \[2023/08\] TurboMind 支持 Qwen-7B,动态NTK-RoPE缩放,动态logN缩放 - \[2023/08\] TurboMind 支持 Qwen-7B,动态NTK-RoPE缩放,动态logN缩放
- \[2023/08\] TurboMind 支持 Windows (tp=1) - \[2023/08\] TurboMind 支持 Windows (tp=1)
- \[2023/08\] TurboMind 支持 4-bit 推理,速度是 FP16 的 2.4 倍,是目前最快的开源实现🚀。部署方式请看[这里](./docs/zh_cn/w4a16.md) - \[2023/08\] TurboMind 支持 4-bit 推理,速度是 FP16 的 2.4 倍,是目前最快的开源实现🚀。部署方式请看[这里](./docs/zh_cn/w4a16.md)
......
...@@ -41,5 +41,10 @@ target_link_libraries(Llama PUBLIC CUDA::cudart ...@@ -41,5 +41,10 @@ target_link_libraries(Llama PUBLIC CUDA::cudart
logger logger
llama_fmha) llama_fmha)
if (NOT MSVC)
add_subdirectory(flash_attention2)
target_link_libraries(Llama PUBLIC flash_attention2)
endif()
add_executable(llama_gemm llama_gemm.cc) add_executable(llama_gemm llama_gemm.cc)
target_link_libraries(llama_gemm PUBLIC CUDA::cudart gpt_gemm_func memory_utils cuda_utils logger) target_link_libraries(llama_gemm PUBLIC CUDA::cudart gpt_gemm_func memory_utils cuda_utils logger)
...@@ -264,6 +264,7 @@ void LlamaContextAttentionLayer<T>::fusedMultiHeadAttention(T** key_cache_ptr ...@@ -264,6 +264,7 @@ void LlamaContextAttentionLayer<T>::fusedMultiHeadAttention(T** key_cache_ptr
{ {
////////////////////////////////////////////// //////////////////////////////////////////////
// flash attention // flash attention
// flash attention 2 only support half inputs
using AttentionOp = FlashAttentionOp<T>; using AttentionOp = FlashAttentionOp<T>;
using Layout = typename AttentionOp::AttentionLayout; using Layout = typename AttentionOp::AttentionLayout;
Layout layout_q{ Layout layout_q{
......
cmake_minimum_required(VERSION 3.8)
project(flash_attention2)
add_library(${PROJECT_NAME} STATIC
flash_api.cpp
flash_fwd_hdim32_fp16_sm80.cu
flash_fwd_hdim64_fp16_sm80.cu
flash_fwd_hdim128_fp16_sm80.cu
flash_fwd_hdim256_fp16_sm80.cu
)
target_include_directories(${PROJECT_NAME} PRIVATE ${CUTLASS_DIR} / include)
target_link_libraries(${PROJECT_NAME} PRIVATE nvidia::cutlass::cutlass)
set_property(TARGET ${PROJECT_NAME} PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET ${PROJECT_NAME} PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
#Flash Attention 2
This is flash attention2 implementation modified from https://github.com/Dao-AILab/flash-attention
- remove dropout
- remove backward
- cutlass 3.1.0
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
namespace flash {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool Varlen = true>
struct BlockInfo {
template<typename Params>
__device__ BlockInfo(const Params& params, const int bidb):
sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]),
sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb]),
actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q :
params.cu_seqlens_q[bidb + 1] - sum_s_q),
actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k :
params.cu_seqlens_k[bidb + 1] - sum_s_k)
{
}
template<typename index_t>
inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const
{
return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
}
template<typename index_t>
inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const
{
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
}
const int sum_s_q;
const int sum_s_k;
const int actual_seqlen_q;
const int actual_seqlen_k;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace flash
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
// modify from: https://github.com/Dao-AILab/flash-attention
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1;
constexpr int D_DIM = 2;
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Qkv_params {
using index_t = uint32_t;
// The QKV matrices.
void* __restrict__ q_ptr;
void* __restrict__ k_ptr;
void* __restrict__ v_ptr;
// batched ptr inputs.
void** __restrict__ k_batched_ptr = nullptr;
void** __restrict__ v_batched_ptr = nullptr;
int k_batched_offset = 0;
int v_batched_offset = 0;
// The stride between rows of the Q, K and V matrices.
index_t q_batch_stride;
index_t k_batch_stride;
index_t v_batch_stride;
index_t q_row_stride;
index_t k_row_stride;
index_t v_row_stride;
index_t q_head_stride;
index_t k_head_stride;
index_t v_head_stride;
// The number of heads.
int h, h_k;
// In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
// different from nheads (query).
int h_h_k_ratio; // precompute h / h_k,
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Flash_fwd_params: public Qkv_params {
// The O matrix (output).
void* __restrict__ o_ptr;
// The stride between rows of O.
index_t o_batch_stride;
index_t o_row_stride;
index_t o_head_stride;
// The pointer to the P matrix.
void* __restrict__ p_ptr;
// The dimensions.
int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded;
// The scaling factors for the kernel.
float scale_softmax;
float scale_softmax_log2;
// array of length b+1 holding starting offset of each sequence.
int* __restrict__ cu_seqlens_q;
int* __restrict__ cu_seqlens_k;
void* __restrict__ blockmask;
bool is_bf16;
bool is_causal;
// enable output seqlen
bool q_enable_seqlen;
bool o_enable_seqlen;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int Headdim>
void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream);
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
// modify from: https://github.com/Dao-AILab/flash-attention
#include "flash.h"
#include "src/turbomind/models/llama/llama_kernels.h"
#include "static_switch.h"
#include <cuda_runtime.h>
#include <cutlass/numeric_types.h>
#include <math.h>
void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream)
{
FP16_SWITCH(true,
[&] { FWD_HEADDIM_SWITCH(params.d, [&] { run_mha_fwd_<elem_type, kHeadDim>(params, stream); }); });
}
namespace turbomind {
static constexpr int FMHA_VERSION = 2;
template<typename T>
class FlashAttentionOpImpl<T, FMHA_VERSION> {
public:
using AttentionLayout = BaseAttentionLayout<T>;
using Params = BaseAttentionParams<T>;
public:
FlashAttentionOpImpl(int batch_size, int head_num, int key_len, int seq_len, int size_per_head);
~FlashAttentionOpImpl();
int get_workspace_size() const;
void operator()(Params& params, cudaStream_t st) const;
private:
class impl;
std::unique_ptr<impl> pimpl;
};
template<typename T>
class FlashAttentionOpImpl<T, FMHA_VERSION>::impl {
private:
using scalar_t =
typename std::conditional_t<std::is_same<half, typename std::decay<T>::type>::value, cutlass::half_t, T>;
using Params = typename FlashAttentionOpImpl<T, FMHA_VERSION>::Params;
int batch_size_;
int head_num_;
int key_len_;
int seq_len_;
int size_per_head_;
public:
impl(int batch_size, int head_num, int key_len, int seq_len, int size_per_head):
batch_size_(batch_size),
head_num_(head_num),
key_len_(key_len),
seq_len_(seq_len),
size_per_head_(size_per_head)
{
}
~impl() {}
int get_workspace_size() const
{
return 0;
}
void operator()(Params& params, cudaStream_t st) const
{
const float qk_scale = static_cast<float>(1.f / sqrtf(size_per_head_ * 1.f));
Flash_fwd_params fwd_params;
memset(&fwd_params, 0, sizeof(fwd_params));
fwd_params.q_ptr = reinterpret_cast<void*>(params.query);
fwd_params.k_ptr = reinterpret_cast<void*>(params.key);
fwd_params.v_ptr = reinterpret_cast<void*>(params.val);
fwd_params.k_batched_ptr = reinterpret_cast<void**>(params.layout_k.batch_seqs);
fwd_params.v_batched_ptr = reinterpret_cast<void**>(params.layout_v.batch_seqs);
fwd_params.k_batched_offset = params.layout_k.batch_seqs_offset;
fwd_params.v_batched_offset = params.layout_v.batch_seqs_offset;
fwd_params.q_batch_stride = params.layout_q.stride_batch;
fwd_params.k_batch_stride = params.layout_k.stride_batch;
fwd_params.v_batch_stride = params.layout_v.stride_batch;
fwd_params.q_row_stride = params.layout_q.stride_seq;
fwd_params.k_row_stride = params.layout_k.stride_seq;
fwd_params.v_row_stride = params.layout_v.stride_seq;
fwd_params.q_head_stride = params.layout_q.stride_head;
fwd_params.v_head_stride = params.layout_v.stride_head;
fwd_params.k_head_stride = params.layout_k.stride_head;
fwd_params.h = head_num_;
fwd_params.h_k = head_num_ / params.group_size;
fwd_params.h_h_k_ratio = params.group_size;
fwd_params.o_ptr = reinterpret_cast<void*>(params.attn_out);
fwd_params.o_batch_stride = params.layout_o.stride_batch;
fwd_params.o_row_stride = params.layout_o.stride_seq;
fwd_params.o_head_stride = params.layout_o.stride_head;
fwd_params.p_ptr = nullptr;
fwd_params.b = batch_size_;
fwd_params.seqlen_q = seq_len_;
fwd_params.seqlen_k = key_len_;
fwd_params.d = size_per_head_;
fwd_params.seqlen_q_rounded = 0;
fwd_params.seqlen_k_rounded = 0;
fwd_params.scale_softmax = qk_scale;
fwd_params.scale_softmax_log2 = qk_scale * M_LOG2E;
fwd_params.cu_seqlens_q = params.cu_seqlens_q;
fwd_params.cu_seqlens_k = params.cu_seqlens_k;
fwd_params.blockmask = reinterpret_cast<void*>(params.mask);
fwd_params.is_bf16 = false;
fwd_params.is_causal = true;
fwd_params.q_enable_seqlen = params.layout_q.use_seqlens;
fwd_params.o_enable_seqlen = params.layout_o.use_seqlens;
run_mha_fwd(fwd_params, st);
}
};
template<typename T>
FlashAttentionOpImpl<T, FMHA_VERSION>::FlashAttentionOpImpl(
int batch_size, int head_num, int key_len, int seq_len, int size_per_head):
pimpl{std::make_unique<FlashAttentionOpImpl<T, FMHA_VERSION>::impl>(
batch_size, head_num, key_len, seq_len, size_per_head)}
{
}
template<typename T>
FlashAttentionOpImpl<T, FMHA_VERSION>::~FlashAttentionOpImpl()
{
}
template<typename T>
int FlashAttentionOpImpl<T, FMHA_VERSION>::get_workspace_size() const
{
return pimpl->get_workspace_size();
}
template<typename T>
void FlashAttentionOpImpl<T, FMHA_VERSION>::operator()(Params& params, cudaStream_t st) const
{
pimpl->operator()(params, st);
}
template class FlashAttentionOpImpl<float, FMHA_VERSION>;
template class FlashAttentionOpImpl<half, FMHA_VERSION>;
} // namespace turbomind
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params& params, cudaStream_t stream)
{
run_mha_fwd_hdim128<cutlass::half_t>(params, stream);
}
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::half_t, 256>(Flash_fwd_params& params, cudaStream_t stream)
{
run_mha_fwd_hdim256<cutlass::half_t>(params, stream);
}
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params& params, cudaStream_t stream)
{
run_mha_fwd_hdim32<cutlass::half_t>(params, stream);
}
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params& params, cudaStream_t stream)
{
run_mha_fwd_hdim64<cutlass::half_t>(params, stream);
}
This diff is collapsed.
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
// modify from: https://github.com/Dao-AILab/flash-attention
#pragma once
#include "flash.h"
#include "flash_fwd_kernel.h"
#include "kernel_traits.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "static_switch.h"
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax>
__global__ void flash_fwd_kernel(Flash_fwd_params params)
{
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax>(params);
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
void run_flash_fwd(Flash_fwd_params& params, cudaStream_t stream)
{
constexpr size_t smem_size = Kernel_traits::kSmemSize;
// printf("smem_size = %d\n", smem_size);
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
// https://github.com/kokkos/kokkos-kernels/issues/349
// https://github.com/HazyResearch/flash-attention/issues/21
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid(num_m_block, params.b, params.h);
// We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check
// for cu_seqlens_q as well.
const bool is_even_N = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr
&& params.seqlen_k % Kernel_traits::kBlockN == 0;
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
// const bool return_softmax = params.p_ptr != nullptr;
constexpr bool return_softmax = false;
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
// Will only return softmax if dropout, to reduce compilation time.
auto kernel = &flash_fwd_kernel < Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, IsEvenKConst,
ReturnSoftmaxConst && Is_dropout > ;
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, true,
// ReturnSoftmaxConst && Is_dropout>;
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
}
int ctas_per_sm;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
});
});
});
}
template<typename T>
void run_mha_fwd_hdim32(Flash_fwd_params& params, cudaStream_t stream)
{
constexpr int Headdim = 32;
static constexpr bool Is_dropout = false;
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params,
stream);
});
}
template<typename T>
void run_mha_fwd_hdim64(Flash_fwd_params& params, cudaStream_t stream)
{
constexpr int Headdim = 64;
static constexpr bool Is_dropout = false;
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if constexpr (!Is_dropout) {
// Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
// Using block size (64 x 256) is 27% slower for seqlen=2k
// Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(
params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout,
// Is_causal>(params, stream); run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>,
// Is_dropout, Is_causal>(params, stream);
}
else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params,
stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params,
// stream); run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout,
// Is_causal>(params, stream); run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>,
// Is_dropout, Is_causal>(params, stream);
}
});
}
template<typename T>
void run_mha_fwd_hdim128(Flash_fwd_params& params, cudaStream_t stream)
{
constexpr int Headdim = 128;
bool is_sm8x = (turbomind::getSMVersion() >= 80);
static constexpr bool Is_dropout = false;
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if constexpr (!Is_dropout) {
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
// and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM.
if (is_sm8x) {
if constexpr (!Is_causal) {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(
params, stream);
}
else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(
params, stream);
}
}
else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(
params, stream);
}
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout,
// Is_causal>(params, stream); run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>,
// Is_dropout, Is_causal>(params, stream); run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false,
// false, T>, Is_dropout, Is_causal>(params, stream); Using 8 warps (128 x 128 and 256 x 64) is 28% slower
// for seqlen=2k run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, false, T>, Is_dropout,
// Is_causal>(params, stream); run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>,
// Is_dropout, Is_causal>(params, stream); 1st ones are good for H100, A100 2nd one is good for A6000 bc we
// get slightly better occupancy
}
else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params,
stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout,
// Is_causal>(params, stream); run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, false, T>,
// Is_dropout, Is_causal>(params, stream); run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true,
// true, T>, Is_dropout, Is_causal>(params, stream);
}
});
}
template<typename T>
void run_mha_fwd_hdim256(Flash_fwd_params& params, cudaStream_t stream)
{
constexpr int Headdim = 256;
int device;
cudaGetDevice(&device);
int max_smem_per_sm, max_smem_per_block;
cudaDeviceGetAttribute(&max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device);
cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
// printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block);
static constexpr bool Is_dropout = false;
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
// For A100, we want to run with 128 x 64 (128KB smem).
// For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.
if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params,
stream);
}
else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params,
stream);
}
// 64 KB
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params,
// stream); 96 KB run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout,
// Is_causal>(params, stream);
});
}
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
#include "cute/algorithm/copy.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/layout/layout.h"
#include <cutlass/numeric_types.h>
using namespace cute;
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type = cutlass::half_t>
struct Flash_kernel_traits {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using Element = elem_type;
static constexpr bool Has_cp_async = true;
#else
using Element = cutlass::half_t;
static constexpr bool Has_cp_async = false;
#endif
using ElementAccum = float;
using index_t = uint32_t;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using MMA_Atom_Arch = std::conditional_t<std::is_same_v<elem_type, cutlass::half_t>,
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>>;
using ValLayoutMNK = Layout<Shape<_1, _2, _1>>;
#else
using MMA_Atom_Arch = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
using ValLayoutMNK = Layout<Shape<_1, _2, _2>>;
#endif
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, elem_type>;
using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x8_LDSM_T, elem_type>;
#else
using SmemCopyAtom = Copy_Atom<DefaultCopy, elem_type>;
using SmemCopyAtomTransposed = Copy_Atom<DefaultCopy, elem_type>;
#endif
};
// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true
template<int kHeadDim_,
int kBlockM_,
int kBlockN_,
int kNWarps_,
bool Is_Q_in_regs_ = false,
bool Share_Q_K_smem_ = false,
typename elem_type = cutlass::half_t,
typename Base = Flash_kernel_traits<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type>>
struct Flash_fwd_kernel_traits: public Base {
using Element = typename Base::Element;
using ElementAccum = typename Base::ElementAccum;
using index_t = typename Base::index_t;
static constexpr bool Has_cp_async = Base::Has_cp_async;
using SmemCopyAtom = typename Base::SmemCopyAtom;
using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;
static constexpr bool Share_Q_K_smem = Share_Q_K_smem_;
static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem;
// The number of threads.
static constexpr int kNWarps = kNWarps_;
static constexpr int kNThreads = kNWarps * 32;
static constexpr int kBlockM = kBlockM_;
static constexpr int kBlockN = kBlockN_;
static constexpr int kHeadDim = kHeadDim_;
static_assert(kHeadDim % 32 == 0);
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
using TiledMma = TiledMMA<typename Base::MMA_Atom_Arch,
Layout<Shape<Int<kNWarps>, _1, _1>>, // 4x1x1 or 8x1x1 thread group
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
using SmemLayoutAtomQ =
decltype(composition(Swizzle<kSwizzle, 3, 3>{},
// This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128
Layout<Shape<_8, Int<kBlockKSmem>>, Stride<Int<kBlockKSmem>, _1>>{}));
using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, Shape<Int<kBlockM>, Int<kHeadDim>>{}));
using SmemLayoutKV = decltype(tile_to_shape(SmemLayoutAtomQ{}, Shape<Int<kBlockN>, Int<kHeadDim>>{}));
// This has to be kBlockN and not 8, otherwise we get wrong results for d=128
using SmemLayoutAtomVtransposedNoSwizzle =
Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>, Stride<_1, Int<kBlockKSmem>>>;
using SmemLayoutAtomVtransposed =
decltype(composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomVtransposedNoSwizzle{}));
using SmemLayoutVtransposed =
decltype(tile_to_shape(SmemLayoutAtomVtransposed{}, Shape<Int<kHeadDim>, Int<kBlockN>>{}));
// Maybe the VtransposeNoSwizzle just needs to have the right shape
// And the strides don't matter?
using SmemLayoutVtransposedNoSwizzle =
decltype(tile_to_shape(SmemLayoutAtomVtransposedNoSwizzle{}, Shape<Int<kHeadDim>, Int<kBlockN>>{}));
// using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
using SmemLayoutAtomO = decltype(composition(
Swizzle<kSwizzle, 3, 3>{}, Layout<Shape<Int<8>, Int<kBlockKSmem>>, Stride<Int<kBlockKSmem>, _1>>{}));
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, Shape<Int<kBlockM>, Int<kHeadDim>>{}));
using SmemCopyAtomO = Copy_Atom<DefaultCopy, elem_type>;
static constexpr int kSmemQCount = size(SmemLayoutQ{});
static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2;
static constexpr int kSmemQSize = kSmemQCount * sizeof(Element);
static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize;
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
// Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts.
// For example, for d=128, smem is split into 2 "pages", each page takes care of columns
// 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem,
// thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page,
// to the same banks.
static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
using GmemLayoutAtom = Layout<Shape<Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
Stride<Int<kGmemThreadsPerRow>, _1>>;
// We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
// from the same address by the same threadblock. This is slightly faster.
using Gmem_copy_struct = std::conditional_t<Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, DefaultCopy>;
using GmemTiledCopyQKV = decltype(make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
using GmemTiledCopyO = decltype(make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad;
static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP");
using GmemLayoutAtomP = Layout<Shape<Int<kNThreads / kGmemThreadsPerRowP>, Int<kGmemThreadsPerRowP>>,
Stride<Int<kGmemThreadsPerRowP>, _1>>;
using GmemTiledCopyP = decltype(make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
GmemLayoutAtomP{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
};
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
#include <cmath>
#include <cute/tensor.hpp>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include "utils.h"
namespace flash {
template<bool zero_init = true,
typename Engine0,
typename Layout0,
typename Engine1,
typename Layout1,
typename Operator>
__device__ inline void
thread_reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1>& summary, Operator& op)
{
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); mi++) {
summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
#pragma unroll
for (int ni = 1; ni < size<1>(tensor); ni++) {
summary(mi) = op(summary(mi), tensor(mi, ni));
}
}
}
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ inline void quad_allreduce_(Tensor<Engine0, Layout0>& dst, Tensor<Engine1, Layout1>& src, Operator& op)
{
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
#pragma unroll
for (int i = 0; i < size(dst); i++) {
dst(i) = Allreduce<4>::run(src(i), op);
}
}
template<bool zero_init = true,
typename Engine0,
typename Layout0,
typename Engine1,
typename Layout1,
typename Operator>
__device__ inline void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1>& summary, Operator& op)
{
thread_reduce_<zero_init>(tensor, summary, op);
quad_allreduce_(summary, summary, op);
}
template<bool zero_init = true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ inline void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1>& max)
{
MaxOp<float> max_op;
reduce_<zero_init>(tensor, max, max_op);
}
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ inline void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1>& sum)
{
SumOp<float> sum_op;
reduce_(tensor, sum, sum_op);
}
// Apply the exp to all the elements.
template<bool Scale_max = true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
inline __device__ void
scale_apply_exp2(Tensor<Engine0, Layout0>& tensor, Tensor<Engine1, Layout1> const& max, const float scale)
{
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
// If we don't have float around M_LOG2E the multiplication is done in fp64.
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E));
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
}
}
}
using namespace cute;
template<typename Engine, typename Layout>
inline __device__ void apply_mask(Tensor<Engine, Layout>& tensor, const int max_seqlen_k, const int col_idx_offset_ = 0)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
const int lane_id = threadIdx.x % 32;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
if (col_idx >= max_seqlen_k) {
// Without the "make_coord" we get wrong results
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
tensor(mi, make_coord(j, nj)) = -INFINITY;
}
}
}
}
}
template<typename Engine, typename Layout>
inline __device__ void apply_mask_causal(Tensor<Engine, Layout>& tensor,
const uint32_t col_idx_offset_,
const uint32_t max_seqlen_k,
const uint32_t row_idx_offset_,
const uint32_t warp_row_stride)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
const uint32_t lane_id = threadIdx.x % 32;
// const uint32_t row_idx_offset = row_idx_offset_ + lane_id / 4;
const uint32_t row_idx_offset = row_idx_offset_;
const uint32_t col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
#pragma unroll
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
const uint32_t row_idx_base = row_idx_offset + mi * warp_row_stride;
#pragma unroll
for (int i = 0; i < size<0, 0>(tensor); ++i) {
const uint32_t row_idx = row_idx_base + i * 8;
const uint32_t col_idx_limit = std::min(max_seqlen_k, row_idx + 1);
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const uint32_t col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const uint32_t col_idx = col_idx_base + j;
if (col_idx >= col_idx_limit) {
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
}
}
}
}
}
}
} // namespace flash
// Inspired by
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
#pragma once
/// @param COND - a boolean expression to switch by
/// @param CONST_NAME - a name given for the constexpr bool variable.
/// @param ... - code to execute for true and false
///
/// Usage:
/// ```
/// BOOL_SWITCH(flag, BoolConst, [&] {
/// some_function<BoolConst>(...);
/// });
/// ```
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr static bool CONST_NAME = true; \
return __VA_ARGS__(); \
} \
else { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
#define FP16_SWITCH(COND, ...) \
[&] { \
if (COND) { \
using elem_type = cutlass::half_t; \
return __VA_ARGS__(); \
} \
else { \
using elem_type = cutlass::bfloat16_t; \
return __VA_ARGS__(); \
} \
}()
#define FWD_HEADDIM_SWITCH(HEADDIM, ...) \
[&] { \
if (HEADDIM <= 32) { \
constexpr static int kHeadDim = 32; \
return __VA_ARGS__(); \
} \
else if (HEADDIM <= 64) { \
constexpr static int kHeadDim = 64; \
return __VA_ARGS__(); \
} \
else if (HEADDIM <= 128) { \
constexpr static int kHeadDim = 128; \
return __VA_ARGS__(); \
} \
else if (HEADDIM <= 256) { \
constexpr static int kHeadDim = 256; \
return __VA_ARGS__(); \
} \
}()
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
#include <cuda_fp16.h>
#include <cute/algorithm/copy.hpp>
#include <cute/algorithm/gemm.hpp>
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace cute {
} // namespace cute
namespace flash {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool A_in_regs = false,
bool B_in_regs = false,
typename Tensor0,
typename Tensor1,
typename Tensor2,
typename Tensor3,
typename Tensor4,
typename TiledMma,
typename TiledCopyA,
typename TiledCopyB,
typename ThrCopyA,
typename ThrCopyB>
inline __device__ void gemm(Tensor0& acc,
Tensor1& tCrA,
Tensor2& tCrB,
Tensor3 const& tCsA,
Tensor4 const& tCsB,
TiledMma tiled_mma,
TiledCopyA smem_tiled_copy_A,
TiledCopyB smem_tiled_copy_B,
ThrCopyA smem_thr_copy_A,
ThrCopyB smem_thr_copy_B)
{
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA);
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
if (!A_in_regs) {
cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{}));
}
if (!B_in_regs) {
cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
}
#pragma unroll
for (int i = 0; i < size<2>(tCrA); ++i) {
if (i < size<2>(tCrA) - 1) {
if (!A_in_regs) {
cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1));
}
if (!B_in_regs) {
cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));
}
}
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Tensor0,
typename Tensor1,
typename Tensor2,
typename Tensor3,
typename TiledMma,
typename TiledCopy,
typename ThrCopy>
inline __device__ void gemm_A_in_regs(Tensor0& acc,
Tensor1& tCrA,
Tensor2& tCrB,
Tensor3 const& tCsB,
TiledMma tiled_mma,
TiledCopy smem_tiled_copy_B,
ThrCopy smem_thr_copy_B)
{
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
#pragma unroll
for (int i = 0; i < size<2>(tCrA); ++i) {
if (i < size<2>(tCrA) - 1) {
cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));
}
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Blocks until all but N previous cp.async.commit_group operations have committed.
// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all
// (which is equivalent to commit_group then wait_group 0).
// Instead we just call cp.async.wait_group 0, which is slightly faster.
// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113
template<int N>
CUTE_HOST_DEVICE void cp_async_wait()
{
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool Is_even_MN = true,
bool Is_even_K = true,
bool Clear_OOB_MN = false,
bool Clear_OOB_K = true,
typename TiledCopy,
typename Engine0,
typename Layout0,
typename Engine1,
typename Layout1,
typename Engine2,
typename Layout2,
typename Engine3,
typename Layout3>
inline __device__ void copy(TiledCopy thr_copy,
Tensor<Engine0, Layout0> const& S,
Tensor<Engine1, Layout1>& D,
Tensor<Engine2, Layout2> const& identity_MN,
Tensor<Engine3, Layout3> const& predicate_K,
int max_MN = 0)
{
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
// There's no case where !Clear_OOB_K && Clear_OOB_MN
static_assert(!(Clear_OOB_MN && !Clear_OOB_K));
#pragma unroll
for (int m = 0; m < size<1>(S); ++m) {
if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
#pragma unroll
for (int k = 0; k < size<2>(S); ++k) {
if (Is_even_K || predicate_K(k)) {
copy(thr_copy, S(_, m, k), D(_, m, k));
}
else if (Clear_OOB_K) {
clear(D(_, m, k));
}
}
}
else if (Clear_OOB_MN) {
clear(D(_, m, _));
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct MaxOp {
__device__ inline T operator()(T const& x, T const& y)
{
return x > y ? x : y;
}
};
template<>
struct MaxOp<float> {
// This is slightly faster
__device__ inline float operator()(float const& x, float const& y)
{
return max(x, y);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct SumOp {
__device__ inline T operator()(T const& x, T const& y)
{
return x + y;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int THREADS>
struct Allreduce {
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator& op)
{
constexpr int OFFSET = THREADS / 2;
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
return Allreduce<OFFSET>::run(x, op);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Allreduce<2> {
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator& op)
{
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
return x;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
template<typename Layout>
inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout)
{
static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3);
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
template<typename MMA_traits, typename Layout>
inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout)
{
using X = Underscore;
static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2);
static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2);
constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{});
static_assert(mma_shape_K == 8 || mma_shape_K == 16);
constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2;
auto l =
logical_divide(rowcol_layout, Shape<X, Shape<X, Int<MMA_N_divisor>>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2)))
return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), get<0, 1>(l), get<1, 1, 1>(l));
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename To_type, typename Engine, typename Layout>
inline __device__ auto convert_type(Tensor<Engine, Layout> const& tensor)
{
using From_type = typename Engine::value_type;
constexpr int numel = decltype(size(tensor))::value;
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
// HACK: this requires tensor to be "contiguous"
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel>*>(tensor.data()));
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace flash
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment