Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
32e0c0bf
Unverified
Commit
32e0c0bf
authored
Apr 02, 2026
by
wliao2
Committed by
GitHub
Apr 03, 2026
Browse files
refactor hard coded device string in test files under tests/v1 and tests/lora (#37566)
Signed-off-by:
Liao, Wei
<
wei.liao@intel.com
>
parent
4a06e124
Changes
28
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
178 additions
and
90 deletions
+178
-90
tests/lora/test_fused_moe_lora_kernel.py
tests/lora/test_fused_moe_lora_kernel.py
+1
-1
tests/lora/test_layers.py
tests/lora/test_layers.py
+6
-2
tests/lora/test_lora_manager.py
tests/lora/test_lora_manager.py
+2
-2
tests/lora/test_moe_lora_align_sum.py
tests/lora/test_moe_lora_align_sum.py
+16
-6
tests/lora/test_punica_ops.py
tests/lora/test_punica_ops.py
+10
-3
tests/lora/test_punica_ops_fp8.py
tests/lora/test_punica_ops_fp8.py
+3
-1
tests/lora/test_worker.py
tests/lora/test_worker.py
+4
-1
tests/lora/utils.py
tests/lora/utils.py
+6
-3
tests/v1/attention/test_attention_backends.py
tests/v1/attention/test_attention_backends.py
+3
-1
tests/v1/attention/test_chunked_local_attention.py
tests/v1/attention/test_chunked_local_attention.py
+4
-1
tests/v1/attention/test_mla_backends.py
tests/v1/attention/test_mla_backends.py
+3
-1
tests/v1/attention/test_sparse_mla_backends.py
tests/v1/attention/test_sparse_mla_backends.py
+6
-4
tests/v1/attention/test_trtllm_attention_integration.py
tests/v1/attention/test_trtllm_attention_integration.py
+2
-1
tests/v1/cudagraph/test_cudagraph_dispatch.py
tests/v1/cudagraph/test_cudagraph_dispatch.py
+11
-9
tests/v1/determinism/test_rms_norm_batch_invariant.py
tests/v1/determinism/test_rms_norm_batch_invariant.py
+10
-7
tests/v1/e2e/general/test_mamba_prefix_cache.py
tests/v1/e2e/general/test_mamba_prefix_cache.py
+15
-5
tests/v1/kv_offload/test_cpu_gpu.py
tests/v1/kv_offload/test_cpu_gpu.py
+4
-2
tests/v1/logits_processors/test_correctness.py
tests/v1/logits_processors/test_correctness.py
+4
-3
tests/v1/sample/test_rejection_sampler.py
tests/v1/sample/test_rejection_sampler.py
+60
-30
tests/v1/sample/test_sampler.py
tests/v1/sample/test_sampler.py
+8
-7
No files found.
tests/lora/test_fused_moe_lora_kernel.py
View file @
32e0c0bf
...
@@ -637,7 +637,7 @@ def use_fused_moe_lora_kernel_tensor_parallel(
...
@@ -637,7 +637,7 @@ def use_fused_moe_lora_kernel_tensor_parallel(
set_random_seed
(
seed
)
set_random_seed
(
seed
)
device
=
torch
.
device
(
f
"
cuda
:
{
local_rank
}
"
)
device
=
torch
.
device
(
f
"
{
DEVICE_TYPE
}
:
{
local_rank
}
"
)
torch
.
accelerator
.
set_device_index
(
device
)
torch
.
accelerator
.
set_device_index
(
device
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
torch
.
set_default_dtype
(
dtype
)
torch
.
set_default_dtype
(
dtype
)
...
...
tests/lora/test_layers.py
View file @
32e0c0bf
...
@@ -60,8 +60,12 @@ pytestmark = pytest.mark.skipif(
...
@@ -60,8 +60,12 @@ pytestmark = pytest.mark.skipif(
reason
=
"Backend not supported"
,
reason
=
"Backend not supported"
,
)
)
DEVICE_TYPE
=
current_platform
.
device_type
DEVICES
=
(
DEVICES
=
(
[
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
accelerator
.
device_count
()
==
1
else
2
)]
[
f
"
{
DEVICE_TYPE
}
:
{
i
}
"
for
i
in
range
(
1
if
torch
.
accelerator
.
device_count
()
==
1
else
2
)
]
if
current_platform
.
is_cuda_alike
()
if
current_platform
.
is_cuda_alike
()
else
[
"cpu"
]
else
[
"cpu"
]
)
)
...
@@ -196,7 +200,7 @@ def create_random_inputs(
...
@@ -196,7 +200,7 @@ def create_random_inputs(
input_size
:
tuple
[
int
,
...],
input_size
:
tuple
[
int
,
...],
input_range
:
tuple
[
float
,
float
],
input_range
:
tuple
[
float
,
float
],
input_type
:
torch
.
dtype
=
torch
.
int
,
input_type
:
torch
.
dtype
=
torch
.
int
,
device
:
torch
.
device
=
"cuda"
,
device
:
torch
.
device
=
DEVICE_TYPE
,
)
->
tuple
[
list
[
torch
.
Tensor
],
list
[
int
],
list
[
int
]]:
)
->
tuple
[
list
[
torch
.
Tensor
],
list
[
int
],
list
[
int
]]:
"""Creates random inputs.
"""Creates random inputs.
...
...
tests/lora/test_lora_manager.py
View file @
32e0c0bf
...
@@ -35,9 +35,9 @@ EMBEDDING_MODULES = {
...
@@ -35,9 +35,9 @@ EMBEDDING_MODULES = {
"lm_head"
:
"output_embeddings"
,
"lm_head"
:
"output_embeddings"
,
}
}
DEVICE_TYPE
=
current_platform
.
device_type
DEVICES
=
(
DEVICES
=
(
[
f
"
cuda
:
{
i
}
"
for
i
in
range
(
1
if
torch
.
accelerator
.
device_count
()
==
1
else
2
)]
[
f
"
{
DEVICE_TYPE
}
:
{
i
}
"
for
i
in
range
(
min
(
torch
.
accelerator
.
device_count
()
,
2
)
)
]
if
current_platform
.
is_cuda_alike
()
if
current_platform
.
is_cuda_alike
()
else
[
"cpu"
]
else
[
"cpu"
]
)
)
...
...
tests/lora/test_moe_lora_align_sum.py
View file @
32e0c0bf
...
@@ -6,6 +6,9 @@ import pytest
...
@@ -6,6 +6,9 @@ import pytest
import
torch
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
DEVICE_TYPE
=
current_platform
.
device_type
def
round_up
(
x
,
base
):
def
round_up
(
x
,
base
):
...
@@ -27,7 +30,7 @@ def sample_data(num_experts, max_loras, num_tokens, topk_num):
...
@@ -27,7 +30,7 @@ def sample_data(num_experts, max_loras, num_tokens, topk_num):
topk_ids
[
i
,
j
]
=
pool
[
j
]
topk_ids
[
i
,
j
]
=
pool
[
j
]
token_lora_mapping
[
i
]
=
random
.
randint
(
0
,
max_loras
-
1
)
token_lora_mapping
[
i
]
=
random
.
randint
(
0
,
max_loras
-
1
)
return
topk_ids
.
to
(
"cuda"
),
token_lora_mapping
.
to
(
"cuda"
)
return
topk_ids
.
to
(
DEVICE_TYPE
),
token_lora_mapping
.
to
(
DEVICE_TYPE
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
100
,
200
,
1024
,
4096
])
# 81920
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
100
,
200
,
1024
,
4096
])
# 81920
...
@@ -56,14 +59,21 @@ def test_moe_lora_align_block_size(
...
@@ -56,14 +59,21 @@ def test_moe_lora_align_block_size(
(
max_loras
*
max_num_tokens_padded
,),
(
max_loras
*
max_num_tokens_padded
,),
topk_ids
.
numel
(),
topk_ids
.
numel
(),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
"cuda"
,
device
=
DEVICE_TYPE
,
)
)
expert_ids
=
torch
.
full
(
expert_ids
=
torch
.
full
(
(
max_loras
*
max_num_m_blocks
,),
num_experts
,
dtype
=
torch
.
int32
,
device
=
"cuda"
(
max_loras
*
max_num_m_blocks
,),
num_experts
,
dtype
=
torch
.
int32
,
device
=
DEVICE_TYPE
,
)
num_tokens_post_pad
=
torch
.
zeros
(
(
max_loras
,),
dtype
=
torch
.
int32
,
device
=
DEVICE_TYPE
)
adapter_enabled
=
torch
.
ones
(
(
max_loras
+
1
,),
dtype
=
torch
.
int32
,
device
=
DEVICE_TYPE
)
)
num_tokens_post_pad
=
torch
.
zeros
((
max_loras
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
lora_ids
=
torch
.
arange
(
max_loras
+
2
,
dtype
=
torch
.
int32
,
device
=
DEVICE_TYPE
)
adapter_enabled
=
torch
.
ones
((
max_loras
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
lora_ids
=
torch
.
arange
(
max_loras
+
2
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# call kernel
# call kernel
ops
.
moe_lora_align_block_size
(
ops
.
moe_lora_align_block_size
(
...
...
tests/lora/test_punica_ops.py
View file @
32e0c0bf
...
@@ -9,10 +9,13 @@ import vllm.lora.ops.torch_ops as torch_ops
...
@@ -9,10 +9,13 @@ import vllm.lora.ops.torch_ops as torch_ops
import
vllm.lora.ops.triton_ops
as
triton_ops
import
vllm.lora.ops.triton_ops
as
triton_ops
from
vllm.lora.ops.triton_ops
import
LoRAKernelMeta
from
vllm.lora.ops.triton_ops
import
LoRAKernelMeta
from
vllm.lora.ops.triton_ops.utils
import
_LORA_A_PTR_DICT
,
_LORA_B_PTR_DICT
from
vllm.lora.ops.triton_ops.utils
import
_LORA_A_PTR_DICT
,
_LORA_B_PTR_DICT
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
set_random_seed
from
vllm.utils.torch_utils
import
set_random_seed
from
.utils
import
PunicaTensors
,
assert_close
,
generate_data_for_nslices
from
.utils
import
PunicaTensors
,
assert_close
,
generate_data_for_nslices
DEVICE_TYPE
=
current_platform
.
device_type
@
pytest
.
fixture
(
autouse
=
True
)
@
pytest
.
fixture
(
autouse
=
True
)
def
reset_device
(
reset_default_device
):
def
reset_device
(
reset_default_device
):
...
@@ -146,7 +149,9 @@ def check_lora_shrink_kernel(
...
@@ -146,7 +149,9 @@ def check_lora_shrink_kernel(
# Setup metadata information for the LoRA kernel.
# Setup metadata information for the LoRA kernel.
lora_meta
=
LoRAKernelMeta
.
make
(
lora_meta
=
LoRAKernelMeta
.
make
(
max_loras
=
num_loras
,
max_num_tokens
=
token_nums
,
device
=
"cuda"
max_loras
=
num_loras
,
max_num_tokens
=
token_nums
,
device
=
DEVICE_TYPE
,
)
)
lora_meta
.
prepare_tensors
(
data
.
token_lora_mapping
)
lora_meta
.
prepare_tensors
(
data
.
token_lora_mapping
)
...
@@ -219,7 +224,9 @@ def check_lora_expand_kernel(
...
@@ -219,7 +224,9 @@ def check_lora_expand_kernel(
# Setup metadata information for the LoRA kernel.
# Setup metadata information for the LoRA kernel.
lora_meta
=
LoRAKernelMeta
.
make
(
lora_meta
=
LoRAKernelMeta
.
make
(
max_loras
=
num_loras
,
max_num_tokens
=
token_nums
,
device
=
"cuda"
max_loras
=
num_loras
,
max_num_tokens
=
token_nums
,
device
=
DEVICE_TYPE
,
)
)
lora_meta
.
prepare_tensors
(
data
.
token_lora_mapping
)
lora_meta
.
prepare_tensors
(
data
.
token_lora_mapping
)
...
@@ -367,7 +374,7 @@ test_params = {
...
@@ -367,7 +374,7 @@ test_params = {
}
}
DTYPES
=
[
torch
.
float16
,
torch
.
bfloat16
]
DTYPES
=
[
torch
.
float16
,
torch
.
bfloat16
]
DEVICES
=
[
f
"
cuda
:
{
0
}
"
]
DEVICES
=
[
f
"
{
DEVICE_TYPE
}
:
{
0
}
"
]
SEED
=
[
0
]
SEED
=
[
0
]
...
...
tests/lora/test_punica_ops_fp8.py
View file @
32e0c0bf
...
@@ -28,9 +28,11 @@ from vllm.lora.ops.triton_ops.lora_shrink_fp8_op import (
...
@@ -28,9 +28,11 @@ from vllm.lora.ops.triton_ops.lora_shrink_fp8_op import (
_SHRINK_LORA_SCALE_PTR_DICT
,
_SHRINK_LORA_SCALE_PTR_DICT
,
)
)
from
vllm.lora.ops.triton_ops.utils
import
_LORA_A_PTR_DICT
,
_LORA_B_PTR_DICT
from
vllm.lora.ops.triton_ops.utils
import
_LORA_A_PTR_DICT
,
_LORA_B_PTR_DICT
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
set_random_seed
from
vllm.utils.torch_utils
import
set_random_seed
DEVICES
=
[
f
"cuda:
{
0
}
"
]
DEVICE_TYPE
=
current_platform
.
device_type
DEVICES
=
[
f
"
{
DEVICE_TYPE
}
:
{
0
}
"
]
SEED
=
[
0
]
SEED
=
[
0
]
_dict_lock
=
Lock
()
_dict_lock
=
Lock
()
...
...
tests/lora/test_worker.py
View file @
32e0c0bf
...
@@ -19,11 +19,14 @@ from vllm.config.load import LoadConfig
...
@@ -19,11 +19,14 @@ from vllm.config.load import LoadConfig
from
vllm.config.lora
import
LoRAConfig
from
vllm.config.lora
import
LoRAConfig
from
vllm.lora.model_manager
import
LoRAMapping
from
vllm.lora.model_manager
import
LoRAMapping
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.platforms
import
current_platform
from
vllm.v1.worker.gpu_worker
import
Worker
from
vllm.v1.worker.gpu_worker
import
Worker
MODEL_PATH
=
"Qwen/Qwen3-0.6B"
MODEL_PATH
=
"Qwen/Qwen3-0.6B"
NUM_LORAS
=
16
NUM_LORAS
=
16
DEVICE_TYPE
=
current_platform
.
device_type
@
patch
.
dict
(
os
.
environ
,
{
"RANK"
:
"0"
})
@
patch
.
dict
(
os
.
environ
,
{
"RANK"
:
"0"
})
def
test_worker_apply_lora
(
qwen3_lora_files
):
def
test_worker_apply_lora
(
qwen3_lora_files
):
...
@@ -61,7 +64,7 @@ def test_worker_apply_lora(qwen3_lora_files):
...
@@ -61,7 +64,7 @@ def test_worker_apply_lora(qwen3_lora_files):
max_num_seqs
=
32
,
max_num_seqs
=
32
,
max_num_partial_prefills
=
32
,
max_num_partial_prefills
=
32
,
),
),
device_config
=
DeviceConfig
(
"cuda"
),
device_config
=
DeviceConfig
(
DEVICE_TYPE
),
cache_config
=
CacheConfig
(
cache_config
=
CacheConfig
(
block_size
=
16
,
block_size
=
16
,
cache_dtype
=
"auto"
,
cache_dtype
=
"auto"
,
...
...
tests/lora/utils.py
View file @
32e0c0bf
...
@@ -9,10 +9,13 @@ import torch
...
@@ -9,10 +9,13 @@ import torch
from
safetensors.torch
import
save_file
from
safetensors.torch
import
save_file
from
vllm.lora.lora_weights
import
LoRALayerWeights
,
PackedLoRALayerWeights
from
vllm.lora.lora_weights
import
LoRALayerWeights
,
PackedLoRALayerWeights
from
vllm.platforms
import
current_platform
DEVICE_TYPE
=
current_platform
.
device_type
class
DummyLoRAManager
:
class
DummyLoRAManager
:
def
__init__
(
self
,
device
:
torch
.
device
=
"cuda
:0"
):
def
__init__
(
self
,
device
:
torch
.
device
=
f
"
{
DEVICE_TYPE
}
:0"
):
super
().
__init__
()
super
().
__init__
()
self
.
_loras
:
dict
[
str
,
LoRALayerWeights
]
=
{}
self
.
_loras
:
dict
[
str
,
LoRALayerWeights
]
=
{}
self
.
_device
=
device
self
.
_device
=
device
...
@@ -57,8 +60,8 @@ class DummyLoRAManager:
...
@@ -57,8 +60,8 @@ class DummyLoRAManager:
module_name
,
module_name
,
rank
=
rank
,
rank
=
rank
,
lora_alpha
=
1
,
lora_alpha
=
1
,
lora_a
=
torch
.
rand
([
rank
,
input_dim
],
device
=
"cuda"
),
lora_a
=
torch
.
rand
([
rank
,
input_dim
],
device
=
DEVICE_TYPE
),
lora_b
=
torch
.
rand
([
output_dim
,
input_dim
],
device
=
"cuda"
),
lora_b
=
torch
.
rand
([
output_dim
,
input_dim
],
device
=
DEVICE_TYPE
),
embeddings_tensor
=
embeddings_tensor
,
embeddings_tensor
=
embeddings_tensor
,
)
)
self
.
set_module_lora
(
module_name
,
lora
)
self
.
set_module_lora
(
module_name
,
lora
)
...
...
tests/v1/attention/test_attention_backends.py
View file @
32e0c0bf
...
@@ -40,6 +40,8 @@ BACKENDS_TO_TEST = [
...
@@ -40,6 +40,8 @@ BACKENDS_TO_TEST = [
"FLEX_ATTENTION_SLOW"
,
"FLEX_ATTENTION_SLOW"
,
]
]
DEVICE_TYPE
=
current_platform
.
device_type
# Remove flashinfer from the list if it's not available
# Remove flashinfer from the list if it's not available
try
:
try
:
import
flashinfer
# noqa: F401
import
flashinfer
# noqa: F401
...
@@ -366,7 +368,7 @@ def _test_backend_correctness(
...
@@ -366,7 +368,7 @@ def _test_backend_correctness(
num_gpu_blocks
=
8192
,
num_gpu_blocks
=
8192
,
hf_config_override
=
hf_config_override
,
hf_config_override
=
hf_config_override
,
)
)
device
=
torch
.
device
(
"cuda
:0"
)
device
=
torch
.
device
(
f
"
{
DEVICE_TYPE
}
:0"
)
kv_cache_spec
=
create_standard_kv_cache_spec
(
vllm_config
)
kv_cache_spec
=
create_standard_kv_cache_spec
(
vllm_config
)
...
...
tests/v1/attention/test_chunked_local_attention.py
View file @
32e0c0bf
...
@@ -7,6 +7,7 @@ import pytest
...
@@ -7,6 +7,7 @@ import pytest
import
torch
import
torch
from
tests.v1.attention.utils
import
BatchSpec
,
create_common_attn_metadata
from
tests.v1.attention.utils
import
BatchSpec
,
create_common_attn_metadata
from
vllm.platforms
import
current_platform
from
vllm.v1.attention.backends.utils
import
make_local_attention_virtual_batches
from
vllm.v1.attention.backends.utils
import
make_local_attention_virtual_batches
...
@@ -22,6 +23,8 @@ class LocalAttentionTestData:
...
@@ -22,6 +23,8 @@ class LocalAttentionTestData:
expected_local_block_table
:
list
[
list
[
int
]]
expected_local_block_table
:
list
[
list
[
int
]]
DEVICE_TYPE
=
current_platform
.
device_type
test_data_list
=
[
test_data_list
=
[
# Same as example in docstring of make_local_attention_virtual_batches
# Same as example in docstring of make_local_attention_virtual_batches
# except block table has 9 columns instead of 10
# except block table has 9 columns instead of 10
...
@@ -151,7 +154,7 @@ test_data_list = [
...
@@ -151,7 +154,7 @@ test_data_list = [
@
pytest
.
mark
.
parametrize
(
"test_data"
,
test_data_list
)
@
pytest
.
mark
.
parametrize
(
"test_data"
,
test_data_list
)
def
test_local_attention_virtual_batches
(
test_data
:
LocalAttentionTestData
):
def
test_local_attention_virtual_batches
(
test_data
:
LocalAttentionTestData
):
device
=
torch
.
device
(
"cuda
:0"
)
device
=
torch
.
device
(
f
"
{
DEVICE_TYPE
}
:0"
)
batch_spec
=
test_data
.
batch_spec
batch_spec
=
test_data
.
batch_spec
attn_chunk_size
=
test_data
.
attn_chunk_size
attn_chunk_size
=
test_data
.
attn_chunk_size
block_size
=
test_data
.
block_size
block_size
=
test_data
.
block_size
...
...
tests/v1/attention/test_mla_backends.py
View file @
32e0c0bf
...
@@ -42,6 +42,8 @@ BACKENDS_TO_TEST = [
...
@@ -42,6 +42,8 @@ BACKENDS_TO_TEST = [
AttentionBackendEnum
.
TRITON_MLA
,
AttentionBackendEnum
.
TRITON_MLA
,
]
]
DEVICE_TYPE
=
current_platform
.
device_type
# Remove sm100 backends from the list if not using sm100
# Remove sm100 backends from the list if not using sm100
if
not
torch
.
cuda
.
is_available
()
or
torch
.
cuda
.
get_device_properties
(
0
).
major
<
10
:
if
not
torch
.
cuda
.
is_available
()
or
torch
.
cuda
.
get_device_properties
(
0
).
major
<
10
:
BACKENDS_TO_TEST
.
remove
(
AttentionBackendEnum
.
CUTLASS_MLA
)
BACKENDS_TO_TEST
.
remove
(
AttentionBackendEnum
.
CUTLASS_MLA
)
...
@@ -763,7 +765,7 @@ def test_backend_correctness(
...
@@ -763,7 +765,7 @@ def test_backend_correctness(
method
=
"ngram"
,
num_speculative_tokens
=
query_len
-
1
method
=
"ngram"
,
num_speculative_tokens
=
query_len
-
1
)
)
device
=
torch
.
device
(
"cuda
:0"
)
device
=
torch
.
device
(
f
"
{
DEVICE_TYPE
}
:0"
)
# 1. Setup
# 1. Setup
batch_size
=
batch_spec
.
batch_size
batch_size
=
batch_spec
.
batch_size
...
...
tests/v1/attention/test_sparse_mla_backends.py
View file @
32e0c0bf
...
@@ -64,6 +64,8 @@ SPARSE_BACKEND_BATCH_SPECS["large_q_pure_prefill"] = BatchSpec(
...
@@ -64,6 +64,8 @@ SPARSE_BACKEND_BATCH_SPECS["large_q_pure_prefill"] = BatchSpec(
seq_lens
=
[
256
]
*
2
,
query_lens
=
[
256
]
*
2
seq_lens
=
[
256
]
*
2
,
query_lens
=
[
256
]
*
2
)
)
DEVICE_TYPE
=
current_platform
.
device_type
def
_float_to_e8m0_truncate
(
f
:
float
)
->
float
:
def
_float_to_e8m0_truncate
(
f
:
float
)
->
float
:
"""Simulate SM100's float -> e8m0 -> bf16 scale conversion.
"""Simulate SM100's float -> e8m0 -> bf16 scale conversion.
...
@@ -222,7 +224,7 @@ def test_sparse_backend_decode_correctness(
...
@@ -222,7 +224,7 @@ def test_sparse_backend_decode_correctness(
batch_spec
=
SPARSE_BACKEND_BATCH_SPECS
[
batch_name
]
batch_spec
=
SPARSE_BACKEND_BATCH_SPECS
[
batch_name
]
use_fp8_ds_mla_quantization
=
kv_cache_dtype
==
"fp8_ds_mla"
use_fp8_ds_mla_quantization
=
kv_cache_dtype
==
"fp8_ds_mla"
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
DEVICE_TYPE
)
dtype
=
torch
.
bfloat16
dtype
=
torch
.
bfloat16
# Model hyper-parameters (kept intentionally small for the unit test)
# Model hyper-parameters (kept intentionally small for the unit test)
...
@@ -586,7 +588,7 @@ def _triton_convert_reference_impl(
...
@@ -586,7 +588,7 @@ def _triton_convert_reference_impl(
def
test_triton_convert_req_index_to_global_index_decode_only
(
def
test_triton_convert_req_index_to_global_index_decode_only
(
block_size
,
num_topk_tokens
block_size
,
num_topk_tokens
):
):
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
DEVICE_TYPE
)
num_tokens
=
8
num_tokens
=
8
num_requests
=
4
num_requests
=
4
max_blocks_per_req
=
10
max_blocks_per_req
=
10
...
@@ -639,7 +641,7 @@ def test_triton_convert_req_index_to_global_index_decode_only(
...
@@ -639,7 +641,7 @@ def test_triton_convert_req_index_to_global_index_decode_only(
reason
=
"FlashMLASparseBackend requires CUDA 9.0 or higher"
,
reason
=
"FlashMLASparseBackend requires CUDA 9.0 or higher"
,
)
)
def
test_triton_convert_req_index_to_global_index_with_prefill_workspace
(
block_size
):
def
test_triton_convert_req_index_to_global_index_with_prefill_workspace
(
block_size
):
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
DEVICE_TYPE
)
num_requests
=
4
num_requests
=
4
max_blocks_per_req
=
8
max_blocks_per_req
=
8
num_topk_tokens
=
128
num_topk_tokens
=
128
...
@@ -794,7 +796,7 @@ def test_split_indexer_prefill_chunks_single_request_overflow():
...
@@ -794,7 +796,7 @@ def test_split_indexer_prefill_chunks_single_request_overflow():
def
test_triton_convert_returns_valid_counts
():
def
test_triton_convert_returns_valid_counts
():
"""Test that return_valid_counts correctly counts non-negative indices."""
"""Test that return_valid_counts correctly counts non-negative indices."""
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
DEVICE_TYPE
)
num_tokens
=
8
num_tokens
=
8
num_requests
=
2
num_requests
=
2
max_blocks_per_req
=
10
max_blocks_per_req
=
10
...
...
tests/v1/attention/test_trtllm_attention_integration.py
View file @
32e0c0bf
...
@@ -55,6 +55,7 @@ class MockAttentionLayer:
...
@@ -55,6 +55,7 @@ class MockAttentionLayer:
MODEL
=
"Qwen/Qwen2.5-0.5B"
MODEL
=
"Qwen/Qwen2.5-0.5B"
BLOCK_SIZE
=
16
BLOCK_SIZE
=
16
NUM_GPU_BLOCKS
=
8192
NUM_GPU_BLOCKS
=
8192
DEVICE_TYPE
=
current_platform
.
device_type
BATCH_SPECS
=
{
BATCH_SPECS
=
{
"decode_only"
:
BatchSpec
(
"decode_only"
:
BatchSpec
(
...
@@ -172,7 +173,7 @@ def _run_trtllm_integration(batch_spec):
...
@@ -172,7 +173,7 @@ def _run_trtllm_integration(batch_spec):
"""Run TRTLLM attention through the full FlashInfer pipeline
"""Run TRTLLM attention through the full FlashInfer pipeline
and compare against an SDPA reference."""
and compare against an SDPA reference."""
set_random_seed
(
42
)
set_random_seed
(
42
)
device
=
torch
.
device
(
"cuda
:0"
)
device
=
torch
.
device
(
f
"
{
DEVICE_TYPE
}
:0"
)
vllm_config
=
create_vllm_config
(
vllm_config
=
create_vllm_config
(
model_name
=
MODEL
,
model_name
=
MODEL
,
...
...
tests/v1/cudagraph/test_cudagraph_dispatch.py
View file @
32e0c0bf
...
@@ -23,6 +23,8 @@ from vllm.forward_context import BatchDescriptor, set_forward_context
...
@@ -23,6 +23,8 @@ from vllm.forward_context import BatchDescriptor, set_forward_context
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.v1.cudagraph_dispatcher
import
CudagraphDispatcher
from
vllm.v1.cudagraph_dispatcher
import
CudagraphDispatcher
DEVICE_TYPE
=
current_platform
.
device_type
# Helper MLP for testing
# Helper MLP for testing
class
SimpleMLP
(
nn
.
Module
):
class
SimpleMLP
(
nn
.
Module
):
...
@@ -269,9 +271,9 @@ class TestCudagraphDispatcher:
...
@@ -269,9 +271,9 @@ class TestCudagraphDispatcher:
class
TestCUDAGraphWrapper
:
class
TestCUDAGraphWrapper
:
def
setup_method
(
self
):
def
setup_method
(
self
):
self
.
vllm_config
=
_create_vllm_config
(
CompilationConfig
())
self
.
vllm_config
=
_create_vllm_config
(
CompilationConfig
())
self
.
model
=
SimpleMLP
().
to
(
"cuda"
)
self
.
model
=
SimpleMLP
().
to
(
DEVICE_TYPE
)
self
.
persistent_input_buffer
=
torch
.
zeros
(
1
,
10
,
device
=
"cuda"
)
self
.
persistent_input_buffer
=
torch
.
zeros
(
1
,
10
,
device
=
DEVICE_TYPE
)
self
.
input_tensor
=
torch
.
randn
(
1
,
10
,
device
=
"cuda"
)
self
.
input_tensor
=
torch
.
randn
(
1
,
10
,
device
=
DEVICE_TYPE
)
def
test_capture_and_replay
(
self
):
def
test_capture_and_replay
(
self
):
wrapper
=
CUDAGraphWrapper
(
wrapper
=
CUDAGraphWrapper
(
...
@@ -428,10 +430,10 @@ class TestCudagraphIntegration:
...
@@ -428,10 +430,10 @@ class TestCudagraphIntegration:
@
create_new_process_for_each_test
(
"spawn"
)
@
create_new_process_for_each_test
(
"spawn"
)
def
test_capture_replay_bypass_logic
(
self
):
def
test_capture_replay_bypass_logic
(
self
):
model
=
SimpleMLP
().
to
(
"cuda"
)
model
=
SimpleMLP
().
to
(
DEVICE_TYPE
)
full_wrapper
=
CUDAGraphWrapper
(
model
,
self
.
vllm_config
,
CUDAGraphMode
.
FULL
)
full_wrapper
=
CUDAGraphWrapper
(
model
,
self
.
vllm_config
,
CUDAGraphMode
.
FULL
)
max_bs
=
16
max_bs
=
16
persistent_input_buffer
=
torch
.
zeros
(
max_bs
,
10
,
device
=
"cuda"
)
persistent_input_buffer
=
torch
.
zeros
(
max_bs
,
10
,
device
=
DEVICE_TYPE
)
input_1
=
persistent_input_buffer
[:
1
]
input_1
=
persistent_input_buffer
[:
1
]
input_2
=
persistent_input_buffer
[:
2
]
input_2
=
persistent_input_buffer
[:
2
]
input_3
=
persistent_input_buffer
[:
3
]
input_3
=
persistent_input_buffer
[:
3
]
...
@@ -486,17 +488,17 @@ class TestCudagraphIntegration:
...
@@ -486,17 +488,17 @@ class TestCudagraphIntegration:
@
create_new_process_for_each_test
(
"spawn"
)
@
create_new_process_for_each_test
(
"spawn"
)
def
test_nested_wrappers
(
self
):
def
test_nested_wrappers
(
self
):
"""Tests a scenario with a PIECEWISE wrapper inside a FULL one."""
"""Tests a scenario with a PIECEWISE wrapper inside a FULL one."""
model
=
SimpleMLP
().
to
(
"cuda"
)
model
=
SimpleMLP
().
to
(
DEVICE_TYPE
)
full_wrapper
=
CUDAGraphWrapper
(
model
,
self
.
vllm_config
,
CUDAGraphMode
.
FULL
)
full_wrapper
=
CUDAGraphWrapper
(
model
,
self
.
vllm_config
,
CUDAGraphMode
.
FULL
)
input_1
=
torch
.
randn
(
1
,
10
,
device
=
"cuda"
)
input_1
=
torch
.
randn
(
1
,
10
,
device
=
DEVICE_TYPE
)
# Setup: Inner model is wrapped with PIECEWISE, outer with FULL
# Setup: Inner model is wrapped with PIECEWISE, outer with FULL
inner_model
=
SimpleMLP
().
to
(
"cuda"
)
inner_model
=
SimpleMLP
().
to
(
DEVICE_TYPE
)
piecewise_wrapper
=
CUDAGraphWrapper
(
piecewise_wrapper
=
CUDAGraphWrapper
(
inner_model
,
self
.
vllm_config
,
CUDAGraphMode
.
PIECEWISE
inner_model
,
self
.
vllm_config
,
CUDAGraphMode
.
PIECEWISE
)
)
inner_model
.
forward
=
MagicMock
(
wraps
=
inner_model
.
forward
)
inner_model
.
forward
=
MagicMock
(
wraps
=
inner_model
.
forward
)
outer_model
=
SimpleMLP
().
to
(
"cuda"
)
outer_model
=
SimpleMLP
().
to
(
DEVICE_TYPE
)
# When outer model is called, it calls the piecewise_wrapper
# When outer model is called, it calls the piecewise_wrapper
outer_model
.
forward
=
MagicMock
(
outer_model
.
forward
=
MagicMock
(
wraps
=
outer_model
.
forward
,
side_effect
=
piecewise_wrapper
wraps
=
outer_model
.
forward
,
side_effect
=
piecewise_wrapper
...
...
tests/v1/determinism/test_rms_norm_batch_invariant.py
View file @
32e0c0bf
...
@@ -13,6 +13,9 @@ from utils import skip_unsupported
...
@@ -13,6 +13,9 @@ from utils import skip_unsupported
from
vllm.model_executor.layers.batch_invariant
import
rms_norm
as
triton_rms_norm
from
vllm.model_executor.layers.batch_invariant
import
rms_norm
as
triton_rms_norm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.platforms
import
current_platform
DEVICE_TYPE
=
current_platform
.
device_type
@
skip_unsupported
@
skip_unsupported
...
@@ -34,7 +37,7 @@ def test_rms_norm_batch_invariant_vs_standard(
...
@@ -34,7 +37,7 @@ def test_rms_norm_batch_invariant_vs_standard(
equivalent results to the standard CUDA implementation across various
equivalent results to the standard CUDA implementation across various
configurations.
configurations.
"""
"""
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
DEVICE_TYPE
)
# Create test input and weight
# Create test input and weight
torch
.
manual_seed
(
42
)
torch
.
manual_seed
(
42
)
...
@@ -81,7 +84,7 @@ def test_rms_norm_3d_input(
...
@@ -81,7 +84,7 @@ def test_rms_norm_3d_input(
Ensures that the batch-invariant RMS norm correctly handles multi-dimensional
Ensures that the batch-invariant RMS norm correctly handles multi-dimensional
inputs that are common in transformer models.
inputs that are common in transformer models.
"""
"""
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
DEVICE_TYPE
)
dtype
=
torch
.
bfloat16
dtype
=
torch
.
bfloat16
eps
=
1e-6
eps
=
1e-6
...
@@ -120,7 +123,7 @@ def test_rms_norm_numerical_stability(default_vllm_config):
...
@@ -120,7 +123,7 @@ def test_rms_norm_numerical_stability(default_vllm_config):
Ensures that both implementations handle edge cases like very small or large
Ensures that both implementations handle edge cases like very small or large
values without producing NaN or Inf.
values without producing NaN or Inf.
"""
"""
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
DEVICE_TYPE
)
dtype
=
torch
.
float16
dtype
=
torch
.
float16
eps
=
1e-6
eps
=
1e-6
hidden_size
=
2048
hidden_size
=
2048
...
@@ -179,7 +182,7 @@ def test_rms_norm_formula(default_vllm_config):
...
@@ -179,7 +182,7 @@ def test_rms_norm_formula(default_vllm_config):
Verifies: output = input / sqrt(mean(input^2) + eps) * weight
Verifies: output = input / sqrt(mean(input^2) + eps) * weight
"""
"""
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
DEVICE_TYPE
)
dtype
=
torch
.
float32
# Use float32 for higher precision in formula check
dtype
=
torch
.
float32
# Use float32 for higher precision in formula check
eps
=
1e-6
eps
=
1e-6
hidden_size
=
1024
hidden_size
=
1024
...
@@ -214,7 +217,7 @@ def test_rms_norm_different_hidden_sizes(default_vllm_config, hidden_size: int):
...
@@ -214,7 +217,7 @@ def test_rms_norm_different_hidden_sizes(default_vllm_config, hidden_size: int):
The Triton kernel uses a fixed BLOCK_SIZE=1024, so this tests that it
The Triton kernel uses a fixed BLOCK_SIZE=1024, so this tests that it
correctly handles hidden sizes both smaller and larger than the block size.
correctly handles hidden sizes both smaller and larger than the block size.
"""
"""
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
DEVICE_TYPE
)
dtype
=
torch
.
bfloat16
dtype
=
torch
.
bfloat16
eps
=
1e-6
eps
=
1e-6
batch_size
=
16
batch_size
=
16
...
@@ -251,7 +254,7 @@ def test_rms_norm_determinism(default_vllm_config):
...
@@ -251,7 +254,7 @@ def test_rms_norm_determinism(default_vllm_config):
Runs the same input through the kernel multiple times and verifies
Runs the same input through the kernel multiple times and verifies
identical outputs.
identical outputs.
"""
"""
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
DEVICE_TYPE
)
dtype
=
torch
.
bfloat16
dtype
=
torch
.
bfloat16
eps
=
1e-6
eps
=
1e-6
hidden_size
=
4096
hidden_size
=
4096
...
@@ -283,7 +286,7 @@ if __name__ == "__main__":
...
@@ -283,7 +286,7 @@ if __name__ == "__main__":
# Run a quick smoke test
# Run a quick smoke test
print
(
"Running quick smoke test of RMS norm implementations..."
)
print
(
"Running quick smoke test of RMS norm implementations..."
)
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
DEVICE_TYPE
)
batch_size
=
8
batch_size
=
8
hidden_size
=
4096
hidden_size
=
4096
dtype
=
torch
.
bfloat16
dtype
=
torch
.
bfloat16
...
...
tests/v1/e2e/general/test_mamba_prefix_cache.py
View file @
32e0c0bf
...
@@ -16,6 +16,7 @@ from vllm import LLM, SamplingParams, TokensPrompt
...
@@ -16,6 +16,7 @@ from vllm import LLM, SamplingParams, TokensPrompt
from
vllm.config
import
CacheConfig
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
cleanup_dist_env_and_memory
from
vllm.distributed
import
cleanup_dist_env_and_memory
from
vllm.model_executor.layers.mamba.mamba_utils
import
MambaStateCopyFunc
from
vllm.model_executor.layers.mamba.mamba_utils
import
MambaStateCopyFunc
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
,
KVCacheManager
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
,
KVCacheManager
...
@@ -48,6 +49,7 @@ num_accepted_tokens = 1
...
@@ -48,6 +49,7 @@ num_accepted_tokens = 1
prompt_token_ids
:
list
[
int
]
=
[]
prompt_token_ids
:
list
[
int
]
=
[]
MODEL
=
"Qwen/Qwen3-Next-80B-A3B-Instruct-FP8"
MODEL
=
"Qwen/Qwen3-Next-80B-A3B-Instruct-FP8"
BLOCK_SIZE
=
560
BLOCK_SIZE
=
560
DEVICE_TYPE
=
current_platform
.
device_type
NUM_HIDDEN_LAYERS
=
1
NUM_HIDDEN_LAYERS
=
1
cur_step_action_idx
=
0
cur_step_action_idx
=
0
cur_step_action
:
StepAction
|
None
=
None
cur_step_action
:
StepAction
|
None
=
None
...
@@ -71,7 +73,7 @@ def get_fake_sample_fn() -> SamplerOutput:
...
@@ -71,7 +73,7 @@ def get_fake_sample_fn() -> SamplerOutput:
return
SamplerOutput
(
return
SamplerOutput
(
sampled_token_ids
=
torch
.
tensor
(
sampled_token_ids
=
torch
.
tensor
(
[[
prompt_token_ids
[
first_token_id_index
]]],
[[
prompt_token_ids
[
first_token_id_index
]]],
device
=
"cuda"
,
device
=
DEVICE_TYPE
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
),
),
logprobs_tensors
=
None
,
logprobs_tensors
=
None
,
...
@@ -83,7 +85,9 @@ def get_fake_sample_fn() -> SamplerOutput:
...
@@ -83,7 +85,9 @@ def get_fake_sample_fn() -> SamplerOutput:
sampled_token_ids
=
accepted_tokens
sampled_token_ids
=
accepted_tokens
return
SamplerOutput
(
return
SamplerOutput
(
sampled_token_ids
=
torch
.
tensor
(
sampled_token_ids
=
torch
.
tensor
(
[
sampled_token_ids
],
device
=
"cuda"
,
dtype
=
torch
.
int32
[
sampled_token_ids
],
device
=
DEVICE_TYPE
,
dtype
=
torch
.
int32
,
),
),
logprobs_tensors
=
None
,
logprobs_tensors
=
None
,
)
)
...
@@ -128,17 +132,23 @@ def get_fake_propose_draft_token_ids_fn():
...
@@ -128,17 +132,23 @@ def get_fake_propose_draft_token_ids_fn():
-
1
-
1
+
num_accepted_tokens
+
num_accepted_tokens
],
],
device
=
"cuda"
,
device
=
DEVICE_TYPE
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
)
)
valid_sampled_tokens_count
=
torch
.
tensor
(
valid_sampled_tokens_count
=
torch
.
tensor
(
[
num_accepted_tokens
],
device
=
"cuda"
,
dtype
=
torch
.
int32
[
num_accepted_tokens
],
device
=
DEVICE_TYPE
,
dtype
=
torch
.
int32
,
)
)
self
.
_copy_valid_sampled_token_count
(
next_token_ids
,
valid_sampled_tokens_count
)
self
.
_copy_valid_sampled_token_count
(
next_token_ids
,
valid_sampled_tokens_count
)
return
torch
.
tensor
(
proposed_draft_token_ids
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
return
torch
.
tensor
(
proposed_draft_token_ids
,
device
=
DEVICE_TYPE
,
dtype
=
torch
.
int32
,
)
return
fake_propose_draft_token_ids_fn
return
fake_propose_draft_token_ids_fn
...
...
tests/v1/kv_offload/test_cpu_gpu.py
View file @
32e0c0bf
...
@@ -6,6 +6,7 @@ import time
...
@@ -6,6 +6,7 @@ import time
import
pytest
import
pytest
import
torch
import
torch
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
set_random_seed
from
vllm.utils.torch_utils
import
set_random_seed
from
vllm.v1.kv_offload.mediums
import
CPULoadStoreSpec
,
GPULoadStoreSpec
from
vllm.v1.kv_offload.mediums
import
CPULoadStoreSpec
,
GPULoadStoreSpec
from
vllm.v1.kv_offload.spec
import
(
from
vllm.v1.kv_offload.spec
import
(
...
@@ -21,7 +22,8 @@ GPU_PAGE_SIZES = [512, 1024]
...
@@ -21,7 +22,8 @@ GPU_PAGE_SIZES = [512, 1024]
BLOCK_SIZE_FACTORS
=
[
1
,
3
]
BLOCK_SIZE_FACTORS
=
[
1
,
3
]
NUM_TENSORS
=
[
4
]
NUM_TENSORS
=
[
4
]
SEEDS
=
[
0
]
SEEDS
=
[
0
]
CUDA_DEVICES
=
[
"cuda:0"
]
DEVICE_TYPE
=
current_platform
.
device_type
DEVICES
=
[
f
"
{
DEVICE_TYPE
}
:0"
]
NUM_MAPPINGS
=
[
3
]
NUM_MAPPINGS
=
[
3
]
...
@@ -33,7 +35,7 @@ NUM_MAPPINGS = [3]
...
@@ -33,7 +35,7 @@ NUM_MAPPINGS = [3]
@
pytest
.
mark
.
parametrize
(
"num_cpu_blocks"
,
NUM_CPU_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"num_cpu_blocks"
,
NUM_CPU_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"num_tensors"
,
NUM_TENSORS
)
@
pytest
.
mark
.
parametrize
(
"num_tensors"
,
NUM_TENSORS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_transfer
(
def
test_transfer
(
default_vllm_config
,
default_vllm_config
,
...
...
tests/v1/logits_processors/test_correctness.py
View file @
32e0c0bf
...
@@ -39,8 +39,9 @@ PIN_MEMORY_AVAILABLE = is_pin_memory_available()
...
@@ -39,8 +39,9 @@ PIN_MEMORY_AVAILABLE = is_pin_memory_available()
MAX_NUM_REQS
=
256
MAX_NUM_REQS
=
256
VOCAB_SIZE
=
1024
VOCAB_SIZE
=
1024
NUM_OUTPUT_TOKENS
=
20
NUM_OUTPUT_TOKENS
=
20
CUDA_DEVICES
=
[
DEVICE_TYPE
=
current_platform
.
device_type
f
"
{
current_platform
.
device_type
}
:
{
i
}
"
DEVICES
=
[
f
"
{
DEVICE_TYPE
}
:
{
i
}
"
for
i
in
range
(
1
if
current_platform
.
device_count
()
==
1
else
2
)
for
i
in
range
(
1
if
current_platform
.
device_count
()
==
1
else
2
)
]
]
MAX_NUM_PROMPT_TOKENS
=
64
MAX_NUM_PROMPT_TOKENS
=
64
...
@@ -801,7 +802,7 @@ def _assert_valid(
...
@@ -801,7 +802,7 @@ def _assert_valid(
@
create_new_process_for_each_test
()
@
create_new_process_for_each_test
()
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"reqs_per_logitproc"
,
[
REQS_PER_LOGITPROC
])
@
pytest
.
mark
.
parametrize
(
"reqs_per_logitproc"
,
[
REQS_PER_LOGITPROC
])
@
pytest
.
mark
.
parametrize
(
"logitsprocs_under_test"
,
_get_test_cases
())
@
pytest
.
mark
.
parametrize
(
"logitsprocs_under_test"
,
_get_test_cases
())
def
test_logitsprocs
(
def
test_logitsprocs
(
...
...
tests/v1/sample/test_rejection_sampler.py
View file @
32e0c0bf
...
@@ -19,7 +19,7 @@ from vllm.v1.sample.rejection_sampler import (
...
@@ -19,7 +19,7 @@ from vllm.v1.sample.rejection_sampler import (
from
vllm.v1.sample.sampler
import
Sampler
,
SamplerOutput
from
vllm.v1.sample.sampler
import
Sampler
,
SamplerOutput
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
DEVICE
=
current_platform
.
device_type
DEVICE
_TYPE
=
current_platform
.
device_type
@
pytest
.
fixture
@
pytest
.
fixture
...
@@ -57,7 +57,7 @@ def create_logits_tensor(
...
@@ -57,7 +57,7 @@ def create_logits_tensor(
will produce desired token ids on argmax"""
will produce desired token ids on argmax"""
token_ids
=
[
tokens
[:
-
1
]
for
tokens
in
output_token_ids
]
token_ids
=
[
tokens
[:
-
1
]
for
tokens
in
output_token_ids
]
num_total_tokens
=
sum
(
len
(
tokens
)
for
tokens
in
token_ids
)
num_total_tokens
=
sum
(
len
(
tokens
)
for
tokens
in
token_ids
)
logits
=
torch
.
full
((
num_total_tokens
,
vocab_size
),
-
100.0
,
device
=
DEVICE
)
logits
=
torch
.
full
((
num_total_tokens
,
vocab_size
),
-
100.0
,
device
=
DEVICE
_TYPE
)
start_loc
=
0
start_loc
=
0
for
tokens
in
token_ids
:
for
tokens
in
token_ids
:
for
j
,
token_id
in
enumerate
(
tokens
):
for
j
,
token_id
in
enumerate
(
tokens
):
...
@@ -99,9 +99,9 @@ def create_sampling_metadata(
...
@@ -99,9 +99,9 @@ def create_sampling_metadata(
assert
output_token_ids
assert
output_token_ids
assert
len
(
output_token_ids
)
>
0
assert
len
(
output_token_ids
)
>
0
frequency_penalties
=
torch
.
tensor
(
frequency_penalties
,
device
=
DEVICE
)
frequency_penalties
=
torch
.
tensor
(
frequency_penalties
,
device
=
DEVICE
_TYPE
)
presence_penalties
=
torch
.
tensor
(
presence_penalties
,
device
=
DEVICE
)
presence_penalties
=
torch
.
tensor
(
presence_penalties
,
device
=
DEVICE
_TYPE
)
repetition_penalties
=
torch
.
tensor
(
repetition_penalties
,
device
=
DEVICE
)
repetition_penalties
=
torch
.
tensor
(
repetition_penalties
,
device
=
DEVICE
_TYPE
)
else
:
else
:
no_penalties
=
True
no_penalties
=
True
frequency_penalties
=
torch
.
tensor
([])
frequency_penalties
=
torch
.
tensor
([])
...
@@ -320,14 +320,27 @@ def test_deterministic_when_seeded(
...
@@ -320,14 +320,27 @@ def test_deterministic_when_seeded(
n_rep
:
int
,
n_rep
:
int
,
):
):
num_tokens
=
batch_size
*
k
num_tokens
=
batch_size
*
k
draft_probs
=
torch
.
rand
(
num_tokens
,
vocab_size
,
dtype
=
torch
.
float32
,
device
=
DEVICE
)
draft_probs
=
torch
.
rand
(
num_tokens
,
vocab_size
,
dtype
=
torch
.
float32
,
device
=
DEVICE_TYPE
,
)
draft_probs
=
F
.
softmax
(
draft_probs
,
dim
=-
1
)
draft_probs
=
F
.
softmax
(
draft_probs
,
dim
=-
1
)
target_logits
=
torch
.
rand_like
(
draft_probs
)
target_logits
=
torch
.
rand_like
(
draft_probs
)
bonus_token_ids
=
torch
.
randint
(
bonus_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
1
),
dtype
=
torch
.
int64
,
device
=
DEVICE
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
1
),
dtype
=
torch
.
int64
,
device
=
DEVICE_TYPE
,
)
)
draft_token_ids
=
torch
.
randint
(
draft_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
k
),
dtype
=
torch
.
int64
,
device
=
DEVICE
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
k
),
dtype
=
torch
.
int64
,
device
=
DEVICE_TYPE
,
)
)
seeded_mask
=
torch
.
rand
(
batch_size
,
dtype
=
torch
.
float32
)
<=
frac_seeded
seeded_mask
=
torch
.
rand
(
batch_size
,
dtype
=
torch
.
float32
)
<=
frac_seeded
...
@@ -335,12 +348,12 @@ def test_deterministic_when_seeded(
...
@@ -335,12 +348,12 @@ def test_deterministic_when_seeded(
results
=
[]
results
=
[]
for
_
in
range
(
n_rep
):
for
_
in
range
(
n_rep
):
seeded_seqs
=
{
seeded_seqs
=
{
i
:
torch
.
Generator
(
device
=
DEVICE
).
manual_seed
(
i
)
i
:
torch
.
Generator
(
device
=
DEVICE
_TYPE
).
manual_seed
(
i
)
for
i
in
range
(
batch_size
)
for
i
in
range
(
batch_size
)
if
seeded_mask
[
i
]
if
seeded_mask
[
i
]
}
}
temperature
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
float32
,
device
=
DEVICE
)
temperature
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
float32
,
device
=
DEVICE
_TYPE
)
sampling_metadata
=
create_sampling_metadata
(
sampling_metadata
=
create_sampling_metadata
(
all_greedy
=
False
,
temperature
=
temperature
,
generators
=
seeded_seqs
all_greedy
=
False
,
temperature
=
temperature
,
generators
=
seeded_seqs
)
)
...
@@ -387,7 +400,7 @@ def test_rejection_sampling_approximates_target_distribution():
...
@@ -387,7 +400,7 @@ def test_rejection_sampling_approximates_target_distribution():
much more than the distance improvement between the observed
much more than the distance improvement between the observed
distribution and the random distribution.
distribution and the random distribution.
"""
"""
torch
.
set_default_device
(
DEVICE
)
torch
.
set_default_device
(
DEVICE
_TYPE
)
vocab_size
=
10
vocab_size
=
10
k
=
2
k
=
2
num_reference_probs
=
100
num_reference_probs
=
100
...
@@ -410,7 +423,7 @@ def test_rejection_sampling_approximates_target_distribution():
...
@@ -410,7 +423,7 @@ def test_rejection_sampling_approximates_target_distribution():
rej_sample_probs
=
estimate_rejection_sampling_pdf
(
rej_sample_probs
=
estimate_rejection_sampling_pdf
(
draft_probs
,
target_logits
,
k
,
vocab_size
,
num_samples
draft_probs
,
target_logits
,
k
,
vocab_size
,
num_samples
)
)
rej_sample_probs
=
rej_sample_probs
.
to
(
DEVICE
)
rej_sample_probs
=
rej_sample_probs
.
to
(
DEVICE
_TYPE
)
# Average distance from reference probs.
# Average distance from reference probs.
reference_vs_rejsample_dist
=
(
reference_vs_rejsample_dist
=
(
...
@@ -491,11 +504,11 @@ def estimate_rejection_sampling_pdf(
...
@@ -491,11 +504,11 @@ def estimate_rejection_sampling_pdf(
draft_probs
=
draft_probs
.
view
(
num_tokens
,
vocab_size
)
draft_probs
=
draft_probs
.
view
(
num_tokens
,
vocab_size
)
# Bonus tokens not used but required.
# Bonus tokens not used but required.
bonus_token_ids
=
torch
.
zeros
((
1
,
1
),
dtype
=
torch
.
int64
,
device
=
DEVICE
).
repeat
(
bonus_token_ids
=
torch
.
zeros
((
1
,
1
),
dtype
=
torch
.
int64
,
device
=
DEVICE
_TYPE
).
repeat
(
num_samples
,
1
num_samples
,
1
)
)
temperature
=
torch
.
ones
(
num_samples
,
dtype
=
torch
.
float32
,
device
=
DEVICE
)
temperature
=
torch
.
ones
(
num_samples
,
dtype
=
torch
.
float32
,
device
=
DEVICE
_TYPE
)
sampling_metadata
=
create_sampling_metadata
(
sampling_metadata
=
create_sampling_metadata
(
all_greedy
=
False
,
temperature
=
temperature
all_greedy
=
False
,
temperature
=
temperature
)
)
...
@@ -600,7 +613,7 @@ def _test_masked_logits(
...
@@ -600,7 +613,7 @@ def _test_masked_logits(
# Create random draft probabilities.
# Create random draft probabilities.
draft_probs
=
torch
.
rand
(
draft_probs
=
torch
.
rand
(
(
num_tokens
,
vocab_size
),
dtype
=
torch
.
float32
,
device
=
DEVICE
(
num_tokens
,
vocab_size
),
dtype
=
torch
.
float32
,
device
=
DEVICE
_TYPE
)
)
draft_probs
=
F
.
softmax
(
draft_probs
,
dim
=-
1
)
draft_probs
=
F
.
softmax
(
draft_probs
,
dim
=-
1
)
...
@@ -610,7 +623,11 @@ def _test_masked_logits(
...
@@ -610,7 +623,11 @@ def _test_masked_logits(
draft_token_ids
=
draft_token_ids
.
tolist
()
draft_token_ids
=
draft_token_ids
.
tolist
()
# Bonus tokens not used but required
# Bonus tokens not used but required
bonus_token_ids
=
torch
.
zeros
((
batch_size
,
1
),
dtype
=
torch
.
int64
,
device
=
DEVICE
)
bonus_token_ids
=
torch
.
zeros
(
(
batch_size
,
1
),
dtype
=
torch
.
int64
,
device
=
DEVICE_TYPE
,
)
# Create spec decode metadata
# Create spec decode metadata
spec_decode_metadata
=
create_spec_decode_metadata
(
draft_token_ids
,
target_logits
)
spec_decode_metadata
=
create_spec_decode_metadata
(
draft_token_ids
,
target_logits
)
...
@@ -645,12 +662,13 @@ def test_top_k(rejection_sampler, top_k):
...
@@ -645,12 +662,13 @@ def test_top_k(rejection_sampler, top_k):
# Randomly create top-k indices.
# Randomly create top-k indices.
top_k_indices
=
[
top_k_indices
=
[
torch
.
randperm
(
vocab_size
,
device
=
DEVICE
)[:
top_k
]
for
_
in
range
(
num_tokens
)
torch
.
randperm
(
vocab_size
,
device
=
DEVICE_TYPE
)[:
top_k
]
for
_
in
range
(
num_tokens
)
]
]
top_k_indices
=
torch
.
stack
(
top_k_indices
)
top_k_indices
=
torch
.
stack
(
top_k_indices
)
# Create logits with the uniform distribution.
# Create logits with the uniform distribution.
target_logits
=
torch
.
zeros
((
num_tokens
,
vocab_size
),
device
=
DEVICE
)
target_logits
=
torch
.
zeros
((
num_tokens
,
vocab_size
),
device
=
DEVICE
_TYPE
)
# Increment the logits for top-k indices, a little bit more than the other
# Increment the logits for top-k indices, a little bit more than the other
# ones. If the masking is effective, the non-topk indices will never be
# ones. If the masking is effective, the non-topk indices will never be
...
@@ -659,11 +677,11 @@ def test_top_k(rejection_sampler, top_k):
...
@@ -659,11 +677,11 @@ def test_top_k(rejection_sampler, top_k):
target_logits
[
i
,
top_k_indices
[
i
]]
+=
0.1
target_logits
[
i
,
top_k_indices
[
i
]]
+=
0.1
# Create sampling metadata
# Create sampling metadata
temperature
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
float32
,
device
=
DEVICE
)
temperature
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
float32
,
device
=
DEVICE
_TYPE
)
sampling_metadata
=
create_sampling_metadata
(
sampling_metadata
=
create_sampling_metadata
(
all_greedy
=
False
,
all_greedy
=
False
,
temperature
=
temperature
,
temperature
=
temperature
,
top_k
=
torch
.
tensor
([
top_k
]
*
batch_size
,
device
=
DEVICE
,
dtype
=
torch
.
int64
),
top_k
=
torch
.
tensor
([
top_k
]
*
batch_size
,
device
=
DEVICE
_TYPE
,
dtype
=
torch
.
int64
),
)
)
_test_masked_logits
(
_test_masked_logits
(
...
@@ -686,8 +704,8 @@ def test_top_p(rejection_sampler, top_p):
...
@@ -686,8 +704,8 @@ def test_top_p(rejection_sampler, top_p):
num_tokens
=
batch_size
*
num_draft_tokens
num_tokens
=
batch_size
*
num_draft_tokens
# Create logits with the uniform distribution.
# Create logits with the uniform distribution.
target_logits
=
torch
.
randn
((
num_tokens
,
vocab_size
),
device
=
DEVICE
)
target_logits
=
torch
.
randn
((
num_tokens
,
vocab_size
),
device
=
DEVICE
_TYPE
)
temperature
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
float32
,
device
=
DEVICE
)
temperature
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
float32
,
device
=
DEVICE
_TYPE
)
rescaled_logits
=
target_logits
/
temperature
rescaled_logits
=
target_logits
/
temperature
logits_sort
,
logits_idx
=
rescaled_logits
.
sort
(
dim
=-
1
,
descending
=
False
)
logits_sort
,
logits_idx
=
rescaled_logits
.
sort
(
dim
=-
1
,
descending
=
False
)
...
@@ -706,7 +724,11 @@ def test_top_p(rejection_sampler, top_p):
...
@@ -706,7 +724,11 @@ def test_top_p(rejection_sampler, top_p):
sampling_metadata
=
create_sampling_metadata
(
sampling_metadata
=
create_sampling_metadata
(
all_greedy
=
False
,
all_greedy
=
False
,
temperature
=
temperature
,
temperature
=
temperature
,
top_p
=
torch
.
tensor
([
top_p
]
*
batch_size
,
device
=
DEVICE
,
dtype
=
torch
.
float32
),
top_p
=
torch
.
tensor
(
[
top_p
]
*
batch_size
,
device
=
DEVICE_TYPE
,
dtype
=
torch
.
float32
,
),
)
)
_test_masked_logits
(
_test_masked_logits
(
...
@@ -732,7 +754,10 @@ def test_frequency_penalties(rejection_sampler):
...
@@ -732,7 +754,10 @@ def test_frequency_penalties(rejection_sampler):
all_greedy
=
True
,
all_greedy
=
True
,
output_token_ids
=
[[
2
],
[
3
],
[
4
]],
output_token_ids
=
[[
2
],
[
3
],
[
4
]],
spec_token_ids
=
spec_tokens
,
spec_token_ids
=
spec_tokens
,
prompt_token_ids
=
torch
.
tensor
([[
5
,
6
,
7
],
[
6
,
7
,
8
],
[
7
,
8
,
9
]],
device
=
DEVICE
),
prompt_token_ids
=
torch
.
tensor
(
[[
5
,
6
,
7
],
[
6
,
7
,
8
],
[
7
,
8
,
9
]],
device
=
DEVICE_TYPE
,
),
frequency_penalties
=
[
1.5
,
1.5
,
0.7
],
frequency_penalties
=
[
1.5
,
1.5
,
0.7
],
presence_penalties
=
[
0.0
]
*
num_requests
,
presence_penalties
=
[
0.0
]
*
num_requests
,
repetition_penalties
=
[
1.0
]
*
num_requests
,
repetition_penalties
=
[
1.0
]
*
num_requests
,
...
@@ -858,21 +883,26 @@ def test_sample_recovered_tokens(
...
@@ -858,21 +883,26 @@ def test_sample_recovered_tokens(
num_tokens
=
batch_size
*
max_spec_len
num_tokens
=
batch_size
*
max_spec_len
# Create random draft probabilities.
# Create random draft probabilities.
draft_probs
=
torch
.
rand
(
num_tokens
,
vocab_size
,
dtype
=
torch
.
float32
,
device
=
DEVICE
)
draft_probs
=
torch
.
rand
(
num_tokens
,
vocab_size
,
dtype
=
torch
.
float32
,
device
=
DEVICE_TYPE
,
)
draft_probs
=
F
.
softmax
(
draft_probs
,
dim
=-
1
)
draft_probs
=
F
.
softmax
(
draft_probs
,
dim
=-
1
)
# Create random target probabilities.
# Create random target probabilities.
target_logits
=
torch
.
rand
(
target_logits
=
torch
.
rand
(
num_tokens
,
vocab_size
,
dtype
=
torch
.
float32
,
device
=
DEVICE
num_tokens
,
vocab_size
,
dtype
=
torch
.
float32
,
device
=
DEVICE
_TYPE
)
)
target_probs
=
F
.
softmax
(
target_logits
,
dim
=-
1
)
target_probs
=
F
.
softmax
(
target_logits
,
dim
=-
1
)
# Randomly sample draft token ids from draft probs
# Randomly sample draft token ids from draft probs
draft_token_ids
=
torch
.
multinomial
(
draft_probs
,
num_samples
=
1
).
to
(
torch
.
int32
)
draft_token_ids
=
torch
.
multinomial
(
draft_probs
,
num_samples
=
1
).
to
(
torch
.
int32
)
temperature
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
float32
,
device
=
DEVICE
)
temperature
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
float32
,
device
=
DEVICE
_TYPE
)
generators
=
{
generators
=
{
i
:
torch
.
Generator
(
device
=
DEVICE
).
manual_seed
(
i
)
for
i
in
range
(
batch_size
)
i
:
torch
.
Generator
(
device
=
DEVICE
_TYPE
).
manual_seed
(
i
)
for
i
in
range
(
batch_size
)
}
}
sampling_metadata
=
create_sampling_metadata
(
sampling_metadata
=
create_sampling_metadata
(
all_greedy
=
False
,
temperature
=
temperature
,
generators
=
generators
all_greedy
=
False
,
temperature
=
temperature
,
generators
=
generators
...
@@ -890,7 +920,7 @@ def test_sample_recovered_tokens(
...
@@ -890,7 +920,7 @@ def test_sample_recovered_tokens(
None
if
no_draft_probs
else
draft_probs
,
None
if
no_draft_probs
else
draft_probs
,
target_probs
,
target_probs
,
sampling_metadata
,
sampling_metadata
,
device
=
DEVICE
,
device
=
DEVICE
_TYPE
,
)
)
recovered_token_ids
=
sample_recovered_tokens
(
recovered_token_ids
=
sample_recovered_tokens
(
max_spec_len
,
max_spec_len
,
...
@@ -900,6 +930,6 @@ def test_sample_recovered_tokens(
...
@@ -900,6 +930,6 @@ def test_sample_recovered_tokens(
None
if
no_draft_probs
else
draft_probs
,
None
if
no_draft_probs
else
draft_probs
,
target_probs
,
target_probs
,
sampling_metadata
,
sampling_metadata
,
device
=
DEVICE
,
device
=
DEVICE
_TYPE
,
)
)
assert
torch
.
equal
(
recovered_token_ids
,
ref_recovered_token_ids
)
assert
torch
.
equal
(
recovered_token_ids
,
ref_recovered_token_ids
)
tests/v1/sample/test_sampler.py
View file @
32e0c0bf
...
@@ -17,8 +17,9 @@ PIN_MEMORY_AVAILABLE = is_pin_memory_available()
...
@@ -17,8 +17,9 @@ PIN_MEMORY_AVAILABLE = is_pin_memory_available()
MAX_NUM_REQS
=
256
MAX_NUM_REQS
=
256
VOCAB_SIZE
=
1024
VOCAB_SIZE
=
1024
NUM_OUTPUT_TOKENS
=
20
NUM_OUTPUT_TOKENS
=
20
CUDA_DEVICES
=
[
DEVICE_TYPE
=
current_platform
.
device_type
f
"
{
current_platform
.
device_type
}
:
{
i
}
"
DEVICES
=
[
f
"
{
DEVICE_TYPE
}
:
{
i
}
"
for
i
in
range
(
1
if
current_platform
.
device_count
()
==
1
else
2
)
for
i
in
range
(
1
if
current_platform
.
device_count
()
==
1
else
2
)
]
]
MAX_NUM_PROMPT_TOKENS
=
64
MAX_NUM_PROMPT_TOKENS
=
64
...
@@ -199,7 +200,7 @@ def _create_weighted_output_token_list(
...
@@ -199,7 +200,7 @@ def _create_weighted_output_token_list(
return
output_token_ids
,
sorted_token_ids_in_output
return
output_token_ids
,
sorted_token_ids_in_output
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"presence_penalty"
,
[
-
2.0
,
2.0
])
@
pytest
.
mark
.
parametrize
(
"presence_penalty"
,
[
-
2.0
,
2.0
])
def
test_sampler_presence_penalty
(
def
test_sampler_presence_penalty
(
...
@@ -249,7 +250,7 @@ def test_sampler_presence_penalty(
...
@@ -249,7 +250,7 @@ def test_sampler_presence_penalty(
assert
penalized_token_id
not
in
output_token_ids
[
batch_idx
]
assert
penalized_token_id
not
in
output_token_ids
[
batch_idx
]
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"frequency_penalty"
,
[
-
2.0
,
2.0
])
@
pytest
.
mark
.
parametrize
(
"frequency_penalty"
,
[
-
2.0
,
2.0
])
def
test_sampler_frequency_penalty
(
def
test_sampler_frequency_penalty
(
...
@@ -305,7 +306,7 @@ def test_sampler_frequency_penalty(
...
@@ -305,7 +306,7 @@ def test_sampler_frequency_penalty(
assert
penalized_token_id
not
in
distinct_sorted_token_ids_in_output
assert
penalized_token_id
not
in
distinct_sorted_token_ids_in_output
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"repetition_penalty"
,
[
0.1
,
1.9
])
@
pytest
.
mark
.
parametrize
(
"repetition_penalty"
,
[
0.1
,
1.9
])
def
test_sampler_repetition_penalty
(
def
test_sampler_repetition_penalty
(
...
@@ -363,7 +364,7 @@ def test_sampler_repetition_penalty(
...
@@ -363,7 +364,7 @@ def test_sampler_repetition_penalty(
)
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"num_allowed_token_ids"
,
[
0
,
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"num_allowed_token_ids"
,
[
0
,
1
,
2
])
def
test_sampler_allowed_token_ids
(
def
test_sampler_allowed_token_ids
(
...
@@ -409,7 +410,7 @@ def test_sampler_allowed_token_ids(
...
@@ -409,7 +410,7 @@ def test_sampler_allowed_token_ids(
assert
logits_for_req
[
token_id
]
!=
-
float
(
"inf"
)
assert
logits_for_req
[
token_id
]
!=
-
float
(
"inf"
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"bad_words_lengths"
,
[(
1
,),
(
1
,
3
),
(
2
,
2
)])
@
pytest
.
mark
.
parametrize
(
"bad_words_lengths"
,
[(
1
,),
(
1
,
3
),
(
2
,
2
)])
def
test_sampler_bad_words
(
def
test_sampler_bad_words
(
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment