Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
d3a5bd9f
Unverified
Commit
d3a5bd9f
authored
Oct 16, 2023
by
Woosuk Kwon
Committed by
GitHub
Oct 16, 2023
Browse files
Fix sampler test (#1379)
parent
e8ef4c08
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
4 deletions
+4
-4
tests/samplers/test_sampler.py
tests/samplers/test_sampler.py
+4
-4
No files found.
tests/samplers/test_sampler.py
View file @
d3a5bd9f
# pylint: disable=protected-access
# pylint: disable=protected-access
import
pytest
import
random
import
random
from
typing
import
Tuple
from
typing
import
Tuple
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
pytest
import
torch
import
torch
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
...
@@ -69,7 +69,7 @@ def test_sampler_all_greedy(seed: int):
...
@@ -69,7 +69,7 @@ def test_sampler_all_greedy(seed: int):
input_metadata
=
input_metadata
)
input_metadata
=
input_metadata
)
expected
=
torch
.
argmax
(
fake_logits
,
dim
=-
1
)
expected
=
torch
.
argmax
(
fake_logits
,
dim
=-
1
)
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
nth_output
in
sequence_output
:
for
nth_output
in
sequence_output
.
samples
:
assert
nth_output
.
output_token
==
expected
[
i
].
item
()
assert
nth_output
.
output_token
==
expected
[
i
].
item
()
...
@@ -101,7 +101,7 @@ def test_sampler_all_random(seed: int):
...
@@ -101,7 +101,7 @@ def test_sampler_all_random(seed: int):
hidden_states
=
input_tensor
,
hidden_states
=
input_tensor
,
input_metadata
=
input_metadata
)
input_metadata
=
input_metadata
)
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
nth_output
in
sequence_output
:
for
nth_output
in
sequence_output
.
samples
:
assert
nth_output
.
output_token
==
i
assert
nth_output
.
output_token
==
i
...
@@ -181,5 +181,5 @@ def test_sampler_mixed(seed: int):
...
@@ -181,5 +181,5 @@ def test_sampler_mixed(seed: int):
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
if
seq_group_metadata_list
[
i
].
sampling_params
.
use_beam_search
:
if
seq_group_metadata_list
[
i
].
sampling_params
.
use_beam_search
:
continue
continue
for
nth_output
in
sequence_output
:
for
nth_output
in
sequence_output
.
samples
:
assert
nth_output
.
output_token
in
expected_tokens
assert
nth_output
.
output_token
in
expected_tokens
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment