Unverified Commit 6cdf015c authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Misc] Fix `Current vLLM config is not set.` warnings, assert to avoid issues...


[Misc] Fix `Current vLLM config is not set.` warnings, assert to avoid issues in the future (#31747)
Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: default avatarLucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
parent 5d3b6097
......@@ -111,7 +111,7 @@ def create_packed_lora(
return LoRAModel(lora_id, 8, loras)
def test_replace_submodules(dist_init, dummy_model):
def test_replace_submodules(default_vllm_config, dist_init, dummy_model):
model = dummy_model
manager = LoRAModelManager(
model,
......@@ -133,7 +133,7 @@ def test_replace_submodules(dist_init, dummy_model):
@pytest.mark.parametrize("device", DEVICES)
def test_lora_model_manager(dist_init, dummy_model, device):
def test_lora_model_manager(default_vllm_config, dist_init, dummy_model, device):
model = dummy_model
model_lora1 = create_lora(
1, model, ["layer1.dense1", "dense2", "lm_head"], device=device
......@@ -199,7 +199,9 @@ def test_lora_model_manager(dist_init, dummy_model, device):
@pytest.mark.parametrize("device", DEVICES)
def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
def test_lora_lru_cache_model_manager(
default_vllm_config, dist_init, dummy_model, device
):
model = dummy_model
model_lora1 = create_lora(
1, model, ["layer1.dense1", "dense2", "lm_head"], device=device
......@@ -289,7 +291,7 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
@pytest.mark.parametrize("device", DEVICES)
def test_lru_lora_model_manager(dist_init, dummy_model, device):
def test_lru_lora_model_manager(default_vllm_config, dist_init, dummy_model, device):
# This tests just the LRU cache functionality, everything else is
# tested in test_lora_model_manager
model = dummy_model
......@@ -415,7 +417,9 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
@pytest.mark.parametrize("device", DEVICES)
def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device, tmp_path):
def test_lru_cache_worker_adapter_manager(
default_vllm_config, dist_init, dummy_model, device, tmp_path
):
lora_config = LoRAConfig(
max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE
)
......@@ -529,7 +533,9 @@ def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device, tmp_pa
@pytest.mark.parametrize("device", DEVICES)
def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, tmp_path):
def test_worker_adapter_manager(
default_vllm_config, dist_init, dummy_model_gate_up, device, tmp_path
):
# Should remove every LoRA not specified in the request.
lora_config = LoRAConfig(
max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE
......@@ -636,7 +642,7 @@ def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, tmp_path
@pytest.mark.parametrize("device", DEVICES)
def test_packed_loras(dist_init, dummy_model_gate_up, device):
def test_packed_loras(default_vllm_config, dist_init, dummy_model_gate_up, device):
model = dummy_model_gate_up
model_lora = create_packed_lora(
1,
......
......@@ -55,7 +55,7 @@ def test_get_draft_quant_config_without_draft_model():
@torch.inference_mode()
@pytest.mark.parametrize("device", DEVICES)
def test_fc_layer_quant_config_usage(dist_init, device) -> None:
def test_fc_layer_quant_config_usage(default_vllm_config, dist_init, device) -> None:
import torch
from vllm.model_executor.layers.linear import ReplicatedLinear
......
......@@ -73,7 +73,9 @@ def run_intern_vit_test(
],
)
@pytest.mark.parametrize("dtype", ["half"])
def test_models(dist_init, image_assets, model_id, dtype: str) -> None:
def test_models(
default_vllm_config, dist_init, image_assets, model_id, dtype: str
) -> None:
run_intern_vit_test(
image_assets,
model_id,
......
......@@ -92,7 +92,9 @@ def run_radio_test(
],
)
@pytest.mark.parametrize("dtype", ["half", "bfloat16"])
def test_radio(dist_init, image_assets, model_id, dtype: str) -> None:
def test_radio(
default_vllm_config, dist_init, image_assets, model_id, dtype: str
) -> None:
run_radio_test(
image_assets,
model_id,
......
......@@ -145,6 +145,9 @@ def initialize_dummy_model(
model_config: ModelConfig,
):
temp_file = tempfile.mkstemp()[1]
current_device = torch.get_default_device()
vllm_config = VllmConfig(model_config=model_config)
with set_current_vllm_config(vllm_config=vllm_config):
init_distributed_environment(
world_size=1,
rank=0,
......@@ -154,9 +157,6 @@ def initialize_dummy_model(
)
initialize_model_parallel(tensor_model_parallel_size=1)
current_device = torch.get_default_device()
vllm_config = VllmConfig(model_config=model_config)
with set_current_vllm_config(vllm_config=vllm_config):
with set_default_torch_dtype(model_config.dtype):
torch.set_default_device(current_platform.device_type)
model = model_cls(vllm_config=vllm_config)
......
......@@ -31,7 +31,7 @@ def test_platform_plugins():
)
def test_oot_custom_op(monkeypatch: pytest.MonkeyPatch):
def test_oot_custom_op(default_vllm_config, monkeypatch: pytest.MonkeyPatch):
# simulate workload by running an example
load_general_plugins()
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
......
......@@ -277,6 +277,7 @@ def test_scaled_fp8_quant(dtype) -> None:
# this is the case for marlin as well as per-tensor Fp8MoEMethod
@pytest.mark.parametrize("use_marlin", [False]) # skip True
def test_fp8_reloading(
default_vllm_config,
method_cls,
is_checkpoint_fp8_serialized,
weight_block_size,
......
......@@ -721,7 +721,28 @@ def init_test_distributed_environment(
distributed_init_port: str,
local_rank: int = -1,
) -> None:
# Note: This function is often called from Ray worker processes, so we
# can't rely on pytest fixtures to set the config. We check if the config
# is already set and only create a default one if needed.
from vllm.config import (
VllmConfig,
get_current_vllm_config_or_none,
set_current_vllm_config,
)
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
if get_current_vllm_config_or_none() is not None:
# Config already set, use it directly
init_distributed_environment(
world_size=pp_size * tp_size,
rank=rank,
distributed_init_method=distributed_init_method,
local_rank=local_rank,
)
else:
# No config set, create a default one for the test
with set_current_vllm_config(VllmConfig()):
init_distributed_environment(
world_size=pp_size * tp_size,
rank=rank,
......
......@@ -556,7 +556,7 @@ def _test_backend_correctness(
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
def test_causal_backend_correctness(
batch_spec_name: str, model: str, tensor_parallel_size: int
default_vllm_config, batch_spec_name: str, model: str, tensor_parallel_size: int
):
"""Test backend's correctness with causal attention."""
......
......@@ -79,7 +79,12 @@ from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionBackend
],
)
def test_mamba_layers_get_attn_backend(
dist_init, layer_class, init_kwargs, expected_backend, expected_mamba_type
default_vllm_config,
dist_init,
layer_class,
init_kwargs,
expected_backend,
expected_mamba_type,
):
"""Test that Mamba-like layers return the correct attention backend."""
layer = layer_class(**init_kwargs)
......
......@@ -394,7 +394,11 @@ def run_attention_backend(
@pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-R1"])
@pytest.mark.parametrize("tensor_parallel_size", [1, 4, 8, 16])
def test_backend_correctness(
dist_init, batch_spec_name: str, model: str, tensor_parallel_size: int
default_vllm_config,
dist_init,
batch_spec_name: str,
model: str,
tensor_parallel_size: int,
):
"""
Test that all backends produce similar outputs to a reference implementation
......
......@@ -124,7 +124,12 @@ def _quantize_dequantize_fp8_ds_mla(
reason="FlashMLASparseBackend requires CUDA 9.0 or higher",
)
def test_sparse_backend_decode_correctness(
dist_init, batch_name, kv_cache_dtype, tensor_parallel_size, workspace_init
default_vllm_config,
dist_init,
batch_name,
kv_cache_dtype,
tensor_parallel_size,
workspace_init,
):
if current_platform.is_rocm():
pytest.skip("ROCm does not support fp8_ds_mla data type for kv cache.")
......
......@@ -21,7 +21,11 @@ from vllm.model_executor.layers.layernorm import RMSNorm
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("eps", [1e-6, 1e-5])
def test_rms_norm_batch_invariant_vs_standard(
batch_size: int, hidden_size: int, dtype: torch.dtype, eps: float
default_vllm_config,
batch_size: int,
hidden_size: int,
dtype: torch.dtype,
eps: float,
):
"""
Compare batch-invariant Triton RMS norm against standard CUDA implementation.
......@@ -68,7 +72,9 @@ def test_rms_norm_batch_invariant_vs_standard(
@pytest.mark.parametrize("batch_size", [1, 16, 128])
@pytest.mark.parametrize("seq_len", [1, 32, 512])
@pytest.mark.parametrize("hidden_size", [2048, 4096])
def test_rms_norm_3d_input(batch_size: int, seq_len: int, hidden_size: int):
def test_rms_norm_3d_input(
default_vllm_config, batch_size: int, seq_len: int, hidden_size: int
):
"""
Test RMS norm with 3D input tensors (batch, seq_len, hidden_size).
......@@ -107,7 +113,7 @@ def test_rms_norm_3d_input(batch_size: int, seq_len: int, hidden_size: int):
@skip_unsupported
def test_rms_norm_numerical_stability():
def test_rms_norm_numerical_stability(default_vllm_config):
"""
Test RMS norm numerical stability with extreme values.
......@@ -167,7 +173,7 @@ def test_rms_norm_numerical_stability():
@skip_unsupported
def test_rms_norm_formula():
def test_rms_norm_formula(default_vllm_config):
"""
Test that RMS norm follows the correct mathematical formula.
......@@ -201,7 +207,7 @@ def test_rms_norm_formula():
@skip_unsupported
@pytest.mark.parametrize("hidden_size", [128, 1024, 4096, 16384])
def test_rms_norm_different_hidden_sizes(hidden_size: int):
def test_rms_norm_different_hidden_sizes(default_vllm_config, hidden_size: int):
"""
Test RMS norm with various hidden sizes to ensure block size handling.
......@@ -238,7 +244,7 @@ def test_rms_norm_different_hidden_sizes(hidden_size: int):
@skip_unsupported
def test_rms_norm_determinism():
def test_rms_norm_determinism(default_vllm_config):
"""
Test that batch-invariant RMS norm produces deterministic results.
......
......@@ -299,6 +299,7 @@ def test_prompt_less_than_block_size():
)
def test_kv_transfer_handshake(dist_init):
"""Unit test for basic NixlConnector interface functionality."""
from vllm.config import set_current_vllm_config
# Test setup, we creates a scheduler that contains a NixlConnector
# of role SCHEDULER, and expect it to be serving NixlAgentMetadata from
......@@ -308,6 +309,7 @@ def test_kv_transfer_handshake(dist_init):
vllm_config.kv_transfer_config.kv_buffer_device = "cpu"
scheduler = create_scheduler(vllm_config)
with set_current_vllm_config(vllm_config):
# Create two NixlConnector of role WORKER, one is the worker of
# the scheduler (prefill), the other is a worker of decode instance.
......@@ -458,6 +460,7 @@ class TestNixlHandshake:
)
def test_multi_xfer_one_engine(
self,
default_vllm_config,
# dist_init is a fixture that initializes the distributed environment.
dist_init,
):
......@@ -547,6 +550,7 @@ class TestNixlHandshake:
)
def test_async_load_kv(
self,
default_vllm_config,
# Fixture that initializes the distributed environment.
dist_init,
# Simulate consumer-producer TP sizes.
......@@ -605,7 +609,7 @@ class TestNixlHandshake:
)
@pytest.mark.parametrize("local_tp_size", [1, 2])
def test_prefill_tp_size_greater_than_decode_tp_size(
self, local_tp_size: int, dist_init
self, local_tp_size: int, default_vllm_config, dist_init
):
"""
Verify remote TP > local TP handshake succeeds with different
......@@ -670,7 +674,7 @@ class TestNixlHandshake:
)
@pytest.mark.parametrize("local_tp_size", [1, 2])
def test_prefill_tp_size_greater_than_decode_tp_size_mla(
self, local_tp_size: int, dist_init
self, local_tp_size: int, default_vllm_config, dist_init
):
"""
Verify remote TP > local TP handshake succeeds with different
......@@ -770,6 +774,7 @@ class TestNixlHandshake:
)
def test_concurrent_load_kv(
self,
default_vllm_config,
# dist_init is a fixture that initializes the distributed environment.
dist_init,
):
......@@ -830,7 +835,9 @@ class TestNixlHandshake:
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init):
def test_handshake_fails_on_kv_cache_layout_mismatch(
self, default_vllm_config, dist_init
):
"""
Verify that adding a remote agent fails if kv_cache_layout differs.
This test is only relevant for heterogeneous TP.
......@@ -879,7 +886,7 @@ class TestNixlHandshake:
FakeNixlWrapper,
)
def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental(
self, dist_init
self, default_vllm_config, dist_init
):
"""
Verify that adding a remote agent fails if kv_cache_layout differs.
......@@ -934,7 +941,7 @@ class TestNixlHandshake:
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
def test_kv_connector_stats(dist_init):
def test_kv_connector_stats(default_vllm_config, dist_init):
"""Test that KV transfer stats are properly recorded and retrieved."""
vllm_config = create_vllm_config()
......@@ -1357,7 +1364,7 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
"TRITON_ATTN",
],
)
def test_register_kv_caches(dist_init, attn_backend):
def test_register_kv_caches(default_vllm_config, dist_init, attn_backend):
"""
Test that register_kv_caches() properly calls nixl_wrapper methods with
correct data.
......@@ -1518,7 +1525,9 @@ class FakePlatform(Platform):
("oot", "VRAM"),
],
)
def test_kv_buffer_to_nixl_memory_types(dist_init, kv_buffer_device, nixl_memory_type):
def test_kv_buffer_to_nixl_memory_types(
default_vllm_config, dist_init, kv_buffer_device, nixl_memory_type
):
"""
Test that register_kv_caches() passes the correct memory types from the
config to the nixl_wrapper.
......@@ -1563,7 +1572,7 @@ def test_kv_buffer_to_nixl_memory_types(dist_init, kv_buffer_device, nixl_memory
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
def test_shutdown_cleans_up_resources(dist_init):
def test_shutdown_cleans_up_resources(default_vllm_config, dist_init):
"""Test that shutdown() properly cleans up all resources."""
vllm_config = create_vllm_config()
......@@ -1622,7 +1631,7 @@ def test_shutdown_cleans_up_resources(dist_init):
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
def test_aborted_request_removed_from_worker_in_batch(dist_init):
def test_aborted_request_removed_from_worker_in_batch(default_vllm_config, dist_init):
"""
Create and schedule a request so that P adds it to in-batch tracking via
the real scheduler, then simulate an abort (request not in next scheduler
......@@ -1731,7 +1740,7 @@ class FailingNixlWrapper(FakeNixlWrapper):
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FailingNixlWrapper,
)
def test_handshake_failure_returns_finished(dist_init):
def test_handshake_failure_returns_finished(default_vllm_config, dist_init):
"""Test that handshake failures mark blocks invalid and return via get_finished."""
vllm_config = create_vllm_config()
......@@ -1780,7 +1789,7 @@ def test_handshake_failure_returns_finished(dist_init):
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FailingNixlWrapper,
)
def test_transfer_setup_failure_returns_finished(dist_init):
def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init):
"""Test that transfer setup failures mark blocks invalid
and return via get_finished."""
vllm_config = create_vllm_config()
......@@ -1855,6 +1864,7 @@ def test_transfer_setup_failure_returns_finished(dist_init):
FakeNixlWrapper,
)
def test_compatibility_hash_validation(
default_vllm_config,
dist_init,
mismatch_type,
config_overrides,
......@@ -1967,7 +1977,7 @@ def test_compatibility_hash_validation(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
def test_handshake_decode_errors(dist_init, error_scenario):
def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario):
"""
Test that msgspec decode errors are properly handled during handshake.
......
......@@ -50,6 +50,7 @@ NUM_MAPPINGS = [3]
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_transfer(
default_vllm_config,
gpu_to_cpu: bool,
num_mappings: int,
head_size: int,
......
......@@ -112,6 +112,7 @@ def get_vllm_config():
@pytest.fixture
def model_runner():
vllm_config = get_vllm_config()
with set_current_vllm_config(vllm_config):
model_config = vllm_config.model_config
num_heads = model_config.get_num_kv_heads(vllm_config.parallel_config)
head_size = model_config.get_head_size()
......@@ -120,7 +121,7 @@ def model_runner():
)
runner = GPUModelRunner(vllm_config, DEVICE)
initialize_kv_cache(runner)
return runner
yield runner
model_runner_2 = model_runner
......@@ -546,7 +547,7 @@ def test_reload_weights_before_load_model(model_runner):
model_runner.reload_weights()
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(default_vllm_config):
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
......@@ -573,7 +574,7 @@ def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
assert fwd_context is not None
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(default_vllm_config):
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
......@@ -600,7 +601,7 @@ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
assert fwd_context is not None
def test_init_kv_cache_with_kv_sharing_target_same_as_current():
def test_init_kv_cache_with_kv_sharing_target_same_as_current(default_vllm_config):
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
......@@ -627,7 +628,7 @@ def test_init_kv_cache_with_kv_sharing_target_same_as_current():
assert fwd_context is not None
def test_init_kv_cache_without_kv_sharing():
def test_init_kv_cache_without_kv_sharing(default_vllm_config):
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
......@@ -694,7 +695,7 @@ def test_init_kv_cache_without_kv_sharing():
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
def test_init_kv_cache_with_kv_sharing_valid():
def test_init_kv_cache_with_kv_sharing_valid(default_vllm_config):
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
......@@ -1047,7 +1048,7 @@ def test_input_batch_with_kernel_block_sizes():
assert block_table.block_size == kernel_size
def test_hybrid_cache_integration(model_runner, dist_init):
def test_hybrid_cache_integration(default_vllm_config, dist_init):
"""Test hybrid cache architecture integration with GPUModelRunner."""
# Create a new model runner with hybrid cache configuration
vllm_config = get_vllm_config()
......
......@@ -6,14 +6,14 @@ import torch
from vllm.v1.worker.utils import bind_kv_cache
def test_bind_kv_cache():
def test_bind_kv_cache(default_vllm_config):
from vllm.attention.layer import Attention
ctx = {
"layers.0.self_attn": Attention(32, 128, 0.1),
"layers.1.self_attn": Attention(32, 128, 0.1),
"layers.2.self_attn": Attention(32, 128, 0.1),
"layers.3.self_attn": Attention(32, 128, 0.1),
"layers.0.self_attn": Attention(32, 128, 0.1, prefix="layers.0.self_attn"),
"layers.1.self_attn": Attention(32, 128, 0.1, prefix="layers.1.self_attn"),
"layers.2.self_attn": Attention(32, 128, 0.1, prefix="layers.2.self_attn"),
"layers.3.self_attn": Attention(32, 128, 0.1, prefix="layers.3.self_attn"),
}
kv_cache = {
"layers.0.self_attn": torch.zeros((1,)),
......@@ -34,13 +34,13 @@ def test_bind_kv_cache():
assert runner_kv_caches[3] is kv_cache["layers.3.self_attn"]
def test_bind_kv_cache_non_attention():
def test_bind_kv_cache_non_attention(default_vllm_config):
from vllm.attention.layer import Attention
# example from Jamba PP=2
ctx = {
"model.layers.20.attn": Attention(32, 128, 0.1),
"model.layers.28.attn": Attention(32, 128, 0.1),
"model.layers.20.attn": Attention(32, 128, 0.1, prefix="model.layers.20.attn"),
"model.layers.28.attn": Attention(32, 128, 0.1, prefix="model.layers.28.attn"),
}
kv_cache = {
"model.layers.20.attn": torch.zeros((1,)),
......
......@@ -59,10 +59,13 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
)
# 2. override if passed by environment or config
from vllm.config import get_current_vllm_config
from vllm.config import get_current_vllm_config_or_none
vllm_config = get_current_vllm_config()
if vllm_config.attention_config.flash_attn_version is not None:
vllm_config = get_current_vllm_config_or_none()
if (
vllm_config is not None
and vllm_config.attention_config.flash_attn_version is not None
):
fa_version = vllm_config.attention_config.flash_attn_version
# 3. fallback for unsupported combinations
......
......@@ -42,6 +42,7 @@ from vllm.config.vllm import (
VllmConfig,
get_cached_compilation_config,
get_current_vllm_config,
get_current_vllm_config_or_none,
get_layers_from_vllm_config,
set_current_vllm_config,
)
......@@ -105,6 +106,7 @@ __all__ = [
"VllmConfig",
"get_cached_compilation_config",
"get_current_vllm_config",
"get_current_vllm_config_or_none",
"set_current_vllm_config",
"get_layers_from_vllm_config",
]
......@@ -1441,13 +1441,18 @@ def get_cached_compilation_config():
def get_current_vllm_config() -> VllmConfig:
if _current_vllm_config is None:
# in ci, usually when we test custom ops/modules directly,
# we don't set the vllm config. In that case, we set a default
# config.
# Use stack level 2 so the log contains the line of the caller,
# so it's easier to track down the source of the warning.
logger.warning("Current vLLM config is not set.", stacklevel=2)
return VllmConfig()
raise AssertionError(
"Current vLLM config is not set. This typically means "
"get_current_vllm_config() was called outside of a "
"set_current_vllm_config() context, or a CustomOp was instantiated "
"at module import time or model forward time when config is not set. "
"For tests that directly test custom ops/modules, use the "
"'default_vllm_config' pytest fixture from tests/conftest.py."
)
return _current_vllm_config
def get_current_vllm_config_or_none() -> VllmConfig | None:
return _current_vllm_config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment