Unverified Commit 32e0c0bf authored by wliao2's avatar wliao2 Committed by GitHub
Browse files

refactor hard coded device string in test files under tests/v1 and tests/lora (#37566)


Signed-off-by: default avatarLiao, Wei <wei.liao@intel.com>
parent 4a06e124
...@@ -7,8 +7,7 @@ from torch import Generator ...@@ -7,8 +7,7 @@ from torch import Generator
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_pytorch from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_pytorch
CUDA_DEVICE = "cuda" if current_platform.is_cuda() else None DEVICE_TYPE = current_platform.device_type
DEVICE = current_platform.device_type
BATCH_SIZE = 1024 BATCH_SIZE = 1024
VOCAB_SIZE = 128 * 1024 VOCAB_SIZE = 128 * 1024
...@@ -26,8 +25,8 @@ def reset_default_device(): ...@@ -26,8 +25,8 @@ def reset_default_device():
def test_topk_impl_equivalence(): def test_topk_impl_equivalence():
torch.set_default_device(DEVICE) torch.set_default_device(DEVICE_TYPE)
generator = Generator(device=DEVICE).manual_seed(33) generator = Generator(device=DEVICE_TYPE).manual_seed(33)
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator) logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
...@@ -76,8 +75,8 @@ def test_flashinfer_sampler(): ...@@ -76,8 +75,8 @@ def test_flashinfer_sampler():
if not FLASHINFER_ENABLED: if not FLASHINFER_ENABLED:
pytest.skip("FlashInfer not installed or not available on this platform.") pytest.skip("FlashInfer not installed or not available on this platform.")
torch.set_default_device(DEVICE) torch.set_default_device(DEVICE_TYPE)
generator = Generator(device=DEVICE).manual_seed(42) generator = Generator(device=DEVICE_TYPE).manual_seed(42)
# Generate random logits # Generate random logits
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator) logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
...@@ -128,15 +127,15 @@ def test_flashinfer_sampler(): ...@@ -128,15 +127,15 @@ def test_flashinfer_sampler():
# ============================================================================= # =============================================================================
@pytest.mark.skipif(CUDA_DEVICE is None, reason="CUDA not available") @pytest.mark.skipif("CPU" in DEVICE_TYPE, reason="CUDA/XPU not available")
class TestTritonTopkTopp: class TestTritonTopkTopp:
"""Tests for the Triton top-k/top-p kernel.""" """Tests for the Triton top-k/top-p kernel."""
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def setup(self): def setup(self):
"""Set up test fixtures.""" """Set up test fixtures."""
torch.set_default_device(CUDA_DEVICE) torch.set_default_device(DEVICE_TYPE)
self.generator = Generator(device=CUDA_DEVICE).manual_seed(42) self.generator = Generator(device=DEVICE_TYPE).manual_seed(42)
def _compare_results( def _compare_results(
self, self,
......
...@@ -42,6 +42,7 @@ dflash_target_dir = "Qwen/Qwen3-8B" ...@@ -42,6 +42,7 @@ dflash_target_dir = "Qwen/Qwen3-8B"
dflash_dir = "z-lab/Qwen3-8B-DFlash-b16" dflash_dir = "z-lab/Qwen3-8B-DFlash-b16"
BLOCK_SIZE = 16 BLOCK_SIZE = 16
DEVICE_TYPE = current_platform.device_type
def _create_proposer( def _create_proposer(
...@@ -92,7 +93,7 @@ def _create_proposer( ...@@ -92,7 +93,7 @@ def _create_proposer(
# Overwrite pard_token to avoid crash during init # Overwrite pard_token to avoid crash during init
speculative_config.draft_model_config.hf_config.pard_token = 0 speculative_config.draft_model_config.hf_config.pard_token = 0
device = current_platform.device_type device = DEVICE_TYPE
vllm_config = VllmConfig( vllm_config = VllmConfig(
model_config=model_config, model_config=model_config,
cache_config=CacheConfig(block_size=16), cache_config=CacheConfig(block_size=16),
...@@ -124,7 +125,7 @@ def test_prepare_next_token_ids(): ...@@ -124,7 +125,7 @@ def test_prepare_next_token_ids():
either the GPU tensor of sampled_token_ids with -1 for rejected tokens, either the GPU tensor of sampled_token_ids with -1 for rejected tokens,
or the CPU python list[list[int]] with the rejected tokens removed. or the CPU python list[list[int]] with the rejected tokens removed.
""" """
device = torch.device(current_platform.device_type) device = torch.device(DEVICE_TYPE)
num_requests = 4 num_requests = 4
num_speculative_tokens = 4 num_speculative_tokens = 4
...@@ -207,7 +208,7 @@ def test_prepare_inputs(): ...@@ -207,7 +208,7 @@ def test_prepare_inputs():
a, a + 1, ..., a + b - n2 - 1, a, a + 1, ..., a + b - n2 - 1,
a + b, a + b + 1, ..., a + b + c - n3 - 1] a + b, a + b + 1, ..., a + b + c - n3 - 1]
""" """
device = torch.device(current_platform.device_type) device = torch.device(DEVICE_TYPE)
# q1 = 4, q2 = 7, q3 = 5 # q1 = 4, q2 = 7, q3 = 5
# n1 = 1, n2 = 3, n3 = 2 # n1 = 1, n2 = 3, n3 = 2
...@@ -300,7 +301,7 @@ def test_prepare_inputs_padded(): ...@@ -300,7 +301,7 @@ def test_prepare_inputs_padded():
from the original indices to sample from. from the original indices to sample from.
""" """
device = torch.device(current_platform.device_type) device = torch.device(DEVICE_TYPE)
expected_token_indices_to_sample = torch.tensor( expected_token_indices_to_sample = torch.tensor(
[1, 5, 6], dtype=torch.int32, device=device [1, 5, 6], dtype=torch.int32, device=device
...@@ -370,7 +371,7 @@ def test_set_inputs_first_pass_default_eagle(): ...@@ -370,7 +371,7 @@ def test_set_inputs_first_pass_default_eagle():
- After inserting next_tokens [100, 200, 300]: - After inserting next_tokens [100, 200, 300]:
[a2, a3, 100, b2, 200, c2, c3, c4, 300] [a2, a3, 100, b2, 200, c2, c3, c4, 300]
""" """
device = torch.device(current_platform.device_type) device = torch.device(DEVICE_TYPE)
num_speculative_tokens = 3 num_speculative_tokens = 3
proposer = _create_proposer("eagle", num_speculative_tokens) proposer = _create_proposer("eagle", num_speculative_tokens)
...@@ -471,7 +472,7 @@ def test_set_inputs_first_pass_draft_model(): ...@@ -471,7 +472,7 @@ def test_set_inputs_first_pass_draft_model():
- idx 5: token 21, pos 1 - idx 5: token 21, pos 1
- idx 6: token 200, pos 2 (bonus token) - idx 6: token 200, pos 2 (bonus token)
""" """
device = torch.device(current_platform.device_type) device = torch.device(DEVICE_TYPE)
num_speculative_tokens = 2 num_speculative_tokens = 2
block_size = BLOCK_SIZE block_size = BLOCK_SIZE
...@@ -609,7 +610,7 @@ def test_set_inputs_first_pass_parallel_drafting(): ...@@ -609,7 +610,7 @@ def test_set_inputs_first_pass_parallel_drafting():
- idx 9: bonus token 200 - idx 9: bonus token 200
- idx 10-11: parallel_drafting_tokens, is_masked=True - idx 10-11: parallel_drafting_tokens, is_masked=True
""" """
device = torch.device(current_platform.device_type) device = torch.device(DEVICE_TYPE)
num_speculative_tokens = 3 num_speculative_tokens = 3
block_size = BLOCK_SIZE block_size = BLOCK_SIZE
...@@ -859,7 +860,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch): ...@@ -859,7 +860,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
# Use GPU device # Use GPU device
device = torch.device(current_platform.device_type) device = torch.device(DEVICE_TYPE)
# Setup test parameters # Setup test parameters
batch_size = 2 batch_size = 2
...@@ -1030,7 +1031,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch): ...@@ -1030,7 +1031,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
) )
def test_propose_tree(spec_token_tree): def test_propose_tree(spec_token_tree):
# Get GPU device. # Get GPU device.
device = torch.device(current_platform.device_type) device = torch.device(DEVICE_TYPE)
# Setup test parameters. # Setup test parameters.
batch_size = 2 batch_size = 2
......
...@@ -5,11 +5,14 @@ ...@@ -5,11 +5,14 @@
import pytest import pytest
import torch import torch
from vllm.platforms import current_platform
from vllm.v1.spec_decode.utils import ( from vllm.v1.spec_decode.utils import (
PADDING_SLOT_ID, PADDING_SLOT_ID,
eagle_step_update_slot_mapping_and_metadata, eagle_step_update_slot_mapping_and_metadata,
) )
DEVICE_TYPE = current_platform.device_type
# Skip if no CUDA - Triton kernel requires GPU # Skip if no CUDA - Triton kernel requires GPU
pytest.importorskip("triton") pytest.importorskip("triton")
if not torch.cuda.is_available(): if not torch.cuda.is_available():
...@@ -47,7 +50,7 @@ def _reference_eagle_step_slot_mapping( ...@@ -47,7 +50,7 @@ def _reference_eagle_step_slot_mapping(
def test_eagle_step_slot_mapping_kernel(): def test_eagle_step_slot_mapping_kernel():
"""Test fused kernel matches Python reference for slot mapping and metadata.""" """Test fused kernel matches Python reference for slot mapping and metadata."""
device = torch.device("cuda") device = torch.device(DEVICE_TYPE)
batch_size = 32 batch_size = 32
block_size = 16 block_size = 16
max_model_len = 4096 max_model_len = 4096
...@@ -93,7 +96,7 @@ def test_eagle_step_slot_mapping_kernel(): ...@@ -93,7 +96,7 @@ def test_eagle_step_slot_mapping_kernel():
def test_eagle_step_slot_mapping_kernel_exceeds_max(): def test_eagle_step_slot_mapping_kernel_exceeds_max():
"""Test fused kernel when position exceeds max_model_len.""" """Test fused kernel when position exceeds max_model_len."""
device = torch.device("cuda") device = torch.device(DEVICE_TYPE)
batch_size = 4 batch_size = 4
block_size = 16 block_size = 16
max_model_len = 100 max_model_len = 100
...@@ -130,7 +133,7 @@ def test_eagle_step_slot_mapping_kernel_exceeds_max(): ...@@ -130,7 +133,7 @@ def test_eagle_step_slot_mapping_kernel_exceeds_max():
def test_eagle_step_slot_mapping_kernel_cudagraph_padding(): def test_eagle_step_slot_mapping_kernel_cudagraph_padding():
"""Test that padding threads write PADDING_SLOT_ID when """Test that padding threads write PADDING_SLOT_ID when
input_batch_size > batch_size (cudagraph padding).""" input_batch_size > batch_size (cudagraph padding)."""
device = torch.device("cuda") device = torch.device(DEVICE_TYPE)
batch_size = 4 batch_size = 4
input_batch_size = 8 input_batch_size = 8
block_size = 16 block_size = 16
......
...@@ -27,6 +27,7 @@ from vllm.v1.spec_decode.extract_hidden_states import ExtractHiddenStatesPropose ...@@ -27,6 +27,7 @@ from vllm.v1.spec_decode.extract_hidden_states import ExtractHiddenStatesPropose
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
model_dir = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" model_dir = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
DEVICE_TYPE = current_platform.device_type
def _create_proposer( def _create_proposer(
...@@ -51,7 +52,7 @@ def _create_proposer( ...@@ -51,7 +52,7 @@ def _create_proposer(
}, },
) )
device = current_platform.device_type device = DEVICE_TYPE
vllm_config = VllmConfig( vllm_config = VllmConfig(
model_config=model_config, model_config=model_config,
cache_config=CacheConfig(), cache_config=CacheConfig(),
...@@ -101,7 +102,7 @@ def test_proposer_initialization_missing_layer_ids(): ...@@ -101,7 +102,7 @@ def test_proposer_initialization_missing_layer_ids():
}, },
) )
device = current_platform.device_type device = DEVICE_TYPE
vllm_config = VllmConfig( vllm_config = VllmConfig(
model_config=model_config, model_config=model_config,
cache_config=CacheConfig(), cache_config=CacheConfig(),
...@@ -130,7 +131,7 @@ def test_prepare_next_token_ids_padded(): ...@@ -130,7 +131,7 @@ def test_prepare_next_token_ids_padded():
For each request we either use the sampled token (if valid and not discarded) For each request we either use the sampled token (if valid and not discarded)
or a backup token from the request state. or a backup token from the request state.
""" """
device = torch.device(current_platform.device_type) device = torch.device(DEVICE_TYPE)
num_requests = 4 num_requests = 4
req_ids = [f"req_{i + 1}" for i in range(num_requests)] req_ids = [f"req_{i + 1}" for i in range(num_requests)]
...@@ -197,7 +198,7 @@ def test_propose(): ...@@ -197,7 +198,7 @@ def test_propose():
2. Return the sampled tokens as "draft" tokens (shape [batch_size, 1]) 2. Return the sampled tokens as "draft" tokens (shape [batch_size, 1])
3. Cache the hidden states in the model's KV cache 3. Cache the hidden states in the model's KV cache
""" """
device = torch.device(current_platform.device_type) device = torch.device(DEVICE_TYPE)
# Setup test parameters # Setup test parameters
batch_size = 2 batch_size = 2
...@@ -273,7 +274,7 @@ def test_propose(): ...@@ -273,7 +274,7 @@ def test_propose():
@pytest.mark.parametrize("num_hidden_layers", [1, 4, 8]) @pytest.mark.parametrize("num_hidden_layers", [1, 4, 8])
def test_propose_different_layer_counts(num_hidden_layers): def test_propose_different_layer_counts(num_hidden_layers):
"""Test that propose works correctly with different numbers of hidden layers.""" """Test that propose works correctly with different numbers of hidden layers."""
device = torch.device(current_platform.device_type) device = torch.device(DEVICE_TYPE)
batch_size = 2 batch_size = 2
num_tokens = 5 num_tokens = 5
......
...@@ -28,6 +28,7 @@ from vllm.v1.attention.backends.registry import AttentionBackendEnum ...@@ -28,6 +28,7 @@ from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.eagle import EagleProposer
mimo_7b_dir = "XiaomiMiMo/MiMo-7B-Base" mimo_7b_dir = "XiaomiMiMo/MiMo-7B-Base"
DEVICE_TYPE = current_platform.device_type
def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer: def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer:
...@@ -48,7 +49,7 @@ def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer: ...@@ -48,7 +49,7 @@ def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer:
model_config=model_config, model_config=model_config,
cache_config=CacheConfig(), cache_config=CacheConfig(),
speculative_config=speculative_config, speculative_config=speculative_config,
device_config=DeviceConfig(device=current_platform.device_type), device_config=DeviceConfig(device=DEVICE_TYPE),
parallel_config=ParallelConfig(), parallel_config=ParallelConfig(),
load_config=LoadConfig(), load_config=LoadConfig(),
scheduler_config=SchedulerConfig( scheduler_config=SchedulerConfig(
...@@ -57,7 +58,7 @@ def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer: ...@@ -57,7 +58,7 @@ def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer:
), ),
) )
return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type) return EagleProposer(vllm_config=vllm_config, device=DEVICE_TYPE)
@mock.patch("vllm.v1.spec_decode.eagle.get_pp_group") @mock.patch("vllm.v1.spec_decode.eagle.get_pp_group")
...@@ -118,7 +119,7 @@ def test_mtp_load_model_unified(mock_get_model, mock_get_layers, mock_get_pp_gro ...@@ -118,7 +119,7 @@ def test_mtp_load_model_unified(mock_get_model, mock_get_layers, mock_get_pp_gro
def test_mtp_propose(num_speculative_tokens, monkeypatch): def test_mtp_propose(num_speculative_tokens, monkeypatch):
"""Test that MTP's forward method returns hidden states directly""" """Test that MTP's forward method returns hidden states directly"""
device = torch.device(current_platform.device_type) device = torch.device(DEVICE_TYPE)
batch_size = 2 batch_size = 2
seq_lens = [5, 3] seq_lens = [5, 3]
total_tokens = sum(seq_lens) total_tokens = sum(seq_lens)
......
...@@ -18,6 +18,8 @@ from vllm.v1.attention.backend import CommonAttentionMetadata ...@@ -18,6 +18,8 @@ from vllm.v1.attention.backend import CommonAttentionMetadata
from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available
from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.registry import AttentionBackendEnum
DEVICE_TYPE = current_platform.device_type
if not is_flash_attn_varlen_func_available(): if not is_flash_attn_varlen_func_available():
pytest.skip( pytest.skip(
"This test requires flash_attn_varlen_func, but it's not available.", "This test requires flash_attn_varlen_func, but it's not available.",
...@@ -170,9 +172,9 @@ def _get_available_reference_backends() -> list[AttentionBackendEnum]: ...@@ -170,9 +172,9 @@ def _get_available_reference_backends() -> list[AttentionBackendEnum]:
class MockAttentionLayer(torch.nn.Module): class MockAttentionLayer(torch.nn.Module):
_q_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda") _q_scale = torch.tensor(1.0, dtype=torch.float32, device=DEVICE_TYPE)
_k_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda") _k_scale = torch.tensor(1.0, dtype=torch.float32, device=DEVICE_TYPE)
_v_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda") _v_scale = torch.tensor(1.0, dtype=torch.float32, device=DEVICE_TYPE)
layer_name = "mock_layer" layer_name = "mock_layer"
def __init__(self): def __init__(self):
......
...@@ -22,10 +22,8 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch ...@@ -22,10 +22,8 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
VOCAB_SIZE = 1024 VOCAB_SIZE = 1024
NUM_OUTPUT_TOKENS = 20 NUM_OUTPUT_TOKENS = 20
MAX_PROMPT_SIZE = 100 MAX_PROMPT_SIZE = 100
CUDA_DEVICES = [ DEVICE_TYPE = current_platform.device_type
f"{current_platform.device_type}:{i}" DEVICES = [f"{DEVICE_TYPE}:{i}" for i in range(min(current_platform.device_count(), 2))]
for i in range(min(current_platform.device_count(), 2))
]
MAX_NUM_PROMPT_TOKENS = 64 MAX_NUM_PROMPT_TOKENS = 64
...@@ -219,7 +217,7 @@ def _construct_cached_request_state(req_id_suffix: int): ...@@ -219,7 +217,7 @@ def _construct_cached_request_state(req_id_suffix: int):
) )
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32, 64]) @pytest.mark.parametrize("batch_size", [1, 2, 32, 64])
def test_sampling_metadata_in_input_batch(device: str, batch_size: int): def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
""" """
...@@ -313,7 +311,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): ...@@ -313,7 +311,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
) )
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("batch_size", [32]) @pytest.mark.parametrize("batch_size", [32])
@pytest.mark.parametrize("swap_list", [((0, 1),)]) @pytest.mark.parametrize("swap_list", [((0, 1),)])
def test_swap_states_in_input_batch(device: str, batch_size: int, swap_list: list): def test_swap_states_in_input_batch(device: str, batch_size: int, swap_list: list):
...@@ -400,7 +398,7 @@ def _construct_pooling_request(req_id_suffix: int, pooling_params=None): ...@@ -400,7 +398,7 @@ def _construct_pooling_request(req_id_suffix: int, pooling_params=None):
) )
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", DEVICES)
def test_pooling_prompt_lens_not_aliased(device: str): def test_pooling_prompt_lens_not_aliased(device: str):
"""Verify that prompt_lens in PoolingMetadata does not share memory """Verify that prompt_lens in PoolingMetadata does not share memory
with the internal num_prompt_tokens pinned buffer. Guards against possible with the internal num_prompt_tokens pinned buffer. Guards against possible
......
...@@ -45,7 +45,7 @@ from vllm.v1.worker.utils import AttentionGroup, select_common_block_size ...@@ -45,7 +45,7 @@ from vllm.v1.worker.utils import AttentionGroup, select_common_block_size
BLOCK_SIZE = 16 BLOCK_SIZE = 16
NUM_BLOCKS = 10 NUM_BLOCKS = 10
DEVICE = current_platform.device_type DEVICE_TYPE = current_platform.device_type
def initialize_kv_cache(runner: GPUModelRunner): def initialize_kv_cache(runner: GPUModelRunner):
...@@ -121,7 +121,7 @@ def model_runner(): ...@@ -121,7 +121,7 @@ def model_runner():
vllm_config.compilation_config.static_forward_context["layer.0"] = Attention( vllm_config.compilation_config.static_forward_context["layer.0"] = Attention(
num_heads, head_size, 0.1 num_heads, head_size, 0.1
) )
runner = GPUModelRunner(vllm_config, DEVICE) runner = GPUModelRunner(vllm_config, DEVICE_TYPE)
initialize_kv_cache(runner) initialize_kv_cache(runner)
yield runner yield runner
...@@ -340,7 +340,7 @@ def test_get_nans_in_logits(model_runner, dist_init): ...@@ -340,7 +340,7 @@ def test_get_nans_in_logits(model_runner, dist_init):
[1.0, 2.0, 3.0], [1.0, 2.0, 3.0],
[3.0, 2.0, 1.0], [3.0, 2.0, 1.0],
], ],
device=DEVICE, device=DEVICE_TYPE,
) )
result = model_runner._get_nans_in_logits(logits) result = model_runner._get_nans_in_logits(logits)
assert result == {"req_0": 0, "req_1": 0} assert result == {"req_0": 0, "req_1": 0}
...@@ -350,7 +350,7 @@ def test_get_nans_in_logits(model_runner, dist_init): ...@@ -350,7 +350,7 @@ def test_get_nans_in_logits(model_runner, dist_init):
[1.0, float("nan"), 3.0], [1.0, float("nan"), 3.0],
[4.0, float("nan"), float("nan")], [4.0, float("nan"), float("nan")],
], ],
device=DEVICE, device=DEVICE_TYPE,
) )
result = model_runner._get_nans_in_logits(logits) result = model_runner._get_nans_in_logits(logits)
assert result == {"req_0": 1, "req_1": 2} assert result == {"req_0": 1, "req_1": 2}
...@@ -360,7 +360,7 @@ def test_get_nans_in_logits(model_runner, dist_init): ...@@ -360,7 +360,7 @@ def test_get_nans_in_logits(model_runner, dist_init):
[1.0, 2.0, 3.0], [1.0, 2.0, 3.0],
[4.0, float("nan"), float("nan")], [4.0, float("nan"), float("nan")],
], ],
device=DEVICE, device=DEVICE_TYPE,
) )
result = model_runner._get_nans_in_logits(logits) result = model_runner._get_nans_in_logits(logits)
assert result == {"req_0": 0, "req_1": 2} assert result == {"req_0": 0, "req_1": 2}
...@@ -372,7 +372,7 @@ def test_get_nans_in_logits(model_runner, dist_init): ...@@ -372,7 +372,7 @@ def test_get_nans_in_logits(model_runner, dist_init):
[ [
[1.0, float("nan"), 3.0], [1.0, float("nan"), 3.0],
], ],
device=DEVICE, device=DEVICE_TYPE,
) )
result = model_runner._get_nans_in_logits(logits) result = model_runner._get_nans_in_logits(logits)
assert result == {"req_0": 1, "req_1": 0} assert result == {"req_0": 1, "req_1": 0}
...@@ -383,7 +383,7 @@ def test_get_nans_in_logits(model_runner, dist_init): ...@@ -383,7 +383,7 @@ def test_get_nans_in_logits(model_runner, dist_init):
[1.0, 2.0, 3.0], [1.0, 2.0, 3.0],
[float("nan"), 2.0, 3.0], [float("nan"), 2.0, 3.0],
], ],
device=DEVICE, device=DEVICE_TYPE,
) )
result = model_runner._get_nans_in_logits(logits) result = model_runner._get_nans_in_logits(logits)
assert result == {"req_0": 2, "req_1": 0} assert result == {"req_0": 2, "req_1": 0}
...@@ -643,7 +643,7 @@ def test_init_kv_cache_without_kv_sharing(default_vllm_config): ...@@ -643,7 +643,7 @@ def test_init_kv_cache_without_kv_sharing(default_vllm_config):
# Set high context length to test max context length estimation # Set high context length to test max context length estimation
vllm_config.model_config.max_model_len = 3_000_000 vllm_config.model_config.max_model_len = 3_000_000
vllm_ctx = vllm_config.compilation_config.static_forward_context vllm_ctx = vllm_config.compilation_config.static_forward_context
runner = GPUModelRunner(vllm_config, DEVICE) runner = GPUModelRunner(vllm_config, DEVICE_TYPE)
kv_cache_spec = runner.get_kv_cache_spec() kv_cache_spec = runner.get_kv_cache_spec()
assert len(kv_cache_spec) == 2 assert len(kv_cache_spec) == 2
assert len(runner.shared_kv_cache_layers) == 0 assert len(runner.shared_kv_cache_layers) == 0
...@@ -711,7 +711,7 @@ def test_init_kv_cache_with_kv_sharing_valid(default_vllm_config): ...@@ -711,7 +711,7 @@ def test_init_kv_cache_with_kv_sharing_valid(default_vllm_config):
# Set high context length to test max context length estimation # Set high context length to test max context length estimation
vllm_config.model_config.max_model_len = 3_000_000 vllm_config.model_config.max_model_len = 3_000_000
vllm_ctx = vllm_config.compilation_config.static_forward_context vllm_ctx = vllm_config.compilation_config.static_forward_context
runner = GPUModelRunner(vllm_config, DEVICE) runner = GPUModelRunner(vllm_config, DEVICE_TYPE)
kv_cache_spec = runner.get_kv_cache_spec() kv_cache_spec = runner.get_kv_cache_spec()
assert len(kv_cache_spec) == 1 assert len(kv_cache_spec) == 1
assert layer_0 in kv_cache_spec assert layer_0 in kv_cache_spec
...@@ -850,7 +850,7 @@ def test_hybrid_attention_mamba_tensor_shapes(): ...@@ -850,7 +850,7 @@ def test_hybrid_attention_mamba_tensor_shapes():
assert fwd_context is not None assert fwd_context is not None
vllm_ctx = vllm_config.compilation_config.static_forward_context vllm_ctx = vllm_config.compilation_config.static_forward_context
runner = GPUModelRunner(vllm_config, DEVICE) runner = GPUModelRunner(vllm_config, DEVICE_TYPE)
current_platform.update_block_size_for_backend(vllm_config) current_platform.update_block_size_for_backend(vllm_config)
kv_cache_spec = runner.get_kv_cache_spec() kv_cache_spec = runner.get_kv_cache_spec()
...@@ -896,13 +896,13 @@ def test_hybrid_attention_mamba_tensor_shapes(): ...@@ -896,13 +896,13 @@ def test_hybrid_attention_mamba_tensor_shapes():
ssm_constant_shape = ssm_shape[1:] ssm_constant_shape = ssm_shape[1:]
attn_blocks_constant = torch.full( attn_blocks_constant = torch.full(
(test_block_size, *attn_constant_shape), device=DEVICE, fill_value=3.33 (test_block_size, *attn_constant_shape), device=DEVICE_TYPE, fill_value=3.33
) )
conv_blocks_constant = torch.full( conv_blocks_constant = torch.full(
(test_block_size, *conv_constant_shape), device=DEVICE, fill_value=6.66 (test_block_size, *conv_constant_shape), device=DEVICE_TYPE, fill_value=6.66
) )
ssm_blocks_constant = torch.full( ssm_blocks_constant = torch.full(
(test_block_size, *ssm_constant_shape), device=DEVICE, fill_value=9.99 (test_block_size, *ssm_constant_shape), device=DEVICE_TYPE, fill_value=9.99
) )
# Fill attention blocks with constants using kv block indices # Fill attention blocks with constants using kv block indices
...@@ -997,7 +997,7 @@ def test_hybrid_block_table_initialization(): ...@@ -997,7 +997,7 @@ def test_hybrid_block_table_initialization():
max_num_blocks_per_req=max_num_blocks_per_req, max_num_blocks_per_req=max_num_blocks_per_req,
max_num_batched_tokens=max_num_batched_tokens, max_num_batched_tokens=max_num_batched_tokens,
pin_memory=False, pin_memory=False,
device=torch.device(DEVICE), device=torch.device(DEVICE_TYPE),
kernel_block_size=kernel_block_sizes[0], kernel_block_size=kernel_block_sizes[0],
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size, cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
) )
...@@ -1036,7 +1036,7 @@ def test_input_batch_with_kernel_block_sizes(): ...@@ -1036,7 +1036,7 @@ def test_input_batch_with_kernel_block_sizes():
max_num_reqs = 10 max_num_reqs = 10
max_model_len = 512 max_model_len = 512
max_num_batched_tokens = 512 max_num_batched_tokens = 512
device = torch.device(DEVICE) device = torch.device(DEVICE_TYPE)
pin_memory = False pin_memory = False
vocab_size = 50272 vocab_size = 50272
...@@ -1083,7 +1083,7 @@ def test_hybrid_cache_integration(default_vllm_config, dist_init): ...@@ -1083,7 +1083,7 @@ def test_hybrid_cache_integration(default_vllm_config, dist_init):
num_heads, head_size, 0.1 num_heads, head_size, 0.1
) )
runner = GPUModelRunner(vllm_config, DEVICE) runner = GPUModelRunner(vllm_config, DEVICE_TYPE)
# Initialize KV cache with configuration # Initialize KV cache with configuration
attn_spec = FullAttentionSpec( attn_spec = FullAttentionSpec(
...@@ -1306,7 +1306,7 @@ def test_mamba_cache_raises_when_max_num_seqs_exceeds_blocks(): ...@@ -1306,7 +1306,7 @@ def test_mamba_cache_raises_when_max_num_seqs_exceeds_blocks():
) )
assert fwd_context is not None assert fwd_context is not None
runner = GPUModelRunner(vllm_config, DEVICE) runner = GPUModelRunner(vllm_config, DEVICE_TYPE)
current_platform.update_block_size_for_backend(vllm_config) current_platform.update_block_size_for_backend(vllm_config)
kv_cache_spec = runner.get_kv_cache_spec() kv_cache_spec = runner.get_kv_cache_spec()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment