Unverified Commit 32e0c0bf authored by wliao2's avatar wliao2 Committed by GitHub
Browse files

refactor hard coded device string in test files under tests/v1 and tests/lora (#37566)


Signed-off-by: default avatarLiao, Wei <wei.liao@intel.com>
parent 4a06e124
...@@ -637,7 +637,7 @@ def use_fused_moe_lora_kernel_tensor_parallel( ...@@ -637,7 +637,7 @@ def use_fused_moe_lora_kernel_tensor_parallel(
set_random_seed(seed) set_random_seed(seed)
device = torch.device(f"cuda:{local_rank}") device = torch.device(f"{DEVICE_TYPE}:{local_rank}")
torch.accelerator.set_device_index(device) torch.accelerator.set_device_index(device)
torch.set_default_device(device) torch.set_default_device(device)
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
......
...@@ -60,8 +60,12 @@ pytestmark = pytest.mark.skipif( ...@@ -60,8 +60,12 @@ pytestmark = pytest.mark.skipif(
reason="Backend not supported", reason="Backend not supported",
) )
DEVICE_TYPE = current_platform.device_type
DEVICES = ( DEVICES = (
[f"cuda:{i}" for i in range(1 if torch.accelerator.device_count() == 1 else 2)] [
f"{DEVICE_TYPE}:{i}"
for i in range(1 if torch.accelerator.device_count() == 1 else 2)
]
if current_platform.is_cuda_alike() if current_platform.is_cuda_alike()
else ["cpu"] else ["cpu"]
) )
...@@ -196,7 +200,7 @@ def create_random_inputs( ...@@ -196,7 +200,7 @@ def create_random_inputs(
input_size: tuple[int, ...], input_size: tuple[int, ...],
input_range: tuple[float, float], input_range: tuple[float, float],
input_type: torch.dtype = torch.int, input_type: torch.dtype = torch.int,
device: torch.device = "cuda", device: torch.device = DEVICE_TYPE,
) -> tuple[list[torch.Tensor], list[int], list[int]]: ) -> tuple[list[torch.Tensor], list[int], list[int]]:
"""Creates random inputs. """Creates random inputs.
......
...@@ -35,9 +35,9 @@ EMBEDDING_MODULES = { ...@@ -35,9 +35,9 @@ EMBEDDING_MODULES = {
"lm_head": "output_embeddings", "lm_head": "output_embeddings",
} }
DEVICE_TYPE = current_platform.device_type
DEVICES = ( DEVICES = (
[f"cuda:{i}" for i in range(1 if torch.accelerator.device_count() == 1 else 2)] [f"{DEVICE_TYPE}:{i}" for i in range(min(torch.accelerator.device_count(), 2))]
if current_platform.is_cuda_alike() if current_platform.is_cuda_alike()
else ["cpu"] else ["cpu"]
) )
......
...@@ -6,6 +6,9 @@ import pytest ...@@ -6,6 +6,9 @@ import pytest
import torch import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms import current_platform
DEVICE_TYPE = current_platform.device_type
def round_up(x, base): def round_up(x, base):
...@@ -27,7 +30,7 @@ def sample_data(num_experts, max_loras, num_tokens, topk_num): ...@@ -27,7 +30,7 @@ def sample_data(num_experts, max_loras, num_tokens, topk_num):
topk_ids[i, j] = pool[j] topk_ids[i, j] = pool[j]
token_lora_mapping[i] = random.randint(0, max_loras - 1) token_lora_mapping[i] = random.randint(0, max_loras - 1)
return topk_ids.to("cuda"), token_lora_mapping.to("cuda") return topk_ids.to(DEVICE_TYPE), token_lora_mapping.to(DEVICE_TYPE)
@pytest.mark.parametrize("num_tokens", [100, 200, 1024, 4096]) # 81920 @pytest.mark.parametrize("num_tokens", [100, 200, 1024, 4096]) # 81920
...@@ -56,14 +59,21 @@ def test_moe_lora_align_block_size( ...@@ -56,14 +59,21 @@ def test_moe_lora_align_block_size(
(max_loras * max_num_tokens_padded,), (max_loras * max_num_tokens_padded,),
topk_ids.numel(), topk_ids.numel(),
dtype=torch.int32, dtype=torch.int32,
device="cuda", device=DEVICE_TYPE,
) )
expert_ids = torch.full( expert_ids = torch.full(
(max_loras * max_num_m_blocks,), num_experts, dtype=torch.int32, device="cuda" (max_loras * max_num_m_blocks,),
num_experts,
dtype=torch.int32,
device=DEVICE_TYPE,
)
num_tokens_post_pad = torch.zeros(
(max_loras,), dtype=torch.int32, device=DEVICE_TYPE
)
adapter_enabled = torch.ones(
(max_loras + 1,), dtype=torch.int32, device=DEVICE_TYPE
) )
num_tokens_post_pad = torch.zeros((max_loras,), dtype=torch.int32, device="cuda") lora_ids = torch.arange(max_loras + 2, dtype=torch.int32, device=DEVICE_TYPE)
adapter_enabled = torch.ones((max_loras + 1,), dtype=torch.int32, device="cuda")
lora_ids = torch.arange(max_loras + 2, dtype=torch.int32, device="cuda")
# call kernel # call kernel
ops.moe_lora_align_block_size( ops.moe_lora_align_block_size(
......
...@@ -9,10 +9,13 @@ import vllm.lora.ops.torch_ops as torch_ops ...@@ -9,10 +9,13 @@ import vllm.lora.ops.torch_ops as torch_ops
import vllm.lora.ops.triton_ops as triton_ops import vllm.lora.ops.triton_ops as triton_ops
from vllm.lora.ops.triton_ops import LoRAKernelMeta from vllm.lora.ops.triton_ops import LoRAKernelMeta
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import set_random_seed
from .utils import PunicaTensors, assert_close, generate_data_for_nslices from .utils import PunicaTensors, assert_close, generate_data_for_nslices
DEVICE_TYPE = current_platform.device_type
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def reset_device(reset_default_device): def reset_device(reset_default_device):
...@@ -146,7 +149,9 @@ def check_lora_shrink_kernel( ...@@ -146,7 +149,9 @@ def check_lora_shrink_kernel(
# Setup metadata information for the LoRA kernel. # Setup metadata information for the LoRA kernel.
lora_meta = LoRAKernelMeta.make( lora_meta = LoRAKernelMeta.make(
max_loras=num_loras, max_num_tokens=token_nums, device="cuda" max_loras=num_loras,
max_num_tokens=token_nums,
device=DEVICE_TYPE,
) )
lora_meta.prepare_tensors(data.token_lora_mapping) lora_meta.prepare_tensors(data.token_lora_mapping)
...@@ -219,7 +224,9 @@ def check_lora_expand_kernel( ...@@ -219,7 +224,9 @@ def check_lora_expand_kernel(
# Setup metadata information for the LoRA kernel. # Setup metadata information for the LoRA kernel.
lora_meta = LoRAKernelMeta.make( lora_meta = LoRAKernelMeta.make(
max_loras=num_loras, max_num_tokens=token_nums, device="cuda" max_loras=num_loras,
max_num_tokens=token_nums,
device=DEVICE_TYPE,
) )
lora_meta.prepare_tensors(data.token_lora_mapping) lora_meta.prepare_tensors(data.token_lora_mapping)
...@@ -367,7 +374,7 @@ test_params = { ...@@ -367,7 +374,7 @@ test_params = {
} }
DTYPES = [torch.float16, torch.bfloat16] DTYPES = [torch.float16, torch.bfloat16]
DEVICES = [f"cuda:{0}"] DEVICES = [f"{DEVICE_TYPE}:{0}"]
SEED = [0] SEED = [0]
......
...@@ -28,9 +28,11 @@ from vllm.lora.ops.triton_ops.lora_shrink_fp8_op import ( ...@@ -28,9 +28,11 @@ from vllm.lora.ops.triton_ops.lora_shrink_fp8_op import (
_SHRINK_LORA_SCALE_PTR_DICT, _SHRINK_LORA_SCALE_PTR_DICT,
) )
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import set_random_seed
DEVICES = [f"cuda:{0}"] DEVICE_TYPE = current_platform.device_type
DEVICES = [f"{DEVICE_TYPE}:{0}"]
SEED = [0] SEED = [0]
_dict_lock = Lock() _dict_lock = Lock()
......
...@@ -19,11 +19,14 @@ from vllm.config.load import LoadConfig ...@@ -19,11 +19,14 @@ from vllm.config.load import LoadConfig
from vllm.config.lora import LoRAConfig from vllm.config.lora import LoRAConfig
from vllm.lora.model_manager import LoRAMapping from vllm.lora.model_manager import LoRAMapping
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.platforms import current_platform
from vllm.v1.worker.gpu_worker import Worker from vllm.v1.worker.gpu_worker import Worker
MODEL_PATH = "Qwen/Qwen3-0.6B" MODEL_PATH = "Qwen/Qwen3-0.6B"
NUM_LORAS = 16 NUM_LORAS = 16
DEVICE_TYPE = current_platform.device_type
@patch.dict(os.environ, {"RANK": "0"}) @patch.dict(os.environ, {"RANK": "0"})
def test_worker_apply_lora(qwen3_lora_files): def test_worker_apply_lora(qwen3_lora_files):
...@@ -61,7 +64,7 @@ def test_worker_apply_lora(qwen3_lora_files): ...@@ -61,7 +64,7 @@ def test_worker_apply_lora(qwen3_lora_files):
max_num_seqs=32, max_num_seqs=32,
max_num_partial_prefills=32, max_num_partial_prefills=32,
), ),
device_config=DeviceConfig("cuda"), device_config=DeviceConfig(DEVICE_TYPE),
cache_config=CacheConfig( cache_config=CacheConfig(
block_size=16, block_size=16,
cache_dtype="auto", cache_dtype="auto",
......
...@@ -9,10 +9,13 @@ import torch ...@@ -9,10 +9,13 @@ import torch
from safetensors.torch import save_file from safetensors.torch import save_file
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
from vllm.platforms import current_platform
DEVICE_TYPE = current_platform.device_type
class DummyLoRAManager: class DummyLoRAManager:
def __init__(self, device: torch.device = "cuda:0"): def __init__(self, device: torch.device = f"{DEVICE_TYPE}:0"):
super().__init__() super().__init__()
self._loras: dict[str, LoRALayerWeights] = {} self._loras: dict[str, LoRALayerWeights] = {}
self._device = device self._device = device
...@@ -57,8 +60,8 @@ class DummyLoRAManager: ...@@ -57,8 +60,8 @@ class DummyLoRAManager:
module_name, module_name,
rank=rank, rank=rank,
lora_alpha=1, lora_alpha=1,
lora_a=torch.rand([rank, input_dim], device="cuda"), lora_a=torch.rand([rank, input_dim], device=DEVICE_TYPE),
lora_b=torch.rand([output_dim, input_dim], device="cuda"), lora_b=torch.rand([output_dim, input_dim], device=DEVICE_TYPE),
embeddings_tensor=embeddings_tensor, embeddings_tensor=embeddings_tensor,
) )
self.set_module_lora(module_name, lora) self.set_module_lora(module_name, lora)
......
...@@ -40,6 +40,8 @@ BACKENDS_TO_TEST = [ ...@@ -40,6 +40,8 @@ BACKENDS_TO_TEST = [
"FLEX_ATTENTION_SLOW", "FLEX_ATTENTION_SLOW",
] ]
DEVICE_TYPE = current_platform.device_type
# Remove flashinfer from the list if it's not available # Remove flashinfer from the list if it's not available
try: try:
import flashinfer # noqa: F401 import flashinfer # noqa: F401
...@@ -366,7 +368,7 @@ def _test_backend_correctness( ...@@ -366,7 +368,7 @@ def _test_backend_correctness(
num_gpu_blocks=8192, num_gpu_blocks=8192,
hf_config_override=hf_config_override, hf_config_override=hf_config_override,
) )
device = torch.device("cuda:0") device = torch.device(f"{DEVICE_TYPE}:0")
kv_cache_spec = create_standard_kv_cache_spec(vllm_config) kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
......
...@@ -7,6 +7,7 @@ import pytest ...@@ -7,6 +7,7 @@ import pytest
import torch import torch
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
from vllm.platforms import current_platform
from vllm.v1.attention.backends.utils import make_local_attention_virtual_batches from vllm.v1.attention.backends.utils import make_local_attention_virtual_batches
...@@ -22,6 +23,8 @@ class LocalAttentionTestData: ...@@ -22,6 +23,8 @@ class LocalAttentionTestData:
expected_local_block_table: list[list[int]] expected_local_block_table: list[list[int]]
DEVICE_TYPE = current_platform.device_type
test_data_list = [ test_data_list = [
# Same as example in docstring of make_local_attention_virtual_batches # Same as example in docstring of make_local_attention_virtual_batches
# except block table has 9 columns instead of 10 # except block table has 9 columns instead of 10
...@@ -151,7 +154,7 @@ test_data_list = [ ...@@ -151,7 +154,7 @@ test_data_list = [
@pytest.mark.parametrize("test_data", test_data_list) @pytest.mark.parametrize("test_data", test_data_list)
def test_local_attention_virtual_batches(test_data: LocalAttentionTestData): def test_local_attention_virtual_batches(test_data: LocalAttentionTestData):
device = torch.device("cuda:0") device = torch.device(f"{DEVICE_TYPE}:0")
batch_spec = test_data.batch_spec batch_spec = test_data.batch_spec
attn_chunk_size = test_data.attn_chunk_size attn_chunk_size = test_data.attn_chunk_size
block_size = test_data.block_size block_size = test_data.block_size
......
...@@ -42,6 +42,8 @@ BACKENDS_TO_TEST = [ ...@@ -42,6 +42,8 @@ BACKENDS_TO_TEST = [
AttentionBackendEnum.TRITON_MLA, AttentionBackendEnum.TRITON_MLA,
] ]
DEVICE_TYPE = current_platform.device_type
# Remove sm100 backends from the list if not using sm100 # Remove sm100 backends from the list if not using sm100
if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10: if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10:
BACKENDS_TO_TEST.remove(AttentionBackendEnum.CUTLASS_MLA) BACKENDS_TO_TEST.remove(AttentionBackendEnum.CUTLASS_MLA)
...@@ -763,7 +765,7 @@ def test_backend_correctness( ...@@ -763,7 +765,7 @@ def test_backend_correctness(
method="ngram", num_speculative_tokens=query_len - 1 method="ngram", num_speculative_tokens=query_len - 1
) )
device = torch.device("cuda:0") device = torch.device(f"{DEVICE_TYPE}:0")
# 1. Setup # 1. Setup
batch_size = batch_spec.batch_size batch_size = batch_spec.batch_size
......
...@@ -64,6 +64,8 @@ SPARSE_BACKEND_BATCH_SPECS["large_q_pure_prefill"] = BatchSpec( ...@@ -64,6 +64,8 @@ SPARSE_BACKEND_BATCH_SPECS["large_q_pure_prefill"] = BatchSpec(
seq_lens=[256] * 2, query_lens=[256] * 2 seq_lens=[256] * 2, query_lens=[256] * 2
) )
DEVICE_TYPE = current_platform.device_type
def _float_to_e8m0_truncate(f: float) -> float: def _float_to_e8m0_truncate(f: float) -> float:
"""Simulate SM100's float -> e8m0 -> bf16 scale conversion. """Simulate SM100's float -> e8m0 -> bf16 scale conversion.
...@@ -222,7 +224,7 @@ def test_sparse_backend_decode_correctness( ...@@ -222,7 +224,7 @@ def test_sparse_backend_decode_correctness(
batch_spec = SPARSE_BACKEND_BATCH_SPECS[batch_name] batch_spec = SPARSE_BACKEND_BATCH_SPECS[batch_name]
use_fp8_ds_mla_quantization = kv_cache_dtype == "fp8_ds_mla" use_fp8_ds_mla_quantization = kv_cache_dtype == "fp8_ds_mla"
device = torch.device("cuda") device = torch.device(DEVICE_TYPE)
dtype = torch.bfloat16 dtype = torch.bfloat16
# Model hyper-parameters (kept intentionally small for the unit test) # Model hyper-parameters (kept intentionally small for the unit test)
...@@ -586,7 +588,7 @@ def _triton_convert_reference_impl( ...@@ -586,7 +588,7 @@ def _triton_convert_reference_impl(
def test_triton_convert_req_index_to_global_index_decode_only( def test_triton_convert_req_index_to_global_index_decode_only(
block_size, num_topk_tokens block_size, num_topk_tokens
): ):
device = torch.device("cuda") device = torch.device(DEVICE_TYPE)
num_tokens = 8 num_tokens = 8
num_requests = 4 num_requests = 4
max_blocks_per_req = 10 max_blocks_per_req = 10
...@@ -639,7 +641,7 @@ def test_triton_convert_req_index_to_global_index_decode_only( ...@@ -639,7 +641,7 @@ def test_triton_convert_req_index_to_global_index_decode_only(
reason="FlashMLASparseBackend requires CUDA 9.0 or higher", reason="FlashMLASparseBackend requires CUDA 9.0 or higher",
) )
def test_triton_convert_req_index_to_global_index_with_prefill_workspace(block_size): def test_triton_convert_req_index_to_global_index_with_prefill_workspace(block_size):
device = torch.device("cuda") device = torch.device(DEVICE_TYPE)
num_requests = 4 num_requests = 4
max_blocks_per_req = 8 max_blocks_per_req = 8
num_topk_tokens = 128 num_topk_tokens = 128
...@@ -794,7 +796,7 @@ def test_split_indexer_prefill_chunks_single_request_overflow(): ...@@ -794,7 +796,7 @@ def test_split_indexer_prefill_chunks_single_request_overflow():
def test_triton_convert_returns_valid_counts(): def test_triton_convert_returns_valid_counts():
"""Test that return_valid_counts correctly counts non-negative indices.""" """Test that return_valid_counts correctly counts non-negative indices."""
device = torch.device("cuda") device = torch.device(DEVICE_TYPE)
num_tokens = 8 num_tokens = 8
num_requests = 2 num_requests = 2
max_blocks_per_req = 10 max_blocks_per_req = 10
......
...@@ -55,6 +55,7 @@ class MockAttentionLayer: ...@@ -55,6 +55,7 @@ class MockAttentionLayer:
MODEL = "Qwen/Qwen2.5-0.5B" MODEL = "Qwen/Qwen2.5-0.5B"
BLOCK_SIZE = 16 BLOCK_SIZE = 16
NUM_GPU_BLOCKS = 8192 NUM_GPU_BLOCKS = 8192
DEVICE_TYPE = current_platform.device_type
BATCH_SPECS = { BATCH_SPECS = {
"decode_only": BatchSpec( "decode_only": BatchSpec(
...@@ -172,7 +173,7 @@ def _run_trtllm_integration(batch_spec): ...@@ -172,7 +173,7 @@ def _run_trtllm_integration(batch_spec):
"""Run TRTLLM attention through the full FlashInfer pipeline """Run TRTLLM attention through the full FlashInfer pipeline
and compare against an SDPA reference.""" and compare against an SDPA reference."""
set_random_seed(42) set_random_seed(42)
device = torch.device("cuda:0") device = torch.device(f"{DEVICE_TYPE}:0")
vllm_config = create_vllm_config( vllm_config = create_vllm_config(
model_name=MODEL, model_name=MODEL,
......
...@@ -23,6 +23,8 @@ from vllm.forward_context import BatchDescriptor, set_forward_context ...@@ -23,6 +23,8 @@ from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
DEVICE_TYPE = current_platform.device_type
# Helper MLP for testing # Helper MLP for testing
class SimpleMLP(nn.Module): class SimpleMLP(nn.Module):
...@@ -269,9 +271,9 @@ class TestCudagraphDispatcher: ...@@ -269,9 +271,9 @@ class TestCudagraphDispatcher:
class TestCUDAGraphWrapper: class TestCUDAGraphWrapper:
def setup_method(self): def setup_method(self):
self.vllm_config = _create_vllm_config(CompilationConfig()) self.vllm_config = _create_vllm_config(CompilationConfig())
self.model = SimpleMLP().to("cuda") self.model = SimpleMLP().to(DEVICE_TYPE)
self.persistent_input_buffer = torch.zeros(1, 10, device="cuda") self.persistent_input_buffer = torch.zeros(1, 10, device=DEVICE_TYPE)
self.input_tensor = torch.randn(1, 10, device="cuda") self.input_tensor = torch.randn(1, 10, device=DEVICE_TYPE)
def test_capture_and_replay(self): def test_capture_and_replay(self):
wrapper = CUDAGraphWrapper( wrapper = CUDAGraphWrapper(
...@@ -428,10 +430,10 @@ class TestCudagraphIntegration: ...@@ -428,10 +430,10 @@ class TestCudagraphIntegration:
@create_new_process_for_each_test("spawn") @create_new_process_for_each_test("spawn")
def test_capture_replay_bypass_logic(self): def test_capture_replay_bypass_logic(self):
model = SimpleMLP().to("cuda") model = SimpleMLP().to(DEVICE_TYPE)
full_wrapper = CUDAGraphWrapper(model, self.vllm_config, CUDAGraphMode.FULL) full_wrapper = CUDAGraphWrapper(model, self.vllm_config, CUDAGraphMode.FULL)
max_bs = 16 max_bs = 16
persistent_input_buffer = torch.zeros(max_bs, 10, device="cuda") persistent_input_buffer = torch.zeros(max_bs, 10, device=DEVICE_TYPE)
input_1 = persistent_input_buffer[:1] input_1 = persistent_input_buffer[:1]
input_2 = persistent_input_buffer[:2] input_2 = persistent_input_buffer[:2]
input_3 = persistent_input_buffer[:3] input_3 = persistent_input_buffer[:3]
...@@ -486,17 +488,17 @@ class TestCudagraphIntegration: ...@@ -486,17 +488,17 @@ class TestCudagraphIntegration:
@create_new_process_for_each_test("spawn") @create_new_process_for_each_test("spawn")
def test_nested_wrappers(self): def test_nested_wrappers(self):
"""Tests a scenario with a PIECEWISE wrapper inside a FULL one.""" """Tests a scenario with a PIECEWISE wrapper inside a FULL one."""
model = SimpleMLP().to("cuda") model = SimpleMLP().to(DEVICE_TYPE)
full_wrapper = CUDAGraphWrapper(model, self.vllm_config, CUDAGraphMode.FULL) full_wrapper = CUDAGraphWrapper(model, self.vllm_config, CUDAGraphMode.FULL)
input_1 = torch.randn(1, 10, device="cuda") input_1 = torch.randn(1, 10, device=DEVICE_TYPE)
# Setup: Inner model is wrapped with PIECEWISE, outer with FULL # Setup: Inner model is wrapped with PIECEWISE, outer with FULL
inner_model = SimpleMLP().to("cuda") inner_model = SimpleMLP().to(DEVICE_TYPE)
piecewise_wrapper = CUDAGraphWrapper( piecewise_wrapper = CUDAGraphWrapper(
inner_model, self.vllm_config, CUDAGraphMode.PIECEWISE inner_model, self.vllm_config, CUDAGraphMode.PIECEWISE
) )
inner_model.forward = MagicMock(wraps=inner_model.forward) inner_model.forward = MagicMock(wraps=inner_model.forward)
outer_model = SimpleMLP().to("cuda") outer_model = SimpleMLP().to(DEVICE_TYPE)
# When outer model is called, it calls the piecewise_wrapper # When outer model is called, it calls the piecewise_wrapper
outer_model.forward = MagicMock( outer_model.forward = MagicMock(
wraps=outer_model.forward, side_effect=piecewise_wrapper wraps=outer_model.forward, side_effect=piecewise_wrapper
......
...@@ -13,6 +13,9 @@ from utils import skip_unsupported ...@@ -13,6 +13,9 @@ from utils import skip_unsupported
from vllm.model_executor.layers.batch_invariant import rms_norm as triton_rms_norm from vllm.model_executor.layers.batch_invariant import rms_norm as triton_rms_norm
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.platforms import current_platform
DEVICE_TYPE = current_platform.device_type
@skip_unsupported @skip_unsupported
...@@ -34,7 +37,7 @@ def test_rms_norm_batch_invariant_vs_standard( ...@@ -34,7 +37,7 @@ def test_rms_norm_batch_invariant_vs_standard(
equivalent results to the standard CUDA implementation across various equivalent results to the standard CUDA implementation across various
configurations. configurations.
""" """
device = torch.device("cuda") device = torch.device(DEVICE_TYPE)
# Create test input and weight # Create test input and weight
torch.manual_seed(42) torch.manual_seed(42)
...@@ -81,7 +84,7 @@ def test_rms_norm_3d_input( ...@@ -81,7 +84,7 @@ def test_rms_norm_3d_input(
Ensures that the batch-invariant RMS norm correctly handles multi-dimensional Ensures that the batch-invariant RMS norm correctly handles multi-dimensional
inputs that are common in transformer models. inputs that are common in transformer models.
""" """
device = torch.device("cuda") device = torch.device(DEVICE_TYPE)
dtype = torch.bfloat16 dtype = torch.bfloat16
eps = 1e-6 eps = 1e-6
...@@ -120,7 +123,7 @@ def test_rms_norm_numerical_stability(default_vllm_config): ...@@ -120,7 +123,7 @@ def test_rms_norm_numerical_stability(default_vllm_config):
Ensures that both implementations handle edge cases like very small or large Ensures that both implementations handle edge cases like very small or large
values without producing NaN or Inf. values without producing NaN or Inf.
""" """
device = torch.device("cuda") device = torch.device(DEVICE_TYPE)
dtype = torch.float16 dtype = torch.float16
eps = 1e-6 eps = 1e-6
hidden_size = 2048 hidden_size = 2048
...@@ -179,7 +182,7 @@ def test_rms_norm_formula(default_vllm_config): ...@@ -179,7 +182,7 @@ def test_rms_norm_formula(default_vllm_config):
Verifies: output = input / sqrt(mean(input^2) + eps) * weight Verifies: output = input / sqrt(mean(input^2) + eps) * weight
""" """
device = torch.device("cuda") device = torch.device(DEVICE_TYPE)
dtype = torch.float32 # Use float32 for higher precision in formula check dtype = torch.float32 # Use float32 for higher precision in formula check
eps = 1e-6 eps = 1e-6
hidden_size = 1024 hidden_size = 1024
...@@ -214,7 +217,7 @@ def test_rms_norm_different_hidden_sizes(default_vllm_config, hidden_size: int): ...@@ -214,7 +217,7 @@ def test_rms_norm_different_hidden_sizes(default_vllm_config, hidden_size: int):
The Triton kernel uses a fixed BLOCK_SIZE=1024, so this tests that it The Triton kernel uses a fixed BLOCK_SIZE=1024, so this tests that it
correctly handles hidden sizes both smaller and larger than the block size. correctly handles hidden sizes both smaller and larger than the block size.
""" """
device = torch.device("cuda") device = torch.device(DEVICE_TYPE)
dtype = torch.bfloat16 dtype = torch.bfloat16
eps = 1e-6 eps = 1e-6
batch_size = 16 batch_size = 16
...@@ -251,7 +254,7 @@ def test_rms_norm_determinism(default_vllm_config): ...@@ -251,7 +254,7 @@ def test_rms_norm_determinism(default_vllm_config):
Runs the same input through the kernel multiple times and verifies Runs the same input through the kernel multiple times and verifies
identical outputs. identical outputs.
""" """
device = torch.device("cuda") device = torch.device(DEVICE_TYPE)
dtype = torch.bfloat16 dtype = torch.bfloat16
eps = 1e-6 eps = 1e-6
hidden_size = 4096 hidden_size = 4096
...@@ -283,7 +286,7 @@ if __name__ == "__main__": ...@@ -283,7 +286,7 @@ if __name__ == "__main__":
# Run a quick smoke test # Run a quick smoke test
print("Running quick smoke test of RMS norm implementations...") print("Running quick smoke test of RMS norm implementations...")
device = torch.device("cuda") device = torch.device(DEVICE_TYPE)
batch_size = 8 batch_size = 8
hidden_size = 4096 hidden_size = 4096
dtype = torch.bfloat16 dtype = torch.bfloat16
......
...@@ -16,6 +16,7 @@ from vllm import LLM, SamplingParams, TokensPrompt ...@@ -16,6 +16,7 @@ from vllm import LLM, SamplingParams, TokensPrompt
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed import cleanup_dist_env_and_memory
from vllm.model_executor.layers.mamba.mamba_utils import MambaStateCopyFunc from vllm.model_executor.layers.mamba.mamba_utils import MambaStateCopyFunc
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
...@@ -48,6 +49,7 @@ num_accepted_tokens = 1 ...@@ -48,6 +49,7 @@ num_accepted_tokens = 1
prompt_token_ids: list[int] = [] prompt_token_ids: list[int] = []
MODEL = "Qwen/Qwen3-Next-80B-A3B-Instruct-FP8" MODEL = "Qwen/Qwen3-Next-80B-A3B-Instruct-FP8"
BLOCK_SIZE = 560 BLOCK_SIZE = 560
DEVICE_TYPE = current_platform.device_type
NUM_HIDDEN_LAYERS = 1 NUM_HIDDEN_LAYERS = 1
cur_step_action_idx = 0 cur_step_action_idx = 0
cur_step_action: StepAction | None = None cur_step_action: StepAction | None = None
...@@ -71,7 +73,7 @@ def get_fake_sample_fn() -> SamplerOutput: ...@@ -71,7 +73,7 @@ def get_fake_sample_fn() -> SamplerOutput:
return SamplerOutput( return SamplerOutput(
sampled_token_ids=torch.tensor( sampled_token_ids=torch.tensor(
[[prompt_token_ids[first_token_id_index]]], [[prompt_token_ids[first_token_id_index]]],
device="cuda", device=DEVICE_TYPE,
dtype=torch.int32, dtype=torch.int32,
), ),
logprobs_tensors=None, logprobs_tensors=None,
...@@ -83,7 +85,9 @@ def get_fake_sample_fn() -> SamplerOutput: ...@@ -83,7 +85,9 @@ def get_fake_sample_fn() -> SamplerOutput:
sampled_token_ids = accepted_tokens sampled_token_ids = accepted_tokens
return SamplerOutput( return SamplerOutput(
sampled_token_ids=torch.tensor( sampled_token_ids=torch.tensor(
[sampled_token_ids], device="cuda", dtype=torch.int32 [sampled_token_ids],
device=DEVICE_TYPE,
dtype=torch.int32,
), ),
logprobs_tensors=None, logprobs_tensors=None,
) )
...@@ -128,17 +132,23 @@ def get_fake_propose_draft_token_ids_fn(): ...@@ -128,17 +132,23 @@ def get_fake_propose_draft_token_ids_fn():
- 1 - 1
+ num_accepted_tokens + num_accepted_tokens
], ],
device="cuda", device=DEVICE_TYPE,
dtype=torch.int32, dtype=torch.int32,
) )
valid_sampled_tokens_count = torch.tensor( valid_sampled_tokens_count = torch.tensor(
[num_accepted_tokens], device="cuda", dtype=torch.int32 [num_accepted_tokens],
device=DEVICE_TYPE,
dtype=torch.int32,
) )
self._copy_valid_sampled_token_count(next_token_ids, valid_sampled_tokens_count) self._copy_valid_sampled_token_count(next_token_ids, valid_sampled_tokens_count)
return torch.tensor(proposed_draft_token_ids, device="cuda", dtype=torch.int32) return torch.tensor(
proposed_draft_token_ids,
device=DEVICE_TYPE,
dtype=torch.int32,
)
return fake_propose_draft_token_ids_fn return fake_propose_draft_token_ids_fn
......
...@@ -6,6 +6,7 @@ import time ...@@ -6,6 +6,7 @@ import time
import pytest import pytest
import torch import torch
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import set_random_seed
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec
from vllm.v1.kv_offload.spec import ( from vllm.v1.kv_offload.spec import (
...@@ -21,7 +22,8 @@ GPU_PAGE_SIZES = [512, 1024] ...@@ -21,7 +22,8 @@ GPU_PAGE_SIZES = [512, 1024]
BLOCK_SIZE_FACTORS = [1, 3] BLOCK_SIZE_FACTORS = [1, 3]
NUM_TENSORS = [4] NUM_TENSORS = [4]
SEEDS = [0] SEEDS = [0]
CUDA_DEVICES = ["cuda:0"] DEVICE_TYPE = current_platform.device_type
DEVICES = [f"{DEVICE_TYPE}:0"]
NUM_MAPPINGS = [3] NUM_MAPPINGS = [3]
...@@ -33,7 +35,7 @@ NUM_MAPPINGS = [3] ...@@ -33,7 +35,7 @@ NUM_MAPPINGS = [3]
@pytest.mark.parametrize("num_cpu_blocks", NUM_CPU_BLOCKS) @pytest.mark.parametrize("num_cpu_blocks", NUM_CPU_BLOCKS)
@pytest.mark.parametrize("num_tensors", NUM_TENSORS) @pytest.mark.parametrize("num_tensors", NUM_TENSORS)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode() @torch.inference_mode()
def test_transfer( def test_transfer(
default_vllm_config, default_vllm_config,
......
...@@ -39,8 +39,9 @@ PIN_MEMORY_AVAILABLE = is_pin_memory_available() ...@@ -39,8 +39,9 @@ PIN_MEMORY_AVAILABLE = is_pin_memory_available()
MAX_NUM_REQS = 256 MAX_NUM_REQS = 256
VOCAB_SIZE = 1024 VOCAB_SIZE = 1024
NUM_OUTPUT_TOKENS = 20 NUM_OUTPUT_TOKENS = 20
CUDA_DEVICES = [ DEVICE_TYPE = current_platform.device_type
f"{current_platform.device_type}:{i}" DEVICES = [
f"{DEVICE_TYPE}:{i}"
for i in range(1 if current_platform.device_count() == 1 else 2) for i in range(1 if current_platform.device_count() == 1 else 2)
] ]
MAX_NUM_PROMPT_TOKENS = 64 MAX_NUM_PROMPT_TOKENS = 64
...@@ -801,7 +802,7 @@ def _assert_valid( ...@@ -801,7 +802,7 @@ def _assert_valid(
@create_new_process_for_each_test() @create_new_process_for_each_test()
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("reqs_per_logitproc", [REQS_PER_LOGITPROC]) @pytest.mark.parametrize("reqs_per_logitproc", [REQS_PER_LOGITPROC])
@pytest.mark.parametrize("logitsprocs_under_test", _get_test_cases()) @pytest.mark.parametrize("logitsprocs_under_test", _get_test_cases())
def test_logitsprocs( def test_logitsprocs(
......
...@@ -19,7 +19,7 @@ from vllm.v1.sample.rejection_sampler import ( ...@@ -19,7 +19,7 @@ from vllm.v1.sample.rejection_sampler import (
from vllm.v1.sample.sampler import Sampler, SamplerOutput from vllm.v1.sample.sampler import Sampler, SamplerOutput
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
DEVICE = current_platform.device_type DEVICE_TYPE = current_platform.device_type
@pytest.fixture @pytest.fixture
...@@ -57,7 +57,7 @@ def create_logits_tensor( ...@@ -57,7 +57,7 @@ def create_logits_tensor(
will produce desired token ids on argmax""" will produce desired token ids on argmax"""
token_ids = [tokens[:-1] for tokens in output_token_ids] token_ids = [tokens[:-1] for tokens in output_token_ids]
num_total_tokens = sum(len(tokens) for tokens in token_ids) num_total_tokens = sum(len(tokens) for tokens in token_ids)
logits = torch.full((num_total_tokens, vocab_size), -100.0, device=DEVICE) logits = torch.full((num_total_tokens, vocab_size), -100.0, device=DEVICE_TYPE)
start_loc = 0 start_loc = 0
for tokens in token_ids: for tokens in token_ids:
for j, token_id in enumerate(tokens): for j, token_id in enumerate(tokens):
...@@ -99,9 +99,9 @@ def create_sampling_metadata( ...@@ -99,9 +99,9 @@ def create_sampling_metadata(
assert output_token_ids assert output_token_ids
assert len(output_token_ids) > 0 assert len(output_token_ids) > 0
frequency_penalties = torch.tensor(frequency_penalties, device=DEVICE) frequency_penalties = torch.tensor(frequency_penalties, device=DEVICE_TYPE)
presence_penalties = torch.tensor(presence_penalties, device=DEVICE) presence_penalties = torch.tensor(presence_penalties, device=DEVICE_TYPE)
repetition_penalties = torch.tensor(repetition_penalties, device=DEVICE) repetition_penalties = torch.tensor(repetition_penalties, device=DEVICE_TYPE)
else: else:
no_penalties = True no_penalties = True
frequency_penalties = torch.tensor([]) frequency_penalties = torch.tensor([])
...@@ -320,14 +320,27 @@ def test_deterministic_when_seeded( ...@@ -320,14 +320,27 @@ def test_deterministic_when_seeded(
n_rep: int, n_rep: int,
): ):
num_tokens = batch_size * k num_tokens = batch_size * k
draft_probs = torch.rand(num_tokens, vocab_size, dtype=torch.float32, device=DEVICE) draft_probs = torch.rand(
num_tokens,
vocab_size,
dtype=torch.float32,
device=DEVICE_TYPE,
)
draft_probs = F.softmax(draft_probs, dim=-1) draft_probs = F.softmax(draft_probs, dim=-1)
target_logits = torch.rand_like(draft_probs) target_logits = torch.rand_like(draft_probs)
bonus_token_ids = torch.randint( bonus_token_ids = torch.randint(
low=0, high=vocab_size, size=(batch_size, 1), dtype=torch.int64, device=DEVICE low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64,
device=DEVICE_TYPE,
) )
draft_token_ids = torch.randint( draft_token_ids = torch.randint(
low=0, high=vocab_size, size=(batch_size, k), dtype=torch.int64, device=DEVICE low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64,
device=DEVICE_TYPE,
) )
seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded
...@@ -335,12 +348,12 @@ def test_deterministic_when_seeded( ...@@ -335,12 +348,12 @@ def test_deterministic_when_seeded(
results = [] results = []
for _ in range(n_rep): for _ in range(n_rep):
seeded_seqs = { seeded_seqs = {
i: torch.Generator(device=DEVICE).manual_seed(i) i: torch.Generator(device=DEVICE_TYPE).manual_seed(i)
for i in range(batch_size) for i in range(batch_size)
if seeded_mask[i] if seeded_mask[i]
} }
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE) temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE_TYPE)
sampling_metadata = create_sampling_metadata( sampling_metadata = create_sampling_metadata(
all_greedy=False, temperature=temperature, generators=seeded_seqs all_greedy=False, temperature=temperature, generators=seeded_seqs
) )
...@@ -387,7 +400,7 @@ def test_rejection_sampling_approximates_target_distribution(): ...@@ -387,7 +400,7 @@ def test_rejection_sampling_approximates_target_distribution():
much more than the distance improvement between the observed much more than the distance improvement between the observed
distribution and the random distribution. distribution and the random distribution.
""" """
torch.set_default_device(DEVICE) torch.set_default_device(DEVICE_TYPE)
vocab_size = 10 vocab_size = 10
k = 2 k = 2
num_reference_probs = 100 num_reference_probs = 100
...@@ -410,7 +423,7 @@ def test_rejection_sampling_approximates_target_distribution(): ...@@ -410,7 +423,7 @@ def test_rejection_sampling_approximates_target_distribution():
rej_sample_probs = estimate_rejection_sampling_pdf( rej_sample_probs = estimate_rejection_sampling_pdf(
draft_probs, target_logits, k, vocab_size, num_samples draft_probs, target_logits, k, vocab_size, num_samples
) )
rej_sample_probs = rej_sample_probs.to(DEVICE) rej_sample_probs = rej_sample_probs.to(DEVICE_TYPE)
# Average distance from reference probs. # Average distance from reference probs.
reference_vs_rejsample_dist = ( reference_vs_rejsample_dist = (
...@@ -491,11 +504,11 @@ def estimate_rejection_sampling_pdf( ...@@ -491,11 +504,11 @@ def estimate_rejection_sampling_pdf(
draft_probs = draft_probs.view(num_tokens, vocab_size) draft_probs = draft_probs.view(num_tokens, vocab_size)
# Bonus tokens not used but required. # Bonus tokens not used but required.
bonus_token_ids = torch.zeros((1, 1), dtype=torch.int64, device=DEVICE).repeat( bonus_token_ids = torch.zeros((1, 1), dtype=torch.int64, device=DEVICE_TYPE).repeat(
num_samples, 1 num_samples, 1
) )
temperature = torch.ones(num_samples, dtype=torch.float32, device=DEVICE) temperature = torch.ones(num_samples, dtype=torch.float32, device=DEVICE_TYPE)
sampling_metadata = create_sampling_metadata( sampling_metadata = create_sampling_metadata(
all_greedy=False, temperature=temperature all_greedy=False, temperature=temperature
) )
...@@ -600,7 +613,7 @@ def _test_masked_logits( ...@@ -600,7 +613,7 @@ def _test_masked_logits(
# Create random draft probabilities. # Create random draft probabilities.
draft_probs = torch.rand( draft_probs = torch.rand(
(num_tokens, vocab_size), dtype=torch.float32, device=DEVICE (num_tokens, vocab_size), dtype=torch.float32, device=DEVICE_TYPE
) )
draft_probs = F.softmax(draft_probs, dim=-1) draft_probs = F.softmax(draft_probs, dim=-1)
...@@ -610,7 +623,11 @@ def _test_masked_logits( ...@@ -610,7 +623,11 @@ def _test_masked_logits(
draft_token_ids = draft_token_ids.tolist() draft_token_ids = draft_token_ids.tolist()
# Bonus tokens not used but required # Bonus tokens not used but required
bonus_token_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=DEVICE) bonus_token_ids = torch.zeros(
(batch_size, 1),
dtype=torch.int64,
device=DEVICE_TYPE,
)
# Create spec decode metadata # Create spec decode metadata
spec_decode_metadata = create_spec_decode_metadata(draft_token_ids, target_logits) spec_decode_metadata = create_spec_decode_metadata(draft_token_ids, target_logits)
...@@ -645,12 +662,13 @@ def test_top_k(rejection_sampler, top_k): ...@@ -645,12 +662,13 @@ def test_top_k(rejection_sampler, top_k):
# Randomly create top-k indices. # Randomly create top-k indices.
top_k_indices = [ top_k_indices = [
torch.randperm(vocab_size, device=DEVICE)[:top_k] for _ in range(num_tokens) torch.randperm(vocab_size, device=DEVICE_TYPE)[:top_k]
for _ in range(num_tokens)
] ]
top_k_indices = torch.stack(top_k_indices) top_k_indices = torch.stack(top_k_indices)
# Create logits with the uniform distribution. # Create logits with the uniform distribution.
target_logits = torch.zeros((num_tokens, vocab_size), device=DEVICE) target_logits = torch.zeros((num_tokens, vocab_size), device=DEVICE_TYPE)
# Increment the logits for top-k indices, a little bit more than the other # Increment the logits for top-k indices, a little bit more than the other
# ones. If the masking is effective, the non-topk indices will never be # ones. If the masking is effective, the non-topk indices will never be
...@@ -659,11 +677,11 @@ def test_top_k(rejection_sampler, top_k): ...@@ -659,11 +677,11 @@ def test_top_k(rejection_sampler, top_k):
target_logits[i, top_k_indices[i]] += 0.1 target_logits[i, top_k_indices[i]] += 0.1
# Create sampling metadata # Create sampling metadata
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE) temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE_TYPE)
sampling_metadata = create_sampling_metadata( sampling_metadata = create_sampling_metadata(
all_greedy=False, all_greedy=False,
temperature=temperature, temperature=temperature,
top_k=torch.tensor([top_k] * batch_size, device=DEVICE, dtype=torch.int64), top_k=torch.tensor([top_k] * batch_size, device=DEVICE_TYPE, dtype=torch.int64),
) )
_test_masked_logits( _test_masked_logits(
...@@ -686,8 +704,8 @@ def test_top_p(rejection_sampler, top_p): ...@@ -686,8 +704,8 @@ def test_top_p(rejection_sampler, top_p):
num_tokens = batch_size * num_draft_tokens num_tokens = batch_size * num_draft_tokens
# Create logits with the uniform distribution. # Create logits with the uniform distribution.
target_logits = torch.randn((num_tokens, vocab_size), device=DEVICE) target_logits = torch.randn((num_tokens, vocab_size), device=DEVICE_TYPE)
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE) temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE_TYPE)
rescaled_logits = target_logits / temperature rescaled_logits = target_logits / temperature
logits_sort, logits_idx = rescaled_logits.sort(dim=-1, descending=False) logits_sort, logits_idx = rescaled_logits.sort(dim=-1, descending=False)
...@@ -706,7 +724,11 @@ def test_top_p(rejection_sampler, top_p): ...@@ -706,7 +724,11 @@ def test_top_p(rejection_sampler, top_p):
sampling_metadata = create_sampling_metadata( sampling_metadata = create_sampling_metadata(
all_greedy=False, all_greedy=False,
temperature=temperature, temperature=temperature,
top_p=torch.tensor([top_p] * batch_size, device=DEVICE, dtype=torch.float32), top_p=torch.tensor(
[top_p] * batch_size,
device=DEVICE_TYPE,
dtype=torch.float32,
),
) )
_test_masked_logits( _test_masked_logits(
...@@ -732,7 +754,10 @@ def test_frequency_penalties(rejection_sampler): ...@@ -732,7 +754,10 @@ def test_frequency_penalties(rejection_sampler):
all_greedy=True, all_greedy=True,
output_token_ids=[[2], [3], [4]], output_token_ids=[[2], [3], [4]],
spec_token_ids=spec_tokens, spec_token_ids=spec_tokens,
prompt_token_ids=torch.tensor([[5, 6, 7], [6, 7, 8], [7, 8, 9]], device=DEVICE), prompt_token_ids=torch.tensor(
[[5, 6, 7], [6, 7, 8], [7, 8, 9]],
device=DEVICE_TYPE,
),
frequency_penalties=[1.5, 1.5, 0.7], frequency_penalties=[1.5, 1.5, 0.7],
presence_penalties=[0.0] * num_requests, presence_penalties=[0.0] * num_requests,
repetition_penalties=[1.0] * num_requests, repetition_penalties=[1.0] * num_requests,
...@@ -858,21 +883,26 @@ def test_sample_recovered_tokens( ...@@ -858,21 +883,26 @@ def test_sample_recovered_tokens(
num_tokens = batch_size * max_spec_len num_tokens = batch_size * max_spec_len
# Create random draft probabilities. # Create random draft probabilities.
draft_probs = torch.rand(num_tokens, vocab_size, dtype=torch.float32, device=DEVICE) draft_probs = torch.rand(
num_tokens,
vocab_size,
dtype=torch.float32,
device=DEVICE_TYPE,
)
draft_probs = F.softmax(draft_probs, dim=-1) draft_probs = F.softmax(draft_probs, dim=-1)
# Create random target probabilities. # Create random target probabilities.
target_logits = torch.rand( target_logits = torch.rand(
num_tokens, vocab_size, dtype=torch.float32, device=DEVICE num_tokens, vocab_size, dtype=torch.float32, device=DEVICE_TYPE
) )
target_probs = F.softmax(target_logits, dim=-1) target_probs = F.softmax(target_logits, dim=-1)
# Randomly sample draft token ids from draft probs # Randomly sample draft token ids from draft probs
draft_token_ids = torch.multinomial(draft_probs, num_samples=1).to(torch.int32) draft_token_ids = torch.multinomial(draft_probs, num_samples=1).to(torch.int32)
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE) temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE_TYPE)
generators = { generators = {
i: torch.Generator(device=DEVICE).manual_seed(i) for i in range(batch_size) i: torch.Generator(device=DEVICE_TYPE).manual_seed(i) for i in range(batch_size)
} }
sampling_metadata = create_sampling_metadata( sampling_metadata = create_sampling_metadata(
all_greedy=False, temperature=temperature, generators=generators all_greedy=False, temperature=temperature, generators=generators
...@@ -890,7 +920,7 @@ def test_sample_recovered_tokens( ...@@ -890,7 +920,7 @@ def test_sample_recovered_tokens(
None if no_draft_probs else draft_probs, None if no_draft_probs else draft_probs,
target_probs, target_probs,
sampling_metadata, sampling_metadata,
device=DEVICE, device=DEVICE_TYPE,
) )
recovered_token_ids = sample_recovered_tokens( recovered_token_ids = sample_recovered_tokens(
max_spec_len, max_spec_len,
...@@ -900,6 +930,6 @@ def test_sample_recovered_tokens( ...@@ -900,6 +930,6 @@ def test_sample_recovered_tokens(
None if no_draft_probs else draft_probs, None if no_draft_probs else draft_probs,
target_probs, target_probs,
sampling_metadata, sampling_metadata,
device=DEVICE, device=DEVICE_TYPE,
) )
assert torch.equal(recovered_token_ids, ref_recovered_token_ids) assert torch.equal(recovered_token_ids, ref_recovered_token_ids)
...@@ -17,8 +17,9 @@ PIN_MEMORY_AVAILABLE = is_pin_memory_available() ...@@ -17,8 +17,9 @@ PIN_MEMORY_AVAILABLE = is_pin_memory_available()
MAX_NUM_REQS = 256 MAX_NUM_REQS = 256
VOCAB_SIZE = 1024 VOCAB_SIZE = 1024
NUM_OUTPUT_TOKENS = 20 NUM_OUTPUT_TOKENS = 20
CUDA_DEVICES = [ DEVICE_TYPE = current_platform.device_type
f"{current_platform.device_type}:{i}" DEVICES = [
f"{DEVICE_TYPE}:{i}"
for i in range(1 if current_platform.device_count() == 1 else 2) for i in range(1 if current_platform.device_count() == 1 else 2)
] ]
MAX_NUM_PROMPT_TOKENS = 64 MAX_NUM_PROMPT_TOKENS = 64
...@@ -199,7 +200,7 @@ def _create_weighted_output_token_list( ...@@ -199,7 +200,7 @@ def _create_weighted_output_token_list(
return output_token_ids, sorted_token_ids_in_output return output_token_ids, sorted_token_ids_in_output
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32]) @pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("presence_penalty", [-2.0, 2.0]) @pytest.mark.parametrize("presence_penalty", [-2.0, 2.0])
def test_sampler_presence_penalty( def test_sampler_presence_penalty(
...@@ -249,7 +250,7 @@ def test_sampler_presence_penalty( ...@@ -249,7 +250,7 @@ def test_sampler_presence_penalty(
assert penalized_token_id not in output_token_ids[batch_idx] assert penalized_token_id not in output_token_ids[batch_idx]
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32]) @pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("frequency_penalty", [-2.0, 2.0]) @pytest.mark.parametrize("frequency_penalty", [-2.0, 2.0])
def test_sampler_frequency_penalty( def test_sampler_frequency_penalty(
...@@ -305,7 +306,7 @@ def test_sampler_frequency_penalty( ...@@ -305,7 +306,7 @@ def test_sampler_frequency_penalty(
assert penalized_token_id not in distinct_sorted_token_ids_in_output assert penalized_token_id not in distinct_sorted_token_ids_in_output
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32]) @pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("repetition_penalty", [0.1, 1.9]) @pytest.mark.parametrize("repetition_penalty", [0.1, 1.9])
def test_sampler_repetition_penalty( def test_sampler_repetition_penalty(
...@@ -363,7 +364,7 @@ def test_sampler_repetition_penalty( ...@@ -363,7 +364,7 @@ def test_sampler_repetition_penalty(
) )
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32]) @pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("num_allowed_token_ids", [0, 1, 2]) @pytest.mark.parametrize("num_allowed_token_ids", [0, 1, 2])
def test_sampler_allowed_token_ids( def test_sampler_allowed_token_ids(
...@@ -409,7 +410,7 @@ def test_sampler_allowed_token_ids( ...@@ -409,7 +410,7 @@ def test_sampler_allowed_token_ids(
assert logits_for_req[token_id] != -float("inf") assert logits_for_req[token_id] != -float("inf")
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32]) @pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("bad_words_lengths", [(1,), (1, 3), (2, 2)]) @pytest.mark.parametrize("bad_words_lengths", [(1,), (1, 3), (2, 2)])
def test_sampler_bad_words( def test_sampler_bad_words(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment