Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b668055a
Unverified
Commit
b668055a
authored
Aug 28, 2025
by
Woosuk Kwon
Committed by
GitHub
Aug 28, 2025
Browse files
[V0 Deprecation] Remove V0 Samplers test (#23862)
parent
d3d2aad5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
0 additions
and
855 deletions
+0
-855
tests/samplers/test_sampler.py
tests/samplers/test_sampler.py
+0
-769
tests/samplers/test_seeded_generate.py
tests/samplers/test_seeded_generate.py
+0
-86
No files found.
tests/samplers/test_sampler.py
deleted
100644 → 0
View file @
d3d2aad5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
itertools
import
random
from
dataclasses
import
dataclass
from
typing
import
Optional
from
unittest.mock
import
Mock
,
patch
import
pytest
import
torch
from
transformers
import
GenerationConfig
,
GenerationMixin
import
vllm.envs
as
envs
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.utils
import
Counter
,
is_pin_memory_available
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
def
use_v0_only
(
monkeypatch
):
"""
This file tests V0 internals, so set VLLM_USE_V1=0.
"""
monkeypatch
.
setenv
(
'VLLM_USE_V1'
,
'0'
)
class
MockLogitsSampler
(
Sampler
):
def
__init__
(
self
,
fake_logits
:
torch
.
Tensor
):
super
().
__init__
()
self
.
fake_logits
=
fake_logits
def
forward
(
self
,
*
args
,
**
kwargs
):
return
super
().
forward
(
*
args
,
**
kwargs
)
def
_prepare_test
(
batch_size
:
int
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
MockLogitsSampler
]:
input_tensor
=
torch
.
rand
((
batch_size
,
1024
),
dtype
=
torch
.
float16
)
fake_logits
=
torch
.
full
((
batch_size
,
VOCAB_SIZE
),
1e-2
,
dtype
=
input_tensor
.
dtype
)
sampler
=
MockLogitsSampler
(
fake_logits
)
return
input_tensor
,
fake_logits
,
sampler
VOCAB_SIZE
=
32000
RANDOM_SEEDS
=
list
(
range
(
128
))
CUDA_DEVICES
=
[
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
]
def
_do_sample
(
batch_size
:
int
,
input_tensor
:
torch
.
Tensor
,
sampler
:
MockLogitsSampler
,
sampling_params
:
SamplingParams
,
device
:
str
,
):
seq_group_metadata_list
:
list
[
SequenceGroupMetadata
]
=
[]
seq_lens
:
list
[
int
]
=
[]
for
i
in
range
(
batch_size
):
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
.
from_seqs
([
1
,
2
,
3
])},
sampling_params
=
sampling_params
,
block_tables
=
{
0
:
[
1
]},
))
seq_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_lens
,
query_lens
=
seq_lens
,
device
=
device
,
pin_memory
=
is_pin_memory_available
())
return
sampler
(
logits
=
input_tensor
,
sampling_metadata
=
sampling_metadata
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_sampler_all_greedy
(
seed
:
int
,
device
:
str
):
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
batch_size
=
random
.
randint
(
1
,
256
)
input_tensor
,
fake_logits
,
sampler
=
_prepare_test
(
batch_size
)
sampling_params
=
SamplingParams
(
temperature
=
0
)
sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
sampling_params
,
device
)
expected
=
torch
.
argmax
(
fake_logits
,
dim
=-
1
)
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
nth_output
in
sequence_output
.
samples
:
assert
nth_output
.
output_token
==
expected
[
i
].
item
()
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_sampler_all_random
(
seed
:
int
,
device
:
str
):
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
batch_size
=
random
.
randint
(
1
,
256
)
_
,
fake_logits
,
sampler
=
_prepare_test
(
batch_size
)
for
i
in
range
(
batch_size
):
fake_logits
[
i
,
i
]
=
1e2
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
n
=
random
.
randint
(
1
,
10
),
)
sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
sampling_params
,
device
)
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
nth_output
in
sequence_output
.
samples
:
assert
nth_output
.
output_token
==
i
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_sampler_all_random_seed
(
seed
:
int
,
device
:
str
):
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
batch_size
=
random
.
randint
(
1
,
256
)
_
,
fake_logits
,
sampler
=
_prepare_test
(
batch_size
)
for
i
in
range
(
batch_size
):
fake_logits
[
i
,
i
]
=
1e2
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
n
=
random
.
randint
(
1
,
10
),
seed
=
random
.
randint
(
0
,
10000
),
)
sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
sampling_params
,
device
)
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
nth_output
in
sequence_output
.
samples
:
assert
nth_output
.
output_token
==
i
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_sampler_all_random_seed_deterministic
(
seed
:
int
,
device
:
str
):
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
batch_size
=
random
.
randint
(
1
,
256
)
_
,
fake_logits
,
sampler
=
_prepare_test
(
batch_size
)
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
n
=
random
.
randint
(
1
,
10
),
seed
=
random
.
randint
(
0
,
10000
),
)
first_sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
sampling_params
,
device
)
second_sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
sampling_params
,
device
)
assert
first_sampler_output
==
second_sampler_output
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_sampler_min_tokens_penalty
(
seed
:
int
,
device
:
str
):
seq_id_counter
=
Counter
(
start
=
random
.
randint
(
0
,
100
))
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
def
create_sampling_params
(
min_tokens
,
eos_token_id
=
0
,
*
,
stop_token_ids
:
Optional
[
list
[
int
]]
=
None
,
prompt_logprobs
:
Optional
[
int
]
=
None
):
sampling_params
=
SamplingParams
(
min_tokens
=
min_tokens
,
max_tokens
=
9999
,
# keep higher than max of min_tokens
stop_token_ids
=
stop_token_ids
,
# requesting prompt_logprobs changes the structure of `logits`
prompt_logprobs
=
prompt_logprobs
,
)
sampling_params
.
all_stop_token_ids
.
add
(
eos_token_id
)
return
sampling_params
def
create_sequence_data
(
num_input
=
3
,
num_generated
=
0
):
seq_data
=
SequenceData
.
from_seqs
(
random
.
choices
(
range
(
0
,
VOCAB_SIZE
),
k
=
num_input
))
if
num_generated
>
0
:
seq_data
.
output_token_ids
=
random
.
choices
(
range
(
0
,
VOCAB_SIZE
),
k
=
num_generated
)
return
seq_data
def
generate_test_case
():
# generate multiple seq groups but limit total batch size
batch_size
=
random
.
randint
(
1
,
128
)
expected_penalization
=
[]
sequence_metadata_list
:
list
[
SequenceGroupMetadata
]
=
[]
# 20% chance to generate seq group metadata list with all prompts
is_prompt
=
random
.
random
()
<
0.2
while
batch_size
>
0
:
num_seqs
=
1
if
is_prompt
else
random
.
randint
(
1
,
batch_size
)
eos_token_id
=
random
.
randint
(
0
,
VOCAB_SIZE
-
1
)
min_tokens
=
random
.
randint
(
0
,
50
)
num_stop_tokens
=
random
.
randint
(
0
,
8
)
if
num_stop_tokens
>
0
:
stop_token_ids
=
random
.
choices
(
range
(
0
,
VOCAB_SIZE
-
1
),
k
=
num_stop_tokens
)
else
:
stop_token_ids
=
None
sampling_params
=
create_sampling_params
(
min_tokens
=
min_tokens
,
eos_token_id
=
eos_token_id
,
stop_token_ids
=
stop_token_ids
)
seq_data
:
dict
[
int
,
SequenceData
]
=
{}
seq_group_penalization
:
list
[
bool
]
=
[]
for
_
in
range
(
num_seqs
):
num_input
=
random
.
randint
(
1
,
100
)
num_generated
=
0
if
is_prompt
else
random
.
randint
(
1
,
100
)
seq_data
[
next
(
seq_id_counter
)]
=
create_sequence_data
(
num_input
=
num_input
,
num_generated
=
num_generated
)
seq_group_penalization
.
append
(
num_generated
<
min_tokens
)
expected_penalization
.
extend
(
seq_group_penalization
)
sequence_metadata_list
.
append
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
batch_size
}
"
,
is_prompt
=
is_prompt
,
seq_data
=
seq_data
,
sampling_params
=
sampling_params
,
block_tables
=
{},
))
batch_size
-=
num_seqs
return
{
"expected_penalization"
:
expected_penalization
,
"seq_group_metadata_list"
:
sequence_metadata_list
,
}
# define some explicit test cases for edge case behavior
prompt_without_penalization
=
{
"expected_penalization"
:
[
False
],
"seq_group_metadata_list"
:
[
SequenceGroupMetadata
(
request_id
=
"test_1"
,
is_prompt
=
True
,
seq_data
=
{
next
(
seq_id_counter
):
create_sequence_data
(),
},
sampling_params
=
create_sampling_params
(
0
),
block_tables
=
{},
),
]
}
prompt_with_penalization
=
{
"expected_penalization"
:
[
True
],
"seq_group_metadata_list"
:
[
SequenceGroupMetadata
(
request_id
=
"test_1"
,
is_prompt
=
True
,
seq_data
=
{
next
(
seq_id_counter
):
create_sequence_data
(),
},
sampling_params
=
create_sampling_params
(
1
),
block_tables
=
{},
),
]
}
prompt_with_penalization_and_prompt_logprobs
=
{
"expected_penalization"
:
[
False
,
False
,
True
],
"seq_group_metadata_list"
:
[
SequenceGroupMetadata
(
request_id
=
"test_1"
,
is_prompt
=
True
,
seq_data
=
{
next
(
seq_id_counter
):
create_sequence_data
(
num_input
=
3
),
},
sampling_params
=
create_sampling_params
(
1
,
prompt_logprobs
=
3
),
block_tables
=
{},
),
]
}
stop_penalizing_after_min_tokens
=
{
"expected_penalization"
:
[
False
],
"seq_group_metadata_list"
:
[
SequenceGroupMetadata
(
request_id
=
"test_1"
,
is_prompt
=
False
,
seq_data
=
{
next
(
seq_id_counter
):
create_sequence_data
(
num_generated
=
1
),
},
sampling_params
=
create_sampling_params
(
1
),
block_tables
=
{},
)
]
}
stop_token_ids
=
[
42
,
99
,
42
,
0
]
# intentional duplication
prompt_combination
=
{
"expected_penalization"
:
[
False
,
True
,
False
],
"seq_group_metadata_list"
:
[
SequenceGroupMetadata
(
request_id
=
"test_2"
,
is_prompt
=
True
,
seq_data
=
{
next
(
seq_id_counter
):
create_sequence_data
(
num_input
=
2
),
},
sampling_params
=
create_sampling_params
(
1
,
prompt_logprobs
=
3
),
block_tables
=
{},
),
SequenceGroupMetadata
(
request_id
=
"test_3"
,
is_prompt
=
True
,
seq_data
=
{
next
(
seq_id_counter
):
create_sequence_data
(),
},
sampling_params
=
create_sampling_params
(
0
,
stop_token_ids
=
stop_token_ids
),
block_tables
=
{},
)
]
}
stop_token_ids
=
[
1
,
999
,
37
,
37
]
# intentional duplication
decode_combination
=
{
"expected_penalization"
:
[
True
,
False
,
False
,
True
,
False
],
"seq_group_metadata_list"
:
[
SequenceGroupMetadata
(
request_id
=
"test_1"
,
is_prompt
=
False
,
seq_data
=
{
next
(
seq_id_counter
):
create_sequence_data
(
num_generated
=
1
),
next
(
seq_id_counter
):
create_sequence_data
(
num_generated
=
100
),
},
sampling_params
=
create_sampling_params
(
2
,
stop_token_ids
=
stop_token_ids
),
block_tables
=
{},
),
SequenceGroupMetadata
(
request_id
=
"test_2"
,
is_prompt
=
False
,
seq_data
=
{
next
(
seq_id_counter
):
create_sequence_data
(
num_generated
=
20
),
next
(
seq_id_counter
):
create_sequence_data
(
num_generated
=
1
),
next
(
seq_id_counter
):
create_sequence_data
(
num_generated
=
10
),
},
sampling_params
=
create_sampling_params
(
10
,
prompt_logprobs
=
5
,
stop_token_ids
=
stop_token_ids
),
block_tables
=
{},
),
]
}
if
seed
==
0
:
test_cases
=
[
prompt_without_penalization
,
prompt_with_penalization
,
prompt_with_penalization_and_prompt_logprobs
,
stop_penalizing_after_min_tokens
,
prompt_combination
,
decode_combination
,
]
else
:
test_cases
=
[
generate_test_case
()]
def
run_test_case
(
*
,
expected_penalization
:
list
[
bool
],
seq_group_metadata_list
:
list
[
SequenceGroupMetadata
]):
assert
expected_penalization
,
\
"Invalid test case, need expected_penalization"
assert
seq_group_metadata_list
,
\
"Invalid test case, need seq_group_metadata_list"
batch_size
=
0
seq_lens
:
list
[
int
]
=
[]
sampling_params_per_row
:
list
[
SamplingParams
]
=
[]
for
sgm
in
seq_group_metadata_list
:
sampling_params
=
sgm
.
sampling_params
num_rows
=
len
(
sgm
.
seq_data
)
if
sgm
.
is_prompt
:
# a prompt seq_group has only one sequence
seq_data
=
next
(
iter
(
sgm
.
seq_data
.
values
()))
prompt_len
=
seq_data
.
get_prompt_len
()
seq_lens
.
append
(
prompt_len
)
assert
sgm
.
sampling_params
is
not
None
if
sgm
.
sampling_params
.
prompt_logprobs
:
# with prompt_logprobs each token in the prompt has a row in
# logits
num_rows
=
prompt_len
batch_size
+=
num_rows
sampling_params_per_row
.
extend
(
itertools
.
repeat
(
sampling_params
,
num_rows
))
assert
len
(
expected_penalization
)
==
batch_size
,
\
(
"Invalid test case, expected_penalization does not match computed"
"batch size"
)
_
,
fake_logits
,
sampler
=
_prepare_test
(
batch_size
)
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_lens
=
seq_lens
if
seq_lens
else
None
,
query_lens
=
seq_lens
if
seq_lens
else
[
1
]
*
batch_size
,
device
=
device
,
pin_memory
=
is_pin_memory_available
())
# the logits tensor is modified in-place by the sampler
_
=
sampler
(
logits
=
fake_logits
,
sampling_metadata
=
sampling_metadata
)
for
logits_idx
,
(
should_penalize
,
sampling_params
)
in
enumerate
(
zip
(
expected_penalization
,
sampling_params_per_row
)):
tokens_to_check
=
sampling_params
.
all_stop_token_ids
if
should_penalize
:
for
token_id
in
tokens_to_check
:
assert
fake_logits
[
logits_idx
,
token_id
]
==
-
float
(
'inf'
),
f
"Expected token
{
token_id
}
for logits row
{
logits_idx
}
"
" to be penalized"
# no other tokens should be set to -inf
assert
torch
.
count_nonzero
(
fake_logits
[
logits_idx
,
:]
==
-
float
(
'inf'
))
==
len
(
tokens_to_check
),
f
"Expected only
{
len
(
tokens_to_check
)
}
to be penalized"
else
:
# no tokens should be set to -inf
assert
torch
.
count_nonzero
(
fake_logits
[
logits_idx
,
:]
==
-
float
(
'inf'
))
==
0
,
"No tokens should have been penalized"
for
test_case
in
test_cases
:
run_test_case
(
**
test_case
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_sampler_mixed
(
seed
:
int
,
device
:
str
):
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
batch_size
=
random
.
randint
(
1
,
256
)
input_tensor
,
fake_logits
,
sampler
=
_prepare_test
(
batch_size
)
seq_group_metadata_list
:
list
[
SequenceGroupMetadata
]
=
[]
expected_tokens
:
list
[
Optional
[
list
[
int
]]]
=
[]
seq_lens
:
list
[
int
]
=
[]
for
i
in
range
(
batch_size
):
expected
:
Optional
[
list
[
int
]]
=
None
sampling_type
=
random
.
randint
(
0
,
2
)
if
sampling_type
==
0
:
sampling_params
=
SamplingParams
(
temperature
=
0
)
expected
=
[
int
(
torch
.
argmax
(
fake_logits
[
i
],
dim
=-
1
).
item
())]
elif
sampling_type
in
(
1
,
2
):
n
=
random
.
randint
(
1
,
10
)
sampling_params
=
SamplingParams
(
temperature
=
random
.
random
()
+
0.1
,
top_p
=
min
(
random
.
random
()
+
0.1
,
1
),
top_k
=
random
.
randint
(
0
,
10
),
n
=
n
,
presence_penalty
=
random
.
randint
(
0
,
1
),
)
if
sampling_type
==
2
:
sampling_params
.
seed
=
random
.
randint
(
0
,
10000
)
else
:
for
idx
in
range
(
n
):
fake_logits
[
i
,
i
+
idx
]
=
1e2
expected
=
list
(
range
(
i
,
i
+
n
))
expected_tokens
.
append
(
expected
)
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
.
from_seqs
([
1
,
2
,
3
])},
sampling_params
=
sampling_params
,
block_tables
=
{
0
:
[
1
]},
))
seq_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
generators
:
dict
[
str
,
torch
.
Generator
]
=
{}
def
test_sampling
():
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_lens
,
query_lens
=
seq_lens
,
device
=
device
,
pin_memory
=
is_pin_memory_available
(),
generators
=
generators
)
sampler_output
=
sampler
(
logits
=
fake_logits
,
sampling_metadata
=
sampling_metadata
)
for
i
,
(
sequence_output
,
metadata
)
in
enumerate
(
zip
(
sampler_output
,
seq_group_metadata_list
)):
assert
metadata
.
sampling_params
is
not
None
if
(
metadata
.
sampling_params
.
seed
is
not
None
and
expected_tokens
[
i
]
is
None
):
# Record seeded random result to compare with results of
# second invocation
expected_tokens
[
i
]
=
[
nth_output
.
output_token
for
nth_output
in
sequence_output
.
samples
]
continue
expected_tokens_item
=
expected_tokens
[
i
]
assert
expected_tokens_item
is
not
None
for
n
,
nth_output
in
enumerate
(
sequence_output
.
samples
):
assert
metadata
.
sampling_params
is
not
None
if
(
metadata
.
sampling_params
.
temperature
==
0
or
metadata
.
sampling_params
.
seed
is
not
None
):
# Ensure exact matches for greedy or random with seed
assert
nth_output
.
output_token
==
expected_tokens_item
[
n
]
else
:
# For non-seeded random check that one of the high-logit
# tokens were chosen
assert
nth_output
.
output_token
in
expected_tokens_item
# Test batch
test_sampling
()
# Shuffle the batch and resample
target_index
=
list
(
range
(
batch_size
))
for
list_to_shuffle
in
(
target_index
,
seq_group_metadata_list
,
expected_tokens
,
seq_lens
):
random
.
Random
(
seed
).
shuffle
(
list_to_shuffle
)
target_index
=
torch
.
tensor
(
target_index
)
input_tensor
.
data
=
input_tensor
.
index_select
(
0
,
target_index
)
fake_logits
.
data
=
fake_logits
.
index_select
(
0
,
target_index
)
# This time, results of seeded random samples will be compared with
# the corresponding sample in the pre-shuffled batch
test_sampling
()
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_sampler_top_k_top_p
(
seed
:
int
,
device
:
str
):
set_random_seed
(
seed
)
batch_size
=
random
.
randint
(
1
,
256
)
top_k
=
random
.
randint
(
100
,
500
)
top_p
=
random
.
random
()
*
0.1
vocab_size
=
32000
input_tensor
=
torch
.
rand
((
batch_size
,
1024
),
device
=
device
,
dtype
=
torch
.
float16
)
fake_logits
=
torch
.
normal
(
0
,
5
,
size
=
(
batch_size
,
vocab_size
),
device
=
input_tensor
.
device
,
dtype
=
input_tensor
.
dtype
)
sampler
=
MockLogitsSampler
(
fake_logits
)
generation_model
=
GenerationMixin
()
generation_config
=
GenerationConfig
(
top_k
=
top_k
,
top_p
=
top_p
,
do_sample
=
True
)
@
dataclass
class
MockConfig
:
is_encoder_decoder
:
bool
=
False
generation_model
.
config
=
MockConfig
()
# needed by the following method
generation_model
.
_prepare_special_tokens
(
generation_config
,
device
=
device
)
processors
=
generation_model
.
_get_logits_processor
(
generation_config
,
None
,
None
,
None
,
[],
device
=
device
)
assert
len
(
processors
)
==
2
# top_p and top_k
seq_group_metadata_list
:
list
[
SequenceGroupMetadata
]
=
[]
seq_lens
:
list
[
int
]
=
[]
for
i
in
range
(
batch_size
):
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
.
from_seqs
([
1
,
2
,
3
])},
sampling_params
=
SamplingParams
(
temperature
=
1
,
top_k
=
top_k
,
top_p
=
top_p
,
),
block_tables
=
{
0
:
[
1
]},
))
seq_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_lens
,
query_lens
=
seq_lens
,
device
=
device
,
pin_memory
=
is_pin_memory_available
())
sample_probs
=
None
def
mock_sample
(
probs
,
*
args
,
**
kwargs
):
nonlocal
sample_probs
sample_probs
=
probs
return
([[
prob
.
topk
(
1
,
dim
=-
1
).
indices
.
tolist
(),
[
0
]]
for
prob
in
probs
],
None
)
# top-k and top-p is only calculated when flashinfer kernel is not available
with
patch
(
"vllm.model_executor.layers.sampler._sample"
,
mock_sample
),
\
patch
(
"vllm.model_executor.layers.sampler."
"flashinfer_top_k_top_p_sampling"
,
None
):
sampler
(
logits
=
fake_logits
,
sampling_metadata
=
sampling_metadata
)
assert
sample_probs
is
not
None
hf_probs
=
processors
(
torch
.
zeros_like
(
fake_logits
),
fake_logits
.
clone
())
hf_probs
=
torch
.
softmax
(
hf_probs
,
dim
=-
1
,
dtype
=
torch
.
float
)
torch
.
testing
.
assert_close
(
hf_probs
,
sample_probs
,
rtol
=
0.0
,
atol
=
1e-5
)
assert
torch
.
equal
(
hf_probs
.
eq
(
0
),
sample_probs
.
eq
(
0
))
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_flashinfer_fallback
(
seed
:
int
,
device
:
str
):
if
not
envs
.
VLLM_USE_FLASHINFER_SAMPLER
:
pytest
.
skip
(
"Flashinfer sampler is disabled"
)
pytest
.
skip
(
"After FlashInfer 0.2.3, sampling will never fail"
)
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
batch_size
=
random
.
randint
(
1
,
256
)
_
,
fake_logits
,
sampler
=
_prepare_test
(
batch_size
)
def
failing_flashinfer_sampling
(
*
_args
,
**
_kwargs
):
return
None
,
torch
.
zeros
(
batch_size
,
device
=
device
,
dtype
=
torch
.
int32
)
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
n
=
random
.
randint
(
1
,
10
),
seed
=
random
.
randint
(
0
,
10000
),
)
sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
sampling_params
,
device
)
with
patch
(
"vllm.model_executor.layers.sampler."
"flashinfer_top_k_top_p_sampling"
,
failing_flashinfer_sampling
):
fallback_sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
sampling_params
,
device
)
assert
sampler_output
==
fallback_sampler_output
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_sampler_repetition_penalty_mixed
(
device
:
str
):
vocab_size
=
8
def
test_sampling_params
(
sampling_params
:
list
[
SamplingParams
]):
seq_group_metadata_list
:
list
[
SequenceGroupMetadata
]
=
[]
seq_lens
:
list
[
int
]
=
[]
for
i
in
range
(
2
):
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
.
from_seqs
([
1
,
2
,
3
])},
sampling_params
=
sampling_params
[
i
],
block_tables
=
{
0
:
[
1
]},
))
seq_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_lens
,
query_lens
=
seq_lens
,
device
=
device
,
pin_memory
=
is_pin_memory_available
())
fake_logits
=
torch
.
full
((
2
,
vocab_size
),
1e-2
,
device
=
device
,
dtype
=
torch
.
float16
)
fake_logits
[:,
5
]
=
1.1e-2
fake_logits
[:,
1
]
=
1.2e-2
sampler
=
MockLogitsSampler
(
fake_logits
)
sampler_output
=
sampler
(
logits
=
fake_logits
,
sampling_metadata
=
sampling_metadata
)
generated_tokens
=
[]
for
output
in
sampler_output
:
generated_tokens
.
append
(
output
.
samples
[
0
].
output_token
)
return
generated_tokens
# one configuration is greedy with repetition_penalty
sampling_params_rep
=
SamplingParams
(
temperature
=
0.0
,
repetition_penalty
=
2.0
,
)
# other configuration is sampling w/o repetition_penalty
sampling_params_sample
=
SamplingParams
(
temperature
=
1.0
,
top_k
=
1
,
seed
=
42
,
)
tokens1
=
test_sampling_params
(
[
sampling_params_rep
,
sampling_params_sample
])
tokens2
=
test_sampling_params
(
[
sampling_params_sample
,
sampling_params_rep
])
assert
tokens1
[
0
]
==
tokens2
[
1
]
assert
tokens1
[
1
]
==
tokens2
[
0
]
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_sampler_include_gpu_probs_tensor
(
device
:
str
):
set_random_seed
(
42
)
torch
.
set_default_device
(
device
)
batch_size
=
random
.
randint
(
1
,
256
)
_
,
fake_logits
,
sampler
=
_prepare_test
(
batch_size
)
sampler
.
include_gpu_probs_tensor
=
True
sampler
.
should_modify_greedy_probs_inplace
=
False
sampling_params
=
SamplingParams
(
temperature
=
0
)
mock_inplace
=
Mock
()
with
patch
(
"vllm.model_executor.layers.sampler._modify_greedy_probs_inplace"
,
mock_inplace
):
sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
sampling_params
,
device
)
mock_inplace
.
assert_not_called
()
assert
sampler_output
.
sampled_token_probs
is
not
None
assert
sampler_output
.
logprobs
is
not
None
assert
sampler_output
.
sampled_token_ids
is
not
None
tests/samplers/test_seeded_generate.py
deleted
100644 → 0
View file @
d3d2aad5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Verify that seeded random sampling is deterministic.
Run `pytest tests/samplers/test_seeded_generate.py`.
"""
import
copy
import
random
from
itertools
import
combinations
import
pytest
from
vllm
import
SamplingParams
from
vllm.model_executor.utils
import
set_random_seed
MODEL
=
"facebook/opt-125m"
RANDOM_SEEDS
=
list
(
range
(
5
))
@
pytest
.
fixture
def
vllm_model
(
vllm_runner
,
monkeypatch
):
# This file relies on V0 internals.
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"0"
)
with
vllm_runner
(
MODEL
,
dtype
=
"half"
)
as
vllm_model
:
yield
vllm_model
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
def
test_random_sample_with_seed
(
vllm_model
,
example_prompts
,
seed
:
int
,
)
->
None
:
set_random_seed
(
seed
)
sampling_params
=
SamplingParams
(
# Parameters to ensure sufficient randomness
temperature
=
3.0
,
top_p
=
min
(
random
.
random
()
+
0.3
,
1
),
top_k
=
random
.
randint
(
5
,
20
),
n
=
random
.
randint
(
1
,
10
),
presence_penalty
=
random
.
randint
(
0
,
1
),
max_tokens
=
8
,
ignore_eos
=
True
,
)
sampling_params_seed_1
=
copy
.
deepcopy
(
sampling_params
)
sampling_params_seed_1
.
seed
=
100
sampling_params_seed_2
=
copy
.
deepcopy
(
sampling_params
)
sampling_params_seed_2
.
seed
=
200
llm
=
vllm_model
.
llm
for
prompt
in
example_prompts
:
for
params
in
(
sampling_params
,
sampling_params_seed_1
,
sampling_params_seed_2
,
sampling_params
,
sampling_params_seed_1
,
sampling_params_seed_2
,
):
llm
.
_add_request
(
prompt
,
params
=
params
)
results
=
llm
.
_run_engine
(
use_tqdm
=
False
)
all_outputs
=
[[
out
.
token_ids
for
out
in
output
.
outputs
]
for
output
in
results
]
for
i
in
range
(
0
,
len
(
example_prompts
),
6
):
outputs
=
all_outputs
[
i
:
i
+
6
]
# verify all non-seeded requests differ
for
output_a
,
output_b
in
combinations
(
(
outputs
[
0
],
outputs
[
1
],
outputs
[
2
],
outputs
[
3
]),
2
,
):
assert
output_a
!=
output_b
# verify requests with the same seed match
assert
outputs
[
1
]
==
outputs
[
4
]
assert
outputs
[
2
]
==
outputs
[
5
]
# verify generations within the same parallel sampling group differ
for
output
in
outputs
:
for
sub_output_a
,
sub_output_b
in
combinations
(
output
,
2
):
assert
sub_output_a
!=
sub_output_b
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment