You need to sign in or sign up before continuing.
Unverified Commit 5ed6bb59 authored by q.yao's avatar q.yao Committed by GitHub
Browse files

support fmha gqa (#160)


Co-authored-by: default avatargrimoire <yaoqian@pjlab.org.cn>
parent 5203c850
...@@ -196,7 +196,7 @@ def export(model_name: str, ...@@ -196,7 +196,7 @@ def export(model_name: str,
step_length=1, step_length=1,
cache_max_entry_count=48, cache_max_entry_count=48,
cache_chunk_size=1, cache_chunk_size=1,
use_context_fmha=int(kv_head_num == head_num), use_context_fmha=1,
quant_policy=0, quant_policy=0,
tensor_para_size=tp)) tensor_para_size=tp))
......
...@@ -208,7 +208,6 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap* ...@@ -208,7 +208,6 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
sync_check_cuda_error(); sync_check_cuda_error();
if (use_fmha_) { if (use_fmha_) {
FT_CHECK(local_head_num_ == local_kv_head_num_);
fusedMultiHeadAttention(k_cache_ptrs, fusedMultiHeadAttention(k_cache_ptrs,
v_cache_ptrs, v_cache_ptrs,
layer_offset, layer_offset,
...@@ -285,8 +284,8 @@ void LlamaContextAttentionLayer<T>::fusedMultiHeadAttention(T** key_cache_ptr ...@@ -285,8 +284,8 @@ void LlamaContextAttentionLayer<T>::fusedMultiHeadAttention(T** key_cache_ptr
.stride_head = int(size_per_head_), .stride_head = int(size_per_head_),
.use_seqlens = true, .use_seqlens = true,
}; };
AttentionOp flash_attention(batch_size, local_head_num_, max_k_len, max_q_len, size_per_head_); size_t group_size = size_t(local_head_num_ / local_kv_head_num_);
AttentionOp flash_attention(batch_size, local_head_num_, max_k_len, max_q_len, size_per_head_);
typename AttentionOp::Params attn_params{.attn_out = qkv_buf_3_, typename AttentionOp::Params attn_params{.attn_out = qkv_buf_3_,
.query = q_buf_2_, .query = q_buf_2_,
.key = k_cache_buf_, .key = k_cache_buf_,
...@@ -295,6 +294,7 @@ void LlamaContextAttentionLayer<T>::fusedMultiHeadAttention(T** key_cache_ptr ...@@ -295,6 +294,7 @@ void LlamaContextAttentionLayer<T>::fusedMultiHeadAttention(T** key_cache_ptr
.out_accum = qk_buf_float_, .out_accum = qk_buf_float_,
.cu_seqlens_q = cu_seqlens, .cu_seqlens_q = cu_seqlens,
.cu_seqlens_k = nullptr, .cu_seqlens_k = nullptr,
.group_size = group_size,
.layout_q = layout_q, .layout_q = layout_q,
.layout_k = layout_k, .layout_k = layout_k,
.layout_v = layout_v, .layout_v = layout_v,
......
...@@ -79,6 +79,8 @@ struct LlamaAttentionKernel: ...@@ -79,6 +79,8 @@ struct LlamaAttentionKernel:
int32_t o_strideM_custom = 0; int32_t o_strideM_custom = 0;
int32_t group_size = 1;
float scale; float scale;
CUTLASS_HOST_DEVICE int32_t o_strideM() const CUTLASS_HOST_DEVICE int32_t o_strideM() const
...@@ -199,8 +201,8 @@ struct LlamaAttentionKernel: ...@@ -199,8 +201,8 @@ struct LlamaAttentionKernel:
// Advance to the current batch / head / query_start // Advance to the current batch / head / query_start
query_ptr += (qq_start + query_start) * q_strideM + head_id * q_strideH; query_ptr += (qq_start + query_start) * q_strideM + head_id * q_strideH;
key_ptr += k_start * k_strideM + head_id * k_strideH; key_ptr += k_start * k_strideM + int64_t(head_id / group_size) * k_strideH;
value_ptr += k_start * v_strideM + head_id * v_strideH; value_ptr += k_start * v_strideM + int64_t(head_id / group_size) * v_strideH;
output_ptr += int64_t(qo_start + query_start) * o_strideM() + head_id * o_strideH; output_ptr += int64_t(qo_start + query_start) * o_strideM() + head_id * o_strideH;
if (output_accum_ptr != nullptr) { if (output_accum_ptr != nullptr) {
...@@ -668,6 +670,7 @@ void invokeFlashAttention_impl(int batch_size, ...@@ -668,6 +670,7 @@ void invokeFlashAttention_impl(int batch_size,
auto layout_k = attention_params.layout_k; auto layout_k = attention_params.layout_k;
auto layout_v = attention_params.layout_v; auto layout_v = attention_params.layout_v;
auto layout_o = attention_params.layout_o; auto layout_o = attention_params.layout_o;
auto group_size = attention_params.group_size;
using scalar_t = using scalar_t =
typename std::conditional_t<std::is_same<half, typename std::decay<T>::type>::value, cutlass::half_t, T>; typename std::conditional_t<std::is_same<half, typename std::decay<T>::type>::value, cutlass::half_t, T>;
...@@ -731,6 +734,8 @@ void invokeFlashAttention_impl(int batch_size, ...@@ -731,6 +734,8 @@ void invokeFlashAttention_impl(int batch_size,
params.num_batches = batch_size; params.num_batches = batch_size;
params.num_heads = head_num; params.num_heads = head_num;
params.group_size = int32_t(group_size);
} }
Attention::check_supported(params); Attention::check_supported(params);
......
...@@ -99,6 +99,7 @@ public: ...@@ -99,6 +99,7 @@ public:
float* out_accum = nullptr; float* out_accum = nullptr;
int* cu_seqlens_q = nullptr; int* cu_seqlens_q = nullptr;
int* cu_seqlens_k = nullptr; int* cu_seqlens_k = nullptr;
size_t group_size = 1;
AttentionLayout layout_q; AttentionLayout layout_q;
AttentionLayout layout_k; AttentionLayout layout_k;
AttentionLayout layout_v; AttentionLayout layout_v;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment