Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
e87557b0
Unverified
Commit
e87557b0
authored
Nov 18, 2023
by
Roy
Committed by
GitHub
Nov 17, 2023
Browse files
Support Min P Sampler (#1642)
parent
dcc543a2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
40 additions
and
4 deletions
+40
-4
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+31
-4
vllm/sampling_params.py
vllm/sampling_params.py
+9
-0
No files found.
vllm/model_executor/layers/sampler.py
View file @
e87557b0
...
@@ -71,13 +71,18 @@ class Sampler(nn.Module):
...
@@ -71,13 +71,18 @@ class Sampler(nn.Module):
logits
.
div_
(
t
.
unsqueeze
(
dim
=
1
))
logits
.
div_
(
t
.
unsqueeze
(
dim
=
1
))
# Apply top-p and top-k truncation.
# Apply top-p and top-k truncation.
top_ps
,
top_ks
=
_get_top_p_top_k
(
input_metadata
,
self
.
vocab_size
)
top_ps
,
top_ks
,
min_ps
=
_get_top_p_top_k_min_p
(
input_metadata
,
self
.
vocab_size
)
assert
len
(
top_ps
)
==
len
(
top_ks
)
==
logits
.
shape
[
0
]
assert
len
(
top_ps
)
==
len
(
top_ks
)
==
logits
.
shape
[
0
]
do_top_p
=
any
(
p
<
1.0
-
_SAMPLING_EPS
for
p
in
top_ps
)
do_top_p
=
any
(
p
<
1.0
-
_SAMPLING_EPS
for
p
in
top_ps
)
do_top_k
=
any
(
k
!=
self
.
vocab_size
for
k
in
top_ks
)
do_top_k
=
any
(
k
!=
self
.
vocab_size
for
k
in
top_ks
)
if
do_top_p
or
do_top_k
:
if
do_top_p
or
do_top_k
:
logits
=
_apply_top_p_top_k
(
logits
,
top_ps
,
top_ks
)
logits
=
_apply_top_p_top_k
(
logits
,
top_ps
,
top_ks
)
do_min_p
=
any
(
mp
>
_SAMPLING_EPS
for
mp
in
min_ps
)
if
do_min_p
:
logits
=
_apply_min_p
(
logits
,
min_ps
)
# We use float32 for probabilities and log probabilities.
# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
# Compute the probabilities.
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float
)
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float
)
...
@@ -261,15 +266,17 @@ def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
...
@@ -261,15 +266,17 @@ def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
return
temperatures
return
temperatures
def
_get_top_p_top_k
(
def
_get_top_p_top_k
_min_p
(
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
vocab_size
:
int
,
vocab_size
:
int
,
)
->
Tuple
[
List
[
float
],
List
[
int
]]:
)
->
Tuple
[
List
[
float
],
List
[
int
]
,
List
[
float
]
]:
top_ps
:
List
[
float
]
=
[]
top_ps
:
List
[
float
]
=
[]
top_ks
:
List
[
int
]
=
[]
top_ks
:
List
[
int
]
=
[]
min_ps
:
List
[
float
]
=
[]
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
seq_ids
,
sampling_params
=
seq_group
top_p
=
sampling_params
.
top_p
top_p
=
sampling_params
.
top_p
min_p
=
sampling_params
.
min_p
# k should not be greater than the vocab size.
# k should not be greater than the vocab size.
top_k
=
min
(
sampling_params
.
top_k
,
vocab_size
)
top_k
=
min
(
sampling_params
.
top_k
,
vocab_size
)
# k=-1 means no truncation.
# k=-1 means no truncation.
...
@@ -279,9 +286,11 @@ def _get_top_p_top_k(
...
@@ -279,9 +286,11 @@ def _get_top_p_top_k(
prompt_len
=
input_metadata
.
prompt_lens
[
i
]
prompt_len
=
input_metadata
.
prompt_lens
[
i
]
top_ps
+=
[
top_p
]
*
(
prompt_len
-
1
)
top_ps
+=
[
top_p
]
*
(
prompt_len
-
1
)
top_ks
+=
[
top_k
]
*
(
prompt_len
-
1
)
top_ks
+=
[
top_k
]
*
(
prompt_len
-
1
)
min_ps
+=
[
min_p
]
*
(
prompt_len
-
1
)
top_ps
+=
[
top_p
]
*
len
(
seq_ids
)
top_ps
+=
[
top_p
]
*
len
(
seq_ids
)
top_ks
+=
[
top_k
]
*
len
(
seq_ids
)
top_ks
+=
[
top_k
]
*
len
(
seq_ids
)
return
top_ps
,
top_ks
min_ps
+=
[
min_p
]
*
len
(
seq_ids
)
return
top_ps
,
top_ks
,
min_ps
def
_apply_top_p_top_k
(
def
_apply_top_p_top_k
(
...
@@ -313,6 +322,24 @@ def _apply_top_p_top_k(
...
@@ -313,6 +322,24 @@ def _apply_top_p_top_k(
return
logits
return
logits
def
_apply_min_p
(
logits
:
torch
.
Tensor
,
min_ps
:
List
[
float
],
)
->
torch
.
Tensor
:
"""
Adapted from
https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
"""
min_p
=
torch
.
tensor
(
min_ps
,
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
top_probs
,
_
=
probs
.
max
(
dim
=-
1
,
keepdim
=
True
)
scaled_min_p
=
min_p
.
unsqueeze
(
dim
=
1
)
*
top_probs
tokens_to_remove
=
probs
<
scaled_min_p
logits
=
logits
.
masked_fill
(
tokens_to_remove
,
-
float
(
"inf"
))
return
logits
def
_greedy_sample
(
def
_greedy_sample
(
selected_seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]],
selected_seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]],
logprobs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
...
...
vllm/sampling_params.py
View file @
e87557b0
...
@@ -52,6 +52,9 @@ class SamplingParams:
...
@@ -52,6 +52,9 @@ class SamplingParams:
to consider. Must be in (0, 1]. Set to 1 to consider all tokens.
to consider. Must be in (0, 1]. Set to 1 to consider all tokens.
top_k: Integer that controls the number of top tokens to consider. Set
top_k: Integer that controls the number of top tokens to consider. Set
to -1 to consider all tokens.
to -1 to consider all tokens.
min_p: Float that represents the minimum probability for a token to be
considered, relative to the probability of the most likely token.
Must be in [0, 1]. Set to 0 to disable this.
use_beam_search: Whether to use beam search instead of sampling.
use_beam_search: Whether to use beam search instead of sampling.
length_penalty: Float that penalizes sequences based on their length.
length_penalty: Float that penalizes sequences based on their length.
Used in beam search.
Used in beam search.
...
@@ -94,6 +97,7 @@ class SamplingParams:
...
@@ -94,6 +97,7 @@ class SamplingParams:
temperature
:
float
=
1.0
,
temperature
:
float
=
1.0
,
top_p
:
float
=
1.0
,
top_p
:
float
=
1.0
,
top_k
:
int
=
-
1
,
top_k
:
int
=
-
1
,
min_p
:
int
=
0.0
,
use_beam_search
:
bool
=
False
,
use_beam_search
:
bool
=
False
,
length_penalty
:
float
=
1.0
,
length_penalty
:
float
=
1.0
,
early_stopping
:
Union
[
bool
,
str
]
=
False
,
early_stopping
:
Union
[
bool
,
str
]
=
False
,
...
@@ -115,6 +119,7 @@ class SamplingParams:
...
@@ -115,6 +119,7 @@ class SamplingParams:
self
.
temperature
=
temperature
self
.
temperature
=
temperature
self
.
top_p
=
top_p
self
.
top_p
=
top_p
self
.
top_k
=
top_k
self
.
top_k
=
top_k
self
.
min_p
=
min_p
self
.
use_beam_search
=
use_beam_search
self
.
use_beam_search
=
use_beam_search
self
.
length_penalty
=
length_penalty
self
.
length_penalty
=
length_penalty
self
.
early_stopping
=
early_stopping
self
.
early_stopping
=
early_stopping
...
@@ -167,6 +172,9 @@ class SamplingParams:
...
@@ -167,6 +172,9 @@ class SamplingParams:
if
self
.
top_k
<
-
1
or
self
.
top_k
==
0
:
if
self
.
top_k
<
-
1
or
self
.
top_k
==
0
:
raise
ValueError
(
f
"top_k must be -1 (disable), or at least 1, "
raise
ValueError
(
f
"top_k must be -1 (disable), or at least 1, "
f
"got
{
self
.
top_k
}
."
)
f
"got
{
self
.
top_k
}
."
)
if
not
0.0
<=
self
.
min_p
<=
1.0
:
raise
ValueError
(
"min_p must be in [0, 1], got "
f
"
{
self
.
min_p
}
."
)
if
self
.
max_tokens
<
1
:
if
self
.
max_tokens
<
1
:
raise
ValueError
(
raise
ValueError
(
f
"max_tokens must be at least 1, got
{
self
.
max_tokens
}
."
)
f
"max_tokens must be at least 1, got
{
self
.
max_tokens
}
."
)
...
@@ -228,6 +236,7 @@ class SamplingParams:
...
@@ -228,6 +236,7 @@ class SamplingParams:
f
"temperature=
{
self
.
temperature
}
, "
f
"temperature=
{
self
.
temperature
}
, "
f
"top_p=
{
self
.
top_p
}
, "
f
"top_p=
{
self
.
top_p
}
, "
f
"top_k=
{
self
.
top_k
}
, "
f
"top_k=
{
self
.
top_k
}
, "
f
"min_p=
{
self
.
min_p
}
, "
f
"use_beam_search=
{
self
.
use_beam_search
}
, "
f
"use_beam_search=
{
self
.
use_beam_search
}
, "
f
"length_penalty=
{
self
.
length_penalty
}
, "
f
"length_penalty=
{
self
.
length_penalty
}
, "
f
"early_stopping=
{
self
.
early_stopping
}
, "
f
"early_stopping=
{
self
.
early_stopping
}
, "
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment