Unverified Commit 6cdf015c authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Misc] Fix `Current vLLM config is not set.` warnings, assert to avoid issues...


[Misc] Fix `Current vLLM config is not set.` warnings, assert to avoid issues in the future (#31747)
Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: default avatarLucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
parent 5d3b6097
...@@ -111,7 +111,7 @@ def create_packed_lora( ...@@ -111,7 +111,7 @@ def create_packed_lora(
return LoRAModel(lora_id, 8, loras) return LoRAModel(lora_id, 8, loras)
def test_replace_submodules(dist_init, dummy_model): def test_replace_submodules(default_vllm_config, dist_init, dummy_model):
model = dummy_model model = dummy_model
manager = LoRAModelManager( manager = LoRAModelManager(
model, model,
...@@ -133,7 +133,7 @@ def test_replace_submodules(dist_init, dummy_model): ...@@ -133,7 +133,7 @@ def test_replace_submodules(dist_init, dummy_model):
@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("device", DEVICES)
def test_lora_model_manager(dist_init, dummy_model, device): def test_lora_model_manager(default_vllm_config, dist_init, dummy_model, device):
model = dummy_model model = dummy_model
model_lora1 = create_lora( model_lora1 = create_lora(
1, model, ["layer1.dense1", "dense2", "lm_head"], device=device 1, model, ["layer1.dense1", "dense2", "lm_head"], device=device
...@@ -199,7 +199,9 @@ def test_lora_model_manager(dist_init, dummy_model, device): ...@@ -199,7 +199,9 @@ def test_lora_model_manager(dist_init, dummy_model, device):
@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("device", DEVICES)
def test_lora_lru_cache_model_manager(dist_init, dummy_model, device): def test_lora_lru_cache_model_manager(
default_vllm_config, dist_init, dummy_model, device
):
model = dummy_model model = dummy_model
model_lora1 = create_lora( model_lora1 = create_lora(
1, model, ["layer1.dense1", "dense2", "lm_head"], device=device 1, model, ["layer1.dense1", "dense2", "lm_head"], device=device
...@@ -289,7 +291,7 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model, device): ...@@ -289,7 +291,7 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("device", DEVICES)
def test_lru_lora_model_manager(dist_init, dummy_model, device): def test_lru_lora_model_manager(default_vllm_config, dist_init, dummy_model, device):
# This tests just the LRU cache functionality, everything else is # This tests just the LRU cache functionality, everything else is
# tested in test_lora_model_manager # tested in test_lora_model_manager
model = dummy_model model = dummy_model
...@@ -415,7 +417,9 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device): ...@@ -415,7 +417,9 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("device", DEVICES)
def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device, tmp_path): def test_lru_cache_worker_adapter_manager(
default_vllm_config, dist_init, dummy_model, device, tmp_path
):
lora_config = LoRAConfig( lora_config = LoRAConfig(
max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE
) )
...@@ -529,7 +533,9 @@ def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device, tmp_pa ...@@ -529,7 +533,9 @@ def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device, tmp_pa
@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("device", DEVICES)
def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, tmp_path): def test_worker_adapter_manager(
default_vllm_config, dist_init, dummy_model_gate_up, device, tmp_path
):
# Should remove every LoRA not specified in the request. # Should remove every LoRA not specified in the request.
lora_config = LoRAConfig( lora_config = LoRAConfig(
max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE
...@@ -636,7 +642,7 @@ def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, tmp_path ...@@ -636,7 +642,7 @@ def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, tmp_path
@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("device", DEVICES)
def test_packed_loras(dist_init, dummy_model_gate_up, device): def test_packed_loras(default_vllm_config, dist_init, dummy_model_gate_up, device):
model = dummy_model_gate_up model = dummy_model_gate_up
model_lora = create_packed_lora( model_lora = create_packed_lora(
1, 1,
......
...@@ -55,7 +55,7 @@ def test_get_draft_quant_config_without_draft_model(): ...@@ -55,7 +55,7 @@ def test_get_draft_quant_config_without_draft_model():
@torch.inference_mode() @torch.inference_mode()
@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("device", DEVICES)
def test_fc_layer_quant_config_usage(dist_init, device) -> None: def test_fc_layer_quant_config_usage(default_vllm_config, dist_init, device) -> None:
import torch import torch
from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.linear import ReplicatedLinear
......
...@@ -73,7 +73,9 @@ def run_intern_vit_test( ...@@ -73,7 +73,9 @@ def run_intern_vit_test(
], ],
) )
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
def test_models(dist_init, image_assets, model_id, dtype: str) -> None: def test_models(
default_vllm_config, dist_init, image_assets, model_id, dtype: str
) -> None:
run_intern_vit_test( run_intern_vit_test(
image_assets, image_assets,
model_id, model_id,
......
...@@ -92,7 +92,9 @@ def run_radio_test( ...@@ -92,7 +92,9 @@ def run_radio_test(
], ],
) )
@pytest.mark.parametrize("dtype", ["half", "bfloat16"]) @pytest.mark.parametrize("dtype", ["half", "bfloat16"])
def test_radio(dist_init, image_assets, model_id, dtype: str) -> None: def test_radio(
default_vllm_config, dist_init, image_assets, model_id, dtype: str
) -> None:
run_radio_test( run_radio_test(
image_assets, image_assets,
model_id, model_id,
......
...@@ -145,18 +145,18 @@ def initialize_dummy_model( ...@@ -145,18 +145,18 @@ def initialize_dummy_model(
model_config: ModelConfig, model_config: ModelConfig,
): ):
temp_file = tempfile.mkstemp()[1] temp_file = tempfile.mkstemp()[1]
init_distributed_environment(
world_size=1,
rank=0,
distributed_init_method=f"file://{temp_file}",
local_rank=0,
backend="nccl",
)
initialize_model_parallel(tensor_model_parallel_size=1)
current_device = torch.get_default_device() current_device = torch.get_default_device()
vllm_config = VllmConfig(model_config=model_config) vllm_config = VllmConfig(model_config=model_config)
with set_current_vllm_config(vllm_config=vllm_config): with set_current_vllm_config(vllm_config=vllm_config):
init_distributed_environment(
world_size=1,
rank=0,
distributed_init_method=f"file://{temp_file}",
local_rank=0,
backend="nccl",
)
initialize_model_parallel(tensor_model_parallel_size=1)
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
torch.set_default_device(current_platform.device_type) torch.set_default_device(current_platform.device_type)
model = model_cls(vllm_config=vllm_config) model = model_cls(vllm_config=vllm_config)
......
...@@ -31,7 +31,7 @@ def test_platform_plugins(): ...@@ -31,7 +31,7 @@ def test_platform_plugins():
) )
def test_oot_custom_op(monkeypatch: pytest.MonkeyPatch): def test_oot_custom_op(default_vllm_config, monkeypatch: pytest.MonkeyPatch):
# simulate workload by running an example # simulate workload by running an example
load_general_plugins() load_general_plugins()
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
......
...@@ -277,6 +277,7 @@ def test_scaled_fp8_quant(dtype) -> None: ...@@ -277,6 +277,7 @@ def test_scaled_fp8_quant(dtype) -> None:
# this is the case for marlin as well as per-tensor Fp8MoEMethod # this is the case for marlin as well as per-tensor Fp8MoEMethod
@pytest.mark.parametrize("use_marlin", [False]) # skip True @pytest.mark.parametrize("use_marlin", [False]) # skip True
def test_fp8_reloading( def test_fp8_reloading(
default_vllm_config,
method_cls, method_cls,
is_checkpoint_fp8_serialized, is_checkpoint_fp8_serialized,
weight_block_size, weight_block_size,
......
...@@ -721,13 +721,34 @@ def init_test_distributed_environment( ...@@ -721,13 +721,34 @@ def init_test_distributed_environment(
distributed_init_port: str, distributed_init_port: str,
local_rank: int = -1, local_rank: int = -1,
) -> None: ) -> None:
distributed_init_method = f"tcp://localhost:{distributed_init_port}" # Note: This function is often called from Ray worker processes, so we
init_distributed_environment( # can't rely on pytest fixtures to set the config. We check if the config
world_size=pp_size * tp_size, # is already set and only create a default one if needed.
rank=rank, from vllm.config import (
distributed_init_method=distributed_init_method, VllmConfig,
local_rank=local_rank, get_current_vllm_config_or_none,
set_current_vllm_config,
) )
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
if get_current_vllm_config_or_none() is not None:
# Config already set, use it directly
init_distributed_environment(
world_size=pp_size * tp_size,
rank=rank,
distributed_init_method=distributed_init_method,
local_rank=local_rank,
)
else:
# No config set, create a default one for the test
with set_current_vllm_config(VllmConfig()):
init_distributed_environment(
world_size=pp_size * tp_size,
rank=rank,
distributed_init_method=distributed_init_method,
local_rank=local_rank,
)
ensure_model_parallel_initialized(tp_size, pp_size) ensure_model_parallel_initialized(tp_size, pp_size)
......
...@@ -556,7 +556,7 @@ def _test_backend_correctness( ...@@ -556,7 +556,7 @@ def _test_backend_correctness(
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) @pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4]) @pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
def test_causal_backend_correctness( def test_causal_backend_correctness(
batch_spec_name: str, model: str, tensor_parallel_size: int default_vllm_config, batch_spec_name: str, model: str, tensor_parallel_size: int
): ):
"""Test backend's correctness with causal attention.""" """Test backend's correctness with causal attention."""
......
...@@ -79,7 +79,12 @@ from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionBackend ...@@ -79,7 +79,12 @@ from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionBackend
], ],
) )
def test_mamba_layers_get_attn_backend( def test_mamba_layers_get_attn_backend(
dist_init, layer_class, init_kwargs, expected_backend, expected_mamba_type default_vllm_config,
dist_init,
layer_class,
init_kwargs,
expected_backend,
expected_mamba_type,
): ):
"""Test that Mamba-like layers return the correct attention backend.""" """Test that Mamba-like layers return the correct attention backend."""
layer = layer_class(**init_kwargs) layer = layer_class(**init_kwargs)
......
...@@ -394,7 +394,11 @@ def run_attention_backend( ...@@ -394,7 +394,11 @@ def run_attention_backend(
@pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-R1"]) @pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-R1"])
@pytest.mark.parametrize("tensor_parallel_size", [1, 4, 8, 16]) @pytest.mark.parametrize("tensor_parallel_size", [1, 4, 8, 16])
def test_backend_correctness( def test_backend_correctness(
dist_init, batch_spec_name: str, model: str, tensor_parallel_size: int default_vllm_config,
dist_init,
batch_spec_name: str,
model: str,
tensor_parallel_size: int,
): ):
""" """
Test that all backends produce similar outputs to a reference implementation Test that all backends produce similar outputs to a reference implementation
......
...@@ -124,7 +124,12 @@ def _quantize_dequantize_fp8_ds_mla( ...@@ -124,7 +124,12 @@ def _quantize_dequantize_fp8_ds_mla(
reason="FlashMLASparseBackend requires CUDA 9.0 or higher", reason="FlashMLASparseBackend requires CUDA 9.0 or higher",
) )
def test_sparse_backend_decode_correctness( def test_sparse_backend_decode_correctness(
dist_init, batch_name, kv_cache_dtype, tensor_parallel_size, workspace_init default_vllm_config,
dist_init,
batch_name,
kv_cache_dtype,
tensor_parallel_size,
workspace_init,
): ):
if current_platform.is_rocm(): if current_platform.is_rocm():
pytest.skip("ROCm does not support fp8_ds_mla data type for kv cache.") pytest.skip("ROCm does not support fp8_ds_mla data type for kv cache.")
......
...@@ -21,7 +21,11 @@ from vllm.model_executor.layers.layernorm import RMSNorm ...@@ -21,7 +21,11 @@ from vllm.model_executor.layers.layernorm import RMSNorm
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("eps", [1e-6, 1e-5]) @pytest.mark.parametrize("eps", [1e-6, 1e-5])
def test_rms_norm_batch_invariant_vs_standard( def test_rms_norm_batch_invariant_vs_standard(
batch_size: int, hidden_size: int, dtype: torch.dtype, eps: float default_vllm_config,
batch_size: int,
hidden_size: int,
dtype: torch.dtype,
eps: float,
): ):
""" """
Compare batch-invariant Triton RMS norm against standard CUDA implementation. Compare batch-invariant Triton RMS norm against standard CUDA implementation.
...@@ -68,7 +72,9 @@ def test_rms_norm_batch_invariant_vs_standard( ...@@ -68,7 +72,9 @@ def test_rms_norm_batch_invariant_vs_standard(
@pytest.mark.parametrize("batch_size", [1, 16, 128]) @pytest.mark.parametrize("batch_size", [1, 16, 128])
@pytest.mark.parametrize("seq_len", [1, 32, 512]) @pytest.mark.parametrize("seq_len", [1, 32, 512])
@pytest.mark.parametrize("hidden_size", [2048, 4096]) @pytest.mark.parametrize("hidden_size", [2048, 4096])
def test_rms_norm_3d_input(batch_size: int, seq_len: int, hidden_size: int): def test_rms_norm_3d_input(
default_vllm_config, batch_size: int, seq_len: int, hidden_size: int
):
""" """
Test RMS norm with 3D input tensors (batch, seq_len, hidden_size). Test RMS norm with 3D input tensors (batch, seq_len, hidden_size).
...@@ -107,7 +113,7 @@ def test_rms_norm_3d_input(batch_size: int, seq_len: int, hidden_size: int): ...@@ -107,7 +113,7 @@ def test_rms_norm_3d_input(batch_size: int, seq_len: int, hidden_size: int):
@skip_unsupported @skip_unsupported
def test_rms_norm_numerical_stability(): def test_rms_norm_numerical_stability(default_vllm_config):
""" """
Test RMS norm numerical stability with extreme values. Test RMS norm numerical stability with extreme values.
...@@ -167,7 +173,7 @@ def test_rms_norm_numerical_stability(): ...@@ -167,7 +173,7 @@ def test_rms_norm_numerical_stability():
@skip_unsupported @skip_unsupported
def test_rms_norm_formula(): def test_rms_norm_formula(default_vllm_config):
""" """
Test that RMS norm follows the correct mathematical formula. Test that RMS norm follows the correct mathematical formula.
...@@ -201,7 +207,7 @@ def test_rms_norm_formula(): ...@@ -201,7 +207,7 @@ def test_rms_norm_formula():
@skip_unsupported @skip_unsupported
@pytest.mark.parametrize("hidden_size", [128, 1024, 4096, 16384]) @pytest.mark.parametrize("hidden_size", [128, 1024, 4096, 16384])
def test_rms_norm_different_hidden_sizes(hidden_size: int): def test_rms_norm_different_hidden_sizes(default_vllm_config, hidden_size: int):
""" """
Test RMS norm with various hidden sizes to ensure block size handling. Test RMS norm with various hidden sizes to ensure block size handling.
...@@ -238,7 +244,7 @@ def test_rms_norm_different_hidden_sizes(hidden_size: int): ...@@ -238,7 +244,7 @@ def test_rms_norm_different_hidden_sizes(hidden_size: int):
@skip_unsupported @skip_unsupported
def test_rms_norm_determinism(): def test_rms_norm_determinism(default_vllm_config):
""" """
Test that batch-invariant RMS norm produces deterministic results. Test that batch-invariant RMS norm produces deterministic results.
......
...@@ -299,6 +299,7 @@ def test_prompt_less_than_block_size(): ...@@ -299,6 +299,7 @@ def test_prompt_less_than_block_size():
) )
def test_kv_transfer_handshake(dist_init): def test_kv_transfer_handshake(dist_init):
"""Unit test for basic NixlConnector interface functionality.""" """Unit test for basic NixlConnector interface functionality."""
from vllm.config import set_current_vllm_config
# Test setup, we creates a scheduler that contains a NixlConnector # Test setup, we creates a scheduler that contains a NixlConnector
# of role SCHEDULER, and expect it to be serving NixlAgentMetadata from # of role SCHEDULER, and expect it to be serving NixlAgentMetadata from
...@@ -308,81 +309,82 @@ def test_kv_transfer_handshake(dist_init): ...@@ -308,81 +309,82 @@ def test_kv_transfer_handshake(dist_init):
vllm_config.kv_transfer_config.kv_buffer_device = "cpu" vllm_config.kv_transfer_config.kv_buffer_device = "cpu"
scheduler = create_scheduler(vllm_config) scheduler = create_scheduler(vllm_config)
# Create two NixlConnector of role WORKER, one is the worker of with set_current_vllm_config(vllm_config):
# the scheduler (prefill), the other is a worker of decode instance. # Create two NixlConnector of role WORKER, one is the worker of
# the scheduler (prefill), the other is a worker of decode instance.
# Prefill connector will register KV cache to populate proper handshake # Prefill connector will register KV cache to populate proper handshake
# metadata. # metadata.
prefill_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) prefill_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape( kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
"layer2": shared_tensor,
}
prefill_connector.register_kv_caches(kv_caches)
# Simulate EngineCore initialization that would gather connector
# metadata from all workers
metadata = prefill_connector.get_handshake_metadata()
# metadata is a NixlHandshakePayload, decode it to get NixlAgentMetadata
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
expected_agent_metadata = decoder.decode(metadata.agent_metadata_bytes)
# The scheduler connector expects metadata to be in
# dict[int, KVConnectorHandshakeMetadata], where the first key is
# the dp_rank, the second key is the tp_rank.
scheduler_connector = scheduler.get_kv_connector()
scheduler_connector.set_xfer_handshake_metadata({0: metadata})
# Simulate a request that finishes prefill, which returns
# corresponding NixlConnectorMetadata for decode instance.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request = create_request(
request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_decode=True,
)
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
delay, kv_connector_metadata = scheduler.get_kv_connector().request_finished(
request, [0, 1, 2]
)
assert delay
# Decode connector will be able to create handshake with the prefill connector.
decode_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
# Here we are testing the retrieval of NIXLAgentMetadata.
# Knowing the implementation detail, we override the add_remote_agent
# to validate the metadata received is the same as the one in prefill_connector.
with patch.object(
decode_connector.connector_worker, "add_remote_agent"
) as mock_add_remote_agent:
mock_add_remote_agent.return_type = "remote_agent"
decode_connector.connector_worker._nixl_handshake(
kv_connector_metadata["remote_host"],
kv_connector_metadata["remote_port"],
kv_connector_metadata["tp_size"],
kv_connector_metadata["remote_engine_id"],
) )
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
"layer2": shared_tensor,
}
prefill_connector.register_kv_caches(kv_caches)
# Simulate EngineCore initialization that would gather connector
# metadata from all workers
metadata = prefill_connector.get_handshake_metadata()
# metadata is a NixlHandshakePayload, decode it to get NixlAgentMetadata
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
expected_agent_metadata = decoder.decode(metadata.agent_metadata_bytes)
# The scheduler connector expects metadata to be in
# dict[int, KVConnectorHandshakeMetadata], where the first key is
# the dp_rank, the second key is the tp_rank.
scheduler_connector = scheduler.get_kv_connector()
scheduler_connector.set_xfer_handshake_metadata({0: metadata})
# Simulate a request that finishes prefill, which returns
# corresponding NixlConnectorMetadata for decode instance.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request = create_request(
request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_decode=True,
)
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
delay, kv_connector_metadata = scheduler.get_kv_connector().request_finished(
request, [0, 1, 2]
)
assert delay
# Decode connector will be able to create handshake with the prefill connector.
decode_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
# Here we are testing the retrieval of NIXLAgentMetadata.
# Knowing the implementation detail, we override the add_remote_agent
# to validate the metadata received is the same as the one in prefill_connector.
with patch.object(
decode_connector.connector_worker, "add_remote_agent"
) as mock_add_remote_agent:
mock_add_remote_agent.return_type = "remote_agent"
decode_connector.connector_worker._nixl_handshake(
kv_connector_metadata["remote_host"],
kv_connector_metadata["remote_port"],
kv_connector_metadata["tp_size"],
kv_connector_metadata["remote_engine_id"],
)
received_metadata = mock_add_remote_agent.call_args.args received_metadata = mock_add_remote_agent.call_args.args
assert received_metadata[0] == expected_agent_metadata assert received_metadata[0] == expected_agent_metadata
assert received_metadata[1] == 0 # remote_tp_rank assert received_metadata[1] == 0 # remote_tp_rank
assert received_metadata[2] == 1 # remote_tp_size assert received_metadata[2] == 1 # remote_tp_size
# Need to shutdown the background thread to release NIXL side channel port # Need to shutdown the background thread to release NIXL side channel port
scheduler_connector.shutdown() scheduler_connector.shutdown()
class FakeNixlConnectorWorker(NixlConnectorWorker): class FakeNixlConnectorWorker(NixlConnectorWorker):
...@@ -458,6 +460,7 @@ class TestNixlHandshake: ...@@ -458,6 +460,7 @@ class TestNixlHandshake:
) )
def test_multi_xfer_one_engine( def test_multi_xfer_one_engine(
self, self,
default_vllm_config,
# dist_init is a fixture that initializes the distributed environment. # dist_init is a fixture that initializes the distributed environment.
dist_init, dist_init,
): ):
...@@ -547,6 +550,7 @@ class TestNixlHandshake: ...@@ -547,6 +550,7 @@ class TestNixlHandshake:
) )
def test_async_load_kv( def test_async_load_kv(
self, self,
default_vllm_config,
# Fixture that initializes the distributed environment. # Fixture that initializes the distributed environment.
dist_init, dist_init,
# Simulate consumer-producer TP sizes. # Simulate consumer-producer TP sizes.
...@@ -605,7 +609,7 @@ class TestNixlHandshake: ...@@ -605,7 +609,7 @@ class TestNixlHandshake:
) )
@pytest.mark.parametrize("local_tp_size", [1, 2]) @pytest.mark.parametrize("local_tp_size", [1, 2])
def test_prefill_tp_size_greater_than_decode_tp_size( def test_prefill_tp_size_greater_than_decode_tp_size(
self, local_tp_size: int, dist_init self, local_tp_size: int, default_vllm_config, dist_init
): ):
""" """
Verify remote TP > local TP handshake succeeds with different Verify remote TP > local TP handshake succeeds with different
...@@ -670,7 +674,7 @@ class TestNixlHandshake: ...@@ -670,7 +674,7 @@ class TestNixlHandshake:
) )
@pytest.mark.parametrize("local_tp_size", [1, 2]) @pytest.mark.parametrize("local_tp_size", [1, 2])
def test_prefill_tp_size_greater_than_decode_tp_size_mla( def test_prefill_tp_size_greater_than_decode_tp_size_mla(
self, local_tp_size: int, dist_init self, local_tp_size: int, default_vllm_config, dist_init
): ):
""" """
Verify remote TP > local TP handshake succeeds with different Verify remote TP > local TP handshake succeeds with different
...@@ -770,6 +774,7 @@ class TestNixlHandshake: ...@@ -770,6 +774,7 @@ class TestNixlHandshake:
) )
def test_concurrent_load_kv( def test_concurrent_load_kv(
self, self,
default_vllm_config,
# dist_init is a fixture that initializes the distributed environment. # dist_init is a fixture that initializes the distributed environment.
dist_init, dist_init,
): ):
...@@ -830,7 +835,9 @@ class TestNixlHandshake: ...@@ -830,7 +835,9 @@ class TestNixlHandshake:
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper, FakeNixlWrapper,
) )
def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init): def test_handshake_fails_on_kv_cache_layout_mismatch(
self, default_vllm_config, dist_init
):
""" """
Verify that adding a remote agent fails if kv_cache_layout differs. Verify that adding a remote agent fails if kv_cache_layout differs.
This test is only relevant for heterogeneous TP. This test is only relevant for heterogeneous TP.
...@@ -879,7 +886,7 @@ class TestNixlHandshake: ...@@ -879,7 +886,7 @@ class TestNixlHandshake:
FakeNixlWrapper, FakeNixlWrapper,
) )
def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental( def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental(
self, dist_init self, default_vllm_config, dist_init
): ):
""" """
Verify that adding a remote agent fails if kv_cache_layout differs. Verify that adding a remote agent fails if kv_cache_layout differs.
...@@ -934,7 +941,7 @@ class TestNixlHandshake: ...@@ -934,7 +941,7 @@ class TestNixlHandshake:
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper, FakeNixlWrapper,
) )
def test_kv_connector_stats(dist_init): def test_kv_connector_stats(default_vllm_config, dist_init):
"""Test that KV transfer stats are properly recorded and retrieved.""" """Test that KV transfer stats are properly recorded and retrieved."""
vllm_config = create_vllm_config() vllm_config = create_vllm_config()
...@@ -1357,7 +1364,7 @@ def _run_abort_timeout_test(llm: LLM, timeout: int): ...@@ -1357,7 +1364,7 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
"TRITON_ATTN", "TRITON_ATTN",
], ],
) )
def test_register_kv_caches(dist_init, attn_backend): def test_register_kv_caches(default_vllm_config, dist_init, attn_backend):
""" """
Test that register_kv_caches() properly calls nixl_wrapper methods with Test that register_kv_caches() properly calls nixl_wrapper methods with
correct data. correct data.
...@@ -1518,7 +1525,9 @@ class FakePlatform(Platform): ...@@ -1518,7 +1525,9 @@ class FakePlatform(Platform):
("oot", "VRAM"), ("oot", "VRAM"),
], ],
) )
def test_kv_buffer_to_nixl_memory_types(dist_init, kv_buffer_device, nixl_memory_type): def test_kv_buffer_to_nixl_memory_types(
default_vllm_config, dist_init, kv_buffer_device, nixl_memory_type
):
""" """
Test that register_kv_caches() passes the correct memory types from the Test that register_kv_caches() passes the correct memory types from the
config to the nixl_wrapper. config to the nixl_wrapper.
...@@ -1563,7 +1572,7 @@ def test_kv_buffer_to_nixl_memory_types(dist_init, kv_buffer_device, nixl_memory ...@@ -1563,7 +1572,7 @@ def test_kv_buffer_to_nixl_memory_types(dist_init, kv_buffer_device, nixl_memory
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper, FakeNixlWrapper,
) )
def test_shutdown_cleans_up_resources(dist_init): def test_shutdown_cleans_up_resources(default_vllm_config, dist_init):
"""Test that shutdown() properly cleans up all resources.""" """Test that shutdown() properly cleans up all resources."""
vllm_config = create_vllm_config() vllm_config = create_vllm_config()
...@@ -1622,7 +1631,7 @@ def test_shutdown_cleans_up_resources(dist_init): ...@@ -1622,7 +1631,7 @@ def test_shutdown_cleans_up_resources(dist_init):
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper, FakeNixlWrapper,
) )
def test_aborted_request_removed_from_worker_in_batch(dist_init): def test_aborted_request_removed_from_worker_in_batch(default_vllm_config, dist_init):
""" """
Create and schedule a request so that P adds it to in-batch tracking via Create and schedule a request so that P adds it to in-batch tracking via
the real scheduler, then simulate an abort (request not in next scheduler the real scheduler, then simulate an abort (request not in next scheduler
...@@ -1731,7 +1740,7 @@ class FailingNixlWrapper(FakeNixlWrapper): ...@@ -1731,7 +1740,7 @@ class FailingNixlWrapper(FakeNixlWrapper):
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FailingNixlWrapper, FailingNixlWrapper,
) )
def test_handshake_failure_returns_finished(dist_init): def test_handshake_failure_returns_finished(default_vllm_config, dist_init):
"""Test that handshake failures mark blocks invalid and return via get_finished.""" """Test that handshake failures mark blocks invalid and return via get_finished."""
vllm_config = create_vllm_config() vllm_config = create_vllm_config()
...@@ -1780,7 +1789,7 @@ def test_handshake_failure_returns_finished(dist_init): ...@@ -1780,7 +1789,7 @@ def test_handshake_failure_returns_finished(dist_init):
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FailingNixlWrapper, FailingNixlWrapper,
) )
def test_transfer_setup_failure_returns_finished(dist_init): def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init):
"""Test that transfer setup failures mark blocks invalid """Test that transfer setup failures mark blocks invalid
and return via get_finished.""" and return via get_finished."""
vllm_config = create_vllm_config() vllm_config = create_vllm_config()
...@@ -1855,6 +1864,7 @@ def test_transfer_setup_failure_returns_finished(dist_init): ...@@ -1855,6 +1864,7 @@ def test_transfer_setup_failure_returns_finished(dist_init):
FakeNixlWrapper, FakeNixlWrapper,
) )
def test_compatibility_hash_validation( def test_compatibility_hash_validation(
default_vllm_config,
dist_init, dist_init,
mismatch_type, mismatch_type,
config_overrides, config_overrides,
...@@ -1967,7 +1977,7 @@ def test_compatibility_hash_validation( ...@@ -1967,7 +1977,7 @@ def test_compatibility_hash_validation(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper, FakeNixlWrapper,
) )
def test_handshake_decode_errors(dist_init, error_scenario): def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario):
""" """
Test that msgspec decode errors are properly handled during handshake. Test that msgspec decode errors are properly handled during handshake.
......
...@@ -50,6 +50,7 @@ NUM_MAPPINGS = [3] ...@@ -50,6 +50,7 @@ NUM_MAPPINGS = [3]
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode() @torch.inference_mode()
def test_transfer( def test_transfer(
default_vllm_config,
gpu_to_cpu: bool, gpu_to_cpu: bool,
num_mappings: int, num_mappings: int,
head_size: int, head_size: int,
......
...@@ -112,15 +112,16 @@ def get_vllm_config(): ...@@ -112,15 +112,16 @@ def get_vllm_config():
@pytest.fixture @pytest.fixture
def model_runner(): def model_runner():
vllm_config = get_vllm_config() vllm_config = get_vllm_config()
model_config = vllm_config.model_config with set_current_vllm_config(vllm_config):
num_heads = model_config.get_num_kv_heads(vllm_config.parallel_config) model_config = vllm_config.model_config
head_size = model_config.get_head_size() num_heads = model_config.get_num_kv_heads(vllm_config.parallel_config)
vllm_config.compilation_config.static_forward_context["layer.0"] = Attention( head_size = model_config.get_head_size()
num_heads, head_size, 0.1 vllm_config.compilation_config.static_forward_context["layer.0"] = Attention(
) num_heads, head_size, 0.1
runner = GPUModelRunner(vllm_config, DEVICE) )
initialize_kv_cache(runner) runner = GPUModelRunner(vllm_config, DEVICE)
return runner initialize_kv_cache(runner)
yield runner
model_runner_2 = model_runner model_runner_2 = model_runner
...@@ -546,7 +547,7 @@ def test_reload_weights_before_load_model(model_runner): ...@@ -546,7 +547,7 @@ def test_reload_weights_before_load_model(model_runner):
model_runner.reload_weights() model_runner.reload_weights()
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(): def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(default_vllm_config):
torch.set_default_dtype(torch.float16) torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn" layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn" layer_1 = "model.layers.1.self_attn.attn"
...@@ -573,7 +574,7 @@ def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(): ...@@ -573,7 +574,7 @@ def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
assert fwd_context is not None assert fwd_context is not None
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(): def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(default_vllm_config):
torch.set_default_dtype(torch.float16) torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn" layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn" layer_1 = "model.layers.1.self_attn.attn"
...@@ -600,7 +601,7 @@ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(): ...@@ -600,7 +601,7 @@ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
assert fwd_context is not None assert fwd_context is not None
def test_init_kv_cache_with_kv_sharing_target_same_as_current(): def test_init_kv_cache_with_kv_sharing_target_same_as_current(default_vllm_config):
torch.set_default_dtype(torch.float16) torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn" layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn" layer_1 = "model.layers.1.self_attn.attn"
...@@ -627,7 +628,7 @@ def test_init_kv_cache_with_kv_sharing_target_same_as_current(): ...@@ -627,7 +628,7 @@ def test_init_kv_cache_with_kv_sharing_target_same_as_current():
assert fwd_context is not None assert fwd_context is not None
def test_init_kv_cache_without_kv_sharing(): def test_init_kv_cache_without_kv_sharing(default_vllm_config):
torch.set_default_dtype(torch.float16) torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn" layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn" layer_1 = "model.layers.1.self_attn.attn"
...@@ -694,7 +695,7 @@ def test_init_kv_cache_without_kv_sharing(): ...@@ -694,7 +695,7 @@ def test_init_kv_cache_without_kv_sharing():
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1 assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
def test_init_kv_cache_with_kv_sharing_valid(): def test_init_kv_cache_with_kv_sharing_valid(default_vllm_config):
torch.set_default_dtype(torch.float16) torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn" layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn" layer_1 = "model.layers.1.self_attn.attn"
...@@ -1047,7 +1048,7 @@ def test_input_batch_with_kernel_block_sizes(): ...@@ -1047,7 +1048,7 @@ def test_input_batch_with_kernel_block_sizes():
assert block_table.block_size == kernel_size assert block_table.block_size == kernel_size
def test_hybrid_cache_integration(model_runner, dist_init): def test_hybrid_cache_integration(default_vllm_config, dist_init):
"""Test hybrid cache architecture integration with GPUModelRunner.""" """Test hybrid cache architecture integration with GPUModelRunner."""
# Create a new model runner with hybrid cache configuration # Create a new model runner with hybrid cache configuration
vllm_config = get_vllm_config() vllm_config = get_vllm_config()
......
...@@ -6,14 +6,14 @@ import torch ...@@ -6,14 +6,14 @@ import torch
from vllm.v1.worker.utils import bind_kv_cache from vllm.v1.worker.utils import bind_kv_cache
def test_bind_kv_cache(): def test_bind_kv_cache(default_vllm_config):
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
ctx = { ctx = {
"layers.0.self_attn": Attention(32, 128, 0.1), "layers.0.self_attn": Attention(32, 128, 0.1, prefix="layers.0.self_attn"),
"layers.1.self_attn": Attention(32, 128, 0.1), "layers.1.self_attn": Attention(32, 128, 0.1, prefix="layers.1.self_attn"),
"layers.2.self_attn": Attention(32, 128, 0.1), "layers.2.self_attn": Attention(32, 128, 0.1, prefix="layers.2.self_attn"),
"layers.3.self_attn": Attention(32, 128, 0.1), "layers.3.self_attn": Attention(32, 128, 0.1, prefix="layers.3.self_attn"),
} }
kv_cache = { kv_cache = {
"layers.0.self_attn": torch.zeros((1,)), "layers.0.self_attn": torch.zeros((1,)),
...@@ -34,13 +34,13 @@ def test_bind_kv_cache(): ...@@ -34,13 +34,13 @@ def test_bind_kv_cache():
assert runner_kv_caches[3] is kv_cache["layers.3.self_attn"] assert runner_kv_caches[3] is kv_cache["layers.3.self_attn"]
def test_bind_kv_cache_non_attention(): def test_bind_kv_cache_non_attention(default_vllm_config):
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
# example from Jamba PP=2 # example from Jamba PP=2
ctx = { ctx = {
"model.layers.20.attn": Attention(32, 128, 0.1), "model.layers.20.attn": Attention(32, 128, 0.1, prefix="model.layers.20.attn"),
"model.layers.28.attn": Attention(32, 128, 0.1), "model.layers.28.attn": Attention(32, 128, 0.1, prefix="model.layers.28.attn"),
} }
kv_cache = { kv_cache = {
"model.layers.20.attn": torch.zeros((1,)), "model.layers.20.attn": torch.zeros((1,)),
......
...@@ -59,10 +59,13 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None: ...@@ -59,10 +59,13 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
) )
# 2. override if passed by environment or config # 2. override if passed by environment or config
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config_or_none
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config_or_none()
if vllm_config.attention_config.flash_attn_version is not None: if (
vllm_config is not None
and vllm_config.attention_config.flash_attn_version is not None
):
fa_version = vllm_config.attention_config.flash_attn_version fa_version = vllm_config.attention_config.flash_attn_version
# 3. fallback for unsupported combinations # 3. fallback for unsupported combinations
......
...@@ -42,6 +42,7 @@ from vllm.config.vllm import ( ...@@ -42,6 +42,7 @@ from vllm.config.vllm import (
VllmConfig, VllmConfig,
get_cached_compilation_config, get_cached_compilation_config,
get_current_vllm_config, get_current_vllm_config,
get_current_vllm_config_or_none,
get_layers_from_vllm_config, get_layers_from_vllm_config,
set_current_vllm_config, set_current_vllm_config,
) )
...@@ -105,6 +106,7 @@ __all__ = [ ...@@ -105,6 +106,7 @@ __all__ = [
"VllmConfig", "VllmConfig",
"get_cached_compilation_config", "get_cached_compilation_config",
"get_current_vllm_config", "get_current_vllm_config",
"get_current_vllm_config_or_none",
"set_current_vllm_config", "set_current_vllm_config",
"get_layers_from_vllm_config", "get_layers_from_vllm_config",
] ]
...@@ -1441,13 +1441,18 @@ def get_cached_compilation_config(): ...@@ -1441,13 +1441,18 @@ def get_cached_compilation_config():
def get_current_vllm_config() -> VllmConfig: def get_current_vllm_config() -> VllmConfig:
if _current_vllm_config is None: if _current_vllm_config is None:
# in ci, usually when we test custom ops/modules directly, raise AssertionError(
# we don't set the vllm config. In that case, we set a default "Current vLLM config is not set. This typically means "
# config. "get_current_vllm_config() was called outside of a "
# Use stack level 2 so the log contains the line of the caller, "set_current_vllm_config() context, or a CustomOp was instantiated "
# so it's easier to track down the source of the warning. "at module import time or model forward time when config is not set. "
logger.warning("Current vLLM config is not set.", stacklevel=2) "For tests that directly test custom ops/modules, use the "
return VllmConfig() "'default_vllm_config' pytest fixture from tests/conftest.py."
)
return _current_vllm_config
def get_current_vllm_config_or_none() -> VllmConfig | None:
return _current_vllm_config return _current_vllm_config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment