Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cf28e5a4
Commit
cf28e5a4
authored
Apr 15, 2025
by
zhuwenwen
Browse files
vdim pad 32
parent
9c3190d0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
18 deletions
+8
-18
vllm/attention/backends/mla/utils.py
vllm/attention/backends/mla/utils.py
+8
-18
No files found.
vllm/attention/backends/mla/utils.py
View file @
cf28e5a4
...
@@ -533,29 +533,16 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -533,29 +533,16 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# For MLA the v head dim is smaller than qk head dim so we pad out
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim
# v with 0s to match the qk head dim
v_padded
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]],
# v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
# value=0)
v_padded
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
(
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]
-
32
)],
value
=
0
)
value
=
0
)
v_tmp
=
v_padded
[...,
:
-
32
].
reshape
(
v
.
shape
[
0
],
v
.
shape
[
1
],
v
.
shape
[
2
])
# if torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count == 120:
# attn_output = flash_attn_varlen_func(
# q=q,
# k=k,
# v=v_padded,
# cu_seqlens_q=seq_start_loc,
# cu_seqlens_k=seq_start_loc,
# max_seqlen_q=max_prefill_seq_len,
# max_seqlen_k=max_prefill_seq_len,
# softmax_scale=self.scale,
# causal=True,
# )
# attn_output = attn_output\
# .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
# .reshape(-1, self.num_heads * v.shape[-1])
# else:
attn_output
=
flash_attn_varlen_func
(
attn_output
=
flash_attn_varlen_func
(
q
=
q
,
q
=
q
,
k
=
k
,
k
=
k
,
v
=
v
,
v
=
v_tmp
if
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
multi_processor_count
==
120
else
v
,
cu_seqlens_q
=
seq_start_loc
,
cu_seqlens_q
=
seq_start_loc
,
cu_seqlens_k
=
seq_start_loc
,
cu_seqlens_k
=
seq_start_loc
,
max_seqlen_q
=
max_prefill_seq_len
,
max_seqlen_q
=
max_prefill_seq_len
,
...
@@ -563,6 +550,9 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -563,6 +550,9 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
softmax_scale
=
self
.
scale
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
causal
=
True
,
)
)
# output = output\
# .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
# .reshape(-1, self.num_heads * v.shape[-1])
attn_output
=
attn_output
\
attn_output
=
attn_output
\
.
reshape
(
-
1
,
self
.
num_heads
*
v
.
shape
[
-
1
])
.
reshape
(
-
1
,
self
.
num_heads
*
v
.
shape
[
-
1
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment