"...text-generation-inference.git" did not exist on "cd5961b5dad560d63f4dd42d08d6ee3877b82003"
Unverified Commit 4b121180 authored by q.yao's avatar q.yao Committed by GitHub
Browse files

Fix fmha on sm 70 (#12)



* update arch

* clang-format

* remove comment

---------
Co-authored-by: default avataryaoqian <yaoqian@localhost.localdomain>
parent 102aefda
#include "src/fastertransformer/models/llama/llama_kernels.h"
#include "src/fastertransformer/utils/cuda_utils.h"
#include "42_fused_multi_head_attention/kernel_forward.h"
#include "mma_accum_lambda_iterator.h"
......@@ -14,18 +15,6 @@
namespace fastertransformer {
#if !defined(__CUDA_ARCH__)
using ArchTag = cutlass::arch::Sm80;
#else
#if (__CUDA_ARCH__ >= 800)
using ArchTag = cutlass::arch::Sm80;
#elif (__CUDA_ARCH__ >= 750)
using ArchTag = cutlass::arch::Sm75;
#elif (__CUDA_ARCH__ >= 700)
using ArchTag = cutlass::arch::Sm70;
#endif
#endif
template<
// dtype of Q/K/V/M
typename Element_,
......@@ -759,12 +748,89 @@ void invokeFlashAttention_impl(int batch_size,
attention_kernel_batched_impl<Attention><<<block_grid, thread_grid, smem_bytes, st>>>(params);
}
#define CUTLASS_ARCH(sm) cutlass::arch::Sm##sm
#define ATTENTION_KERNEL(scalar_t, sm, querys_per_block, keys_per_block, single_value) \
LlamaAttentionKernel<scalar_t, CUTLASS_ARCH(sm), querys_per_block, keys_per_block, single_value>
template<typename T, int kQueriesPerBlock, int kKeysPerBlock>
bool get_needs_accum_buffer()
{
using scalar_t =
typename std::conditional_t<std::is_same<half, typename std::decay<T>::type>::value, cutlass::half_t, T>;
#define GET_NEED_ACCUM_BUFFER(sm) \
ATTENTION_KERNEL(scalar_t, sm, kQueriesPerBlock, kKeysPerBlock, false)::kNeedsOutputAccumulatorBuffer
auto sm = getSMVersion();
switch (sm) {
case 75:
return GET_NEED_ACCUM_BUFFER(75);
default:
if (sm >= 80) {
return GET_NEED_ACCUM_BUFFER(80);
}
else {
return GET_NEED_ACCUM_BUFFER(70);
}
}
#undef GET_NEED_ACCUM_BUFFER
}
template<typename T, int kQueriesPerBlock, int kKeysPerBlock>
void invoke_attention_impl(bool single_val_iteration,
int batch_size,
int head_num,
int key_len,
int seq_len,
int size_per_head,
typename FlashAttentionOp<T>::Params& params,
cudaStream_t st)
{
using scalar_t =
typename std::conditional_t<std::is_same<half, typename std::decay<T>::type>::value, cutlass::half_t, T>;
#define INVOKE_ATTEN_IMPL(sm, single_value) \
{ \
using AttentionKernel = ATTENTION_KERNEL(scalar_t, sm, kQueriesPerBlock, kKeysPerBlock, single_value); \
invokeFlashAttention_impl<T, AttentionKernel>( \
batch_size, head_num, key_len, seq_len, size_per_head, params, st); \
}
#define INVOKE_ATTENN_IMPL_V2(sm) \
{ \
if (single_val_iteration) \
INVOKE_ATTEN_IMPL(sm, true) \
else \
INVOKE_ATTEN_IMPL(sm, false) \
}
auto sm = getSMVersion();
switch (sm) {
case 75:
INVOKE_ATTENN_IMPL_V2(75);
break;
default:
if (sm >= 80) {
INVOKE_ATTENN_IMPL_V2(80);
}
else {
INVOKE_ATTENN_IMPL_V2(70);
}
}
#undef INVOKE_ATTENN_IMPL_V2
#undef INVOKE_ATTEN_IMPL
}
template<typename T>
class FlashAttentionOp<T>::impl {
private:
static constexpr int kQueriesPerBlock = 32;
static constexpr int kKeysPerBlock = 128;
using ArchTag = cutlass::arch::Sm80;
using scalar_t =
typename std::conditional_t<std::is_same<half, typename std::decay<T>::type>::value, cutlass::half_t, T>;
using SingleValueAttention = LlamaAttentionKernel<scalar_t, ArchTag, kQueriesPerBlock, kKeysPerBlock, true>;
......@@ -798,7 +864,7 @@ public:
return 0;
}
else {
constexpr bool kNeedsOutputAccumulatorBuffer = MultiValueAttention::kNeedsOutputAccumulatorBuffer;
bool kNeedsOutputAccumulatorBuffer = get_needs_accum_buffer<T, kQueriesPerBlock, kKeysPerBlock>();
if (kNeedsOutputAccumulatorBuffer) {
return batch_size_ * head_num_ * seq_len_ * size_per_head_ * sizeof(float);
}
......@@ -810,14 +876,8 @@ public:
void operator()(Params& params, cudaStream_t st) const
{
if (kSingleValueIteration) {
invokeFlashAttention_impl<T, SingleValueAttention>(
batch_size_, head_num_, key_len_, seq_len_, size_per_head_, params, st);
}
else {
invokeFlashAttention_impl<T, MultiValueAttention>(
batch_size_, head_num_, key_len_, seq_len_, size_per_head_, params, st);
}
invoke_attention_impl<T, kQueriesPerBlock, kKeysPerBlock>(
kSingleValueIteration, batch_size_, head_num_, key_len_, seq_len_, size_per_head_, params, st);
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment