Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0faab90e
Unverified
Commit
0faab90e
authored
Sep 20, 2024
by
youkaichao
Committed by
GitHub
Sep 20, 2024
Browse files
[beam search] add output for manually checking the correctness (#8684)
parent
0455c46e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
3 deletions
+10
-3
tests/samplers/test_beam_search.py
tests/samplers/test_beam_search.py
+10
-3
No files found.
tests/samplers/test_beam_search.py
View file @
0faab90e
...
...
@@ -11,7 +11,7 @@ import pytest
# 3. Use the model "huggyllama/llama-7b".
MAX_TOKENS
=
[
128
]
BEAM_WIDTHS
=
[
4
]
MODELS
=
[
"
facebook/opt-125m
"
]
MODELS
=
[
"
TinyLlama/TinyLlama-1.1B-Chat-v1.0
"
]
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
...
...
@@ -37,8 +37,15 @@ def test_beam_search_single_input(
beam_width
,
max_tokens
)
for
i
in
range
(
len
(
example_prompts
)):
hf_output_ids
,
_
=
hf_outputs
[
i
]
vllm_output_ids
,
_
=
vllm_outputs
[
i
]
hf_output_ids
,
hf_output_texts
=
hf_outputs
[
i
]
vllm_output_ids
,
vllm_output_texts
=
vllm_outputs
[
i
]
for
i
,
(
hf_text
,
vllm_text
)
in
enumerate
(
zip
(
hf_output_texts
,
vllm_output_texts
)):
print
(
f
">>>
{
i
}
-th hf output:"
)
print
(
hf_text
)
print
(
f
">>>
{
i
}
-th vllm output:"
)
print
(
vllm_text
)
assert
len
(
hf_output_ids
)
==
len
(
vllm_output_ids
)
for
j
in
range
(
len
(
hf_output_ids
)):
assert
hf_output_ids
[
j
]
==
vllm_output_ids
[
j
],
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment