Unverified Commit 452822a4 authored by q.yao's avatar q.yao Committed by GitHub
Browse files

Add flashattention2 (#196)



* first

* fix causal mask

* disable flash attention2 on sm70

* fix 2

* update readme

* clang-format

* disable ft2 on windows

* fix lint

* fix build

* fix build

* fix long kv seq

* fix lint

* sync copy output

---------
Co-authored-by: default avatargrimoire <yaoqian@pjlab.org.cn>
Co-authored-by: default avatarirexyc <irexyc@gmail.com>
parent d4d609bd
#include "src/turbomind/models/llama/llama_kernels.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "42_fused_multi_head_attention/kernel_forward.h"
#include "mma_accum_lambda_iterator.h"
#include "tile_smem_loader.h"
#include "41_fused_multi_head_attention/kernel_forward.h"
#include <cuda_fp16.h>
#include <cutlass/arch/arch.h>
#include <cutlass/gemm/gemm.h>
......@@ -77,20 +75,10 @@ struct LlamaAttentionKernel:
int v_batch_seqs_offset = 0;
int o_batch_seqs_offset = 0;
int32_t o_strideM_custom = 0;
int32_t group_size = 1;
float scale;
CUTLASS_HOST_DEVICE int32_t o_strideM() const
{
if (o_strideM_custom == 0)
return BaseParams::head_dim_value;
else
return o_strideM_custom;
}
template<typename ptr_t>
CUTLASS_DEVICE void
update_batched_ptr(ptr_t& data_ptr, ptr_t* batch_seq_ptr, int batch_seq_offset, int batch_id, int strideB)
......@@ -107,8 +95,8 @@ struct LlamaAttentionKernel:
auto& query_ptr = BaseParams::query_ptr;
auto& key_ptr = BaseParams::key_ptr;
auto& value_ptr = BaseParams::value_ptr;
auto& cu_seqlens_q_ptr = BaseParams::cu_seqlens_q_ptr;
auto& cu_seqlens_k_ptr = BaseParams::cu_seqlens_k_ptr;
auto& cu_seqlens_q_ptr = BaseParams::seqstart_q_ptr;
auto& cu_seqlens_k_ptr = BaseParams::seqstart_k_ptr;
auto& output_ptr = BaseParams::output_ptr;
auto& output_accum_ptr = BaseParams::output_accum_ptr;
......@@ -119,22 +107,19 @@ struct LlamaAttentionKernel:
auto& num_queries = BaseParams::num_queries;
auto& num_keys = BaseParams::num_keys;
auto& causal = BaseParams::causal;
auto& q_strideM = BaseParams::q_strideM;
auto& k_strideM = BaseParams::k_strideM;
auto& v_strideM = BaseParams::v_strideM;
auto& o_strideM = BaseParams::o_strideM;
// Everything below is only used in `advance_to_block`
// and shouldn't use registers
auto& q_strideH = BaseParams::q_strideH;
auto& k_strideH = BaseParams::k_strideH;
auto& v_strideH = BaseParams::v_strideH;
auto& o_strideH = BaseParams::o_strideH;
auto& q_strideB = BaseParams::q_strideB;
auto& k_strideB = BaseParams::k_strideB;
auto& v_strideB = BaseParams::v_strideB;
auto& o_strideB = BaseParams::o_strideB;
auto& num_batches = BaseParams::num_batches;
auto& num_heads = BaseParams::num_heads;
......@@ -142,6 +127,8 @@ struct LlamaAttentionKernel:
auto head_id = blockIdx.y;
auto query_start = blockIdx.x * kQueriesPerBlock;
auto o_strideB = o_strideM * num_queries;
auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE;
int64_t q_start, k_start;
......@@ -203,10 +190,10 @@ struct LlamaAttentionKernel:
query_ptr += (qq_start + query_start) * q_strideM + head_id * q_strideH;
key_ptr += k_start * k_strideM + int64_t(head_id / group_size) * k_strideH;
value_ptr += k_start * v_strideM + int64_t(head_id / group_size) * v_strideH;
output_ptr += int64_t(qo_start + query_start) * o_strideM() + head_id * o_strideH;
output_ptr += int64_t(qo_start + query_start) * o_strideM + head_id * head_dim_value;
if (output_accum_ptr != nullptr) {
output_accum_ptr += int64_t(query_start) * o_strideM() + head_id * o_strideH;
output_accum_ptr += int64_t(query_start) * o_strideM + head_id * head_dim_value;
}
else {
// Accumulate directly in the destination buffer (eg for f32)
......@@ -218,9 +205,6 @@ struct LlamaAttentionKernel:
}
num_queries -= query_start;
if (causal) {
num_keys = cutlass::fast_min(int32_t(query_start + kQueriesPerBlock), num_keys);
}
num_batches = 0; // no longer used after
// Make sure the compiler knows these variables are the same on all
......@@ -328,7 +312,7 @@ struct LlamaAttentionKernel:
auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator {
using OutputTileIterator = typename MM1::OutputTileIterator;
return OutputTileIterator(typename OutputTileIterator::Params{(int32_t)p.o_strideM()},
return OutputTileIterator(typename OutputTileIterator::Params{(int32_t)p.o_strideM},
p.output_ptr,
typename OutputTileIterator::TensorCoord{p.num_queries, p.head_dim_value},
thread_id(),
......@@ -338,7 +322,7 @@ struct LlamaAttentionKernel:
auto createOutputAccumIter = [&](int col) -> typename MM1::OutputTileIteratorAccum {
using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum;
return OutputTileIteratorAccum(
typename OutputTileIteratorAccum::Params{(int32_t)p.o_strideM()},
typename OutputTileIteratorAccum::Params{(int32_t)p.o_strideM},
p.output_accum_ptr,
typename OutputTileIteratorAccum::TensorCoord{p.num_queries, p.head_dim_value},
thread_id(),
......@@ -453,28 +437,27 @@ struct LlamaAttentionKernel:
[&](int accum_m) {});
}
DISPATCH_BOOL(
iter_key_start == 0, kIsFirst, ([&] {
DISPATCH_BOOL(
p.num_keys - iter_key_start >= kKeysPerBlock, kFullColumns, ([&] {
// Update `mi` from accum stored in registers
// Also updates `accum` with accum[i] <-
// exp(accum[i] * scale
// - mi)
MM0::ScalingCoefsUpdater::update<kQueriesPerBlock, kFullColumns, kIsFirst, kKeepOutputInRF>(
accum_o,
accum,
mi,
m_prime,
s_prime,
lane_id(),
thread_id(),
warp_id(),
p.num_keys - iter_key_start,
iteratorC_tile_offset,
kSupportsBias ? 1.0f : p.scale);
}));
}));
DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] {
DISPATCH_BOOL(
p.num_keys - iter_key_start >= kKeysPerBlock, kFullColumns, ([&] {
// Update `mi` from accum stored in registers
// Also updates `accum` with accum[i] <-
// exp(accum[i] * scale
// - mi)
Base::iterative_softmax<MM0::Mma::Operator::IteratorC, kFullColumns, kIsFirst>(
accum_o,
accum,
mi,
m_prime,
s_prime,
lane_id(),
thread_id(),
warp_id(),
p.num_keys - iter_key_start,
iteratorC_tile_offset,
kSupportsBias ? 1.0f : p.scale);
}));
}));
// Output results to shared-memory
int warp_idx_mn_0 = my_warp_id % (MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN);
......@@ -651,13 +634,13 @@ struct LlamaAttentionKernel:
};
template<typename T, typename Attention>
void invokeFlashAttention_impl(int batch_size,
int head_num,
int key_len,
int seq_len,
int size_per_head,
typename FlashAttentionOp<T>::Params& attention_params,
cudaStream_t st)
void invokeFlashAttention_impl(int batch_size,
int head_num,
int key_len,
int seq_len,
int size_per_head,
typename FlashAttentionOpImpl<T, 1>::Params& attention_params,
cudaStream_t st)
{
T* out_ptr = attention_params.attn_out;
T* query_ptr = attention_params.query;
......@@ -685,11 +668,11 @@ void invokeFlashAttention_impl(int batch_size,
// fill param
typename Attention::Params params{};
{
params.query_ptr = (scalar_t*)(query_ptr);
params.key_ptr = (scalar_t*)(key_ptr);
params.value_ptr = (scalar_t*)(value_ptr);
params.attn_bias_ptr = (scalar_t*)(mask_ptr);
params.cu_seqlens_q_ptr = cu_seqlens_q_ptr;
params.query_ptr = (scalar_t*)(query_ptr);
params.key_ptr = (scalar_t*)(key_ptr);
params.value_ptr = (scalar_t*)(value_ptr);
params.attn_bias_ptr = (scalar_t*)(mask_ptr);
params.seqstart_q_ptr = cu_seqlens_q_ptr;
params.output_ptr = (scalar_t*)(out_ptr);
params.output_accum_ptr = kNeedsOutputAccumulatorBuffer ? output_accum_ptr : nullptr;
......@@ -725,9 +708,7 @@ void invokeFlashAttention_impl(int batch_size,
params.v_batch_seqs_ptr = (scalar_t**)layout_v.batch_seqs;
params.v_batch_seqs_offset = layout_v.batch_seqs_offset;
params.o_strideH = layout_o.stride_head;
params.o_strideM_custom = layout_o.stride_seq;
params.o_strideB = layout_o.stride_batch;
params.o_strideM = layout_o.stride_seq;
params.o_use_seqlens = layout_o.use_seqlens;
params.o_batch_seqs_ptr = (scalar_t**)layout_o.batch_seqs;
params.o_batch_seqs_offset = layout_o.batch_seqs_offset;
......@@ -784,14 +765,14 @@ bool get_needs_accum_buffer()
}
template<typename T, int kQueriesPerBlock, int kKeysPerBlock>
void invoke_attention_impl(bool single_val_iteration,
int batch_size,
int head_num,
int key_len,
int seq_len,
int size_per_head,
typename FlashAttentionOp<T>::Params& params,
cudaStream_t st)
void invoke_attention_impl(bool single_val_iteration,
int batch_size,
int head_num,
int key_len,
int seq_len,
int size_per_head,
typename FlashAttentionOpImpl<T, 1>::Params& params,
cudaStream_t st)
{
using scalar_t =
typename std::conditional_t<std::is_same<half, typename std::decay<T>::type>::value, cutlass::half_t, T>;
......@@ -830,18 +811,34 @@ void invoke_attention_impl(bool single_val_iter
}
template<typename T>
class FlashAttentionOp<T>::impl {
class FlashAttentionOpImpl<T, 1> {
public:
using AttentionLayout = BaseAttentionLayout<T>;
using Params = BaseAttentionParams<T>;
public:
FlashAttentionOpImpl(int batch_size, int head_num, int key_len, int seq_len, int size_per_head);
~FlashAttentionOpImpl();
int get_workspace_size() const;
void operator()(Params& params, cudaStream_t st) const;
private:
class impl;
std::unique_ptr<impl> pimpl;
};
template<typename T>
class FlashAttentionOpImpl<T, 1>::impl {
private:
static constexpr int kQueriesPerBlock = 32;
static constexpr int kKeysPerBlock = 128;
using ArchTag = cutlass::arch::Sm80;
using scalar_t =
typename std::conditional_t<std::is_same<half, typename std::decay<T>::type>::value, cutlass::half_t, T>;
using SingleValueAttention = LlamaAttentionKernel<scalar_t, ArchTag, kQueriesPerBlock, kKeysPerBlock, true>;
using MultiValueAttention = LlamaAttentionKernel<scalar_t, ArchTag, kQueriesPerBlock, kKeysPerBlock, false>;
using AttentionLayout = typename FlashAttentionOp<T>::AttentionLayout;
using Params = typename FlashAttentionOp<T>::Params;
using Params = typename FlashAttentionOpImpl<T, 1>::Params;
int batch_size_;
int head_num_;
......@@ -887,29 +884,30 @@ public:
};
template<typename T>
FlashAttentionOp<T>::FlashAttentionOp(int batch_size, int head_num, int key_len, int seq_len, int size_per_head):
pimpl{std::make_unique<FlashAttentionOp<T>::impl>(batch_size, head_num, key_len, seq_len, size_per_head)}
FlashAttentionOpImpl<T, 1>::FlashAttentionOpImpl(
int batch_size, int head_num, int key_len, int seq_len, int size_per_head):
pimpl{std::make_unique<FlashAttentionOpImpl<T, 1>::impl>(batch_size, head_num, key_len, seq_len, size_per_head)}
{
}
template<typename T>
FlashAttentionOp<T>::~FlashAttentionOp()
FlashAttentionOpImpl<T, 1>::~FlashAttentionOpImpl()
{
}
template<typename T>
int FlashAttentionOp<T>::get_workspace_size() const
int FlashAttentionOpImpl<T, 1>::get_workspace_size() const
{
return pimpl->get_workspace_size();
}
template<typename T>
void FlashAttentionOp<T>::operator()(Params& params, cudaStream_t st) const
void FlashAttentionOpImpl<T, 1>::operator()(Params& params, cudaStream_t st) const
{
pimpl->operator()(params, st);
}
template class FlashAttentionOp<float>;
template class FlashAttentionOp<half>;
template class FlashAttentionOpImpl<float, 1>;
template class FlashAttentionOpImpl<half, 1>;
} // namespace turbomind
......@@ -329,11 +329,11 @@ static inline __device__ uint32_t float4_to_char4(float x, float y, float z, flo
asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, %0;\n" : "+r"(dst) : "r"(b), "r"(a));
#else
char4 tmp;
tmp.x = x;
tmp.y = y;
tmp.z = z;
tmp.w = w;
dst = reinterpret_cast<const uint32_t&>(tmp);
tmp.x = x;
tmp.y = y;
tmp.z = z;
tmp.w = w;
dst = reinterpret_cast<const uint32_t&>(tmp);
#endif
return dst;
}
......@@ -724,4 +724,61 @@ void invokeGatherOutput(int* output_ids,
output_ids, ids, context_length, max_context_len, max_gen_step, max_output_len, batch_size);
}
#define VERSION_SWITCH(VERSION, CONST_NAME, ...) \
[&] { \
if (VERSION == 2) { \
constexpr static int CONST_NAME = 2; \
return __VA_ARGS__(); \
} \
else { \
constexpr static int CONST_NAME = 1; \
return __VA_ARGS__(); \
} \
}()
template<typename T>
FlashAttentionOp<T>::FlashAttentionOp(int batch_size, int head_num, int key_len, int seq_len, int size_per_head):
batch_size_(batch_size), head_num_(head_num), key_len_(key_len), seq_len_(seq_len), size_per_head_(size_per_head)
{
#ifdef _MSC_VER
op_version_ = 1;
#else
op_version_ = std::is_same<half, typename std::decay<T>::type>::value ? 2 : 1;
if (op_version_ == 2 && getSMVersion() < 80) {
op_version_ = 1;
}
#endif
}
template<typename T>
int FlashAttentionOp<T>::get_workspace_size() const
{
#ifdef _MSC_VER
FlashAttentionOpImpl<T, 1> attention_op(batch_size_, head_num_, key_len_, seq_len_, size_per_head_);
return attention_op.get_workspace_size();
#else
return VERSION_SWITCH(op_version_, OP_VERSION, [&]() {
FlashAttentionOpImpl<T, OP_VERSION> attention_op(batch_size_, head_num_, key_len_, seq_len_, size_per_head_);
return attention_op.get_workspace_size();
});
#endif
}
template<typename T>
void FlashAttentionOp<T>::operator()(Params& params, cudaStream_t st) const
{
#ifdef _MSC_VER
FlashAttentionOpImpl<T, 1> attention_op(batch_size_, head_num_, key_len_, seq_len_, size_per_head_);
return attention_op(params, st);
#else
return VERSION_SWITCH(op_version_, OP_VERSION, [&]() {
FlashAttentionOpImpl<T, OP_VERSION> attention_op(batch_size_, head_num_, key_len_, seq_len_, size_per_head_);
return attention_op(params, st);
});
#endif
}
template class FlashAttentionOp<float>;
template class FlashAttentionOp<half>;
} // namespace turbomind
......@@ -79,36 +79,41 @@ void invokeGatherOutput(int* output_ids,
void invokeMyCopyInt(int* dst, const int* src, size_t count, cudaStream_t st);
template<typename T>
class FlashAttentionOp {
struct BaseAttentionLayout {
int stride_batch;
int stride_seq;
int stride_head;
bool use_seqlens = false;
int batch_seqs_offset = 0;
T** batch_seqs = nullptr;
};
template<typename T>
struct BaseAttentionParams {
T* attn_out;
T* query;
T* key;
T* val;
T* mask;
float* out_accum = nullptr;
int* cu_seqlens_q = nullptr;
int* cu_seqlens_k = nullptr;
size_t group_size = 1;
BaseAttentionLayout<T> layout_q;
BaseAttentionLayout<T> layout_k;
BaseAttentionLayout<T> layout_v;
BaseAttentionLayout<T> layout_o;
};
template<typename T, int version>
class FlashAttentionOpImpl {
public:
struct AttentionLayout {
int stride_batch;
int stride_seq;
int stride_head;
bool use_seqlens = false;
int batch_seqs_offset = 0;
T** batch_seqs = nullptr;
};
struct Params {
T* attn_out;
T* query;
T* key;
T* val;
T* mask;
float* out_accum = nullptr;
int* cu_seqlens_q = nullptr;
int* cu_seqlens_k = nullptr;
size_t group_size = 1;
AttentionLayout layout_q;
AttentionLayout layout_k;
AttentionLayout layout_v;
AttentionLayout layout_o;
};
using AttentionLayout = BaseAttentionLayout<T>;
using Params = BaseAttentionParams<T>;
public:
FlashAttentionOp(int batch_size, int head_num, int key_len, int seq_len, int size_per_head);
~FlashAttentionOp();
FlashAttentionOpImpl(int batch_size, int head_num, int key_len, int seq_len, int size_per_head);
~FlashAttentionOpImpl();
int get_workspace_size() const;
......@@ -119,6 +124,28 @@ private:
std::unique_ptr<impl> pimpl;
};
template<typename T>
class FlashAttentionOp {
public:
using AttentionLayout = BaseAttentionLayout<T>;
using Params = BaseAttentionParams<T>;
public:
FlashAttentionOp(int batch_size, int head_num, int key_len, int seq_len, int size_per_head);
int get_workspace_size() const;
void operator()(Params& params, cudaStream_t st) const;
private:
int batch_size_;
int head_num_;
int key_len_;
int seq_len_;
int size_per_head_;
int op_version_;
};
template<typename T>
inline void dump(const T* x, int size, cudaStream_t st, const char* msg, bool full = false)
{
......
......@@ -10,6 +10,8 @@
#include <thrust/host_vector.h>
#include <thrust/transform.h>
#undef TORCH_CUDA
#include "src/turbomind/kernels/bert_preprocess_kernels.h"
#include "src/turbomind/kernels/unfused_attention_kernels.h"
#include "src/turbomind/models/llama/llama_kernels.h"
......@@ -202,15 +204,6 @@ void naive_mha(scalar_t* out_ptr,
pad_out(out_ptr, cu_seqlens, batch_size, head_num * seq_len * size_per_head, head_num * size_per_head, stream);
}
template<typename scalar_t>
struct UpdateMask {
UpdateMask() {}
__host__ __device__ scalar_t operator()(const scalar_t& x) const
{
return x > scalar_t(0.0f) ? scalar_t(1.0f) : scalar_t(0.0f);
}
};
static const char* usage = "Usage: %s <batch-size> <num-heads> <key-len> <query-len> <size-per-head>\n"
"Example: $test_context_attention_layer 2, 8, 1024, 512, 128\n";
......@@ -290,13 +283,14 @@ int main(int argc, const char* argv[])
cudaRandomUniform<scalar_t>(key_ptr, batch_size * num_heads * key_len * size_per_head);
cudaRandomUniform<scalar_t>(val_ptr, batch_size * num_heads * key_len * size_per_head);
cudaRandomUniform<scalar_t>(mask_ptr, batch_size * seq_len * key_len);
thrust::transform(
thrust::device, mask_ptr, mask_ptr + batch_size * seq_len * key_len, mask_ptr, UpdateMask<scalar_t>());
// create random length for batch
std::uniform_int_distribution<int> dist{seq_len / 2, seq_len};
auto gen = [&dist, &mersenne_engine]() { return dist(mersenne_engine); };
std::generate(begin(input_lengths_host), end(input_lengths_host), gen);
// for(int batch_id=0;batch_id<batch_size;++batch_id){
// input_lengths_host[batch_id] = seq_len;
// }
thrust::copy(input_lengths_host.begin(), input_lengths_host.end(), input_lengths.begin());
size_t h_token_num = 0;
size_t* h_pinned_token_num;
......@@ -312,6 +306,12 @@ int main(int argc, const char* argv[])
stream);
cudaFreeHost((void*)h_pinned_token_num);
int* k_lens = (int*)allocator.malloc(batch_size * sizeof(int));
deviceFill(k_lens, batch_size, key_len, stream);
invokeCreateCausalMasks(mask_ptr, input_lengths_ptr, k_lens, seq_len, key_len, batch_size, stream);
// deviceFill(mask_ptr, batch_size*key_len*seq_len, scalar_t(1), stream);
// compute gt
naive_mha<scalar_t>(expect_out_ptr,
query_ptr,
......@@ -334,7 +334,12 @@ int main(int argc, const char* argv[])
&cublas_wrapper);
// compute actual
using AttentionOp = FlashAttentionOp<scalar_t>;
#ifdef _MSC_VER
static constexpr int FMHA_VERSION = 1;
#else
static constexpr int FMHA_VERSION = 2;
#endif
using AttentionOp = FlashAttentionOpImpl<scalar_t, FMHA_VERSION>;
using Layout = typename AttentionOp::AttentionLayout;
Layout layout_q{num_heads * seq_len * size_per_head, size_per_head, seq_len * size_per_head};
Layout layout_k{num_heads * key_len * size_per_head, size_per_head, key_len * size_per_head};
......@@ -359,11 +364,13 @@ int main(int argc, const char* argv[])
flash_attention(attn_params, stream);
sync_check_cuda_error();
// int num_rows = 8;
// printf("expect:\n");
// printMatrix(expect_out_ptr, num_rows, size_per_head, size_per_head, true);
// printf("actual:\n");
// printMatrix(actual_out_ptr, num_rows, size_per_head, size_per_head, true);
int num_rows = 8;
// printf("query:\n");
// printMatrix(query_ptr, num_rows, 8, size_per_head, true);
printf("expect:\n");
printMatrix(expect_out_ptr, num_rows, 8, size_per_head, true);
printf("actual:\n");
printMatrix(actual_out_ptr, num_rows, 8, size_per_head, true);
checkResult(
"all close:", actual_out_ptr, expect_out_ptr, batch_size * num_heads * seq_len * size_per_head, true, true);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment