Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ef8f16f4
Commit
ef8f16f4
authored
May 07, 2025
by
zhuwenwen
Browse files
update v1 mla
parent
660af62e
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
4 deletions
+14
-4
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+14
-4
No files found.
vllm/v1/attention/backends/mla/common.py
View file @
ef8f16f4
...
@@ -190,6 +190,7 @@ from dataclasses import dataclass
...
@@ -190,6 +190,7 @@ from dataclasses import dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Generic
,
Optional
,
TypeVar
from
typing
import
TYPE_CHECKING
,
Any
,
Generic
,
Optional
,
TypeVar
import
torch
import
torch
import
os
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionLayer
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionLayer
,
...
@@ -643,13 +644,16 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -643,13 +644,16 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
functools
.
partial
(
flash_attn_varlen_func
,
functools
.
partial
(
flash_attn_varlen_func
,
fa_version
=
self
.
vllm_flash_attn_version
)
fa_version
=
self
.
vllm_flash_attn_version
)
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
# For MLA the v head dim is smaller than qk head dim so we pad out
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim for attention backends that do
# v with 0s to match the qk head dim for attention backends that do
# not support different headdims
# not support different headdims
# We don't need to pad V if we are on a hopper system with FA3
# We don't need to pad V if we are on a hopper system with FA3
self
.
_pad_v
=
self
.
vllm_flash_attn_version
is
None
or
not
(
self
.
_pad_v
=
self
.
vllm_flash_attn_version
is
None
or
not
(
self
.
vllm_flash_attn_version
==
3
self
.
vllm_flash_attn_version
==
3
and
current_platform
.
get_device_capability
()[
0
]
==
9
)
and
current_platform
.
get_device_capability
()[
0
]
==
9
and
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
multi_processor_count
==
120
)
def
_flash_attn_varlen_diff_headdims
(
self
,
def
_flash_attn_varlen_diff_headdims
(
self
,
q
,
q
,
...
@@ -660,8 +664,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -660,8 +664,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
**
kwargs
):
**
kwargs
):
maybe_padded_v
=
v
maybe_padded_v
=
v
if
self
.
_pad_v
:
if
self
.
_pad_v
:
# maybe_padded_v = torch.nn.functional.pad(
# v, [0, q.shape[-1] - v.shape[-1]], value=0)
maybe_padded_v
=
torch
.
nn
.
functional
.
pad
(
maybe_padded_v
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]],
value
=
0
)
v
,
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]
-
32
],
value
=
0
)
maybe_padded_v
=
maybe_padded_v
[...,
:
-
32
].
reshape
(
v
.
shape
[
0
],
v
.
shape
[
1
],
v
.
shape
[
2
])
attn_out
=
self
.
flash_attn_varlen_func
(
attn_out
=
self
.
flash_attn_varlen_func
(
q
=
q
,
q
=
q
,
...
@@ -737,6 +744,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -737,6 +744,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
# we currently do not have quantized bmm's which are needed for
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
if
self
.
use_llama_nn
and
isinstance
(
self
.
kv_b_proj
.
quant_method
,
UnquantizedLinearMethod
):
kv_b_proj_weight
=
get_and_maybe_dequant_weights
(
self
.
kv_b_proj
)
else
:
kv_b_proj_weight
=
get_and_maybe_dequant_weights
(
self
.
kv_b_proj
).
T
kv_b_proj_weight
=
get_and_maybe_dequant_weights
(
self
.
kv_b_proj
).
T
assert
kv_b_proj_weight
.
shape
==
(
assert
kv_b_proj_weight
.
shape
==
(
self
.
kv_lora_rank
,
self
.
kv_lora_rank
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment