Unverified Commit 32e0c0bf authored by wliao2's avatar wliao2 Committed by GitHub
Browse files

refactor hard coded device string in test files under tests/v1 and tests/lora (#37566)


Signed-off-by: default avatarLiao, Wei <wei.liao@intel.com>
parent 4a06e124
......@@ -637,7 +637,7 @@ def use_fused_moe_lora_kernel_tensor_parallel(
set_random_seed(seed)
device = torch.device(f"cuda:{local_rank}")
device = torch.device(f"{DEVICE_TYPE}:{local_rank}")
torch.accelerator.set_device_index(device)
torch.set_default_device(device)
torch.set_default_dtype(dtype)
......
......@@ -60,8 +60,12 @@ pytestmark = pytest.mark.skipif(
reason="Backend not supported",
)
DEVICE_TYPE = current_platform.device_type
DEVICES = (
[f"cuda:{i}" for i in range(1 if torch.accelerator.device_count() == 1 else 2)]
[
f"{DEVICE_TYPE}:{i}"
for i in range(1 if torch.accelerator.device_count() == 1 else 2)
]
if current_platform.is_cuda_alike()
else ["cpu"]
)
......@@ -196,7 +200,7 @@ def create_random_inputs(
input_size: tuple[int, ...],
input_range: tuple[float, float],
input_type: torch.dtype = torch.int,
device: torch.device = "cuda",
device: torch.device = DEVICE_TYPE,
) -> tuple[list[torch.Tensor], list[int], list[int]]:
"""Creates random inputs.
......
......@@ -35,9 +35,9 @@ EMBEDDING_MODULES = {
"lm_head": "output_embeddings",
}
DEVICE_TYPE = current_platform.device_type
DEVICES = (
[f"cuda:{i}" for i in range(1 if torch.accelerator.device_count() == 1 else 2)]
[f"{DEVICE_TYPE}:{i}" for i in range(min(torch.accelerator.device_count(), 2))]
if current_platform.is_cuda_alike()
else ["cpu"]
)
......
......@@ -6,6 +6,9 @@ import pytest
import torch
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
DEVICE_TYPE = current_platform.device_type
def round_up(x, base):
......@@ -27,7 +30,7 @@ def sample_data(num_experts, max_loras, num_tokens, topk_num):
topk_ids[i, j] = pool[j]
token_lora_mapping[i] = random.randint(0, max_loras - 1)
return topk_ids.to("cuda"), token_lora_mapping.to("cuda")
return topk_ids.to(DEVICE_TYPE), token_lora_mapping.to(DEVICE_TYPE)
@pytest.mark.parametrize("num_tokens", [100, 200, 1024, 4096]) # 81920
......@@ -56,14 +59,21 @@ def test_moe_lora_align_block_size(
(max_loras * max_num_tokens_padded,),
topk_ids.numel(),
dtype=torch.int32,
device="cuda",
device=DEVICE_TYPE,
)
expert_ids = torch.full(
(max_loras * max_num_m_blocks,), num_experts, dtype=torch.int32, device="cuda"
(max_loras * max_num_m_blocks,),
num_experts,
dtype=torch.int32,
device=DEVICE_TYPE,
)
num_tokens_post_pad = torch.zeros(
(max_loras,), dtype=torch.int32, device=DEVICE_TYPE
)
adapter_enabled = torch.ones(
(max_loras + 1,), dtype=torch.int32, device=DEVICE_TYPE
)
num_tokens_post_pad = torch.zeros((max_loras,), dtype=torch.int32, device="cuda")
adapter_enabled = torch.ones((max_loras + 1,), dtype=torch.int32, device="cuda")
lora_ids = torch.arange(max_loras + 2, dtype=torch.int32, device="cuda")
lora_ids = torch.arange(max_loras + 2, dtype=torch.int32, device=DEVICE_TYPE)
# call kernel
ops.moe_lora_align_block_size(
......
......@@ -9,10 +9,13 @@ import vllm.lora.ops.torch_ops as torch_ops
import vllm.lora.ops.triton_ops as triton_ops
from vllm.lora.ops.triton_ops import LoRAKernelMeta
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
from .utils import PunicaTensors, assert_close, generate_data_for_nslices
DEVICE_TYPE = current_platform.device_type
@pytest.fixture(autouse=True)
def reset_device(reset_default_device):
......@@ -146,7 +149,9 @@ def check_lora_shrink_kernel(
# Setup metadata information for the LoRA kernel.
lora_meta = LoRAKernelMeta.make(
max_loras=num_loras, max_num_tokens=token_nums, device="cuda"
max_loras=num_loras,
max_num_tokens=token_nums,
device=DEVICE_TYPE,
)
lora_meta.prepare_tensors(data.token_lora_mapping)
......@@ -219,7 +224,9 @@ def check_lora_expand_kernel(
# Setup metadata information for the LoRA kernel.
lora_meta = LoRAKernelMeta.make(
max_loras=num_loras, max_num_tokens=token_nums, device="cuda"
max_loras=num_loras,
max_num_tokens=token_nums,
device=DEVICE_TYPE,
)
lora_meta.prepare_tensors(data.token_lora_mapping)
......@@ -367,7 +374,7 @@ test_params = {
}
DTYPES = [torch.float16, torch.bfloat16]
DEVICES = [f"cuda:{0}"]
DEVICES = [f"{DEVICE_TYPE}:{0}"]
SEED = [0]
......
......@@ -28,9 +28,11 @@ from vllm.lora.ops.triton_ops.lora_shrink_fp8_op import (
_SHRINK_LORA_SCALE_PTR_DICT,
)
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
DEVICES = [f"cuda:{0}"]
DEVICE_TYPE = current_platform.device_type
DEVICES = [f"{DEVICE_TYPE}:{0}"]
SEED = [0]
_dict_lock = Lock()
......
......@@ -19,11 +19,14 @@ from vllm.config.load import LoadConfig
from vllm.config.lora import LoRAConfig
from vllm.lora.model_manager import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.platforms import current_platform
from vllm.v1.worker.gpu_worker import Worker
MODEL_PATH = "Qwen/Qwen3-0.6B"
NUM_LORAS = 16
DEVICE_TYPE = current_platform.device_type
@patch.dict(os.environ, {"RANK": "0"})
def test_worker_apply_lora(qwen3_lora_files):
......@@ -61,7 +64,7 @@ def test_worker_apply_lora(qwen3_lora_files):
max_num_seqs=32,
max_num_partial_prefills=32,
),
device_config=DeviceConfig("cuda"),
device_config=DeviceConfig(DEVICE_TYPE),
cache_config=CacheConfig(
block_size=16,
cache_dtype="auto",
......
......@@ -9,10 +9,13 @@ import torch
from safetensors.torch import save_file
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
from vllm.platforms import current_platform
DEVICE_TYPE = current_platform.device_type
class DummyLoRAManager:
def __init__(self, device: torch.device = "cuda:0"):
def __init__(self, device: torch.device = f"{DEVICE_TYPE}:0"):
super().__init__()
self._loras: dict[str, LoRALayerWeights] = {}
self._device = device
......@@ -57,8 +60,8 @@ class DummyLoRAManager:
module_name,
rank=rank,
lora_alpha=1,
lora_a=torch.rand([rank, input_dim], device="cuda"),
lora_b=torch.rand([output_dim, input_dim], device="cuda"),
lora_a=torch.rand([rank, input_dim], device=DEVICE_TYPE),
lora_b=torch.rand([output_dim, input_dim], device=DEVICE_TYPE),
embeddings_tensor=embeddings_tensor,
)
self.set_module_lora(module_name, lora)
......
......@@ -40,6 +40,8 @@ BACKENDS_TO_TEST = [
"FLEX_ATTENTION_SLOW",
]
DEVICE_TYPE = current_platform.device_type
# Remove flashinfer from the list if it's not available
try:
import flashinfer # noqa: F401
......@@ -366,7 +368,7 @@ def _test_backend_correctness(
num_gpu_blocks=8192,
hf_config_override=hf_config_override,
)
device = torch.device("cuda:0")
device = torch.device(f"{DEVICE_TYPE}:0")
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
......
......@@ -7,6 +7,7 @@ import pytest
import torch
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
from vllm.platforms import current_platform
from vllm.v1.attention.backends.utils import make_local_attention_virtual_batches
......@@ -22,6 +23,8 @@ class LocalAttentionTestData:
expected_local_block_table: list[list[int]]
DEVICE_TYPE = current_platform.device_type
test_data_list = [
# Same as example in docstring of make_local_attention_virtual_batches
# except block table has 9 columns instead of 10
......@@ -151,7 +154,7 @@ test_data_list = [
@pytest.mark.parametrize("test_data", test_data_list)
def test_local_attention_virtual_batches(test_data: LocalAttentionTestData):
device = torch.device("cuda:0")
device = torch.device(f"{DEVICE_TYPE}:0")
batch_spec = test_data.batch_spec
attn_chunk_size = test_data.attn_chunk_size
block_size = test_data.block_size
......
......@@ -42,6 +42,8 @@ BACKENDS_TO_TEST = [
AttentionBackendEnum.TRITON_MLA,
]
DEVICE_TYPE = current_platform.device_type
# Remove sm100 backends from the list if not using sm100
if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10:
BACKENDS_TO_TEST.remove(AttentionBackendEnum.CUTLASS_MLA)
......@@ -763,7 +765,7 @@ def test_backend_correctness(
method="ngram", num_speculative_tokens=query_len - 1
)
device = torch.device("cuda:0")
device = torch.device(f"{DEVICE_TYPE}:0")
# 1. Setup
batch_size = batch_spec.batch_size
......
......@@ -64,6 +64,8 @@ SPARSE_BACKEND_BATCH_SPECS["large_q_pure_prefill"] = BatchSpec(
seq_lens=[256] * 2, query_lens=[256] * 2
)
DEVICE_TYPE = current_platform.device_type
def _float_to_e8m0_truncate(f: float) -> float:
"""Simulate SM100's float -> e8m0 -> bf16 scale conversion.
......@@ -222,7 +224,7 @@ def test_sparse_backend_decode_correctness(
batch_spec = SPARSE_BACKEND_BATCH_SPECS[batch_name]
use_fp8_ds_mla_quantization = kv_cache_dtype == "fp8_ds_mla"
device = torch.device("cuda")
device = torch.device(DEVICE_TYPE)
dtype = torch.bfloat16
# Model hyper-parameters (kept intentionally small for the unit test)
......@@ -586,7 +588,7 @@ def _triton_convert_reference_impl(
def test_triton_convert_req_index_to_global_index_decode_only(
block_size, num_topk_tokens
):
device = torch.device("cuda")
device = torch.device(DEVICE_TYPE)
num_tokens = 8
num_requests = 4
max_blocks_per_req = 10
......@@ -639,7 +641,7 @@ def test_triton_convert_req_index_to_global_index_decode_only(
reason="FlashMLASparseBackend requires CUDA 9.0 or higher",
)
def test_triton_convert_req_index_to_global_index_with_prefill_workspace(block_size):
device = torch.device("cuda")
device = torch.device(DEVICE_TYPE)
num_requests = 4
max_blocks_per_req = 8
num_topk_tokens = 128
......@@ -794,7 +796,7 @@ def test_split_indexer_prefill_chunks_single_request_overflow():
def test_triton_convert_returns_valid_counts():
"""Test that return_valid_counts correctly counts non-negative indices."""
device = torch.device("cuda")
device = torch.device(DEVICE_TYPE)
num_tokens = 8
num_requests = 2
max_blocks_per_req = 10
......
......@@ -55,6 +55,7 @@ class MockAttentionLayer:
MODEL = "Qwen/Qwen2.5-0.5B"
BLOCK_SIZE = 16
NUM_GPU_BLOCKS = 8192
DEVICE_TYPE = current_platform.device_type
BATCH_SPECS = {
"decode_only": BatchSpec(
......@@ -172,7 +173,7 @@ def _run_trtllm_integration(batch_spec):
"""Run TRTLLM attention through the full FlashInfer pipeline
and compare against an SDPA reference."""
set_random_seed(42)
device = torch.device("cuda:0")
device = torch.device(f"{DEVICE_TYPE}:0")
vllm_config = create_vllm_config(
model_name=MODEL,
......
......@@ -23,6 +23,8 @@ from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.platforms import current_platform
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
DEVICE_TYPE = current_platform.device_type
# Helper MLP for testing
class SimpleMLP(nn.Module):
......@@ -269,9 +271,9 @@ class TestCudagraphDispatcher:
class TestCUDAGraphWrapper:
def setup_method(self):
self.vllm_config = _create_vllm_config(CompilationConfig())
self.model = SimpleMLP().to("cuda")
self.persistent_input_buffer = torch.zeros(1, 10, device="cuda")
self.input_tensor = torch.randn(1, 10, device="cuda")
self.model = SimpleMLP().to(DEVICE_TYPE)
self.persistent_input_buffer = torch.zeros(1, 10, device=DEVICE_TYPE)
self.input_tensor = torch.randn(1, 10, device=DEVICE_TYPE)
def test_capture_and_replay(self):
wrapper = CUDAGraphWrapper(
......@@ -428,10 +430,10 @@ class TestCudagraphIntegration:
@create_new_process_for_each_test("spawn")
def test_capture_replay_bypass_logic(self):
model = SimpleMLP().to("cuda")
model = SimpleMLP().to(DEVICE_TYPE)
full_wrapper = CUDAGraphWrapper(model, self.vllm_config, CUDAGraphMode.FULL)
max_bs = 16
persistent_input_buffer = torch.zeros(max_bs, 10, device="cuda")
persistent_input_buffer = torch.zeros(max_bs, 10, device=DEVICE_TYPE)
input_1 = persistent_input_buffer[:1]
input_2 = persistent_input_buffer[:2]
input_3 = persistent_input_buffer[:3]
......@@ -486,17 +488,17 @@ class TestCudagraphIntegration:
@create_new_process_for_each_test("spawn")
def test_nested_wrappers(self):
"""Tests a scenario with a PIECEWISE wrapper inside a FULL one."""
model = SimpleMLP().to("cuda")
model = SimpleMLP().to(DEVICE_TYPE)
full_wrapper = CUDAGraphWrapper(model, self.vllm_config, CUDAGraphMode.FULL)
input_1 = torch.randn(1, 10, device="cuda")
input_1 = torch.randn(1, 10, device=DEVICE_TYPE)
# Setup: Inner model is wrapped with PIECEWISE, outer with FULL
inner_model = SimpleMLP().to("cuda")
inner_model = SimpleMLP().to(DEVICE_TYPE)
piecewise_wrapper = CUDAGraphWrapper(
inner_model, self.vllm_config, CUDAGraphMode.PIECEWISE
)
inner_model.forward = MagicMock(wraps=inner_model.forward)
outer_model = SimpleMLP().to("cuda")
outer_model = SimpleMLP().to(DEVICE_TYPE)
# When outer model is called, it calls the piecewise_wrapper
outer_model.forward = MagicMock(
wraps=outer_model.forward, side_effect=piecewise_wrapper
......
......@@ -13,6 +13,9 @@ from utils import skip_unsupported
from vllm.model_executor.layers.batch_invariant import rms_norm as triton_rms_norm
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.platforms import current_platform
DEVICE_TYPE = current_platform.device_type
@skip_unsupported
......@@ -34,7 +37,7 @@ def test_rms_norm_batch_invariant_vs_standard(
equivalent results to the standard CUDA implementation across various
configurations.
"""
device = torch.device("cuda")
device = torch.device(DEVICE_TYPE)
# Create test input and weight
torch.manual_seed(42)
......@@ -81,7 +84,7 @@ def test_rms_norm_3d_input(
Ensures that the batch-invariant RMS norm correctly handles multi-dimensional
inputs that are common in transformer models.
"""
device = torch.device("cuda")
device = torch.device(DEVICE_TYPE)
dtype = torch.bfloat16
eps = 1e-6
......@@ -120,7 +123,7 @@ def test_rms_norm_numerical_stability(default_vllm_config):
Ensures that both implementations handle edge cases like very small or large
values without producing NaN or Inf.
"""
device = torch.device("cuda")
device = torch.device(DEVICE_TYPE)
dtype = torch.float16
eps = 1e-6
hidden_size = 2048
......@@ -179,7 +182,7 @@ def test_rms_norm_formula(default_vllm_config):
Verifies: output = input / sqrt(mean(input^2) + eps) * weight
"""
device = torch.device("cuda")
device = torch.device(DEVICE_TYPE)
dtype = torch.float32 # Use float32 for higher precision in formula check
eps = 1e-6
hidden_size = 1024
......@@ -214,7 +217,7 @@ def test_rms_norm_different_hidden_sizes(default_vllm_config, hidden_size: int):
The Triton kernel uses a fixed BLOCK_SIZE=1024, so this tests that it
correctly handles hidden sizes both smaller and larger than the block size.
"""
device = torch.device("cuda")
device = torch.device(DEVICE_TYPE)
dtype = torch.bfloat16
eps = 1e-6
batch_size = 16
......@@ -251,7 +254,7 @@ def test_rms_norm_determinism(default_vllm_config):
Runs the same input through the kernel multiple times and verifies
identical outputs.
"""
device = torch.device("cuda")
device = torch.device(DEVICE_TYPE)
dtype = torch.bfloat16
eps = 1e-6
hidden_size = 4096
......@@ -283,7 +286,7 @@ if __name__ == "__main__":
# Run a quick smoke test
print("Running quick smoke test of RMS norm implementations...")
device = torch.device("cuda")
device = torch.device(DEVICE_TYPE)
batch_size = 8
hidden_size = 4096
dtype = torch.bfloat16
......
......@@ -16,6 +16,7 @@ from vllm import LLM, SamplingParams, TokensPrompt
from vllm.config import CacheConfig
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.model_executor.layers.mamba.mamba_utils import MambaStateCopyFunc
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
......@@ -48,6 +49,7 @@ num_accepted_tokens = 1
prompt_token_ids: list[int] = []
MODEL = "Qwen/Qwen3-Next-80B-A3B-Instruct-FP8"
BLOCK_SIZE = 560
DEVICE_TYPE = current_platform.device_type
NUM_HIDDEN_LAYERS = 1
cur_step_action_idx = 0
cur_step_action: StepAction | None = None
......@@ -71,7 +73,7 @@ def get_fake_sample_fn() -> SamplerOutput:
return SamplerOutput(
sampled_token_ids=torch.tensor(
[[prompt_token_ids[first_token_id_index]]],
device="cuda",
device=DEVICE_TYPE,
dtype=torch.int32,
),
logprobs_tensors=None,
......@@ -83,7 +85,9 @@ def get_fake_sample_fn() -> SamplerOutput:
sampled_token_ids = accepted_tokens
return SamplerOutput(
sampled_token_ids=torch.tensor(
[sampled_token_ids], device="cuda", dtype=torch.int32
[sampled_token_ids],
device=DEVICE_TYPE,
dtype=torch.int32,
),
logprobs_tensors=None,
)
......@@ -128,17 +132,23 @@ def get_fake_propose_draft_token_ids_fn():
- 1
+ num_accepted_tokens
],
device="cuda",
device=DEVICE_TYPE,
dtype=torch.int32,
)
valid_sampled_tokens_count = torch.tensor(
[num_accepted_tokens], device="cuda", dtype=torch.int32
[num_accepted_tokens],
device=DEVICE_TYPE,
dtype=torch.int32,
)
self._copy_valid_sampled_token_count(next_token_ids, valid_sampled_tokens_count)
return torch.tensor(proposed_draft_token_ids, device="cuda", dtype=torch.int32)
return torch.tensor(
proposed_draft_token_ids,
device=DEVICE_TYPE,
dtype=torch.int32,
)
return fake_propose_draft_token_ids_fn
......
......@@ -6,6 +6,7 @@ import time
import pytest
import torch
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec
from vllm.v1.kv_offload.spec import (
......@@ -21,7 +22,8 @@ GPU_PAGE_SIZES = [512, 1024]
BLOCK_SIZE_FACTORS = [1, 3]
NUM_TENSORS = [4]
SEEDS = [0]
CUDA_DEVICES = ["cuda:0"]
DEVICE_TYPE = current_platform.device_type
DEVICES = [f"{DEVICE_TYPE}:0"]
NUM_MAPPINGS = [3]
......@@ -33,7 +35,7 @@ NUM_MAPPINGS = [3]
@pytest.mark.parametrize("num_cpu_blocks", NUM_CPU_BLOCKS)
@pytest.mark.parametrize("num_tensors", NUM_TENSORS)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_transfer(
default_vllm_config,
......
......@@ -39,8 +39,9 @@ PIN_MEMORY_AVAILABLE = is_pin_memory_available()
MAX_NUM_REQS = 256
VOCAB_SIZE = 1024
NUM_OUTPUT_TOKENS = 20
CUDA_DEVICES = [
f"{current_platform.device_type}:{i}"
DEVICE_TYPE = current_platform.device_type
DEVICES = [
f"{DEVICE_TYPE}:{i}"
for i in range(1 if current_platform.device_count() == 1 else 2)
]
MAX_NUM_PROMPT_TOKENS = 64
......@@ -801,7 +802,7 @@ def _assert_valid(
@create_new_process_for_each_test()
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("reqs_per_logitproc", [REQS_PER_LOGITPROC])
@pytest.mark.parametrize("logitsprocs_under_test", _get_test_cases())
def test_logitsprocs(
......
......@@ -19,7 +19,7 @@ from vllm.v1.sample.rejection_sampler import (
from vllm.v1.sample.sampler import Sampler, SamplerOutput
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
DEVICE = current_platform.device_type
DEVICE_TYPE = current_platform.device_type
@pytest.fixture
......@@ -57,7 +57,7 @@ def create_logits_tensor(
will produce desired token ids on argmax"""
token_ids = [tokens[:-1] for tokens in output_token_ids]
num_total_tokens = sum(len(tokens) for tokens in token_ids)
logits = torch.full((num_total_tokens, vocab_size), -100.0, device=DEVICE)
logits = torch.full((num_total_tokens, vocab_size), -100.0, device=DEVICE_TYPE)
start_loc = 0
for tokens in token_ids:
for j, token_id in enumerate(tokens):
......@@ -99,9 +99,9 @@ def create_sampling_metadata(
assert output_token_ids
assert len(output_token_ids) > 0
frequency_penalties = torch.tensor(frequency_penalties, device=DEVICE)
presence_penalties = torch.tensor(presence_penalties, device=DEVICE)
repetition_penalties = torch.tensor(repetition_penalties, device=DEVICE)
frequency_penalties = torch.tensor(frequency_penalties, device=DEVICE_TYPE)
presence_penalties = torch.tensor(presence_penalties, device=DEVICE_TYPE)
repetition_penalties = torch.tensor(repetition_penalties, device=DEVICE_TYPE)
else:
no_penalties = True
frequency_penalties = torch.tensor([])
......@@ -320,14 +320,27 @@ def test_deterministic_when_seeded(
n_rep: int,
):
num_tokens = batch_size * k
draft_probs = torch.rand(num_tokens, vocab_size, dtype=torch.float32, device=DEVICE)
draft_probs = torch.rand(
num_tokens,
vocab_size,
dtype=torch.float32,
device=DEVICE_TYPE,
)
draft_probs = F.softmax(draft_probs, dim=-1)
target_logits = torch.rand_like(draft_probs)
bonus_token_ids = torch.randint(
low=0, high=vocab_size, size=(batch_size, 1), dtype=torch.int64, device=DEVICE
low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64,
device=DEVICE_TYPE,
)
draft_token_ids = torch.randint(
low=0, high=vocab_size, size=(batch_size, k), dtype=torch.int64, device=DEVICE
low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64,
device=DEVICE_TYPE,
)
seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded
......@@ -335,12 +348,12 @@ def test_deterministic_when_seeded(
results = []
for _ in range(n_rep):
seeded_seqs = {
i: torch.Generator(device=DEVICE).manual_seed(i)
i: torch.Generator(device=DEVICE_TYPE).manual_seed(i)
for i in range(batch_size)
if seeded_mask[i]
}
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE_TYPE)
sampling_metadata = create_sampling_metadata(
all_greedy=False, temperature=temperature, generators=seeded_seqs
)
......@@ -387,7 +400,7 @@ def test_rejection_sampling_approximates_target_distribution():
much more than the distance improvement between the observed
distribution and the random distribution.
"""
torch.set_default_device(DEVICE)
torch.set_default_device(DEVICE_TYPE)
vocab_size = 10
k = 2
num_reference_probs = 100
......@@ -410,7 +423,7 @@ def test_rejection_sampling_approximates_target_distribution():
rej_sample_probs = estimate_rejection_sampling_pdf(
draft_probs, target_logits, k, vocab_size, num_samples
)
rej_sample_probs = rej_sample_probs.to(DEVICE)
rej_sample_probs = rej_sample_probs.to(DEVICE_TYPE)
# Average distance from reference probs.
reference_vs_rejsample_dist = (
......@@ -491,11 +504,11 @@ def estimate_rejection_sampling_pdf(
draft_probs = draft_probs.view(num_tokens, vocab_size)
# Bonus tokens not used but required.
bonus_token_ids = torch.zeros((1, 1), dtype=torch.int64, device=DEVICE).repeat(
bonus_token_ids = torch.zeros((1, 1), dtype=torch.int64, device=DEVICE_TYPE).repeat(
num_samples, 1
)
temperature = torch.ones(num_samples, dtype=torch.float32, device=DEVICE)
temperature = torch.ones(num_samples, dtype=torch.float32, device=DEVICE_TYPE)
sampling_metadata = create_sampling_metadata(
all_greedy=False, temperature=temperature
)
......@@ -600,7 +613,7 @@ def _test_masked_logits(
# Create random draft probabilities.
draft_probs = torch.rand(
(num_tokens, vocab_size), dtype=torch.float32, device=DEVICE
(num_tokens, vocab_size), dtype=torch.float32, device=DEVICE_TYPE
)
draft_probs = F.softmax(draft_probs, dim=-1)
......@@ -610,7 +623,11 @@ def _test_masked_logits(
draft_token_ids = draft_token_ids.tolist()
# Bonus tokens not used but required
bonus_token_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=DEVICE)
bonus_token_ids = torch.zeros(
(batch_size, 1),
dtype=torch.int64,
device=DEVICE_TYPE,
)
# Create spec decode metadata
spec_decode_metadata = create_spec_decode_metadata(draft_token_ids, target_logits)
......@@ -645,12 +662,13 @@ def test_top_k(rejection_sampler, top_k):
# Randomly create top-k indices.
top_k_indices = [
torch.randperm(vocab_size, device=DEVICE)[:top_k] for _ in range(num_tokens)
torch.randperm(vocab_size, device=DEVICE_TYPE)[:top_k]
for _ in range(num_tokens)
]
top_k_indices = torch.stack(top_k_indices)
# Create logits with the uniform distribution.
target_logits = torch.zeros((num_tokens, vocab_size), device=DEVICE)
target_logits = torch.zeros((num_tokens, vocab_size), device=DEVICE_TYPE)
# Increment the logits for top-k indices, a little bit more than the other
# ones. If the masking is effective, the non-topk indices will never be
......@@ -659,11 +677,11 @@ def test_top_k(rejection_sampler, top_k):
target_logits[i, top_k_indices[i]] += 0.1
# Create sampling metadata
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE_TYPE)
sampling_metadata = create_sampling_metadata(
all_greedy=False,
temperature=temperature,
top_k=torch.tensor([top_k] * batch_size, device=DEVICE, dtype=torch.int64),
top_k=torch.tensor([top_k] * batch_size, device=DEVICE_TYPE, dtype=torch.int64),
)
_test_masked_logits(
......@@ -686,8 +704,8 @@ def test_top_p(rejection_sampler, top_p):
num_tokens = batch_size * num_draft_tokens
# Create logits with the uniform distribution.
target_logits = torch.randn((num_tokens, vocab_size), device=DEVICE)
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
target_logits = torch.randn((num_tokens, vocab_size), device=DEVICE_TYPE)
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE_TYPE)
rescaled_logits = target_logits / temperature
logits_sort, logits_idx = rescaled_logits.sort(dim=-1, descending=False)
......@@ -706,7 +724,11 @@ def test_top_p(rejection_sampler, top_p):
sampling_metadata = create_sampling_metadata(
all_greedy=False,
temperature=temperature,
top_p=torch.tensor([top_p] * batch_size, device=DEVICE, dtype=torch.float32),
top_p=torch.tensor(
[top_p] * batch_size,
device=DEVICE_TYPE,
dtype=torch.float32,
),
)
_test_masked_logits(
......@@ -732,7 +754,10 @@ def test_frequency_penalties(rejection_sampler):
all_greedy=True,
output_token_ids=[[2], [3], [4]],
spec_token_ids=spec_tokens,
prompt_token_ids=torch.tensor([[5, 6, 7], [6, 7, 8], [7, 8, 9]], device=DEVICE),
prompt_token_ids=torch.tensor(
[[5, 6, 7], [6, 7, 8], [7, 8, 9]],
device=DEVICE_TYPE,
),
frequency_penalties=[1.5, 1.5, 0.7],
presence_penalties=[0.0] * num_requests,
repetition_penalties=[1.0] * num_requests,
......@@ -858,21 +883,26 @@ def test_sample_recovered_tokens(
num_tokens = batch_size * max_spec_len
# Create random draft probabilities.
draft_probs = torch.rand(num_tokens, vocab_size, dtype=torch.float32, device=DEVICE)
draft_probs = torch.rand(
num_tokens,
vocab_size,
dtype=torch.float32,
device=DEVICE_TYPE,
)
draft_probs = F.softmax(draft_probs, dim=-1)
# Create random target probabilities.
target_logits = torch.rand(
num_tokens, vocab_size, dtype=torch.float32, device=DEVICE
num_tokens, vocab_size, dtype=torch.float32, device=DEVICE_TYPE
)
target_probs = F.softmax(target_logits, dim=-1)
# Randomly sample draft token ids from draft probs
draft_token_ids = torch.multinomial(draft_probs, num_samples=1).to(torch.int32)
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE_TYPE)
generators = {
i: torch.Generator(device=DEVICE).manual_seed(i) for i in range(batch_size)
i: torch.Generator(device=DEVICE_TYPE).manual_seed(i) for i in range(batch_size)
}
sampling_metadata = create_sampling_metadata(
all_greedy=False, temperature=temperature, generators=generators
......@@ -890,7 +920,7 @@ def test_sample_recovered_tokens(
None if no_draft_probs else draft_probs,
target_probs,
sampling_metadata,
device=DEVICE,
device=DEVICE_TYPE,
)
recovered_token_ids = sample_recovered_tokens(
max_spec_len,
......@@ -900,6 +930,6 @@ def test_sample_recovered_tokens(
None if no_draft_probs else draft_probs,
target_probs,
sampling_metadata,
device=DEVICE,
device=DEVICE_TYPE,
)
assert torch.equal(recovered_token_ids, ref_recovered_token_ids)
......@@ -17,8 +17,9 @@ PIN_MEMORY_AVAILABLE = is_pin_memory_available()
MAX_NUM_REQS = 256
VOCAB_SIZE = 1024
NUM_OUTPUT_TOKENS = 20
CUDA_DEVICES = [
f"{current_platform.device_type}:{i}"
DEVICE_TYPE = current_platform.device_type
DEVICES = [
f"{DEVICE_TYPE}:{i}"
for i in range(1 if current_platform.device_count() == 1 else 2)
]
MAX_NUM_PROMPT_TOKENS = 64
......@@ -199,7 +200,7 @@ def _create_weighted_output_token_list(
return output_token_ids, sorted_token_ids_in_output
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("presence_penalty", [-2.0, 2.0])
def test_sampler_presence_penalty(
......@@ -249,7 +250,7 @@ def test_sampler_presence_penalty(
assert penalized_token_id not in output_token_ids[batch_idx]
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("frequency_penalty", [-2.0, 2.0])
def test_sampler_frequency_penalty(
......@@ -305,7 +306,7 @@ def test_sampler_frequency_penalty(
assert penalized_token_id not in distinct_sorted_token_ids_in_output
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("repetition_penalty", [0.1, 1.9])
def test_sampler_repetition_penalty(
......@@ -363,7 +364,7 @@ def test_sampler_repetition_penalty(
)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("num_allowed_token_ids", [0, 1, 2])
def test_sampler_allowed_token_ids(
......@@ -409,7 +410,7 @@ def test_sampler_allowed_token_ids(
assert logits_for_req[token_id] != -float("inf")
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("bad_words_lengths", [(1,), (1, 3), (2, 2)])
def test_sampler_bad_words(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment