Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2f12cd32
Unverified
Commit
2f12cd32
authored
Dec 27, 2025
by
Boyuan Feng
Committed by
GitHub
Dec 27, 2025
Browse files
[BugFix] Fix cache issue in compilation_config (#31376)
Signed-off-by:
Boyuan Feng
<
boyuan@meta.com
>
parent
40a87562
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
47 additions
and
0 deletions
+47
-0
tests/compile/test_config.py
tests/compile/test_config.py
+42
-0
vllm/config/vllm.py
vllm/config/vllm.py
+5
-0
No files found.
tests/compile/test_config.py
View file @
2f12cd32
...
@@ -428,3 +428,45 @@ def test_cudagraph_sizes_post_init(
...
@@ -428,3 +428,45 @@ def test_cudagraph_sizes_post_init(
vllm_config
.
compilation_config
.
max_cudagraph_capture_size
vllm_config
.
compilation_config
.
max_cudagraph_capture_size
==
expected_max_size
==
expected_max_size
)
)
def
test_cached_compilation_config
():
import
torch
from
torch._inductor.utils
import
run_and_get_code
from
vllm.config
import
get_cached_compilation_config
,
set_current_vllm_config
from
vllm.model_executor.layers.quantization.input_quant_fp8
import
QuantFP8
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
GroupShape
dtype
=
torch
.
bfloat16
device
=
torch
.
device
(
"cuda:0"
)
batch_size
,
num_qo_heads
,
head_size
=
8
,
16
,
128
# access and cache default compilation config
# default compilation config does not contain +quant_fp8 custom op. If this is
# used, the generated code would use inductor-generated triton kernel instead
# of the custom op `torch.ops._C.static_scaled_fp8_quant`.
get_cached_compilation_config
()
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
mode
=
CompilationMode
.
VLLM_COMPILE
,
custom_ops
=
[
"+quant_fp8"
],
)
)
# set_current_vllm_config should clear cached compilation config and
# use the new compilation_config in vllm_config
with
set_current_vllm_config
(
vllm_config
):
query_quant
=
QuantFP8
(
static
=
True
,
group_shape
=
GroupShape
.
PER_TENSOR
)
query_quant
=
torch
.
compile
(
query_quant
)
_q_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
query
=
torch
.
randn
(
batch_size
,
num_qo_heads
*
head_size
,
dtype
=
dtype
,
device
=
device
)
_
,
code
=
run_and_get_code
(
query_quant
,
query
,
_q_scale
)
code
=
" "
.
join
(
code
)
assert
"torch.ops._C.static_scaled_fp8_quant.default("
in
code
vllm/config/vllm.py
View file @
2f12cd32
...
@@ -1360,6 +1360,11 @@ def set_current_vllm_config(
...
@@ -1360,6 +1360,11 @@ def set_current_vllm_config(
num_models_seen
=
compilation_counter
.
num_models_seen
num_models_seen
=
compilation_counter
.
num_models_seen
try
:
try
:
# Clear the compilation config cache when context changes.
# This is needed since the old config may have been accessed
# and cached before the new config is set.
get_cached_compilation_config
.
cache_clear
()
_current_vllm_config
=
vllm_config
_current_vllm_config
=
vllm_config
_current_prefix
=
prefix
_current_prefix
=
prefix
yield
yield
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment