Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
425040d4
Unverified
Commit
425040d4
authored
Jun 28, 2023
by
Lily Liu
Committed by
GitHub
Jun 28, 2023
Browse files
remove floats == 0 comparison (#285)
parent
4338cc47
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
9 deletions
+11
-9
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+6
-5
vllm/sampling_params.py
vllm/sampling_params.py
+5
-4
No files found.
vllm/model_executor/layers/sampler.py
View file @
425040d4
...
@@ -11,6 +11,7 @@ from vllm.model_executor.parallel_utils.tensor_parallel import (
...
@@ -11,6 +11,7 @@ from vllm.model_executor.parallel_utils.tensor_parallel import (
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
SequenceOutputs
from
vllm.sequence
import
SequenceOutputs
_SAMPLING_EPS
=
1e-5
class
Sampler
(
nn
.
Module
):
class
Sampler
(
nn
.
Module
):
"""Samples the next tokens from the model's outputs.
"""Samples the next tokens from the model's outputs.
...
@@ -74,7 +75,7 @@ class Sampler(nn.Module):
...
@@ -74,7 +75,7 @@ class Sampler(nn.Module):
# Apply top-p and top-k truncation.
# Apply top-p and top-k truncation.
top_ps
,
top_ks
=
_get_top_p_top_k
(
input_metadata
,
self
.
vocab_size
)
top_ps
,
top_ks
=
_get_top_p_top_k
(
input_metadata
,
self
.
vocab_size
)
assert
len
(
top_ps
)
==
len
(
top_ks
)
==
probs
.
shape
[
0
]
assert
len
(
top_ps
)
==
len
(
top_ks
)
==
probs
.
shape
[
0
]
if
any
(
p
<
1.0
for
p
in
top_ps
)
or
any
(
k
!=
self
.
vocab_size
for
k
in
top_ks
):
if
any
(
p
<
1.0
-
_SAMPLING_EPS
for
p
in
top_ps
)
or
any
(
k
!=
self
.
vocab_size
for
k
in
top_ks
):
probs
=
_apply_top_p_top_k
(
probs
,
top_ps
,
top_ks
)
probs
=
_apply_top_p_top_k
(
probs
,
top_ps
,
top_ks
)
# Sample the next tokens.
# Sample the next tokens.
...
@@ -152,7 +153,7 @@ def _apply_penalties(
...
@@ -152,7 +153,7 @@ def _apply_penalties(
continue
continue
p
=
presence_penalties
[
i
]
p
=
presence_penalties
[
i
]
f
=
frequency_penalties
[
i
]
f
=
frequency_penalties
[
i
]
if
p
==
0.0
and
f
==
0.0
:
if
p
<
_SAMPLING_EPS
and
f
<
_SAMPLING_EPS
:
continue
continue
indices
.
append
(
i
)
indices
.
append
(
i
)
...
@@ -190,7 +191,7 @@ def _get_temperatures(
...
@@ -190,7 +191,7 @@ def _get_temperatures(
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
seq_ids
,
sampling_params
=
seq_group
temperature
=
sampling_params
.
temperature
temperature
=
sampling_params
.
temperature
if
temperature
==
0.0
:
if
temperature
<
_SAMPLING_EPS
:
# NOTE: Zero temperature means deterministic sampling
# NOTE: Zero temperature means deterministic sampling
# (i.e., greedy sampling or beam search).
# (i.e., greedy sampling or beam search).
# Set the temperature to 1 to avoid division by zero.
# Set the temperature to 1 to avoid division by zero.
...
@@ -286,7 +287,7 @@ def _sample_from_prompt(
...
@@ -286,7 +287,7 @@ def _sample_from_prompt(
beam_width
=
sampling_params
.
best_of
beam_width
=
sampling_params
.
best_of
_
,
next_token_ids
=
torch
.
topk
(
prob
,
beam_width
)
_
,
next_token_ids
=
torch
.
topk
(
prob
,
beam_width
)
next_token_ids
=
next_token_ids
.
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
elif
sampling_params
.
temperature
==
0.0
:
elif
sampling_params
.
temperature
<
_SAMPLING_EPS
:
# Greedy sampling.
# Greedy sampling.
assert
sampling_params
.
best_of
==
1
assert
sampling_params
.
best_of
==
1
next_token_id
=
torch
.
argmax
(
prob
)
next_token_id
=
torch
.
argmax
(
prob
)
...
@@ -343,7 +344,7 @@ def _sample_from_generation_tokens(
...
@@ -343,7 +344,7 @@ def _sample_from_generation_tokens(
parent_seq_ids
=
[
beam_outputs
[
seq_id
][
0
]
for
seq_id
in
seq_ids
]
parent_seq_ids
=
[
beam_outputs
[
seq_id
][
0
]
for
seq_id
in
seq_ids
]
next_token_ids
=
[
beam_outputs
[
seq_id
][
1
]
for
seq_id
in
seq_ids
]
next_token_ids
=
[
beam_outputs
[
seq_id
][
1
]
for
seq_id
in
seq_ids
]
elif
sampling_params
.
temperature
==
0.0
:
elif
sampling_params
.
temperature
<
_SAMPLING_EPS
:
# Greedy sampling.
# Greedy sampling.
assert
len
(
seq_ids
)
==
1
assert
len
(
seq_ids
)
==
1
next_token_id
=
torch
.
argmax
(
probs
,
dim
=-
1
)
next_token_id
=
torch
.
argmax
(
probs
,
dim
=-
1
)
...
...
vllm/sampling_params.py
View file @
425040d4
"""Sampling parameters for text generation."""
"""Sampling parameters for text generation."""
from
typing
import
List
,
Optional
,
Union
from
typing
import
List
,
Optional
,
Union
_SAMPLING_EPS
=
1e-5
class
SamplingParams
:
class
SamplingParams
:
"""Sampling parameters for text generation.
"""Sampling parameters for text generation.
...
@@ -71,7 +72,7 @@ class SamplingParams:
...
@@ -71,7 +72,7 @@ class SamplingParams:
self
.
_verify_args
()
self
.
_verify_args
()
if
self
.
use_beam_search
:
if
self
.
use_beam_search
:
self
.
_verity_beam_search
()
self
.
_verity_beam_search
()
elif
self
.
temperature
==
0.0
:
elif
self
.
temperature
<
_SAMPLING_EPS
:
# Zero temperature means greedy sampling.
# Zero temperature means greedy sampling.
self
.
_verify_greedy_sampling
()
self
.
_verify_greedy_sampling
()
...
@@ -106,9 +107,9 @@ class SamplingParams:
...
@@ -106,9 +107,9 @@ class SamplingParams:
if
self
.
best_of
==
1
:
if
self
.
best_of
==
1
:
raise
ValueError
(
"best_of must be greater than 1 when using beam "
raise
ValueError
(
"best_of must be greater than 1 when using beam "
f
"search. Got
{
self
.
best_of
}
."
)
f
"search. Got
{
self
.
best_of
}
."
)
if
self
.
temperature
>
0.0
:
if
self
.
temperature
>
_SAMPLING_EPS
:
raise
ValueError
(
"temperature must be 0 when using beam search."
)
raise
ValueError
(
"temperature must be 0 when using beam search."
)
if
self
.
top_p
<
1.0
:
if
self
.
top_p
<
1.0
-
_SAMPLING_EPS
:
raise
ValueError
(
"top_p must be 1 when using beam search."
)
raise
ValueError
(
"top_p must be 1 when using beam search."
)
if
self
.
top_k
!=
-
1
:
if
self
.
top_k
!=
-
1
:
raise
ValueError
(
"top_k must be -1 when using beam search."
)
raise
ValueError
(
"top_k must be -1 when using beam search."
)
...
@@ -117,7 +118,7 @@ class SamplingParams:
...
@@ -117,7 +118,7 @@ class SamplingParams:
if
self
.
best_of
>
1
:
if
self
.
best_of
>
1
:
raise
ValueError
(
"best_of must be 1 when using greedy sampling."
raise
ValueError
(
"best_of must be 1 when using greedy sampling."
f
"Got
{
self
.
best_of
}
."
)
f
"Got
{
self
.
best_of
}
."
)
if
self
.
top_p
<
1.0
:
if
self
.
top_p
<
1.0
-
_SAMPLING_EPS
:
raise
ValueError
(
"top_p must be 1 when using greedy sampling."
)
raise
ValueError
(
"top_p must be 1 when using greedy sampling."
)
if
self
.
top_k
!=
-
1
:
if
self
.
top_k
!=
-
1
:
raise
ValueError
(
"top_k must be -1 when using greedy sampling."
)
raise
ValueError
(
"top_k must be -1 when using greedy sampling."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment