Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
5f09cbdb
Unverified
Commit
5f09cbdb
authored
Dec 02, 2023
by
Woosuk Kwon
Committed by
GitHub
Dec 02, 2023
Browse files
Fix broken sampler tests (#1896)
Co-authored-by:
Antoni Baum
<
antoni.baum@protonmail.com
>
parent
4cefa9b4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
41 additions
and
21 deletions
+41
-21
tests/samplers/test_sampler.py
tests/samplers/test_sampler.py
+37
-20
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+4
-1
No files found.
tests/samplers/test_sampler.py
View file @
5f09cbdb
...
...
@@ -8,7 +8,7 @@ import torch
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.worker.
work
er
import
Work
er
from
vllm.worker.
model_runn
er
import
ModelRunn
er
class
MockLogitsSampler
(
Sampler
):
...
...
@@ -27,7 +27,7 @@ class MockLogitsSampler(Sampler):
def
_prepare_test
(
batch_size
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
MockLogitsSampler
,
Work
er
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
MockLogitsSampler
,
ModelRunn
er
]:
vocab_size
=
32000
input_tensor
=
torch
.
rand
((
batch_size
,
1024
),
device
=
"cuda"
,
...
...
@@ -37,9 +37,8 @@ def _prepare_test(
device
=
input_tensor
.
device
,
dtype
=
input_tensor
.
dtype
)
sampler
=
MockLogitsSampler
(
32000
,
fake_logits
)
worker
=
Worker
(
None
,
None
,
None
)
worker
.
block_size
=
16
return
input_tensor
,
fake_logits
,
sampler
,
worker
model_runner
=
ModelRunner
(
None
,
None
,
None
)
return
input_tensor
,
fake_logits
,
sampler
,
model_runner
RANDOM_SEEDS
=
list
(
range
(
128
))
...
...
@@ -49,9 +48,11 @@ RANDOM_SEEDS = list(range(128))
def
test_sampler_all_greedy
(
seed
:
int
):
set_random_seed
(
seed
)
batch_size
=
random
.
randint
(
1
,
256
)
input_tensor
,
fake_logits
,
sampler
,
worker
=
_prepare_test
(
batch_size
)
input_tensor
,
fake_logits
,
sampler
,
model_runner
=
_prepare_test
(
batch_size
)
seq_group_metadata_list
=
[]
prompt_lens
=
[]
for
i
in
range
(
batch_size
):
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
...
...
@@ -61,11 +62,13 @@ def test_sampler_all_greedy(seed: int):
sampling_params
=
SamplingParams
(
temperature
=
0
,
),
block_tables
=
{
0
:
[
1
]},
))
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
_
,
_
,
input_metadata
=
worker
.
_prepare_inputs
(
seq_group_metadata_list
)
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
)
sampler_output
=
sampler
(
embedding
=
None
,
hidden_states
=
input_tensor
,
input
_metadata
=
input
_metadata
)
sampling
_metadata
=
sampling
_metadata
)
expected
=
torch
.
argmax
(
fake_logits
,
dim
=-
1
)
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
nth_output
in
sequence_output
.
samples
:
...
...
@@ -76,12 +79,14 @@ def test_sampler_all_greedy(seed: int):
def
test_sampler_all_random
(
seed
:
int
):
set_random_seed
(
seed
)
batch_size
=
random
.
randint
(
1
,
256
)
input_tensor
,
fake_logits
,
sampler
,
worker
=
_prepare_test
(
batch_size
)
input_tensor
,
fake_logits
,
sampler
,
model_runner
=
_prepare_test
(
batch_size
)
for
i
in
range
(
batch_size
):
fake_logits
[
i
,
i
]
=
1e2
seq_group_metadata_list
=
[]
prompt_lens
=
[]
for
i
in
range
(
batch_size
):
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
...
...
@@ -94,11 +99,13 @@ def test_sampler_all_random(seed: int):
),
block_tables
=
{
0
:
[
1
]},
))
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
_
,
_
,
input_metadata
=
worker
.
_prepare_inputs
(
seq_group_metadata_list
)
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
)
sampler_output
=
sampler
(
embedding
=
None
,
hidden_states
=
input_tensor
,
input
_metadata
=
input
_metadata
)
sampling
_metadata
=
sampling
_metadata
)
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
nth_output
in
sequence_output
.
samples
:
assert
nth_output
.
output_token
==
i
...
...
@@ -108,9 +115,10 @@ def test_sampler_all_random(seed: int):
def
test_sampler_all_beam
(
seed
:
int
):
set_random_seed
(
seed
)
batch_size
=
random
.
randint
(
1
,
256
)
input_tensor
,
_
,
sampler
,
work
er
=
_prepare_test
(
batch_size
)
input_tensor
,
_
,
sampler
,
model_runn
er
=
_prepare_test
(
batch_size
)
seq_group_metadata_list
=
[]
prompt_lens
=
[]
for
i
in
range
(
batch_size
):
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
...
...
@@ -124,11 +132,13 @@ def test_sampler_all_beam(seed: int):
),
block_tables
=
{
0
:
[
1
]},
))
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
_
,
_
,
input_metadata
=
worker
.
_prepare_inputs
(
seq_group_metadata_list
)
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
)
sampler
(
embedding
=
None
,
hidden_states
=
input_tensor
,
input
_metadata
=
input
_metadata
)
sampling
_metadata
=
sampling
_metadata
)
# no assertion here as I am not sure how to determine whether
# the outputs are expected - in other words, this just tests
# whether there are no exceptions in the sampler
...
...
@@ -139,10 +149,12 @@ def test_sampler_all_beam(seed: int):
def
test_sampler_mixed
(
seed
:
int
):
set_random_seed
(
seed
)
batch_size
=
random
.
randint
(
1
,
256
)
input_tensor
,
fake_logits
,
sampler
,
worker
=
_prepare_test
(
batch_size
)
input_tensor
,
fake_logits
,
sampler
,
model_runner
=
_prepare_test
(
batch_size
)
seq_group_metadata_list
=
[]
expected_tokens
=
[]
prompt_lens
=
[]
for
i
in
range
(
batch_size
):
n
=
1
sampling_type
=
random
.
randint
(
0
,
2
)
...
...
@@ -172,11 +184,13 @@ def test_sampler_mixed(seed: int):
sampling_params
=
sampling_params
,
block_tables
=
{
0
:
[
1
]},
))
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
_
,
_
,
input_metadata
=
worker
.
_prepare_inputs
(
seq_group_metadata_list
)
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
)
sampler_output
=
sampler
(
embedding
=
None
,
hidden_states
=
input_tensor
,
input
_metadata
=
input
_metadata
)
sampling
_metadata
=
sampling
_metadata
)
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
if
seq_group_metadata_list
[
i
].
sampling_params
.
use_beam_search
:
continue
...
...
@@ -188,7 +202,7 @@ def test_sampler_mixed(seed: int):
def
test_sampler_logits_processors
(
seed
:
int
):
set_random_seed
(
seed
)
batch_size
=
random
.
randint
(
1
,
256
)
input_tensor
,
_
,
sampler
,
work
er
=
_prepare_test
(
batch_size
)
input_tensor
,
_
,
sampler
,
model_runn
er
=
_prepare_test
(
batch_size
)
# This sample logits processor gives infinite score to the i-th token,
# where i is the length of the input sequence.
...
...
@@ -198,6 +212,7 @@ def test_sampler_logits_processors(seed: int):
return
logits
seq_group_metadata_list
=
[]
prompt_lens
=
[]
for
i
in
range
(
batch_size
):
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
...
...
@@ -208,11 +223,13 @@ def test_sampler_logits_processors(seed: int):
logits_processors
=
[
pick_ith
]),
block_tables
=
{
0
:
[
1
]},
))
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
_
,
_
,
input_metadata
=
worker
.
_prepare_inputs
(
seq_group_metadata_list
)
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
)
sampler_output
=
sampler
(
embedding
=
None
,
hidden_states
=
input_tensor
,
input
_metadata
=
input
_metadata
)
sampling
_metadata
=
sampling
_metadata
)
for
_
,
sequence_output
in
enumerate
(
sampler_output
):
for
idx
,
nth_output
in
enumerate
(
sequence_output
.
samples
):
assert
nth_output
.
output_token
==
idx
vllm/worker/model_runner.py
View file @
5f09cbdb
...
...
@@ -25,7 +25,10 @@ class ModelRunner:
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
sliding_window
=
model_config
.
get_sliding_window
()
# model_config can be None in tests/samplers/test_sampler.py.
# FIXME(woosuk): This is a hack to make the tests work. Refactor this.
self
.
sliding_window
=
(
model_config
.
get_sliding_window
()
if
model_config
is
not
None
else
None
)
self
.
model
=
None
self
.
block_size
=
None
# Set after initial profiling.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment