Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a5a790ee
Unverified
Commit
a5a790ee
authored
Nov 10, 2025
by
Adrian Abeyta
Committed by
GitHub
Nov 10, 2025
Browse files
[Bugfix] Ensure calculated KV scales are applied in attention. (#27232)
Signed-off-by:
adabeyta
<
aabeyta@redhat.com
>
parent
b30372cb
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
29 additions
and
36 deletions
+29
-36
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+5
-2
tests/compile/test_full_graph.py
tests/compile/test_full_graph.py
+8
-2
vllm/attention/layer.py
vllm/attention/layer.py
+7
-22
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+9
-10
No files found.
.buildkite/test-pipeline.yaml
View file @
a5a790ee
...
@@ -471,8 +471,8 @@ steps:
...
@@ -471,8 +471,8 @@ steps:
-
vllm/
-
vllm/
-
tests/compile
-
tests/compile
commands
:
commands
:
-
pytest -v -s compile/test_full_graph.py
-
pytest -v -s compile/test_full_graph.py
-k 'not test_fp8_kv_scale_compile'
# Limit to no custom ops to reduce running time
# Limit to no custom ops to reduce running time
# Wrap with quotes to escape yaml and avoid starting -k string with a -
# Wrap with quotes to escape yaml and avoid starting -k string with a -
-
"
pytest
-v
-s
compile/test_fusions_e2e.py
-k
'TRITON
and
-quant_fp8'"
-
"
pytest
-v
-s
compile/test_fusions_e2e.py
-k
'TRITON
and
-quant_fp8'"
...
@@ -951,10 +951,13 @@ steps:
...
@@ -951,10 +951,13 @@ steps:
-
vllm/model_executor/layers/activation.py
-
vllm/model_executor/layers/activation.py
-
vllm/model_executor/layers/quantization/input_quant_fp8.py
-
vllm/model_executor/layers/quantization/input_quant_fp8.py
-
tests/compile/test_fusions_e2e.py
-
tests/compile/test_fusions_e2e.py
-
tests/compile/test_full_graph.py
commands
:
commands
:
-
nvidia-smi
-
nvidia-smi
# Run all e2e fusion tests
# Run all e2e fusion tests
-
pytest -v -s tests/compile/test_fusions_e2e.py
-
pytest -v -s tests/compile/test_fusions_e2e.py
# test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40)
-
pytest -v -s tests/compile/test_full_graph.py::test_fp8_kv_scale_compile
-
label
:
Blackwell GPT-OSS Eval
-
label
:
Blackwell GPT-OSS Eval
timeout_in_minutes
:
60
timeout_in_minutes
:
60
...
...
tests/compile/test_full_graph.py
View file @
a5a790ee
...
@@ -183,8 +183,14 @@ def test_custom_compile_config(
...
@@ -183,8 +183,14 @@ def test_custom_compile_config(
"compilation_mode"
,
"compilation_mode"
,
[
CompilationMode
.
NONE
,
CompilationMode
.
VLLM_COMPILE
],
[
CompilationMode
.
NONE
,
CompilationMode
.
VLLM_COMPILE
],
)
)
def
test_fp8_kv_scale_compile
(
compilation_mode
:
int
):
@
pytest
.
mark
.
parametrize
(
model
=
"Qwen/Qwen2-0.5B"
"model"
,
[
"Qwen/Qwen2-0.5B"
,
# Standard attention model
"deepseek-ai/DeepSeek-V2-Lite"
,
# MLA (Multi-head Latent Attention) model
],
)
def
test_fp8_kv_scale_compile
(
compilation_mode
:
int
,
model
:
str
):
model_kwargs
=
{
model_kwargs
=
{
"quantization"
:
"fp8"
,
"quantization"
:
"fp8"
,
"kv_cache_dtype"
:
"fp8_e4m3"
,
"kv_cache_dtype"
:
"fp8_e4m3"
,
...
...
vllm/attention/layer.py
View file @
a5a790ee
...
@@ -745,6 +745,9 @@ class MLAAttention(nn.Module, AttentionLayerBase):
...
@@ -745,6 +745,9 @@ class MLAAttention(nn.Module, AttentionLayerBase):
k_pe
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
output_shape
:
torch
.
Size
|
None
=
None
,
output_shape
:
torch
.
Size
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
self
.
calculate_kv_scales
:
torch
.
ops
.
vllm
.
maybe_calc_kv_scales
(
q
,
kv_c_normed
,
k_pe
,
self
.
layer_name
)
if
self
.
use_direct_call
:
if
self
.
use_direct_call
:
forward_context
:
ForwardContext
=
get_forward_context
()
forward_context
:
ForwardContext
=
get_forward_context
()
attn_metadata
=
forward_context
.
attn_metadata
attn_metadata
=
forward_context
.
attn_metadata
...
@@ -752,12 +755,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
...
@@ -752,12 +755,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
attn_metadata
=
attn_metadata
[
self
.
layer_name
]
attn_metadata
=
attn_metadata
[
self
.
layer_name
]
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
# Mirror Attention.forward scale calculation path
if
self
.
calculate_kv_scales
and
getattr
(
attn_metadata
,
"enable_kv_scales_calculation"
,
False
):
self
.
calc_kv_scales
(
q
,
kv_c_normed
,
k_pe
)
if
self
.
attn_backend
.
accept_output_buffer
:
if
self
.
attn_backend
.
accept_output_buffer
:
output
=
torch
.
empty
(
output_shape
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
output
=
torch
.
empty
(
output_shape
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
self
.
impl
.
forward
(
self
.
impl
.
forward
(
...
@@ -786,14 +783,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
...
@@ -786,14 +783,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
)
)
return
output
return
output
else
:
else
:
# We can still access forward context to check calculation flag
if
self
.
calculate_kv_scales
:
forward_context
=
get_forward_context
()
attn_metadata
=
forward_context
.
attn_metadata
if
isinstance
(
attn_metadata
,
dict
):
attn_metadata
=
attn_metadata
[
self
.
layer_name
]
if
getattr
(
attn_metadata
,
"enable_kv_scales_calculation"
,
False
):
self
.
calc_kv_scales
(
q
,
kv_c_normed
,
k_pe
)
return
torch
.
ops
.
vllm
.
unified_mla_attention
(
return
torch
.
ops
.
vllm
.
unified_mla_attention
(
q
,
q
,
kv_c_normed
,
kv_c_normed
,
...
@@ -881,17 +870,13 @@ def maybe_calc_kv_scales(
...
@@ -881,17 +870,13 @@ def maybe_calc_kv_scales(
layer_name
:
str
,
layer_name
:
str
,
)
->
None
:
)
->
None
:
forward_context
:
ForwardContext
=
get_forward_context
()
forward_context
:
ForwardContext
=
get_forward_context
()
attn_metadata
=
forward_context
.
attn_metadata
self
=
forward_context
.
no_compile_layers
[
layer_name
]
if
isinstance
(
attn_metadata
,
dict
):
attn_metadata
=
attn_metadata
[
layer_name
]
if
attn_metad
at
a
i
s
None
or
not
getattr
(
# Only calcul
at
e
i
f the layer's calculate_kv_scales flag is True
attn_metadata
,
"enable_kv_scales_calculation"
,
False
# This flag gets set to False after the first forward pass
)
:
if
not
self
.
calculate_kv_scales
:
return
return
self
=
forward_context
.
no_compile_layers
[
layer_name
]
self
.
calc_kv_scales
(
query
,
key
,
value
)
self
.
calc_kv_scales
(
query
,
key
,
value
)
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
a5a790ee
...
@@ -279,6 +279,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -279,6 +279,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# This will be overridden in load_model()
# This will be overridden in load_model()
self
.
is_multimodal_pruning_enabled
=
False
self
.
is_multimodal_pruning_enabled
=
False
self
.
max_model_len
=
model_config
.
max_model_len
self
.
max_model_len
=
model_config
.
max_model_len
# Always set to false after the first forward pass
self
.
calculate_kv_scales
=
self
.
cache_config
.
calculate_kv_scales
self
.
dcp_world_size
=
self
.
parallel_config
.
decode_context_parallel_size
self
.
dcp_world_size
=
self
.
parallel_config
.
decode_context_parallel_size
self
.
dcp_rank
=
0
if
self
.
dcp_world_size
<=
1
else
get_dcp_group
().
rank_in_group
self
.
dcp_rank
=
0
if
self
.
dcp_world_size
<=
1
else
get_dcp_group
().
rank_in_group
self
.
max_num_tokens
=
scheduler_config
.
max_num_batched_tokens
self
.
max_num_tokens
=
scheduler_config
.
max_num_batched_tokens
...
@@ -2625,16 +2628,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2625,16 +2628,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
)
# Set cudagraph mode to none if calc_kv_scales is true.
# Set cudagraph mode to none if calc_kv_scales is true.
if
attn_metadata
is
not
None
:
# KV scales calculation involves dynamic operations that are incompatible
metadata_list
=
(
# with CUDA graph capture.
attn_metadata
.
values
()
if
self
.
calculate_kv_scales
:
if
isinstance
(
attn_metadata
,
dict
)
cudagraph_runtime_mode
=
CUDAGraphMode
.
NONE
else
[
attn_metadata
]
# Mark KV scales as calculated after the first forward pass
)
self
.
calculate_kv_scales
=
False
if
any
(
getattr
(
m
,
"enable_kv_scales_calculation"
,
False
)
for
m
in
metadata_list
):
cudagraph_runtime_mode
=
CUDAGraphMode
.
NONE
# Run the model.
# Run the model.
# Use persistent buffers for CUDA graphs.
# Use persistent buffers for CUDA graphs.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment