Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e5150f2c
Unverified
Commit
e5150f2c
authored
Jun 19, 2024
by
Thomas Parnell
Committed by
GitHub
Jun 19, 2024
Browse files
[Bugfix] Added test for sampling repetition penalty bug. (#5659)
Signed-off-by:
Thomas Parnell
<
tpa@zurich.ibm.com
>
parent
59a1eb59
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
69 additions
and
0 deletions
+69
-0
tests/samplers/test_sampler.py
tests/samplers/test_sampler.py
+69
-0
No files found.
tests/samplers/test_sampler.py
View file @
e5150f2c
...
@@ -631,3 +631,72 @@ def test_sampler_top_k_top_p(seed: int, device: str):
...
@@ -631,3 +631,72 @@ def test_sampler_top_k_top_p(seed: int, device: str):
hf_probs
=
torch
.
softmax
(
hf_probs
,
dim
=-
1
,
dtype
=
torch
.
float
)
hf_probs
=
torch
.
softmax
(
hf_probs
,
dim
=-
1
,
dtype
=
torch
.
float
)
assert
torch
.
allclose
(
hf_probs
,
sample_probs
,
atol
=
1e-5
)
assert
torch
.
allclose
(
hf_probs
,
sample_probs
,
atol
=
1e-5
)
assert
torch
.
equal
(
hf_probs
.
eq
(
0
),
sample_probs
.
eq
(
0
))
assert
torch
.
equal
(
hf_probs
.
eq
(
0
),
sample_probs
.
eq
(
0
))
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_sampler_repetition_penalty_mixed
(
device
:
str
):
vocab_size
=
8
def
test_sampling_params
(
sampling_params
:
List
[
SamplingParams
]):
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
seq_lens
:
List
[
int
]
=
[]
for
i
in
range
(
2
):
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
([
1
,
2
,
3
])},
sampling_params
=
sampling_params
[
i
],
block_tables
=
{
0
:
[
1
]},
))
seq_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_lens
,
query_lens
=
seq_lens
,
device
=
device
,
pin_memory
=
is_pin_memory_available
())
fake_logits
=
torch
.
full
((
2
,
vocab_size
),
1e-2
,
device
=
device
,
dtype
=
torch
.
float16
)
fake_logits
[:,
5
]
=
1.1e-2
fake_logits
[:,
1
]
=
1.2e-2
sampler
=
MockLogitsSampler
(
fake_logits
)
sampler_output
=
sampler
(
logits
=
fake_logits
,
sampling_metadata
=
sampling_metadata
)
generated_tokens
=
[]
for
output
in
sampler_output
:
generated_tokens
.
append
(
output
.
samples
[
0
].
output_token
)
return
generated_tokens
# one configuration is greedy with repetition_penalty
sampling_params_rep
=
SamplingParams
(
temperature
=
0.0
,
repetition_penalty
=
2.0
,
)
# other configuration is sampling w/o repetition_penalty
sampling_params_sample
=
SamplingParams
(
temperature
=
1.0
,
top_k
=
1
,
seed
=
42
,
)
tokens1
=
test_sampling_params
(
[
sampling_params_rep
,
sampling_params_sample
])
tokens2
=
test_sampling_params
(
[
sampling_params_sample
,
sampling_params_rep
])
assert
tokens1
[
0
]
==
tokens2
[
1
]
assert
tokens1
[
1
]
==
tokens2
[
0
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment