Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
5f09cbdb
Unverified
Commit
5f09cbdb
authored
Dec 02, 2023
by
Woosuk Kwon
Committed by
GitHub
Dec 02, 2023
Browse files
Fix broken sampler tests (#1896)
Co-authored-by:
Antoni Baum
<
antoni.baum@protonmail.com
>
parent
4cefa9b4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
41 additions
and
21 deletions
+41
-21
tests/samplers/test_sampler.py
tests/samplers/test_sampler.py
+37
-20
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+4
-1
No files found.
tests/samplers/test_sampler.py
View file @
5f09cbdb
...
...
@@ -8,7 +8,7 @@ import torch
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.worker.
work
er
import
Work
er
from
vllm.worker.
model_runn
er
import
ModelRunn
er
class
MockLogitsSampler
(
Sampler
):
...
...
@@ -27,7 +27,7 @@ class MockLogitsSampler(Sampler):
def
_prepare_test
(
batch_size
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
MockLogitsSampler
,
Work
er
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
MockLogitsSampler
,
ModelRunn
er
]:
vocab_size
=
32000
input_tensor
=
torch
.
rand
((
batch_size
,
1024
),
device
=
"cuda"
,
...
...
@@ -37,9 +37,8 @@ def _prepare_test(
device
=
input_tensor
.
device
,
dtype
=
input_tensor
.
dtype
)
sampler
=
MockLogitsSampler
(
32000
,
fake_logits
)
worker
=
Worker
(
None
,
None
,
None
)
worker
.
block_size
=
16
return
input_tensor
,
fake_logits
,
sampler
,
worker
model_runner
=
ModelRunner
(
None
,
None
,
None
)
return
input_tensor
,
fake_logits
,
sampler
,
model_runner
RANDOM_SEEDS
=
list
(
range
(
128
))
...
...
@@ -49,9 +48,11 @@ RANDOM_SEEDS = list(range(128))
def
test_sampler_all_greedy
(
seed
:
int
):
set_random_seed
(
seed
)
batch_size
=
random
.
randint
(
1
,
256
)
input_tensor
,
fake_logits
,
sampler
,
worker
=
_prepare_test
(
batch_size
)
input_tensor
,
fake_logits
,
sampler
,
model_runner
=
_prepare_test
(
batch_size
)
seq_group_metadata_list
=
[]
prompt_lens
=
[]
for
i
in
range
(
batch_size
):
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
...
...
@@ -61,11 +62,13 @@ def test_sampler_all_greedy(seed: int):
sampling_params
=
SamplingParams
(
temperature
=
0
,
),
block_tables
=
{
0
:
[
1
]},
))
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
_
,
_
,
input_metadata
=
worker
.
_prepare_inputs
(
seq_group_metadata_list
)
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
)
sampler_output
=
sampler
(
embedding
=
None
,
hidden_states
=
input_tensor
,
input
_metadata
=
input
_metadata
)
sampling
_metadata
=
sampling
_metadata
)
expected
=
torch
.
argmax
(
fake_logits
,
dim
=-
1
)
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
nth_output
in
sequence_output
.
samples
:
...
...
@@ -76,12 +79,14 @@ def test_sampler_all_greedy(seed: int):
def
test_sampler_all_random
(
seed
:
int
):
set_random_seed
(
seed
)
batch_size
=
random
.
randint
(
1
,
256
)
input_tensor
,
fake_logits
,
sampler
,
worker
=
_prepare_test
(
batch_size
)
input_tensor
,
fake_logits
,
sampler
,
model_runner
=
_prepare_test
(
batch_size
)
for
i
in
range
(
batch_size
):
fake_logits
[
i
,
i
]
=
1e2
seq_group_metadata_list
=
[]
prompt_lens
=
[]
for
i
in
range
(
batch_size
):
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
...
...
@@ -94,11 +99,13 @@ def test_sampler_all_random(seed: int):
),
block_tables
=
{
0
:
[
1
]},
))
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
_
,
_
,
input_metadata
=
worker
.
_prepare_inputs
(
seq_group_metadata_list
)
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
)
sampler_output
=
sampler
(
embedding
=
None
,
hidden_states
=
input_tensor
,
input
_metadata
=
input
_metadata
)
sampling
_metadata
=
sampling
_metadata
)
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
nth_output
in
sequence_output
.
samples
:
assert
nth_output
.
output_token
==
i
...
...
@@ -108,9 +115,10 @@ def test_sampler_all_random(seed: int):
def
test_sampler_all_beam
(
seed
:
int
):
set_random_seed
(
seed
)
batch_size
=
random
.
randint
(
1
,
256
)
input_tensor
,
_
,
sampler
,
work
er
=
_prepare_test
(
batch_size
)
input_tensor
,
_
,
sampler
,
model_runn
er
=
_prepare_test
(
batch_size
)
seq_group_metadata_list
=
[]
prompt_lens
=
[]
for
i
in
range
(
batch_size
):
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
...
...
@@ -124,11 +132,13 @@ def test_sampler_all_beam(seed: int):
),
block_tables
=
{
0
:
[
1
]},
))
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
_
,
_
,
input_metadata
=
worker
.
_prepare_inputs
(
seq_group_metadata_list
)
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
)
sampler
(
embedding
=
None
,
hidden_states
=
input_tensor
,
input
_metadata
=
input
_metadata
)
sampling
_metadata
=
sampling
_metadata
)
# no assertion here as I am not sure how to determine whether
# the outputs are expected - in other words, this just tests
# whether there are no exceptions in the sampler
...
...
@@ -139,10 +149,12 @@ def test_sampler_all_beam(seed: int):
def
test_sampler_mixed
(
seed
:
int
):
set_random_seed
(
seed
)
batch_size
=
random
.
randint
(
1
,
256
)
input_tensor
,
fake_logits
,
sampler
,
worker
=
_prepare_test
(
batch_size
)
input_tensor
,
fake_logits
,
sampler
,
model_runner
=
_prepare_test
(
batch_size
)
seq_group_metadata_list
=
[]
expected_tokens
=
[]
prompt_lens
=
[]
for
i
in
range
(
batch_size
):
n
=
1
sampling_type
=
random
.
randint
(
0
,
2
)
...
...
@@ -172,11 +184,13 @@ def test_sampler_mixed(seed: int):
sampling_params
=
sampling_params
,
block_tables
=
{
0
:
[
1
]},
))
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
_
,
_
,
input_metadata
=
worker
.
_prepare_inputs
(
seq_group_metadata_list
)
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
)
sampler_output
=
sampler
(
embedding
=
None
,
hidden_states
=
input_tensor
,
input
_metadata
=
input
_metadata
)
sampling
_metadata
=
sampling
_metadata
)
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
if
seq_group_metadata_list
[
i
].
sampling_params
.
use_beam_search
:
continue
...
...
@@ -188,7 +202,7 @@ def test_sampler_mixed(seed: int):
def
test_sampler_logits_processors
(
seed
:
int
):
set_random_seed
(
seed
)
batch_size
=
random
.
randint
(
1
,
256
)
input_tensor
,
_
,
sampler
,
work
er
=
_prepare_test
(
batch_size
)
input_tensor
,
_
,
sampler
,
model_runn
er
=
_prepare_test
(
batch_size
)
# This sample logits processor gives infinite score to the i-th token,
# where i is the length of the input sequence.
...
...
@@ -198,6 +212,7 @@ def test_sampler_logits_processors(seed: int):
return
logits
seq_group_metadata_list
=
[]
prompt_lens
=
[]
for
i
in
range
(
batch_size
):
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
...
...
@@ -208,11 +223,13 @@ def test_sampler_logits_processors(seed: int):
logits_processors
=
[
pick_ith
]),
block_tables
=
{
0
:
[
1
]},
))
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
_
,
_
,
input_metadata
=
worker
.
_prepare_inputs
(
seq_group_metadata_list
)
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
)
sampler_output
=
sampler
(
embedding
=
None
,
hidden_states
=
input_tensor
,
input
_metadata
=
input
_metadata
)
sampling
_metadata
=
sampling
_metadata
)
for
_
,
sequence_output
in
enumerate
(
sampler_output
):
for
idx
,
nth_output
in
enumerate
(
sequence_output
.
samples
):
assert
nth_output
.
output_token
==
idx
vllm/worker/model_runner.py
View file @
5f09cbdb
...
...
@@ -25,7 +25,10 @@ class ModelRunner:
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
sliding_window
=
model_config
.
get_sliding_window
()
# model_config can be None in tests/samplers/test_sampler.py.
# FIXME(woosuk): This is a hack to make the tests work. Refactor this.
self
.
sliding_window
=
(
model_config
.
get_sliding_window
()
if
model_config
is
not
None
else
None
)
self
.
model
=
None
self
.
block_size
=
None
# Set after initial profiling.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment