Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0daa00fb
Commit
0daa00fb
authored
Mar 09, 2026
by
yangql
Browse files
适配在bmz上的mla的kvcache_e5m2和e4m3量化的支持
parent
cb1a27d2
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
42 additions
and
19 deletions
+42
-19
vllm/model_executor/layers/attention/mla_attention.py
vllm/model_executor/layers/attention/mla_attention.py
+3
-2
vllm/model_executor/layers/sparse_attn_indexer.py
vllm/model_executor/layers/sparse_attn_indexer.py
+1
-1
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+4
-0
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+1
-0
vllm/v1/attention/ops/flashmla.py
vllm/v1/attention/ops/flashmla.py
+33
-16
No files found.
vllm/model_executor/layers/attention/mla_attention.py
View file @
0daa00fb
...
...
@@ -215,6 +215,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
get_and_maybe_dequant_weights
,
)
from
vllm.platforms
import
current_platform
from
vllm.platforms.rocm
import
get_gcn_arch_name
from
vllm.utils.flashinfer
import
has_nvidia_artifactory
from
vllm.utils.math_utils
import
cdiv
,
round_down
from
vllm.v1.attention.backend
import
(
...
...
@@ -2115,7 +2116,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
scale
=
layer
.
_k_scale
,
)
if
fp8_attention
:
if
fp8_attention
and
get_gcn_arch_name
()
==
"gfx938"
:
kv_cache
=
kv_cache
.
view
(
current_platform
.
fp8_dtype
())
if
has_prefill
:
...
...
@@ -2185,7 +2186,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
# Convert from (N, B, L) to (B, N, L)
decode_ql_nope
=
decode_ql_nope
.
transpose
(
0
,
1
)
if
fp8_attention
:
if
fp8_attention
and
get_gcn_arch_name
()
==
"gfx938"
:
assert
decode_ql_nope
.
shape
[
0
]
==
decode_q_pe
.
shape
[
0
]
assert
decode_ql_nope
.
shape
[
1
]
==
decode_q_pe
.
shape
[
1
]
decode_q
=
self
.
_decode_concat_quant_fp8_op
(
...
...
vllm/model_executor/layers/sparse_attn_indexer.py
View file @
0daa00fb
...
...
@@ -49,7 +49,7 @@ def sparse_attn_indexer(
if
not
isinstance
(
attn_metadata
,
dict
):
# Reserve workspace for indexer during profiling run
current_workspace_manager
().
get_simultaneous
(
((
total_seq_lens
,
head_dim
),
fp8_dtype
if
not
current_platform
.
is_rocm
()
or
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
else
torch
.
bfloat16
),
((
total_seq_lens
,
head_dim
),
fp8_dtype
if
not
current_platform
.
is_rocm
()
or
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
else
k
.
dtype
,
),
((
total_seq_lens
,
4
),
torch
.
uint8
),
)
return
sparse_attn_indexer_fake
(
...
...
vllm/platforms/rocm.py
View file @
0daa00fb
...
...
@@ -121,6 +121,10 @@ def on_gfx9() -> bool:
GPU_ARCH
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
return
any
(
arch
in
GPU_ARCH
for
arch
in
[
"gfx90a"
,
"gfx942"
,
"gfx950"
,
"gfx928"
,
"gfx936"
,
"gfx938"
])
@
cache
def
get_gcn_arch_name
()
->
str
:
GPU_ARCH
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
return
GPU_ARCH
.
split
(
':'
)[
0
]
@
cache
def
on_gfx942
()
->
bool
:
...
...
vllm/v1/attention/backends/mla/flashmla.py
View file @
0daa00fb
...
...
@@ -310,6 +310,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
causal
=
True
,
descale_q
=
layer
.
_q_scale
.
reshape
(
1
),
descale_k
=
layer
.
_k_scale
.
reshape
(
1
),
kv_cache_dtype
=
self
.
kv_cache_dtype
,
)
else
:
o
,
lse
=
flash_mla_with_kvcache
(
...
...
vllm/v1/attention/ops/flashmla.py
View file @
0daa00fb
...
...
@@ -6,7 +6,7 @@ import torch
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms.rocm
import
get_gcn_arch_name
logger
=
init_logger
(
__name__
)
if
current_platform
.
is_cuda
():
...
...
@@ -136,7 +136,7 @@ def get_mla_metadata_dense_fp8(
cache_seqlens
,
num_q_tokens_per_head_k
,
num_heads_k
,
16
,
#
16,
)
else
:
return
torch
.
ops
.
_flashmla_extension_C
.
get_mla_decoding_metadata_dense_fp8
(
...
...
@@ -158,12 +158,14 @@ def flash_mla_with_kvcache_fp8(
causal
:
bool
=
False
,
descale_q
:
torch
.
Tensor
|
None
=
None
,
descale_k
:
torch
.
Tensor
|
None
=
None
,
kv_cache_dtype
:
str
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
not
_is_flashmla_available
()[
0
]:
_raise_flashmla_unavailable
()
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
if
current_platform
.
is_rocm
():
if
get_gcn_arch_name
()
==
"gfx938"
:
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_mla_fp8
(
q
,
k_cache
,
...
...
@@ -178,6 +180,21 @@ def flash_mla_with_kvcache_fp8(
descale_q
,
descale_k
,
)
else
:
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_quantization_mla
(
q
,
k_cache
,
None
,
head_dim_v
,
cache_seqlens
,
block_table
,
softmax_scale
,
causal
,
tile_scheduler_metadata
,
num_splits
,
descale_k
,
kv_cache_dtype
,
)
else
:
out
,
softmax_lse
=
torch
.
ops
.
_flashmla_extension_C
.
fwd_kvcache_mla_fp8
(
q
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment