Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Lmdeploy
Commits
5ed6bb59
Unverified
Commit
5ed6bb59
authored
Jul 25, 2023
by
q.yao
Committed by
GitHub
Jul 25, 2023
Browse files
support fmha gqa (#160)
Co-authored-by:
grimoire
<
yaoqian@pjlab.org.cn
>
parent
5203c850
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
12 additions
and
6 deletions
+12
-6
lmdeploy/serve/turbomind/deploy.py
lmdeploy/serve/turbomind/deploy.py
+1
-1
src/turbomind/models/llama/LlamaContextAttentionLayer.cc
src/turbomind/models/llama/LlamaContextAttentionLayer.cc
+3
-3
src/turbomind/models/llama/fused_multi_head_attention/llama_flash_attention_kernel.cu
...used_multi_head_attention/llama_flash_attention_kernel.cu
+7
-2
src/turbomind/models/llama/llama_kernels.h
src/turbomind/models/llama/llama_kernels.h
+1
-0
No files found.
lmdeploy/serve/turbomind/deploy.py
View file @
5ed6bb59
...
...
@@ -196,7 +196,7 @@ def export(model_name: str,
step_length
=
1
,
cache_max_entry_count
=
48
,
cache_chunk_size
=
1
,
use_context_fmha
=
int
(
kv_head_num
==
head_num
)
,
use_context_fmha
=
1
,
quant_policy
=
0
,
tensor_para_size
=
tp
))
...
...
src/turbomind/models/llama/LlamaContextAttentionLayer.cc
View file @
5ed6bb59
...
...
@@ -208,7 +208,6 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
sync_check_cuda_error
();
if
(
use_fmha_
)
{
FT_CHECK
(
local_head_num_
==
local_kv_head_num_
);
fusedMultiHeadAttention
(
k_cache_ptrs
,
v_cache_ptrs
,
layer_offset
,
...
...
@@ -285,8 +284,8 @@ void LlamaContextAttentionLayer<T>::fusedMultiHeadAttention(T** key_cache_ptr
.
stride_head
=
int
(
size_per_head_
),
.
use_seqlens
=
true
,
};
AttentionOp
flash_attention
(
batch_size
,
local_head_num_
,
max_k_len
,
max_q_len
,
size_per
_head_
);
size_t
group_size
=
size_t
(
local_head_num_
/
local_kv
_head_
num_
);
AttentionOp
flash_attention
(
batch_size
,
local_head_num_
,
max_k_len
,
max_q_len
,
size_per_head_
);
typename
AttentionOp
::
Params
attn_params
{.
attn_out
=
qkv_buf_3_
,
.
query
=
q_buf_2_
,
.
key
=
k_cache_buf_
,
...
...
@@ -295,6 +294,7 @@ void LlamaContextAttentionLayer<T>::fusedMultiHeadAttention(T** key_cache_ptr
.
out_accum
=
qk_buf_float_
,
.
cu_seqlens_q
=
cu_seqlens
,
.
cu_seqlens_k
=
nullptr
,
.
group_size
=
group_size
,
.
layout_q
=
layout_q
,
.
layout_k
=
layout_k
,
.
layout_v
=
layout_v
,
...
...
src/turbomind/models/llama/fused_multi_head_attention/llama_flash_attention_kernel.cu
View file @
5ed6bb59
...
...
@@ -79,6 +79,8 @@ struct LlamaAttentionKernel:
int32_t
o_strideM_custom
=
0
;
int32_t
group_size
=
1
;
float
scale
;
CUTLASS_HOST_DEVICE
int32_t
o_strideM
()
const
...
...
@@ -199,8 +201,8 @@ struct LlamaAttentionKernel:
// Advance to the current batch / head / query_start
query_ptr
+=
(
qq_start
+
query_start
)
*
q_strideM
+
head_id
*
q_strideH
;
key_ptr
+=
k_start
*
k_strideM
+
head_id
*
k_strideH
;
value_ptr
+=
k_start
*
v_strideM
+
head_id
*
v_strideH
;
key_ptr
+=
k_start
*
k_strideM
+
int64_t
(
head_id
/
group_size
)
*
k_strideH
;
value_ptr
+=
k_start
*
v_strideM
+
int64_t
(
head_id
/
group_size
)
*
v_strideH
;
output_ptr
+=
int64_t
(
qo_start
+
query_start
)
*
o_strideM
()
+
head_id
*
o_strideH
;
if
(
output_accum_ptr
!=
nullptr
)
{
...
...
@@ -668,6 +670,7 @@ void invokeFlashAttention_impl(int batch_size,
auto
layout_k
=
attention_params
.
layout_k
;
auto
layout_v
=
attention_params
.
layout_v
;
auto
layout_o
=
attention_params
.
layout_o
;
auto
group_size
=
attention_params
.
group_size
;
using
scalar_t
=
typename
std
::
conditional_t
<
std
::
is_same
<
half
,
typename
std
::
decay
<
T
>::
type
>::
value
,
cutlass
::
half_t
,
T
>
;
...
...
@@ -731,6 +734,8 @@ void invokeFlashAttention_impl(int batch_size,
params
.
num_batches
=
batch_size
;
params
.
num_heads
=
head_num
;
params
.
group_size
=
int32_t
(
group_size
);
}
Attention
::
check_supported
(
params
);
...
...
src/turbomind/models/llama/llama_kernels.h
View file @
5ed6bb59
...
...
@@ -99,6 +99,7 @@ public:
float
*
out_accum
=
nullptr
;
int
*
cu_seqlens_q
=
nullptr
;
int
*
cu_seqlens_k
=
nullptr
;
size_t
group_size
=
1
;
AttentionLayout
layout_q
;
AttentionLayout
layout_k
;
AttentionLayout
layout_v
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment