Unverified Commit 2f12cd32 authored by Boyuan Feng's avatar Boyuan Feng Committed by GitHub
Browse files

[BugFix] Fix cache issue in compilation_config (#31376)


Signed-off-by: default avatarBoyuan Feng <boyuan@meta.com>
parent 40a87562
...@@ -428,3 +428,45 @@ def test_cudagraph_sizes_post_init( ...@@ -428,3 +428,45 @@ def test_cudagraph_sizes_post_init(
vllm_config.compilation_config.max_cudagraph_capture_size vllm_config.compilation_config.max_cudagraph_capture_size
== expected_max_size == expected_max_size
) )
def test_cached_compilation_config():
import torch
from torch._inductor.utils import run_and_get_code
from vllm.config import get_cached_compilation_config, set_current_vllm_config
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
dtype = torch.bfloat16
device = torch.device("cuda:0")
batch_size, num_qo_heads, head_size = 8, 16, 128
# access and cache default compilation config
# default compilation config does not contain +quant_fp8 custom op. If this is
# used, the generated code would use inductor-generated triton kernel instead
# of the custom op `torch.ops._C.static_scaled_fp8_quant`.
get_cached_compilation_config()
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
custom_ops=["+quant_fp8"],
)
)
# set_current_vllm_config should clear cached compilation config and
# use the new compilation_config in vllm_config
with set_current_vllm_config(vllm_config):
query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
query_quant = torch.compile(query_quant)
_q_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
query = torch.randn(
batch_size, num_qo_heads * head_size, dtype=dtype, device=device
)
_, code = run_and_get_code(query_quant, query, _q_scale)
code = " ".join(code)
assert "torch.ops._C.static_scaled_fp8_quant.default(" in code
...@@ -1360,6 +1360,11 @@ def set_current_vllm_config( ...@@ -1360,6 +1360,11 @@ def set_current_vllm_config(
num_models_seen = compilation_counter.num_models_seen num_models_seen = compilation_counter.num_models_seen
try: try:
# Clear the compilation config cache when context changes.
# This is needed since the old config may have been accessed
# and cached before the new config is set.
get_cached_compilation_config.cache_clear()
_current_vllm_config = vllm_config _current_vllm_config = vllm_config
_current_prefix = prefix _current_prefix = prefix
yield yield
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment