Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2e1f5a46
Commit
2e1f5a46
authored
Dec 19, 2025
by
王敏
Browse files
Merge remote-tracking branch 'origin/v0.9.2-dev' into v0.9.2-dev
parents
8ba8a855
1e622f10
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
12 additions
and
4 deletions
+12
-4
vllm/attention/backends/flashmla.py
vllm/attention/backends/flashmla.py
+3
-1
vllm/attention/layer.py
vllm/attention/layer.py
+2
-1
vllm/attention/ops/flashmla.py
vllm/attention/ops/flashmla.py
+4
-1
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+3
-1
No files found.
vllm/attention/backends/flashmla.py
View file @
2e1f5a46
...
@@ -266,7 +266,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
...
@@ -266,7 +266,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
kv_cache_dtype
==
"fp8_e4m3"
and
envs
.
VLLM_USE_FLASH_MLA_FP8
:
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
kv_cache_dtype
==
"fp8_e4m3"
and
envs
.
VLLM_USE_FLASH_MLA_FP8
:
o
,
_
=
flash_mla_with_kvcache_fp8
(
o
,
_
=
flash_mla_with_kvcache_fp8
(
q
=
q
.
to
(
torch
.
float8_e4m3fn
),
q
=
q
.
to
(
torch
.
float8_e4m3fn
),
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
).
to
(
torch
.
float8_e4m3fn
),
# Add head dim of 1
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
).
view
(
torch
.
float8_e4m3fn
),
# Add head dim of 1
block_table
=
decode_meta
.
block_tables
,
block_table
=
decode_meta
.
block_tables
,
cache_seqlens
=
decode_meta
.
seq_lens_tensor
,
cache_seqlens
=
decode_meta
.
seq_lens_tensor
,
head_dim_v
=
self
.
kv_lora_rank
,
head_dim_v
=
self
.
kv_lora_rank
,
...
@@ -288,6 +288,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
...
@@ -288,6 +288,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
num_splits
=
decode_meta
.
decode_num_splits
,
num_splits
=
decode_meta
.
decode_num_splits
,
softmax_scale
=
self
.
scale
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
causal
=
True
,
is_fp8_kvcache
=
False
,
indices
=
None
,
k_scale
=
k_scale
,
k_scale
=
k_scale
,
kv_cache_dtype
=
kv_cache_dtype
,
kv_cache_dtype
=
kv_cache_dtype
,
)
)
...
...
vllm/attention/layer.py
View file @
2e1f5a46
...
@@ -101,12 +101,13 @@ class Attention(nn.Module):
...
@@ -101,12 +101,13 @@ class Attention(nn.Module):
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
kv_cache_dtype
==
"fp8_e4m3"
and
envs
.
VLLM_USE_FLASH_MLA_FP8
:
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
kv_cache_dtype
==
"fp8_e4m3"
and
envs
.
VLLM_USE_FLASH_MLA_FP8
:
self
.
_k_scale
=
torch
.
ones
((
1
),
dtype
=
torch
.
float32
)
self
.
_k_scale
=
torch
.
ones
((
1
),
dtype
=
torch
.
float32
)
self
.
_v_scale
=
torch
.
ones
((
1
),
dtype
=
torch
.
float32
)
self
.
_v_scale
=
torch
.
ones
((
1
),
dtype
=
torch
.
float32
)
self
.
_q_scale
=
torch
.
ones
((
1
),
dtype
=
torch
.
float32
)
else
:
else
:
self
.
_k_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
self
.
_k_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
self
.
_v_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
self
.
_v_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
self
.
_q_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
# FlashAttn doesn't support quantizing the kv-cache only
# FlashAttn doesn't support quantizing the kv-cache only
# but requires q to be quantized as well.
# but requires q to be quantized as well.
self
.
_q_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
self
.
_prob_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
self
.
_prob_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
# We also keep the float32 versions of k/v_scale for attention
# We also keep the float32 versions of k/v_scale for attention
...
...
vllm/attention/ops/flashmla.py
View file @
2e1f5a46
...
@@ -101,6 +101,8 @@ def flash_mla_with_kvcache(
...
@@ -101,6 +101,8 @@ def flash_mla_with_kvcache(
num_splits
:
torch
.
Tensor
,
num_splits
:
torch
.
Tensor
,
softmax_scale
:
Optional
[
float
]
=
None
,
softmax_scale
:
Optional
[
float
]
=
None
,
causal
:
bool
=
False
,
causal
:
bool
=
False
,
is_fp8_kvcache
:
bool
=
False
,
indices
:
Optional
[
torch
.
Tensor
]
=
None
,
k_scale
=
None
,
k_scale
=
None
,
kv_cache_dtype
=
"auto"
,
kv_cache_dtype
=
"auto"
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
...
@@ -145,7 +147,6 @@ def flash_mla_with_kvcache(
...
@@ -145,7 +147,6 @@ def flash_mla_with_kvcache(
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_mla
(
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_mla
(
q
,
q
,
k_cache
,
k_cache
,
None
,
head_dim_v
,
head_dim_v
,
cache_seqlens
,
cache_seqlens
,
block_table
,
block_table
,
...
@@ -153,6 +154,8 @@ def flash_mla_with_kvcache(
...
@@ -153,6 +154,8 @@ def flash_mla_with_kvcache(
causal
,
causal
,
tile_scheduler_metadata
,
tile_scheduler_metadata
,
num_splits
,
num_splits
,
is_fp8_kvcache
,
indices
,
)
)
else
:
else
:
out
,
softmax_lse
=
torch
.
ops
.
_flashmla_C
.
fwd_kvcache_mla
(
out
,
softmax_lse
=
torch
.
ops
.
_flashmla_C
.
fwd_kvcache_mla
(
...
...
vllm/v1/attention/backends/mla/flashmla.py
View file @
2e1f5a46
...
@@ -194,7 +194,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
...
@@ -194,7 +194,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
o
,
_
=
flash_mla_with_kvcache_fp8
(
o
,
_
=
flash_mla_with_kvcache_fp8
(
q
=
q
.
to
(
torch
.
float8_e4m3fn
),
q
=
q
.
to
(
torch
.
float8_e4m3fn
),
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
).
to
(
torch
.
float8_e4m3fn
),
# Add head dim of 1
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
).
view
(
torch
.
float8_e4m3fn
),
# Add head dim of 1
block_table
=
attn_metadata
.
decode
.
block_table
,
block_table
=
attn_metadata
.
decode
.
block_table
,
cache_seqlens
=
attn_metadata
.
decode
.
seq_lens
,
cache_seqlens
=
attn_metadata
.
decode
.
seq_lens
,
head_dim_v
=
self
.
kv_lora_rank
,
head_dim_v
=
self
.
kv_lora_rank
,
...
@@ -232,6 +232,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
...
@@ -232,6 +232,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
num_splits
=
attn_metadata
.
decode
.
num_splits
,
num_splits
=
attn_metadata
.
decode
.
num_splits
,
softmax_scale
=
self
.
scale
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
causal
=
True
,
is_fp8_kvcache
=
False
,
indices
=
None
,
k_scale
=
k_scale
,
k_scale
=
k_scale
,
kv_cache_dtype
=
kv_cache_dtype
,
kv_cache_dtype
=
kv_cache_dtype
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment