Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Lmdeploy
Commits
5ed6bb59
You need to sign in or sign up before continuing.
Unverified
Commit
5ed6bb59
authored
Jul 25, 2023
by
q.yao
Committed by
GitHub
Jul 25, 2023
Browse files
support fmha gqa (#160)
Co-authored-by:
grimoire
<
yaoqian@pjlab.org.cn
>
parent
5203c850
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
12 additions
and
6 deletions
+12
-6
lmdeploy/serve/turbomind/deploy.py
lmdeploy/serve/turbomind/deploy.py
+1
-1
src/turbomind/models/llama/LlamaContextAttentionLayer.cc
src/turbomind/models/llama/LlamaContextAttentionLayer.cc
+3
-3
src/turbomind/models/llama/fused_multi_head_attention/llama_flash_attention_kernel.cu
...used_multi_head_attention/llama_flash_attention_kernel.cu
+7
-2
src/turbomind/models/llama/llama_kernels.h
src/turbomind/models/llama/llama_kernels.h
+1
-0
No files found.
lmdeploy/serve/turbomind/deploy.py
View file @
5ed6bb59
...
@@ -196,7 +196,7 @@ def export(model_name: str,
...
@@ -196,7 +196,7 @@ def export(model_name: str,
step_length
=
1
,
step_length
=
1
,
cache_max_entry_count
=
48
,
cache_max_entry_count
=
48
,
cache_chunk_size
=
1
,
cache_chunk_size
=
1
,
use_context_fmha
=
int
(
kv_head_num
==
head_num
)
,
use_context_fmha
=
1
,
quant_policy
=
0
,
quant_policy
=
0
,
tensor_para_size
=
tp
))
tensor_para_size
=
tp
))
...
...
src/turbomind/models/llama/LlamaContextAttentionLayer.cc
View file @
5ed6bb59
...
@@ -208,7 +208,6 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
...
@@ -208,7 +208,6 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
sync_check_cuda_error
();
sync_check_cuda_error
();
if
(
use_fmha_
)
{
if
(
use_fmha_
)
{
FT_CHECK
(
local_head_num_
==
local_kv_head_num_
);
fusedMultiHeadAttention
(
k_cache_ptrs
,
fusedMultiHeadAttention
(
k_cache_ptrs
,
v_cache_ptrs
,
v_cache_ptrs
,
layer_offset
,
layer_offset
,
...
@@ -285,8 +284,8 @@ void LlamaContextAttentionLayer<T>::fusedMultiHeadAttention(T** key_cache_ptr
...
@@ -285,8 +284,8 @@ void LlamaContextAttentionLayer<T>::fusedMultiHeadAttention(T** key_cache_ptr
.
stride_head
=
int
(
size_per_head_
),
.
stride_head
=
int
(
size_per_head_
),
.
use_seqlens
=
true
,
.
use_seqlens
=
true
,
};
};
AttentionOp
flash_attention
(
batch_size
,
local_head_num_
,
max_k_len
,
max_q_len
,
size_per
_head_
);
size_t
group_size
=
size_t
(
local_head_num_
/
local_kv
_head_
num_
);
AttentionOp
flash_attention
(
batch_size
,
local_head_num_
,
max_k_len
,
max_q_len
,
size_per_head_
);
typename
AttentionOp
::
Params
attn_params
{.
attn_out
=
qkv_buf_3_
,
typename
AttentionOp
::
Params
attn_params
{.
attn_out
=
qkv_buf_3_
,
.
query
=
q_buf_2_
,
.
query
=
q_buf_2_
,
.
key
=
k_cache_buf_
,
.
key
=
k_cache_buf_
,
...
@@ -295,6 +294,7 @@ void LlamaContextAttentionLayer<T>::fusedMultiHeadAttention(T** key_cache_ptr
...
@@ -295,6 +294,7 @@ void LlamaContextAttentionLayer<T>::fusedMultiHeadAttention(T** key_cache_ptr
.
out_accum
=
qk_buf_float_
,
.
out_accum
=
qk_buf_float_
,
.
cu_seqlens_q
=
cu_seqlens
,
.
cu_seqlens_q
=
cu_seqlens
,
.
cu_seqlens_k
=
nullptr
,
.
cu_seqlens_k
=
nullptr
,
.
group_size
=
group_size
,
.
layout_q
=
layout_q
,
.
layout_q
=
layout_q
,
.
layout_k
=
layout_k
,
.
layout_k
=
layout_k
,
.
layout_v
=
layout_v
,
.
layout_v
=
layout_v
,
...
...
src/turbomind/models/llama/fused_multi_head_attention/llama_flash_attention_kernel.cu
View file @
5ed6bb59
...
@@ -79,6 +79,8 @@ struct LlamaAttentionKernel:
...
@@ -79,6 +79,8 @@ struct LlamaAttentionKernel:
int32_t
o_strideM_custom
=
0
;
int32_t
o_strideM_custom
=
0
;
int32_t
group_size
=
1
;
float
scale
;
float
scale
;
CUTLASS_HOST_DEVICE
int32_t
o_strideM
()
const
CUTLASS_HOST_DEVICE
int32_t
o_strideM
()
const
...
@@ -199,8 +201,8 @@ struct LlamaAttentionKernel:
...
@@ -199,8 +201,8 @@ struct LlamaAttentionKernel:
// Advance to the current batch / head / query_start
// Advance to the current batch / head / query_start
query_ptr
+=
(
qq_start
+
query_start
)
*
q_strideM
+
head_id
*
q_strideH
;
query_ptr
+=
(
qq_start
+
query_start
)
*
q_strideM
+
head_id
*
q_strideH
;
key_ptr
+=
k_start
*
k_strideM
+
head_id
*
k_strideH
;
key_ptr
+=
k_start
*
k_strideM
+
int64_t
(
head_id
/
group_size
)
*
k_strideH
;
value_ptr
+=
k_start
*
v_strideM
+
head_id
*
v_strideH
;
value_ptr
+=
k_start
*
v_strideM
+
int64_t
(
head_id
/
group_size
)
*
v_strideH
;
output_ptr
+=
int64_t
(
qo_start
+
query_start
)
*
o_strideM
()
+
head_id
*
o_strideH
;
output_ptr
+=
int64_t
(
qo_start
+
query_start
)
*
o_strideM
()
+
head_id
*
o_strideH
;
if
(
output_accum_ptr
!=
nullptr
)
{
if
(
output_accum_ptr
!=
nullptr
)
{
...
@@ -668,6 +670,7 @@ void invokeFlashAttention_impl(int batch_size,
...
@@ -668,6 +670,7 @@ void invokeFlashAttention_impl(int batch_size,
auto
layout_k
=
attention_params
.
layout_k
;
auto
layout_k
=
attention_params
.
layout_k
;
auto
layout_v
=
attention_params
.
layout_v
;
auto
layout_v
=
attention_params
.
layout_v
;
auto
layout_o
=
attention_params
.
layout_o
;
auto
layout_o
=
attention_params
.
layout_o
;
auto
group_size
=
attention_params
.
group_size
;
using
scalar_t
=
using
scalar_t
=
typename
std
::
conditional_t
<
std
::
is_same
<
half
,
typename
std
::
decay
<
T
>::
type
>::
value
,
cutlass
::
half_t
,
T
>
;
typename
std
::
conditional_t
<
std
::
is_same
<
half
,
typename
std
::
decay
<
T
>::
type
>::
value
,
cutlass
::
half_t
,
T
>
;
...
@@ -731,6 +734,8 @@ void invokeFlashAttention_impl(int batch_size,
...
@@ -731,6 +734,8 @@ void invokeFlashAttention_impl(int batch_size,
params
.
num_batches
=
batch_size
;
params
.
num_batches
=
batch_size
;
params
.
num_heads
=
head_num
;
params
.
num_heads
=
head_num
;
params
.
group_size
=
int32_t
(
group_size
);
}
}
Attention
::
check_supported
(
params
);
Attention
::
check_supported
(
params
);
...
...
src/turbomind/models/llama/llama_kernels.h
View file @
5ed6bb59
...
@@ -99,6 +99,7 @@ public:
...
@@ -99,6 +99,7 @@ public:
float
*
out_accum
=
nullptr
;
float
*
out_accum
=
nullptr
;
int
*
cu_seqlens_q
=
nullptr
;
int
*
cu_seqlens_q
=
nullptr
;
int
*
cu_seqlens_k
=
nullptr
;
int
*
cu_seqlens_k
=
nullptr
;
size_t
group_size
=
1
;
AttentionLayout
layout_q
;
AttentionLayout
layout_q
;
AttentionLayout
layout_k
;
AttentionLayout
layout_k
;
AttentionLayout
layout_v
;
AttentionLayout
layout_v
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment