Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
32e0c0bf
Unverified
Commit
32e0c0bf
authored
Apr 02, 2026
by
wliao2
Committed by
GitHub
Apr 03, 2026
Browse files
refactor hard coded device string in test files under tests/v1 and tests/lora (#37566)
Signed-off-by:
Liao, Wei
<
wei.liao@intel.com
>
parent
4a06e124
Changes
28
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
178 additions
and
90 deletions
+178
-90
tests/lora/test_fused_moe_lora_kernel.py
tests/lora/test_fused_moe_lora_kernel.py
+1
-1
tests/lora/test_layers.py
tests/lora/test_layers.py
+6
-2
tests/lora/test_lora_manager.py
tests/lora/test_lora_manager.py
+2
-2
tests/lora/test_moe_lora_align_sum.py
tests/lora/test_moe_lora_align_sum.py
+16
-6
tests/lora/test_punica_ops.py
tests/lora/test_punica_ops.py
+10
-3
tests/lora/test_punica_ops_fp8.py
tests/lora/test_punica_ops_fp8.py
+3
-1
tests/lora/test_worker.py
tests/lora/test_worker.py
+4
-1
tests/lora/utils.py
tests/lora/utils.py
+6
-3
tests/v1/attention/test_attention_backends.py
tests/v1/attention/test_attention_backends.py
+3
-1
tests/v1/attention/test_chunked_local_attention.py
tests/v1/attention/test_chunked_local_attention.py
+4
-1
tests/v1/attention/test_mla_backends.py
tests/v1/attention/test_mla_backends.py
+3
-1
tests/v1/attention/test_sparse_mla_backends.py
tests/v1/attention/test_sparse_mla_backends.py
+6
-4
tests/v1/attention/test_trtllm_attention_integration.py
tests/v1/attention/test_trtllm_attention_integration.py
+2
-1
tests/v1/cudagraph/test_cudagraph_dispatch.py
tests/v1/cudagraph/test_cudagraph_dispatch.py
+11
-9
tests/v1/determinism/test_rms_norm_batch_invariant.py
tests/v1/determinism/test_rms_norm_batch_invariant.py
+10
-7
tests/v1/e2e/general/test_mamba_prefix_cache.py
tests/v1/e2e/general/test_mamba_prefix_cache.py
+15
-5
tests/v1/kv_offload/test_cpu_gpu.py
tests/v1/kv_offload/test_cpu_gpu.py
+4
-2
tests/v1/logits_processors/test_correctness.py
tests/v1/logits_processors/test_correctness.py
+4
-3
tests/v1/sample/test_rejection_sampler.py
tests/v1/sample/test_rejection_sampler.py
+60
-30
tests/v1/sample/test_sampler.py
tests/v1/sample/test_sampler.py
+8
-7
No files found.
tests/lora/test_fused_moe_lora_kernel.py
View file @
32e0c0bf
...
...
@@ -637,7 +637,7 @@ def use_fused_moe_lora_kernel_tensor_parallel(
set_random_seed
(
seed
)
device
=
torch
.
device
(
f
"
cuda
:
{
local_rank
}
"
)
device
=
torch
.
device
(
f
"
{
DEVICE_TYPE
}
:
{
local_rank
}
"
)
torch
.
accelerator
.
set_device_index
(
device
)
torch
.
set_default_device
(
device
)
torch
.
set_default_dtype
(
dtype
)
...
...
tests/lora/test_layers.py
View file @
32e0c0bf
...
...
@@ -60,8 +60,12 @@ pytestmark = pytest.mark.skipif(
reason
=
"Backend not supported"
,
)
DEVICE_TYPE
=
current_platform
.
device_type
DEVICES
=
(
[
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
accelerator
.
device_count
()
==
1
else
2
)]
[
f
"
{
DEVICE_TYPE
}
:
{
i
}
"
for
i
in
range
(
1
if
torch
.
accelerator
.
device_count
()
==
1
else
2
)
]
if
current_platform
.
is_cuda_alike
()
else
[
"cpu"
]
)
...
...
@@ -196,7 +200,7 @@ def create_random_inputs(
input_size
:
tuple
[
int
,
...],
input_range
:
tuple
[
float
,
float
],
input_type
:
torch
.
dtype
=
torch
.
int
,
device
:
torch
.
device
=
"cuda"
,
device
:
torch
.
device
=
DEVICE_TYPE
,
)
->
tuple
[
list
[
torch
.
Tensor
],
list
[
int
],
list
[
int
]]:
"""Creates random inputs.
...
...
tests/lora/test_lora_manager.py
View file @
32e0c0bf
...
...
@@ -35,9 +35,9 @@ EMBEDDING_MODULES = {
"lm_head"
:
"output_embeddings"
,
}
DEVICE_TYPE
=
current_platform
.
device_type
DEVICES
=
(
[
f
"
cuda
:
{
i
}
"
for
i
in
range
(
1
if
torch
.
accelerator
.
device_count
()
==
1
else
2
)]
[
f
"
{
DEVICE_TYPE
}
:
{
i
}
"
for
i
in
range
(
min
(
torch
.
accelerator
.
device_count
()
,
2
)
)
]
if
current_platform
.
is_cuda_alike
()
else
[
"cpu"
]
)
...
...
tests/lora/test_moe_lora_align_sum.py
View file @
32e0c0bf
...
...
@@ -6,6 +6,9 @@ import pytest
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
DEVICE_TYPE
=
current_platform
.
device_type
def
round_up
(
x
,
base
):
...
...
@@ -27,7 +30,7 @@ def sample_data(num_experts, max_loras, num_tokens, topk_num):
topk_ids
[
i
,
j
]
=
pool
[
j
]
token_lora_mapping
[
i
]
=
random
.
randint
(
0
,
max_loras
-
1
)
return
topk_ids
.
to
(
"cuda"
),
token_lora_mapping
.
to
(
"cuda"
)
return
topk_ids
.
to
(
DEVICE_TYPE
),
token_lora_mapping
.
to
(
DEVICE_TYPE
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
100
,
200
,
1024
,
4096
])
# 81920
...
...
@@ -56,14 +59,21 @@ def test_moe_lora_align_block_size(
(
max_loras
*
max_num_tokens_padded
,),
topk_ids
.
numel
(),
dtype
=
torch
.
int32
,
device
=
"cuda"
,
device
=
DEVICE_TYPE
,
)
expert_ids
=
torch
.
full
(
(
max_loras
*
max_num_m_blocks
,),
num_experts
,
dtype
=
torch
.
int32
,
device
=
"cuda"
(
max_loras
*
max_num_m_blocks
,),
num_experts
,
dtype
=
torch
.
int32
,
device
=
DEVICE_TYPE
,
)
num_tokens_post_pad
=
torch
.
zeros
(
(
max_loras
,),
dtype
=
torch
.
int32
,
device
=
DEVICE_TYPE
)
adapter_enabled
=
torch
.
ones
(
(
max_loras
+
1
,),
dtype
=
torch
.
int32
,
device
=
DEVICE_TYPE
)
num_tokens_post_pad
=
torch
.
zeros
((
max_loras
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
adapter_enabled
=
torch
.
ones
((
max_loras
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
lora_ids
=
torch
.
arange
(
max_loras
+
2
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
lora_ids
=
torch
.
arange
(
max_loras
+
2
,
dtype
=
torch
.
int32
,
device
=
DEVICE_TYPE
)
# call kernel
ops
.
moe_lora_align_block_size
(
...
...
tests/lora/test_punica_ops.py
View file @
32e0c0bf
...
...
@@ -9,10 +9,13 @@ import vllm.lora.ops.torch_ops as torch_ops
import
vllm.lora.ops.triton_ops
as
triton_ops
from
vllm.lora.ops.triton_ops
import
LoRAKernelMeta
from
vllm.lora.ops.triton_ops.utils
import
_LORA_A_PTR_DICT
,
_LORA_B_PTR_DICT
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
set_random_seed
from
.utils
import
PunicaTensors
,
assert_close
,
generate_data_for_nslices
DEVICE_TYPE
=
current_platform
.
device_type
@
pytest
.
fixture
(
autouse
=
True
)
def
reset_device
(
reset_default_device
):
...
...
@@ -146,7 +149,9 @@ def check_lora_shrink_kernel(
# Setup metadata information for the LoRA kernel.
lora_meta
=
LoRAKernelMeta
.
make
(
max_loras
=
num_loras
,
max_num_tokens
=
token_nums
,
device
=
"cuda"
max_loras
=
num_loras
,
max_num_tokens
=
token_nums
,
device
=
DEVICE_TYPE
,
)
lora_meta
.
prepare_tensors
(
data
.
token_lora_mapping
)
...
...
@@ -219,7 +224,9 @@ def check_lora_expand_kernel(
# Setup metadata information for the LoRA kernel.
lora_meta
=
LoRAKernelMeta
.
make
(
max_loras
=
num_loras
,
max_num_tokens
=
token_nums
,
device
=
"cuda"
max_loras
=
num_loras
,
max_num_tokens
=
token_nums
,
device
=
DEVICE_TYPE
,
)
lora_meta
.
prepare_tensors
(
data
.
token_lora_mapping
)
...
...
@@ -367,7 +374,7 @@ test_params = {
}
DTYPES
=
[
torch
.
float16
,
torch
.
bfloat16
]
DEVICES
=
[
f
"
cuda
:
{
0
}
"
]
DEVICES
=
[
f
"
{
DEVICE_TYPE
}
:
{
0
}
"
]
SEED
=
[
0
]
...
...
tests/lora/test_punica_ops_fp8.py
View file @
32e0c0bf
...
...
@@ -28,9 +28,11 @@ from vllm.lora.ops.triton_ops.lora_shrink_fp8_op import (
_SHRINK_LORA_SCALE_PTR_DICT
,
)
from
vllm.lora.ops.triton_ops.utils
import
_LORA_A_PTR_DICT
,
_LORA_B_PTR_DICT
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
set_random_seed
DEVICES
=
[
f
"cuda:
{
0
}
"
]
DEVICE_TYPE
=
current_platform
.
device_type
DEVICES
=
[
f
"
{
DEVICE_TYPE
}
:
{
0
}
"
]
SEED
=
[
0
]
_dict_lock
=
Lock
()
...
...
tests/lora/test_worker.py
View file @
32e0c0bf
...
...
@@ -19,11 +19,14 @@ from vllm.config.load import LoadConfig
from
vllm.config.lora
import
LoRAConfig
from
vllm.lora.model_manager
import
LoRAMapping
from
vllm.lora.request
import
LoRARequest
from
vllm.platforms
import
current_platform
from
vllm.v1.worker.gpu_worker
import
Worker
MODEL_PATH
=
"Qwen/Qwen3-0.6B"
NUM_LORAS
=
16
DEVICE_TYPE
=
current_platform
.
device_type
@
patch
.
dict
(
os
.
environ
,
{
"RANK"
:
"0"
})
def
test_worker_apply_lora
(
qwen3_lora_files
):
...
...
@@ -61,7 +64,7 @@ def test_worker_apply_lora(qwen3_lora_files):
max_num_seqs
=
32
,
max_num_partial_prefills
=
32
,
),
device_config
=
DeviceConfig
(
"cuda"
),
device_config
=
DeviceConfig
(
DEVICE_TYPE
),
cache_config
=
CacheConfig
(
block_size
=
16
,
cache_dtype
=
"auto"
,
...
...
tests/lora/utils.py
View file @
32e0c0bf
...
...
@@ -9,10 +9,13 @@ import torch
from
safetensors.torch
import
save_file
from
vllm.lora.lora_weights
import
LoRALayerWeights
,
PackedLoRALayerWeights
from
vllm.platforms
import
current_platform
DEVICE_TYPE
=
current_platform
.
device_type
class
DummyLoRAManager
:
def
__init__
(
self
,
device
:
torch
.
device
=
"cuda
:0"
):
def
__init__
(
self
,
device
:
torch
.
device
=
f
"
{
DEVICE_TYPE
}
:0"
):
super
().
__init__
()
self
.
_loras
:
dict
[
str
,
LoRALayerWeights
]
=
{}
self
.
_device
=
device
...
...
@@ -57,8 +60,8 @@ class DummyLoRAManager:
module_name
,
rank
=
rank
,
lora_alpha
=
1
,
lora_a
=
torch
.
rand
([
rank
,
input_dim
],
device
=
"cuda"
),
lora_b
=
torch
.
rand
([
output_dim
,
input_dim
],
device
=
"cuda"
),
lora_a
=
torch
.
rand
([
rank
,
input_dim
],
device
=
DEVICE_TYPE
),
lora_b
=
torch
.
rand
([
output_dim
,
input_dim
],
device
=
DEVICE_TYPE
),
embeddings_tensor
=
embeddings_tensor
,
)
self
.
set_module_lora
(
module_name
,
lora
)
...
...
tests/v1/attention/test_attention_backends.py
View file @
32e0c0bf
...
...
@@ -40,6 +40,8 @@ BACKENDS_TO_TEST = [
"FLEX_ATTENTION_SLOW"
,
]
DEVICE_TYPE
=
current_platform
.
device_type
# Remove flashinfer from the list if it's not available
try
:
import
flashinfer
# noqa: F401
...
...
@@ -366,7 +368,7 @@ def _test_backend_correctness(
num_gpu_blocks
=
8192
,
hf_config_override
=
hf_config_override
,
)
device
=
torch
.
device
(
"cuda
:0"
)
device
=
torch
.
device
(
f
"
{
DEVICE_TYPE
}
:0"
)
kv_cache_spec
=
create_standard_kv_cache_spec
(
vllm_config
)
...
...
tests/v1/attention/test_chunked_local_attention.py
View file @
32e0c0bf
...
...
@@ -7,6 +7,7 @@ import pytest
import
torch
from
tests.v1.attention.utils
import
BatchSpec
,
create_common_attn_metadata
from
vllm.platforms
import
current_platform
from
vllm.v1.attention.backends.utils
import
make_local_attention_virtual_batches
...
...
@@ -22,6 +23,8 @@ class LocalAttentionTestData:
expected_local_block_table
:
list
[
list
[
int
]]
DEVICE_TYPE
=
current_platform
.
device_type
test_data_list
=
[
# Same as example in docstring of make_local_attention_virtual_batches
# except block table has 9 columns instead of 10
...
...
@@ -151,7 +154,7 @@ test_data_list = [
@
pytest
.
mark
.
parametrize
(
"test_data"
,
test_data_list
)
def
test_local_attention_virtual_batches
(
test_data
:
LocalAttentionTestData
):
device
=
torch
.
device
(
"cuda
:0"
)
device
=
torch
.
device
(
f
"
{
DEVICE_TYPE
}
:0"
)
batch_spec
=
test_data
.
batch_spec
attn_chunk_size
=
test_data
.
attn_chunk_size
block_size
=
test_data
.
block_size
...
...
tests/v1/attention/test_mla_backends.py
View file @
32e0c0bf
...
...
@@ -42,6 +42,8 @@ BACKENDS_TO_TEST = [
AttentionBackendEnum
.
TRITON_MLA
,
]
DEVICE_TYPE
=
current_platform
.
device_type
# Remove sm100 backends from the list if not using sm100
if
not
torch
.
cuda
.
is_available
()
or
torch
.
cuda
.
get_device_properties
(
0
).
major
<
10
:
BACKENDS_TO_TEST
.
remove
(
AttentionBackendEnum
.
CUTLASS_MLA
)
...
...
@@ -763,7 +765,7 @@ def test_backend_correctness(
method
=
"ngram"
,
num_speculative_tokens
=
query_len
-
1
)
device
=
torch
.
device
(
"cuda
:0"
)
device
=
torch
.
device
(
f
"
{
DEVICE_TYPE
}
:0"
)
# 1. Setup
batch_size
=
batch_spec
.
batch_size
...
...
tests/v1/attention/test_sparse_mla_backends.py
View file @
32e0c0bf
...
...
@@ -64,6 +64,8 @@ SPARSE_BACKEND_BATCH_SPECS["large_q_pure_prefill"] = BatchSpec(
seq_lens
=
[
256
]
*
2
,
query_lens
=
[
256
]
*
2
)
DEVICE_TYPE
=
current_platform
.
device_type
def
_float_to_e8m0_truncate
(
f
:
float
)
->
float
:
"""Simulate SM100's float -> e8m0 -> bf16 scale conversion.
...
...
@@ -222,7 +224,7 @@ def test_sparse_backend_decode_correctness(
batch_spec
=
SPARSE_BACKEND_BATCH_SPECS
[
batch_name
]
use_fp8_ds_mla_quantization
=
kv_cache_dtype
==
"fp8_ds_mla"
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
DEVICE_TYPE
)
dtype
=
torch
.
bfloat16
# Model hyper-parameters (kept intentionally small for the unit test)
...
...
@@ -586,7 +588,7 @@ def _triton_convert_reference_impl(
def
test_triton_convert_req_index_to_global_index_decode_only
(
block_size
,
num_topk_tokens
):
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
DEVICE_TYPE
)
num_tokens
=
8
num_requests
=
4
max_blocks_per_req
=
10
...
...
@@ -639,7 +641,7 @@ def test_triton_convert_req_index_to_global_index_decode_only(
reason
=
"FlashMLASparseBackend requires CUDA 9.0 or higher"
,
)
def
test_triton_convert_req_index_to_global_index_with_prefill_workspace
(
block_size
):
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
DEVICE_TYPE
)
num_requests
=
4
max_blocks_per_req
=
8
num_topk_tokens
=
128
...
...
@@ -794,7 +796,7 @@ def test_split_indexer_prefill_chunks_single_request_overflow():
def
test_triton_convert_returns_valid_counts
():
"""Test that return_valid_counts correctly counts non-negative indices."""
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
DEVICE_TYPE
)
num_tokens
=
8
num_requests
=
2
max_blocks_per_req
=
10
...
...
tests/v1/attention/test_trtllm_attention_integration.py
View file @
32e0c0bf
...
...
@@ -55,6 +55,7 @@ class MockAttentionLayer:
MODEL
=
"Qwen/Qwen2.5-0.5B"
BLOCK_SIZE
=
16
NUM_GPU_BLOCKS
=
8192
DEVICE_TYPE
=
current_platform
.
device_type
BATCH_SPECS
=
{
"decode_only"
:
BatchSpec
(
...
...
@@ -172,7 +173,7 @@ def _run_trtllm_integration(batch_spec):
"""Run TRTLLM attention through the full FlashInfer pipeline
and compare against an SDPA reference."""
set_random_seed
(
42
)
device
=
torch
.
device
(
"cuda
:0"
)
device
=
torch
.
device
(
f
"
{
DEVICE_TYPE
}
:0"
)
vllm_config
=
create_vllm_config
(
model_name
=
MODEL
,
...
...
tests/v1/cudagraph/test_cudagraph_dispatch.py
View file @
32e0c0bf
...
...
@@ -23,6 +23,8 @@ from vllm.forward_context import BatchDescriptor, set_forward_context
from
vllm.platforms
import
current_platform
from
vllm.v1.cudagraph_dispatcher
import
CudagraphDispatcher
DEVICE_TYPE
=
current_platform
.
device_type
# Helper MLP for testing
class
SimpleMLP
(
nn
.
Module
):
...
...
@@ -269,9 +271,9 @@ class TestCudagraphDispatcher:
class
TestCUDAGraphWrapper
:
def
setup_method
(
self
):
self
.
vllm_config
=
_create_vllm_config
(
CompilationConfig
())
self
.
model
=
SimpleMLP
().
to
(
"cuda"
)
self
.
persistent_input_buffer
=
torch
.
zeros
(
1
,
10
,
device
=
"cuda"
)
self
.
input_tensor
=
torch
.
randn
(
1
,
10
,
device
=
"cuda"
)
self
.
model
=
SimpleMLP
().
to
(
DEVICE_TYPE
)
self
.
persistent_input_buffer
=
torch
.
zeros
(
1
,
10
,
device
=
DEVICE_TYPE
)
self
.
input_tensor
=
torch
.
randn
(
1
,
10
,
device
=
DEVICE_TYPE
)
def
test_capture_and_replay
(
self
):
wrapper
=
CUDAGraphWrapper
(
...
...
@@ -428,10 +430,10 @@ class TestCudagraphIntegration:
@
create_new_process_for_each_test
(
"spawn"
)
def
test_capture_replay_bypass_logic
(
self
):
model
=
SimpleMLP
().
to
(
"cuda"
)
model
=
SimpleMLP
().
to
(
DEVICE_TYPE
)
full_wrapper
=
CUDAGraphWrapper
(
model
,
self
.
vllm_config
,
CUDAGraphMode
.
FULL
)
max_bs
=
16
persistent_input_buffer
=
torch
.
zeros
(
max_bs
,
10
,
device
=
"cuda"
)
persistent_input_buffer
=
torch
.
zeros
(
max_bs
,
10
,
device
=
DEVICE_TYPE
)
input_1
=
persistent_input_buffer
[:
1
]
input_2
=
persistent_input_buffer
[:
2
]
input_3
=
persistent_input_buffer
[:
3
]
...
...
@@ -486,17 +488,17 @@ class TestCudagraphIntegration:
@
create_new_process_for_each_test
(
"spawn"
)
def
test_nested_wrappers
(
self
):
"""Tests a scenario with a PIECEWISE wrapper inside a FULL one."""
model
=
SimpleMLP
().
to
(
"cuda"
)
model
=
SimpleMLP
().
to
(
DEVICE_TYPE
)
full_wrapper
=
CUDAGraphWrapper
(
model
,
self
.
vllm_config
,
CUDAGraphMode
.
FULL
)
input_1
=
torch
.
randn
(
1
,
10
,
device
=
"cuda"
)
input_1
=
torch
.
randn
(
1
,
10
,
device
=
DEVICE_TYPE
)
# Setup: Inner model is wrapped with PIECEWISE, outer with FULL
inner_model
=
SimpleMLP
().
to
(
"cuda"
)
inner_model
=
SimpleMLP
().
to
(
DEVICE_TYPE
)
piecewise_wrapper
=
CUDAGraphWrapper
(
inner_model
,
self
.
vllm_config
,
CUDAGraphMode
.
PIECEWISE
)
inner_model
.
forward
=
MagicMock
(
wraps
=
inner_model
.
forward
)
outer_model
=
SimpleMLP
().
to
(
"cuda"
)
outer_model
=
SimpleMLP
().
to
(
DEVICE_TYPE
)
# When outer model is called, it calls the piecewise_wrapper
outer_model
.
forward
=
MagicMock
(
wraps
=
outer_model
.
forward
,
side_effect
=
piecewise_wrapper
...
...
tests/v1/determinism/test_rms_norm_batch_invariant.py
View file @
32e0c0bf
...
...
@@ -13,6 +13,9 @@ from utils import skip_unsupported
from
vllm.model_executor.layers.batch_invariant
import
rms_norm
as
triton_rms_norm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.platforms
import
current_platform
DEVICE_TYPE
=
current_platform
.
device_type
@
skip_unsupported
...
...
@@ -34,7 +37,7 @@ def test_rms_norm_batch_invariant_vs_standard(
equivalent results to the standard CUDA implementation across various
configurations.
"""
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
DEVICE_TYPE
)
# Create test input and weight
torch
.
manual_seed
(
42
)
...
...
@@ -81,7 +84,7 @@ def test_rms_norm_3d_input(
Ensures that the batch-invariant RMS norm correctly handles multi-dimensional
inputs that are common in transformer models.
"""
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
DEVICE_TYPE
)
dtype
=
torch
.
bfloat16
eps
=
1e-6
...
...
@@ -120,7 +123,7 @@ def test_rms_norm_numerical_stability(default_vllm_config):
Ensures that both implementations handle edge cases like very small or large
values without producing NaN or Inf.
"""
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
DEVICE_TYPE
)
dtype
=
torch
.
float16
eps
=
1e-6
hidden_size
=
2048
...
...
@@ -179,7 +182,7 @@ def test_rms_norm_formula(default_vllm_config):
Verifies: output = input / sqrt(mean(input^2) + eps) * weight
"""
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
DEVICE_TYPE
)
dtype
=
torch
.
float32
# Use float32 for higher precision in formula check
eps
=
1e-6
hidden_size
=
1024
...
...
@@ -214,7 +217,7 @@ def test_rms_norm_different_hidden_sizes(default_vllm_config, hidden_size: int):
The Triton kernel uses a fixed BLOCK_SIZE=1024, so this tests that it
correctly handles hidden sizes both smaller and larger than the block size.
"""
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
DEVICE_TYPE
)
dtype
=
torch
.
bfloat16
eps
=
1e-6
batch_size
=
16
...
...
@@ -251,7 +254,7 @@ def test_rms_norm_determinism(default_vllm_config):
Runs the same input through the kernel multiple times and verifies
identical outputs.
"""
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
DEVICE_TYPE
)
dtype
=
torch
.
bfloat16
eps
=
1e-6
hidden_size
=
4096
...
...
@@ -283,7 +286,7 @@ if __name__ == "__main__":
# Run a quick smoke test
print
(
"Running quick smoke test of RMS norm implementations..."
)
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
DEVICE_TYPE
)
batch_size
=
8
hidden_size
=
4096
dtype
=
torch
.
bfloat16
...
...
tests/v1/e2e/general/test_mamba_prefix_cache.py
View file @
32e0c0bf
...
...
@@ -16,6 +16,7 @@ from vllm import LLM, SamplingParams, TokensPrompt
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
cleanup_dist_env_and_memory
from
vllm.model_executor.layers.mamba.mamba_utils
import
MambaStateCopyFunc
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
,
KVCacheManager
...
...
@@ -48,6 +49,7 @@ num_accepted_tokens = 1
prompt_token_ids
:
list
[
int
]
=
[]
MODEL
=
"Qwen/Qwen3-Next-80B-A3B-Instruct-FP8"
BLOCK_SIZE
=
560
DEVICE_TYPE
=
current_platform
.
device_type
NUM_HIDDEN_LAYERS
=
1
cur_step_action_idx
=
0
cur_step_action
:
StepAction
|
None
=
None
...
...
@@ -71,7 +73,7 @@ def get_fake_sample_fn() -> SamplerOutput:
return
SamplerOutput
(
sampled_token_ids
=
torch
.
tensor
(
[[
prompt_token_ids
[
first_token_id_index
]]],
device
=
"cuda"
,
device
=
DEVICE_TYPE
,
dtype
=
torch
.
int32
,
),
logprobs_tensors
=
None
,
...
...
@@ -83,7 +85,9 @@ def get_fake_sample_fn() -> SamplerOutput:
sampled_token_ids
=
accepted_tokens
return
SamplerOutput
(
sampled_token_ids
=
torch
.
tensor
(
[
sampled_token_ids
],
device
=
"cuda"
,
dtype
=
torch
.
int32
[
sampled_token_ids
],
device
=
DEVICE_TYPE
,
dtype
=
torch
.
int32
,
),
logprobs_tensors
=
None
,
)
...
...
@@ -128,17 +132,23 @@ def get_fake_propose_draft_token_ids_fn():
-
1
+
num_accepted_tokens
],
device
=
"cuda"
,
device
=
DEVICE_TYPE
,
dtype
=
torch
.
int32
,
)
valid_sampled_tokens_count
=
torch
.
tensor
(
[
num_accepted_tokens
],
device
=
"cuda"
,
dtype
=
torch
.
int32
[
num_accepted_tokens
],
device
=
DEVICE_TYPE
,
dtype
=
torch
.
int32
,
)
self
.
_copy_valid_sampled_token_count
(
next_token_ids
,
valid_sampled_tokens_count
)
return
torch
.
tensor
(
proposed_draft_token_ids
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
return
torch
.
tensor
(
proposed_draft_token_ids
,
device
=
DEVICE_TYPE
,
dtype
=
torch
.
int32
,
)
return
fake_propose_draft_token_ids_fn
...
...
tests/v1/kv_offload/test_cpu_gpu.py
View file @
32e0c0bf
...
...
@@ -6,6 +6,7 @@ import time
import
pytest
import
torch
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
set_random_seed
from
vllm.v1.kv_offload.mediums
import
CPULoadStoreSpec
,
GPULoadStoreSpec
from
vllm.v1.kv_offload.spec
import
(
...
...
@@ -21,7 +22,8 @@ GPU_PAGE_SIZES = [512, 1024]
BLOCK_SIZE_FACTORS
=
[
1
,
3
]
NUM_TENSORS
=
[
4
]
SEEDS
=
[
0
]
CUDA_DEVICES
=
[
"cuda:0"
]
DEVICE_TYPE
=
current_platform
.
device_type
DEVICES
=
[
f
"
{
DEVICE_TYPE
}
:0"
]
NUM_MAPPINGS
=
[
3
]
...
...
@@ -33,7 +35,7 @@ NUM_MAPPINGS = [3]
@
pytest
.
mark
.
parametrize
(
"num_cpu_blocks"
,
NUM_CPU_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"num_tensors"
,
NUM_TENSORS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
torch
.
inference_mode
()
def
test_transfer
(
default_vllm_config
,
...
...
tests/v1/logits_processors/test_correctness.py
View file @
32e0c0bf
...
...
@@ -39,8 +39,9 @@ PIN_MEMORY_AVAILABLE = is_pin_memory_available()
MAX_NUM_REQS
=
256
VOCAB_SIZE
=
1024
NUM_OUTPUT_TOKENS
=
20
CUDA_DEVICES
=
[
f
"
{
current_platform
.
device_type
}
:
{
i
}
"
DEVICE_TYPE
=
current_platform
.
device_type
DEVICES
=
[
f
"
{
DEVICE_TYPE
}
:
{
i
}
"
for
i
in
range
(
1
if
current_platform
.
device_count
()
==
1
else
2
)
]
MAX_NUM_PROMPT_TOKENS
=
64
...
...
@@ -801,7 +802,7 @@ def _assert_valid(
@
create_new_process_for_each_test
()
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"reqs_per_logitproc"
,
[
REQS_PER_LOGITPROC
])
@
pytest
.
mark
.
parametrize
(
"logitsprocs_under_test"
,
_get_test_cases
())
def
test_logitsprocs
(
...
...
tests/v1/sample/test_rejection_sampler.py
View file @
32e0c0bf
...
...
@@ -19,7 +19,7 @@ from vllm.v1.sample.rejection_sampler import (
from
vllm.v1.sample.sampler
import
Sampler
,
SamplerOutput
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
DEVICE
=
current_platform
.
device_type
DEVICE
_TYPE
=
current_platform
.
device_type
@
pytest
.
fixture
...
...
@@ -57,7 +57,7 @@ def create_logits_tensor(
will produce desired token ids on argmax"""
token_ids
=
[
tokens
[:
-
1
]
for
tokens
in
output_token_ids
]
num_total_tokens
=
sum
(
len
(
tokens
)
for
tokens
in
token_ids
)
logits
=
torch
.
full
((
num_total_tokens
,
vocab_size
),
-
100.0
,
device
=
DEVICE
)
logits
=
torch
.
full
((
num_total_tokens
,
vocab_size
),
-
100.0
,
device
=
DEVICE
_TYPE
)
start_loc
=
0
for
tokens
in
token_ids
:
for
j
,
token_id
in
enumerate
(
tokens
):
...
...
@@ -99,9 +99,9 @@ def create_sampling_metadata(
assert
output_token_ids
assert
len
(
output_token_ids
)
>
0
frequency_penalties
=
torch
.
tensor
(
frequency_penalties
,
device
=
DEVICE
)
presence_penalties
=
torch
.
tensor
(
presence_penalties
,
device
=
DEVICE
)
repetition_penalties
=
torch
.
tensor
(
repetition_penalties
,
device
=
DEVICE
)
frequency_penalties
=
torch
.
tensor
(
frequency_penalties
,
device
=
DEVICE
_TYPE
)
presence_penalties
=
torch
.
tensor
(
presence_penalties
,
device
=
DEVICE
_TYPE
)
repetition_penalties
=
torch
.
tensor
(
repetition_penalties
,
device
=
DEVICE
_TYPE
)
else
:
no_penalties
=
True
frequency_penalties
=
torch
.
tensor
([])
...
...
@@ -320,14 +320,27 @@ def test_deterministic_when_seeded(
n_rep
:
int
,
):
num_tokens
=
batch_size
*
k
draft_probs
=
torch
.
rand
(
num_tokens
,
vocab_size
,
dtype
=
torch
.
float32
,
device
=
DEVICE
)
draft_probs
=
torch
.
rand
(
num_tokens
,
vocab_size
,
dtype
=
torch
.
float32
,
device
=
DEVICE_TYPE
,
)
draft_probs
=
F
.
softmax
(
draft_probs
,
dim
=-
1
)
target_logits
=
torch
.
rand_like
(
draft_probs
)
bonus_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
1
),
dtype
=
torch
.
int64
,
device
=
DEVICE
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
1
),
dtype
=
torch
.
int64
,
device
=
DEVICE_TYPE
,
)
draft_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
k
),
dtype
=
torch
.
int64
,
device
=
DEVICE
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
k
),
dtype
=
torch
.
int64
,
device
=
DEVICE_TYPE
,
)
seeded_mask
=
torch
.
rand
(
batch_size
,
dtype
=
torch
.
float32
)
<=
frac_seeded
...
...
@@ -335,12 +348,12 @@ def test_deterministic_when_seeded(
results
=
[]
for
_
in
range
(
n_rep
):
seeded_seqs
=
{
i
:
torch
.
Generator
(
device
=
DEVICE
).
manual_seed
(
i
)
i
:
torch
.
Generator
(
device
=
DEVICE
_TYPE
).
manual_seed
(
i
)
for
i
in
range
(
batch_size
)
if
seeded_mask
[
i
]
}
temperature
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
float32
,
device
=
DEVICE
)
temperature
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
float32
,
device
=
DEVICE
_TYPE
)
sampling_metadata
=
create_sampling_metadata
(
all_greedy
=
False
,
temperature
=
temperature
,
generators
=
seeded_seqs
)
...
...
@@ -387,7 +400,7 @@ def test_rejection_sampling_approximates_target_distribution():
much more than the distance improvement between the observed
distribution and the random distribution.
"""
torch
.
set_default_device
(
DEVICE
)
torch
.
set_default_device
(
DEVICE
_TYPE
)
vocab_size
=
10
k
=
2
num_reference_probs
=
100
...
...
@@ -410,7 +423,7 @@ def test_rejection_sampling_approximates_target_distribution():
rej_sample_probs
=
estimate_rejection_sampling_pdf
(
draft_probs
,
target_logits
,
k
,
vocab_size
,
num_samples
)
rej_sample_probs
=
rej_sample_probs
.
to
(
DEVICE
)
rej_sample_probs
=
rej_sample_probs
.
to
(
DEVICE
_TYPE
)
# Average distance from reference probs.
reference_vs_rejsample_dist
=
(
...
...
@@ -491,11 +504,11 @@ def estimate_rejection_sampling_pdf(
draft_probs
=
draft_probs
.
view
(
num_tokens
,
vocab_size
)
# Bonus tokens not used but required.
bonus_token_ids
=
torch
.
zeros
((
1
,
1
),
dtype
=
torch
.
int64
,
device
=
DEVICE
).
repeat
(
bonus_token_ids
=
torch
.
zeros
((
1
,
1
),
dtype
=
torch
.
int64
,
device
=
DEVICE
_TYPE
).
repeat
(
num_samples
,
1
)
temperature
=
torch
.
ones
(
num_samples
,
dtype
=
torch
.
float32
,
device
=
DEVICE
)
temperature
=
torch
.
ones
(
num_samples
,
dtype
=
torch
.
float32
,
device
=
DEVICE
_TYPE
)
sampling_metadata
=
create_sampling_metadata
(
all_greedy
=
False
,
temperature
=
temperature
)
...
...
@@ -600,7 +613,7 @@ def _test_masked_logits(
# Create random draft probabilities.
draft_probs
=
torch
.
rand
(
(
num_tokens
,
vocab_size
),
dtype
=
torch
.
float32
,
device
=
DEVICE
(
num_tokens
,
vocab_size
),
dtype
=
torch
.
float32
,
device
=
DEVICE
_TYPE
)
draft_probs
=
F
.
softmax
(
draft_probs
,
dim
=-
1
)
...
...
@@ -610,7 +623,11 @@ def _test_masked_logits(
draft_token_ids
=
draft_token_ids
.
tolist
()
# Bonus tokens not used but required
bonus_token_ids
=
torch
.
zeros
((
batch_size
,
1
),
dtype
=
torch
.
int64
,
device
=
DEVICE
)
bonus_token_ids
=
torch
.
zeros
(
(
batch_size
,
1
),
dtype
=
torch
.
int64
,
device
=
DEVICE_TYPE
,
)
# Create spec decode metadata
spec_decode_metadata
=
create_spec_decode_metadata
(
draft_token_ids
,
target_logits
)
...
...
@@ -645,12 +662,13 @@ def test_top_k(rejection_sampler, top_k):
# Randomly create top-k indices.
top_k_indices
=
[
torch
.
randperm
(
vocab_size
,
device
=
DEVICE
)[:
top_k
]
for
_
in
range
(
num_tokens
)
torch
.
randperm
(
vocab_size
,
device
=
DEVICE_TYPE
)[:
top_k
]
for
_
in
range
(
num_tokens
)
]
top_k_indices
=
torch
.
stack
(
top_k_indices
)
# Create logits with the uniform distribution.
target_logits
=
torch
.
zeros
((
num_tokens
,
vocab_size
),
device
=
DEVICE
)
target_logits
=
torch
.
zeros
((
num_tokens
,
vocab_size
),
device
=
DEVICE
_TYPE
)
# Increment the logits for top-k indices, a little bit more than the other
# ones. If the masking is effective, the non-topk indices will never be
...
...
@@ -659,11 +677,11 @@ def test_top_k(rejection_sampler, top_k):
target_logits
[
i
,
top_k_indices
[
i
]]
+=
0.1
# Create sampling metadata
temperature
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
float32
,
device
=
DEVICE
)
temperature
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
float32
,
device
=
DEVICE
_TYPE
)
sampling_metadata
=
create_sampling_metadata
(
all_greedy
=
False
,
temperature
=
temperature
,
top_k
=
torch
.
tensor
([
top_k
]
*
batch_size
,
device
=
DEVICE
,
dtype
=
torch
.
int64
),
top_k
=
torch
.
tensor
([
top_k
]
*
batch_size
,
device
=
DEVICE
_TYPE
,
dtype
=
torch
.
int64
),
)
_test_masked_logits
(
...
...
@@ -686,8 +704,8 @@ def test_top_p(rejection_sampler, top_p):
num_tokens
=
batch_size
*
num_draft_tokens
# Create logits with the uniform distribution.
target_logits
=
torch
.
randn
((
num_tokens
,
vocab_size
),
device
=
DEVICE
)
temperature
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
float32
,
device
=
DEVICE
)
target_logits
=
torch
.
randn
((
num_tokens
,
vocab_size
),
device
=
DEVICE
_TYPE
)
temperature
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
float32
,
device
=
DEVICE
_TYPE
)
rescaled_logits
=
target_logits
/
temperature
logits_sort
,
logits_idx
=
rescaled_logits
.
sort
(
dim
=-
1
,
descending
=
False
)
...
...
@@ -706,7 +724,11 @@ def test_top_p(rejection_sampler, top_p):
sampling_metadata
=
create_sampling_metadata
(
all_greedy
=
False
,
temperature
=
temperature
,
top_p
=
torch
.
tensor
([
top_p
]
*
batch_size
,
device
=
DEVICE
,
dtype
=
torch
.
float32
),
top_p
=
torch
.
tensor
(
[
top_p
]
*
batch_size
,
device
=
DEVICE_TYPE
,
dtype
=
torch
.
float32
,
),
)
_test_masked_logits
(
...
...
@@ -732,7 +754,10 @@ def test_frequency_penalties(rejection_sampler):
all_greedy
=
True
,
output_token_ids
=
[[
2
],
[
3
],
[
4
]],
spec_token_ids
=
spec_tokens
,
prompt_token_ids
=
torch
.
tensor
([[
5
,
6
,
7
],
[
6
,
7
,
8
],
[
7
,
8
,
9
]],
device
=
DEVICE
),
prompt_token_ids
=
torch
.
tensor
(
[[
5
,
6
,
7
],
[
6
,
7
,
8
],
[
7
,
8
,
9
]],
device
=
DEVICE_TYPE
,
),
frequency_penalties
=
[
1.5
,
1.5
,
0.7
],
presence_penalties
=
[
0.0
]
*
num_requests
,
repetition_penalties
=
[
1.0
]
*
num_requests
,
...
...
@@ -858,21 +883,26 @@ def test_sample_recovered_tokens(
num_tokens
=
batch_size
*
max_spec_len
# Create random draft probabilities.
draft_probs
=
torch
.
rand
(
num_tokens
,
vocab_size
,
dtype
=
torch
.
float32
,
device
=
DEVICE
)
draft_probs
=
torch
.
rand
(
num_tokens
,
vocab_size
,
dtype
=
torch
.
float32
,
device
=
DEVICE_TYPE
,
)
draft_probs
=
F
.
softmax
(
draft_probs
,
dim
=-
1
)
# Create random target probabilities.
target_logits
=
torch
.
rand
(
num_tokens
,
vocab_size
,
dtype
=
torch
.
float32
,
device
=
DEVICE
num_tokens
,
vocab_size
,
dtype
=
torch
.
float32
,
device
=
DEVICE
_TYPE
)
target_probs
=
F
.
softmax
(
target_logits
,
dim
=-
1
)
# Randomly sample draft token ids from draft probs
draft_token_ids
=
torch
.
multinomial
(
draft_probs
,
num_samples
=
1
).
to
(
torch
.
int32
)
temperature
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
float32
,
device
=
DEVICE
)
temperature
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
float32
,
device
=
DEVICE
_TYPE
)
generators
=
{
i
:
torch
.
Generator
(
device
=
DEVICE
).
manual_seed
(
i
)
for
i
in
range
(
batch_size
)
i
:
torch
.
Generator
(
device
=
DEVICE
_TYPE
).
manual_seed
(
i
)
for
i
in
range
(
batch_size
)
}
sampling_metadata
=
create_sampling_metadata
(
all_greedy
=
False
,
temperature
=
temperature
,
generators
=
generators
...
...
@@ -890,7 +920,7 @@ def test_sample_recovered_tokens(
None
if
no_draft_probs
else
draft_probs
,
target_probs
,
sampling_metadata
,
device
=
DEVICE
,
device
=
DEVICE
_TYPE
,
)
recovered_token_ids
=
sample_recovered_tokens
(
max_spec_len
,
...
...
@@ -900,6 +930,6 @@ def test_sample_recovered_tokens(
None
if
no_draft_probs
else
draft_probs
,
target_probs
,
sampling_metadata
,
device
=
DEVICE
,
device
=
DEVICE
_TYPE
,
)
assert
torch
.
equal
(
recovered_token_ids
,
ref_recovered_token_ids
)
tests/v1/sample/test_sampler.py
View file @
32e0c0bf
...
...
@@ -17,8 +17,9 @@ PIN_MEMORY_AVAILABLE = is_pin_memory_available()
MAX_NUM_REQS
=
256
VOCAB_SIZE
=
1024
NUM_OUTPUT_TOKENS
=
20
CUDA_DEVICES
=
[
f
"
{
current_platform
.
device_type
}
:
{
i
}
"
DEVICE_TYPE
=
current_platform
.
device_type
DEVICES
=
[
f
"
{
DEVICE_TYPE
}
:
{
i
}
"
for
i
in
range
(
1
if
current_platform
.
device_count
()
==
1
else
2
)
]
MAX_NUM_PROMPT_TOKENS
=
64
...
...
@@ -199,7 +200,7 @@ def _create_weighted_output_token_list(
return
output_token_ids
,
sorted_token_ids_in_output
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"presence_penalty"
,
[
-
2.0
,
2.0
])
def
test_sampler_presence_penalty
(
...
...
@@ -249,7 +250,7 @@ def test_sampler_presence_penalty(
assert
penalized_token_id
not
in
output_token_ids
[
batch_idx
]
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"frequency_penalty"
,
[
-
2.0
,
2.0
])
def
test_sampler_frequency_penalty
(
...
...
@@ -305,7 +306,7 @@ def test_sampler_frequency_penalty(
assert
penalized_token_id
not
in
distinct_sorted_token_ids_in_output
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"repetition_penalty"
,
[
0.1
,
1.9
])
def
test_sampler_repetition_penalty
(
...
...
@@ -363,7 +364,7 @@ def test_sampler_repetition_penalty(
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"num_allowed_token_ids"
,
[
0
,
1
,
2
])
def
test_sampler_allowed_token_ids
(
...
...
@@ -409,7 +410,7 @@ def test_sampler_allowed_token_ids(
assert
logits_for_req
[
token_id
]
!=
-
float
(
"inf"
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"bad_words_lengths"
,
[(
1
,),
(
1
,
3
),
(
2
,
2
)])
def
test_sampler_bad_words
(
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment