Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3d01cce7
Commit
3d01cce7
authored
Jan 21, 2026
by
xiabo
Browse files
1、kvcache支持fp8的scale
parent
6dcb89d2
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
42 additions
and
7 deletions
+42
-7
vllm/attention/layer.py
vllm/attention/layer.py
+37
-3
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+5
-4
No files found.
vllm/attention/layer.py
View file @
3d01cce7
...
...
@@ -282,9 +282,11 @@ class Attention(nn.Module, AttentionLayerBase):
`vllm.forward_context.get_forward_context().attn_metadata`.
"""
if
self
.
calculate_kv_scales
:
attn_metadata
=
get_forward_context
().
attn_metadata
if
attn_metadata
.
enable_kv_scales_calculation
:
self
.
calc_kv_scales
(
query
,
key
,
value
)
# attn_metadata = get_forward_context().attn_metadata
# if attn_metadata.enable_kv_scales_calculation:
# self.calc_kv_scales(query, key, value)
torch
.
ops
.
vllm
.
maybe_calc_kv_scales
(
query
,
key
,
value
,
self
.
layer_name
)
output_dtype
=
query
.
dtype
if
self
.
query_quant
is
not
None
:
...
...
@@ -570,6 +572,38 @@ def maybe_save_kv_layer_to_connector(
connector
.
save_kv_layer
(
layer_name
,
kv_cache_layer
,
attn_metadata
[
layer_name
])
def
maybe_calc_kv_scales
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
layer_name
:
str
,
)
->
None
:
forward_context
:
ForwardContext
=
get_forward_context
()
self
=
forward_context
.
no_compile_layers
[
layer_name
]
# Only calculate if the layer's calculate_kv_scales flag is True
# This flag gets set to False after the first forward pass
if
not
self
.
calculate_kv_scales
:
return
self
.
calc_kv_scales
(
query
,
key
,
value
)
def
maybe_calc_kv_scales_fake
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
layer_name
:
str
,
)
->
None
:
return
direct_register_custom_op
(
op_name
=
"maybe_calc_kv_scales"
,
op_func
=
maybe_calc_kv_scales
,
mutates_args
=
[
"query"
,
"key"
,
"value"
],
fake_impl
=
maybe_calc_kv_scales_fake
,
)
def
unified_attention
(
query
:
torch
.
Tensor
,
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
3d01cce7
...
...
@@ -593,9 +593,9 @@ class FlashAttentionImpl(AttentionImpl):
block_table
=
attn_metadata
.
block_table
scheduler_metadata
=
attn_metadata
.
scheduler_metadata
descale_shape
=
(
cu_seqlens_q
.
shape
[
0
]
-
1
,
self
.
num_kv_heads
)
# descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
if
not
current_platform
.
is_rocm
():
descale_shape
=
(
cu_seqlens_q
.
shape
[
0
]
-
1
,
self
.
num_kv_heads
)
flash_attn_varlen_func
(
q
=
query
[:
num_actual_tokens
],
k
=
key_cache
,
...
...
@@ -643,8 +643,9 @@ class FlashAttentionImpl(AttentionImpl):
scheduler_metadata
=
scheduler_metadata
,
# fa_version=self.vllm_flash_attn_version,
# q_descale=layer._q_scale.expand(descale_shape),
# k_descale=layer._k_scale.expand(descale_shape),
# v_descale=layer._v_scale.expand(descale_shape),
q_descale
=
None
,
k_descale
=
layer
.
_k_scale
,
v_descale
=
layer
.
_v_scale
,
# num_splits=attn_metadata.max_num_splits,
s_aux
=
self
.
sinks
,
is_prefix_cache
=
True
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment