Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
22a46529
Commit
22a46529
authored
Aug 18, 2025
by
zhuwenwen
Browse files
增加marlin对cache13的支持,以及新增flash mla的kvcache fp8的支持
parent
e70b0ea0
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
48 additions
and
15 deletions
+48
-15
vllm/attention/backends/flashmla.py
vllm/attention/backends/flashmla.py
+7
-2
vllm/attention/backends/mla/common.py
vllm/attention/backends/mla/common.py
+1
-1
vllm/attention/ops/flashmla.py
vllm/attention/ops/flashmla.py
+18
-0
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
+14
-9
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+1
-1
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+7
-2
No files found.
vllm/attention/backends/flashmla.py
View file @
22a46529
...
...
@@ -207,8 +207,9 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
"FlashMLAImpl"
)
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
raise
NotImplementedError
(
"FlashMLA with FP8 KV cache not yet supported"
)
if
self
.
kv_cache_dtype
!=
"fp8"
:
raise
NotImplementedError
(
"FlashMLA with other KV cache not yet supported"
)
def
_forward_decode
(
self
,
...
...
@@ -216,6 +217,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
FlashMLAMetadata
,
k_scale
=
None
,
kv_cache_dtype
=
"auto"
,
)
->
torch
.
Tensor
:
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
...
...
@@ -235,6 +238,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
num_splits
=
decode_meta
.
decode_num_splits
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
k_scale
=
k_scale
,
kv_cache_dtype
=
kv_cache_dtype
,
)
return
self
.
_v_up_proj
(
o
)
vllm/attention/backends/mla/common.py
View file @
22a46529
...
...
@@ -1396,6 +1396,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
decode_ql_nope
=
decode_ql_nope
.
transpose
(
0
,
1
)
output
[
num_prefill_tokens
:]
=
self
.
_forward_decode
(
decode_ql_nope
,
decode_q_pe
,
kv_cache
,
attn_metadata
)
decode_ql_nope
,
decode_q_pe
,
kv_cache
,
attn_metadata
,
layer
.
_k_scale
,
self
.
kv_cache_dtype
)
return
output
\ No newline at end of file
vllm/attention/ops/flashmla.py
View file @
22a46529
...
...
@@ -75,6 +75,8 @@ def flash_mla_with_kvcache(
num_splits
:
torch
.
Tensor
,
softmax_scale
:
Optional
[
float
]
=
None
,
causal
:
bool
=
False
,
k_scale
=
None
,
kv_cache_dtype
=
"auto"
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Arguments:
...
...
@@ -97,6 +99,22 @@ def flash_mla_with_kvcache(
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
if
current_platform
.
is_rocm
():
if
kv_cache_dtype
==
"fp8"
:
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_quantization_mla
(
q
,
k_cache
,
None
,
head_dim_v
,
cache_seqlens
,
block_table
,
softmax_scale
,
causal
,
tile_scheduler_metadata
,
num_splits
,
k_scale
,
"fp8_e4m3"
,
)
return
out
,
softmax_lse
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_mla
(
q
,
k_cache
,
...
...
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
View file @
22a46529
...
...
@@ -17,6 +17,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_make_workspace_new
,
maybe_warn_marlin_atomic_add
)
from
vllm.scalar_type
import
ScalarType
,
scalar_types
from
vllm.utils
import
direct_register_custom_op
from
vllm.model_executor.layers.fused_moe.fused_moe
import
get_moe_cache
def
get_scalar_type
(
num_bits
:
int
,
has_zp
:
bool
):
if
has_zp
:
...
...
@@ -104,7 +105,8 @@ def fused_marlin_moe(
topk
=
topk_ids
.
shape
[
1
]
# 8
#暂时固定为16384
CHUNK_SIZE
=
16384
#CHUNK_SIZE = 16384
CHUNK_SIZE
=
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
M
=
min
(
num_tokens
,
CHUNK_SIZE
)
...
...
@@ -122,18 +124,21 @@ def fused_marlin_moe(
if
global_num_experts
==
-
1
:
global_num_experts
=
E
intermediate_cache2
=
torch
.
empty
(
(
M
*
topk
_ids
.
shape
[
1
]
,
N
),
(
M
*
topk
,
N
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
intermediate_cache13
=
torch
.
empty
(
(
M
*
topk_ids
.
shape
[
1
]
*
max
(
2
*
N
,
K
),
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
intermediate_cache1
=
intermediate_cache13
[:
M
*
topk_ids
.
shape
[
1
]
*
2
*
N
]
if
envs
.
VLLM_USE_GLOBAL_CACHE13
:
intermediate_cache13
=
get_moe_cache
(
topk
,
N
,
K
,
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
else
:
intermediate_cache13
=
torch
.
empty
(
(
M
*
topk
*
max
(
2
*
N
,
K
),
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
intermediate_cache1
=
intermediate_cache13
[:
M
*
topk
*
2
*
N
]
intermediate_cache1
=
intermediate_cache1
.
view
(
-
1
,
2
*
N
)
intermediate_cache3
=
intermediate_cache13
[:
M
*
topk
_ids
.
shape
[
1
]
*
K
]
intermediate_cache3
=
intermediate_cache13
[:
M
*
topk
*
K
]
intermediate_cache3
=
intermediate_cache3
.
view
(
-
1
,
K
)
use_atomic_add
=
hidden_states
.
dtype
==
torch
.
half
or
\
...
...
vllm/v1/attention/backends/mla/common.py
View file @
22a46529
...
...
@@ -1325,6 +1325,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
decode_ql_nope
=
decode_ql_nope
.
transpose
(
0
,
1
)
output
[:
num_decode_tokens
]
=
self
.
_forward_decode
(
decode_ql_nope
,
decode_q_pe
,
kv_cache
,
attn_metadata
)
decode_ql_nope
,
decode_q_pe
,
kv_cache
,
attn_metadata
,
layer
.
_k_scale
,
self
.
kv_cache_dtype
)
return
output_padded
\ No newline at end of file
vllm/v1/attention/backends/mla/flashmla.py
View file @
22a46529
...
...
@@ -145,8 +145,9 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
"FlashMLAImpl"
)
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
raise
NotImplementedError
(
"FlashMLA V1 with FP8 KV cache not yet supported"
)
if
self
.
kv_cache_dtype
!=
"fp8"
:
raise
NotImplementedError
(
"FlashMLA with other KV cache not yet supported"
)
def
_forward_decode
(
self
,
...
...
@@ -154,6 +155,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
FlashMLAMetadata
,
k_scale
=
None
,
kv_cache_dtype
=
"auto"
,
)
->
torch
.
Tensor
:
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
assert
attn_metadata
.
decode
is
not
None
...
...
@@ -172,6 +175,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
num_splits
=
attn_metadata
.
decode
.
num_splits
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
k_scale
=
k_scale
,
kv_cache_dtype
=
kv_cache_dtype
,
)
return
self
.
_v_up_proj
(
o
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment