Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a0895c00
Commit
a0895c00
authored
Sep 23, 2025
by
zhuwenwen
Browse files
Merge remote-tracking branch 'origin/v0.9.2-dev-yql-kvfp8' into v0.9.2-dev
parents
4020670f
8b7daa0d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
27 additions
and
8 deletions
+27
-8
vllm/_custom_ops.py
vllm/_custom_ops.py
+13
-3
vllm/attention/backends/mla/common.py
vllm/attention/backends/mla/common.py
+6
-2
vllm/attention/layer.py
vllm/attention/layer.py
+2
-1
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+6
-2
No files found.
vllm/_custom_ops.py
View file @
a0895c00
...
@@ -2162,9 +2162,19 @@ def gather_cache(src_cache: torch.Tensor,
...
@@ -2162,9 +2162,19 @@ def gather_cache(src_cache: torch.Tensor,
block_table
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
cu_seq_lens
:
torch
.
Tensor
,
cu_seq_lens
:
torch
.
Tensor
,
batch_size
:
int
,
batch_size
:
int
,
seq_starts
:
Optional
[
torch
.
Tensor
]
=
None
)
->
None
:
seq_starts
:
Optional
[
torch
.
Tensor
]
=
None
,
torch
.
ops
.
_C_cache_ops
.
gather_cache
(
src_cache
,
dst
,
block_table
,
kv_dtype
=
"auto"
,
cu_seq_lens
,
batch_size
,
seq_starts
)
scale
:
float
=
1.0
,
)
->
None
:
#支持"kv cache fp8"
if
kv_dtype
==
"fp8"
:
dst_fp8
=
torch
.
zeros
(
dst
.
shape
,
dtype
=
torch
.
uint8
,
device
=
dst
.
device
)
convert_fp8
(
dst_fp8
,
dst
,
scale
,
kv_dtype
)
torch
.
ops
.
_C_cache_ops
.
gather_cache
(
src_cache
,
dst_fp8
,
block_table
,
cu_seq_lens
,
batch_size
,
seq_starts
)
else
:
torch
.
ops
.
_C_cache_ops
.
gather_cache
(
src_cache
,
dst
,
block_table
,
cu_seq_lens
,
batch_size
,
seq_starts
)
def
get_device_attribute
(
attribute
:
int
,
device
:
int
)
->
int
:
def
get_device_attribute
(
attribute
:
int
,
device
:
int
)
->
int
:
...
...
vllm/attention/backends/mla/common.py
View file @
a0895c00
...
@@ -1179,6 +1179,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1179,6 +1179,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
q
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
MLACommonMetadata
,
attn_metadata
:
MLACommonMetadata
,
kv_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
),
):
):
prefill_metadata
=
attn_metadata
.
prefill_metadata
prefill_metadata
=
attn_metadata
.
prefill_metadata
assert
prefill_metadata
is
not
None
assert
prefill_metadata
is
not
None
...
@@ -1207,6 +1208,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1207,6 +1208,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
cu_seq_lens
=
prefill_metadata
.
context_chunk_cu_seq_lens
[
i
],
cu_seq_lens
=
prefill_metadata
.
context_chunk_cu_seq_lens
[
i
],
batch_size
=
prefill_metadata
.
num_prefills
,
batch_size
=
prefill_metadata
.
num_prefills
,
seq_starts
=
prefill_metadata
.
context_chunk_starts
[
i
],
seq_starts
=
prefill_metadata
.
context_chunk_starts
[
i
],
kv_dtype
=
self
.
kv_cache_dtype
,
scale
=
kv_scale
,
)
)
kv_c_normed
=
workspace
[:
toks
]
\
kv_c_normed
=
workspace
[:
toks
]
\
...
@@ -1262,6 +1265,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1262,6 +1265,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
k_pe
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
MLACommonMetadata
,
attn_metadata
:
MLACommonMetadata
,
kv_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
),
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
prefill_metadata
=
attn_metadata
.
prefill_metadata
prefill_metadata
=
attn_metadata
.
prefill_metadata
...
@@ -1297,7 +1301,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1297,7 +1301,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# ROCm flash_attn_varlen_func will return 3 objects instead of 2
# ROCm flash_attn_varlen_func will return 3 objects instead of 2
suffix_output
,
suffix_lse
=
output
suffix_output
,
suffix_lse
=
output
context_output
,
context_lse
=
self
.
_compute_prefill_context
(
\
context_output
,
context_lse
=
self
.
_compute_prefill_context
(
\
q
,
kv_c_and_k_pe_cache
,
attn_metadata
)
q
,
kv_c_and_k_pe_cache
,
attn_metadata
,
kv_scale
)
output
=
torch
.
empty_like
(
suffix_output
)
output
=
torch
.
empty_like
(
suffix_output
)
merge_attn_states
(
merge_attn_states
(
...
@@ -1387,7 +1391,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1387,7 +1391,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
if
has_prefill
:
if
has_prefill
:
output
[:
num_prefill_tokens
]
=
self
.
_forward_prefill
(
output
[:
num_prefill_tokens
]
=
self
.
_forward_prefill
(
prefill_q
,
prefill_k_c_normed
,
prefill_k_pe
,
kv_cache
,
prefill_q
,
prefill_k_c_normed
,
prefill_k_pe
,
kv_cache
,
attn_metadata
)
attn_metadata
,
kv_scale
=
layer
.
_k_scale
)
if
has_decode
:
if
has_decode
:
decode_q_nope
,
decode_q_pe
=
decode_q
.
split
(
decode_q_nope
,
decode_q_pe
=
decode_q
.
split
(
...
...
vllm/attention/layer.py
View file @
a0895c00
...
@@ -205,7 +205,8 @@ class Attention(nn.Module):
...
@@ -205,7 +205,8 @@ class Attention(nn.Module):
"""
"""
if
self
.
calculate_kv_scales
:
if
self
.
calculate_kv_scales
:
attn_metadata
=
get_forward_context
().
attn_metadata
attn_metadata
=
get_forward_context
().
attn_metadata
if
attn_metadata
.
enable_kv_scales_calculation
:
if
(
attn_metadata
is
not
None
and
getattr
(
attn_metadata
,
"enable_kv_scales_calculation"
,
False
)):
# if key is not None and value is not None:
self
.
calc_kv_scales
(
query
,
key
,
value
)
self
.
calc_kv_scales
(
query
,
key
,
value
)
if
self
.
use_output
:
if
self
.
use_output
:
output_shape
=
(
output_shape
output_shape
=
(
output_shape
...
...
vllm/v1/attention/backends/mla/common.py
View file @
a0895c00
...
@@ -894,6 +894,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -894,6 +894,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
q
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
MLACommonMetadata
,
attn_metadata
:
MLACommonMetadata
,
kv_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
),
):
):
assert
attn_metadata
.
prefill
is
not
None
assert
attn_metadata
.
prefill
is
not
None
prefill_metadata
=
attn_metadata
.
prefill
prefill_metadata
=
attn_metadata
.
prefill
...
@@ -913,6 +914,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -913,6 +914,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
cu_seq_lens
=
prefill_metadata
.
chunked_context
.
cu_seq_lens
[
i
],
cu_seq_lens
=
prefill_metadata
.
chunked_context
.
cu_seq_lens
[
i
],
batch_size
=
attn_metadata
.
num_prefills
,
batch_size
=
attn_metadata
.
num_prefills
,
seq_starts
=
prefill_metadata
.
chunked_context
.
starts
[
i
],
seq_starts
=
prefill_metadata
.
chunked_context
.
starts
[
i
],
kv_dtype
=
self
.
kv_cache_dtype
,
scale
=
kv_scale
,
)
)
kv_c_normed
=
workspace
[:
toks
]
\
kv_c_normed
=
workspace
[:
toks
]
\
...
@@ -976,6 +979,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -976,6 +979,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k_pe
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
MLACommonMetadata
,
attn_metadata
:
MLACommonMetadata
,
kv_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
),
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
attn_metadata
.
prefill
is
not
None
assert
attn_metadata
.
prefill
is
not
None
...
@@ -1015,7 +1019,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -1015,7 +1019,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if
has_context
:
if
has_context
:
suffix_output
,
suffix_lse
=
output
suffix_output
,
suffix_lse
=
output
context_output
,
context_lse
=
self
.
_compute_prefill_context
(
\
context_output
,
context_lse
=
self
.
_compute_prefill_context
(
\
q
,
kv_c_and_k_pe_cache
,
attn_metadata
)
q
,
kv_c_and_k_pe_cache
,
attn_metadata
,
kv_scale
)
output
=
torch
.
empty_like
(
suffix_output
)
output
=
torch
.
empty_like
(
suffix_output
)
merge_attn_states
(
merge_attn_states
(
...
@@ -1104,7 +1108,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -1104,7 +1108,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if
has_prefill
:
if
has_prefill
:
output
[
num_decode_tokens
:]
=
self
.
_forward_prefill
(
output
[
num_decode_tokens
:]
=
self
.
_forward_prefill
(
prefill_q
,
prefill_k_c_normed
,
prefill_k_pe
,
kv_cache
,
prefill_q
,
prefill_k_c_normed
,
prefill_k_pe
,
kv_cache
,
attn_metadata
)
attn_metadata
,
kv_scale
=
layer
.
_k_scale
)
if
has_decode
:
if
has_decode
:
assert
attn_metadata
.
decode
is
not
None
assert
attn_metadata
.
decode
is
not
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment