Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
dd7c9ad8
Unverified
Commit
dd7c9ad8
authored
Jan 16, 2025
by
Isotr0py
Committed by
GitHub
Jan 16, 2025
Browse files
[Bugfix] Remove hardcoded `head_size=256` for Deepseek v2 and v3 (#12067)
Signed-off-by:
Isotr0py
<
2037008807@qq.com
>
parent
9aa1519f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
23 additions
and
40 deletions
+23
-40
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+3
-3
vllm/config.py
vllm/config.py
+6
-3
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+7
-17
vllm/model_executor/models/deepseek_v3.py
vllm/model_executor/models/deepseek_v3.py
+7
-17
No files found.
tests/kernels/test_attention.py
View file @
dd7c9ad8
...
...
@@ -31,9 +31,9 @@ NUM_GEN_SEQS = [7] # Arbitrary values for testing
NUM_PREFILL_SEQS
=
[
3
]
# Arbitrary values for testing
NUM_HEADS
=
[(
40
,
40
),
(
64
,
8
)]
# Arbitrary values for testing
#
FlashAttention forward only
support
s
head
dimension at most 128
#
https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
HEAD_SIZES
=
[
64
,
80
,
120
,
256
]
#
This should be sync with get_
support
ed_
head
_sizes() in
#
vllm.attention.ops.paged_attn.PagedAttention
HEAD_SIZES
=
[
32
,
64
,
80
,
96
,
112
,
120
,
128
,
192
,
256
]
BLOCK_SIZES
=
[
16
,
32
]
USE_ALIBI
=
[
False
,
True
]
...
...
vllm/config.py
View file @
dd7c9ad8
...
...
@@ -733,9 +733,12 @@ class ModelConfig:
if
hasattr
(
self
.
hf_text_config
,
"model_type"
)
and
(
self
.
hf_text_config
.
model_type
in
(
'deepseek_v2'
,
'deepseek_v3'
)):
# FlashAttention supports only head_size 32, 64, 128, 256,
# we need to pad head_size 192 to 256
return
256
qk_rope_head_dim
=
getattr
(
self
.
hf_text_config
,
"qk_rope_head_dim"
,
0
)
qk_nope_head_dim
=
getattr
(
self
.
hf_text_config
,
"qk_nope_head_dim"
,
0
)
if
qk_rope_head_dim
and
qk_nope_head_dim
:
return
qk_rope_head_dim
+
qk_nope_head_dim
if
self
.
is_attention_free
:
return
0
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
dd7c9ad8
...
...
@@ -262,14 +262,8 @@ class DeepseekV2Attention(nn.Module):
mscale
=
yarn_get_mscale
(
scaling_factor
,
float
(
mscale_all_dim
))
self
.
scaling
=
self
.
scaling
*
mscale
*
mscale
# self.attn = Attention(self.num_heads,
# self.qk_head_dim,
# self.scaling,
# num_kv_heads=self.num_heads)
# TODO, support head_size 192
self
.
attn
=
Attention
(
self
.
num_local_heads
,
256
,
self
.
qk_head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_local_heads
,
cache_config
=
cache_config
,
...
...
@@ -319,18 +313,14 @@ class DeepseekV2Attention(nn.Module):
k
=
torch
.
empty_like
(
q
)
k
[...,
:
self
.
qk_nope_head_dim
]
=
k_nope
k
[...,
self
.
qk_nope_head_dim
:]
=
k_pe
q
=
torch
.
nn
.
functional
.
pad
(
q
,
[
0
,
256
-
self
.
qk_head_dim
],
value
=
0
).
view
(
-
1
,
self
.
num_local_heads
*
256
)
k
=
torch
.
nn
.
functional
.
pad
(
k
,
[
0
,
256
-
self
.
qk_head_dim
],
value
=
0
).
view
(
-
1
,
self
.
num_local_heads
*
256
)
v
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
256
-
self
.
v_head_dim
],
value
=
0
).
view
(
-
1
,
self
.
num_local_heads
*
256
)
# padding value to qk_head_dim for alignment
v
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
self
.
qk_head_dim
-
self
.
v_head_dim
],
value
=
0
).
view
(
-
1
,
self
.
num_local_heads
*
self
.
qk_head_dim
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
attn_output
.
view
(
-
1
,
self
.
num_local_heads
,
256
)[...,
:
self
.
v_head_dim
].
reshape
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)[...,
:
self
.
v_head_dim
].
reshape
(
-
1
,
self
.
num_local_heads
*
self
.
v_head_dim
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
vllm/model_executor/models/deepseek_v3.py
View file @
dd7c9ad8
...
...
@@ -269,14 +269,8 @@ class DeepseekV3Attention(nn.Module):
mscale
=
yarn_get_mscale
(
scaling_factor
,
float
(
mscale_all_dim
))
self
.
scaling
=
self
.
scaling
*
mscale
*
mscale
# self.attn = Attention(self.num_heads,
# self.qk_head_dim,
# self.scaling,
# num_kv_heads=self.num_heads)
# TODO, support head_size 192
self
.
attn
=
Attention
(
self
.
num_local_heads
,
256
,
self
.
qk_head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_local_heads
,
cache_config
=
cache_config
,
...
...
@@ -326,18 +320,14 @@ class DeepseekV3Attention(nn.Module):
k
=
torch
.
empty_like
(
q
)
k
[...,
:
self
.
qk_nope_head_dim
]
=
k_nope
k
[...,
self
.
qk_nope_head_dim
:]
=
k_pe
q
=
torch
.
nn
.
functional
.
pad
(
q
,
[
0
,
256
-
self
.
qk_head_dim
],
value
=
0
).
view
(
-
1
,
self
.
num_local_heads
*
256
)
k
=
torch
.
nn
.
functional
.
pad
(
k
,
[
0
,
256
-
self
.
qk_head_dim
],
value
=
0
).
view
(
-
1
,
self
.
num_local_heads
*
256
)
v
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
256
-
self
.
v_head_dim
],
value
=
0
).
view
(
-
1
,
self
.
num_local_heads
*
256
)
# padding value to qk_head_dim for alignment
v
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
self
.
qk_head_dim
-
self
.
v_head_dim
],
value
=
0
).
view
(
-
1
,
self
.
num_local_heads
*
self
.
qk_head_dim
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
attn_output
.
view
(
-
1
,
self
.
num_local_heads
,
256
)[...,
:
self
.
v_head_dim
].
reshape
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)[...,
:
self
.
v_head_dim
].
reshape
(
-
1
,
self
.
num_local_heads
*
self
.
v_head_dim
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment