Unverified Commit 5ed6bb59 authored by q.yao's avatar q.yao Committed by GitHub
Browse files

support fmha gqa (#160)


Co-authored-by: default avatargrimoire <yaoqian@pjlab.org.cn>
parent 5203c850
......@@ -196,7 +196,7 @@ def export(model_name: str,
step_length=1,
cache_max_entry_count=48,
cache_chunk_size=1,
use_context_fmha=int(kv_head_num == head_num),
use_context_fmha=1,
quant_policy=0,
tensor_para_size=tp))
......
......@@ -208,7 +208,6 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
sync_check_cuda_error();
if (use_fmha_) {
FT_CHECK(local_head_num_ == local_kv_head_num_);
fusedMultiHeadAttention(k_cache_ptrs,
v_cache_ptrs,
layer_offset,
......@@ -285,8 +284,8 @@ void LlamaContextAttentionLayer<T>::fusedMultiHeadAttention(T** key_cache_ptr
.stride_head = int(size_per_head_),
.use_seqlens = true,
};
size_t group_size = size_t(local_head_num_ / local_kv_head_num_);
AttentionOp flash_attention(batch_size, local_head_num_, max_k_len, max_q_len, size_per_head_);
typename AttentionOp::Params attn_params{.attn_out = qkv_buf_3_,
.query = q_buf_2_,
.key = k_cache_buf_,
......@@ -295,6 +294,7 @@ void LlamaContextAttentionLayer<T>::fusedMultiHeadAttention(T** key_cache_ptr
.out_accum = qk_buf_float_,
.cu_seqlens_q = cu_seqlens,
.cu_seqlens_k = nullptr,
.group_size = group_size,
.layout_q = layout_q,
.layout_k = layout_k,
.layout_v = layout_v,
......
......@@ -79,6 +79,8 @@ struct LlamaAttentionKernel:
int32_t o_strideM_custom = 0;
int32_t group_size = 1;
float scale;
CUTLASS_HOST_DEVICE int32_t o_strideM() const
......@@ -199,8 +201,8 @@ struct LlamaAttentionKernel:
// Advance to the current batch / head / query_start
query_ptr += (qq_start + query_start) * q_strideM + head_id * q_strideH;
key_ptr += k_start * k_strideM + head_id * k_strideH;
value_ptr += k_start * v_strideM + head_id * v_strideH;
key_ptr += k_start * k_strideM + int64_t(head_id / group_size) * k_strideH;
value_ptr += k_start * v_strideM + int64_t(head_id / group_size) * v_strideH;
output_ptr += int64_t(qo_start + query_start) * o_strideM() + head_id * o_strideH;
if (output_accum_ptr != nullptr) {
......@@ -668,6 +670,7 @@ void invokeFlashAttention_impl(int batch_size,
auto layout_k = attention_params.layout_k;
auto layout_v = attention_params.layout_v;
auto layout_o = attention_params.layout_o;
auto group_size = attention_params.group_size;
using scalar_t =
typename std::conditional_t<std::is_same<half, typename std::decay<T>::type>::value, cutlass::half_t, T>;
......@@ -731,6 +734,8 @@ void invokeFlashAttention_impl(int batch_size,
params.num_batches = batch_size;
params.num_heads = head_num;
params.group_size = int32_t(group_size);
}
Attention::check_supported(params);
......
......@@ -99,6 +99,7 @@ public:
float* out_accum = nullptr;
int* cu_seqlens_q = nullptr;
int* cu_seqlens_k = nullptr;
size_t group_size = 1;
AttentionLayout layout_q;
AttentionLayout layout_k;
AttentionLayout layout_v;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment