Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
9f88db35
Unverified
Commit
9f88db35
authored
May 10, 2023
by
Woosuk Kwon
Committed by
GitHub
May 10, 2023
Browse files
Support top-k sampling (#94)
parent
ae356774
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
78 additions
and
43 deletions
+78
-43
cacheflow/model_executor/layers/sampler.py
cacheflow/model_executor/layers/sampler.py
+37
-16
cacheflow/sampling_params.py
cacheflow/sampling_params.py
+38
-25
simple_server.py
simple_server.py
+3
-2
No files found.
cacheflow/model_executor/layers/sampler.py
View file @
9f88db35
...
@@ -46,12 +46,13 @@ class Sampler(nn.Module):
...
@@ -46,12 +46,13 @@ class Sampler(nn.Module):
# Compute the log probabilities (before applying top-p).
# Compute the log probabilities (before applying top-p).
logprobs
=
torch
.
log
(
probs
)
logprobs
=
torch
.
log
(
probs
)
# Apply top-p truncation.
# Apply top-p
and top-k
truncation.
top_ps
=
_get_top_p
s
(
input_metadata
)
top_ps
,
top_ks
=
_get_top_p
_top_k
(
input_metadata
,
self
.
vocab_size
)
assert
len
(
top_ps
)
==
probs
.
shape
[
0
]
assert
len
(
top_ps
)
==
len
(
top_ks
)
==
probs
.
shape
[
0
]
if
any
(
p
<
1.0
for
p
in
top_ps
):
if
any
(
p
<
1.0
for
p
in
top_ps
)
or
any
(
k
!=
-
1
for
k
in
top_ks
)
:
p
=
torch
.
tensor
(
top_ps
,
dtype
=
probs
.
dtype
,
device
=
probs
.
device
)
p
=
torch
.
tensor
(
top_ps
,
dtype
=
probs
.
dtype
,
device
=
probs
.
device
)
probs
=
_apply_top_p
(
probs
,
p
)
k
=
torch
.
tensor
(
top_ks
,
dtype
=
torch
.
int
,
device
=
probs
.
device
)
probs
=
_apply_top_p_top_k
(
probs
,
p
,
k
)
# Sample the next tokens.
# Sample the next tokens.
return
_sample
(
probs
,
logprobs
,
input_metadata
)
return
_sample
(
probs
,
logprobs
,
input_metadata
)
...
@@ -94,31 +95,51 @@ def _get_temperatures(
...
@@ -94,31 +95,51 @@ def _get_temperatures(
return
temperatures
return
temperatures
def
_get_top_p
s
(
def
_get_top_p
_top_k
(
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
)
->
List
[
float
]:
vocab_size
:
int
,
)
->
Tuple
[
List
[
float
],
List
[
int
]]:
top_ps
:
List
[
float
]
=
[]
top_ps
:
List
[
float
]
=
[]
top_ks
:
List
[
int
]
=
[]
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
seq_ids
,
sampling_params
=
seq_group
top_p
=
sampling_params
.
top_p
# k should not be greater than the vocab size.
top_k
=
min
(
sampling_params
.
top_k
,
vocab_size
)
# k=-1 means no truncation.
top_k
=
vocab_size
if
top_k
==
-
1
else
top_k
if
i
<
input_metadata
.
num_prompts
:
if
i
<
input_metadata
.
num_prompts
:
# A prompt input.
# A prompt input.
top_ps
.
append
(
sampling_params
.
top_p
)
top_ps
.
append
(
top_p
)
top_ks
.
append
(
top_k
)
else
:
else
:
# A generation token.
# A generation token.
top_ps
+=
[
sampling_params
.
top_p
]
*
len
(
seq_ids
)
top_ps
+=
[
top_p
]
*
len
(
seq_ids
)
return
top_ps
top_ks
+=
[
top_k
]
*
len
(
seq_ids
)
return
top_ps
,
top_ks
def
_apply_top_p
(
def
_apply_top_p
_top_k
(
probs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# TODO(woosuk): Optimize.
# TODO(woosuk): Optimize.
probs_sort
,
probs_idx
=
probs
.
sort
(
dim
=-
1
,
descending
=
True
)
probs_sort
,
probs_idx
=
probs
.
sort
(
dim
=-
1
,
descending
=
True
)
# Apply top-p.
probs_sum
=
torch
.
cumsum
(
probs_sort
,
dim
=-
1
)
probs_sum
=
torch
.
cumsum
(
probs_sort
,
dim
=-
1
)
mask
=
(
probs_sum
-
probs_sort
)
>
p
.
unsqueeze
(
dim
=
1
)
top_p_mask
=
(
probs_sum
-
probs_sort
)
>
p
.
unsqueeze
(
dim
=
1
)
probs_sort
[
mask
]
=
0.0
probs_sort
[
top_p_mask
]
=
0.0
probs_sort
.
div_
(
probs_sort
.
sum
(
dim
=-
1
,
keepdim
=
True
))
# Apply top-k.
# Create a mask for the top-k elements.
top_k_mask
=
torch
.
arange
(
probs_idx
.
shape
[
-
1
],
device
=
probs_idx
.
device
)
top_k_mask
=
top_k_mask
.
expand
(
probs_idx
.
shape
[
0
],
-
1
)
top_k_mask
=
top_k_mask
>=
k
.
unsqueeze
(
dim
=
1
)
probs_sort
[
top_k_mask
]
=
0.0
# Re-sort the probabilities.
probs
=
torch
.
gather
(
probs
=
torch
.
gather
(
probs_sort
,
dim
=-
1
,
index
=
torch
.
argsort
(
probs_idx
,
dim
=-
1
))
probs_sort
,
dim
=-
1
,
index
=
torch
.
argsort
(
probs_idx
,
dim
=-
1
))
return
probs
return
probs
...
@@ -160,7 +181,7 @@ def _sample_from_prompt(
...
@@ -160,7 +181,7 @@ def _sample_from_prompt(
next_token_id
=
torch
.
argmax
(
prob
)
next_token_id
=
torch
.
argmax
(
prob
)
next_token_ids
=
[
next_token_id
.
item
()]
next_token_ids
=
[
next_token_id
.
item
()]
else
:
else
:
#
Neucleus
sampling.
#
Random
sampling.
# Sample n tokens for the prompt.
# Sample n tokens for the prompt.
n
=
sampling_params
.
n
n
=
sampling_params
.
n
next_token_ids
=
torch
.
multinomial
(
next_token_ids
=
torch
.
multinomial
(
...
@@ -218,7 +239,7 @@ def _sample_from_generation_tokens(
...
@@ -218,7 +239,7 @@ def _sample_from_generation_tokens(
next_token_ids
=
[
next_token_id
.
item
()]
next_token_ids
=
[
next_token_id
.
item
()]
parent_seq_ids
=
seq_ids
parent_seq_ids
=
seq_ids
else
:
else
:
#
Neucleus
sampling.
#
Random
sampling.
# Sample 1 token for each sequence in the group.
# Sample 1 token for each sequence in the group.
next_token_ids
=
torch
.
multinomial
(
next_token_ids
=
torch
.
multinomial
(
probs
,
num_samples
=
1
,
replacement
=
True
)
probs
,
num_samples
=
1
,
replacement
=
True
)
...
...
cacheflow/sampling_params.py
View file @
9f88db35
...
@@ -8,69 +8,82 @@ class SamplingParams:
...
@@ -8,69 +8,82 @@ class SamplingParams:
n
:
int
,
n
:
int
,
temperature
:
float
,
temperature
:
float
,
top_p
:
float
,
top_p
:
float
,
top_k
:
int
,
use_beam_search
:
bool
,
use_beam_search
:
bool
,
stop_token_ids
:
Set
[
int
],
stop_token_ids
:
Set
[
int
],
max_num_steps
:
int
,
max_num_steps
:
int
,
num_logprobs
:
int
,
num_logprobs
:
int
,
)
->
None
:
)
->
None
:
if
n
<
1
:
if
n
<
1
:
raise
ValueError
(
f
'
n must be at least 1, got
{
n
}
.
'
)
raise
ValueError
(
f
"
n must be at least 1, got
{
n
}
.
"
)
if
temperature
<
0.0
:
if
temperature
<
0.0
:
raise
ValueError
(
raise
ValueError
(
f
'
temperature must be non-negative, got
{
temperature
}
.
'
)
f
"
temperature must be non-negative, got
{
temperature
}
.
"
)
if
not
0.0
<
top_p
<=
1.0
:
if
not
0.0
<
top_p
<=
1.0
:
raise
ValueError
(
f
'top_p must be in (0, 1], got
{
top_p
}
.'
)
raise
ValueError
(
f
"top_p must be in (0, 1], got
{
top_p
}
."
)
if
top_k
<
-
1
or
top_k
==
0
:
raise
ValueError
(
f
"top_k must be -1 (disable), or at least 1, "
f
"got
{
top_k
}
."
)
if
max_num_steps
<
1
:
if
max_num_steps
<
1
:
raise
ValueError
(
raise
ValueError
(
f
'
max_num_steps must be at least 1, got
{
max_num_steps
}
.
'
)
f
"
max_num_steps must be at least 1, got
{
max_num_steps
}
.
"
)
if
num_logprobs
<
0
:
if
num_logprobs
<
0
:
raise
ValueError
(
raise
ValueError
(
f
'
num_logprobs must be non-negative, got
{
num_logprobs
}
.
'
)
f
"
num_logprobs must be non-negative, got
{
num_logprobs
}
.
"
)
if
use_beam_search
:
if
use_beam_search
:
if
n
==
1
:
if
n
==
1
:
raise
ValueError
(
raise
ValueError
(
'
n must be greater than 1 when using beam search.
'
)
"
n must be greater than 1 when using beam search.
"
)
if
temperature
>
0.0
:
if
temperature
>
0.0
:
raise
ValueError
(
raise
ValueError
(
'
temperature must be 0 when using beam search.
'
)
"
temperature must be 0 when using beam search.
"
)
if
top_p
<
1.0
:
if
top_p
<
1.0
:
raise
ValueError
(
raise
ValueError
(
'top_p must be 1 when using beam search.'
)
"top_p must be 1 when using beam search."
)
if
top_k
!=
-
1
:
raise
ValueError
(
"top_k must be -1 when using beam search."
)
elif
temperature
==
0.0
:
elif
temperature
==
0.0
:
# Zero temperature means greedy sampling.
# Zero temperature means greedy sampling.
if
n
>
1
:
if
n
>
1
:
raise
ValueError
(
raise
ValueError
(
'
n must be 1 when using greedy sampling.
'
)
"
n must be 1 when using greedy sampling.
"
)
if
top_p
<
1.0
:
if
top_p
<
1.0
:
raise
ValueError
(
raise
ValueError
(
'top_p must be 1 when using greedy sampling.'
)
"top_p must be 1 when using greedy sampling."
)
if
top_k
!=
-
1
:
raise
ValueError
(
"top_k must be -1 when using greedy sampling."
)
self
.
n
=
n
self
.
n
=
n
self
.
temperature
=
temperature
self
.
temperature
=
temperature
self
.
top_p
=
top_p
self
.
top_p
=
top_p
self
.
top_k
=
top_k
self
.
use_beam_search
=
use_beam_search
self
.
use_beam_search
=
use_beam_search
self
.
stop_token_ids
=
stop_token_ids
self
.
stop_token_ids
=
stop_token_ids
self
.
max_num_steps
=
max_num_steps
self
.
max_num_steps
=
max_num_steps
self
.
num_logprobs
=
num_logprobs
self
.
num_logprobs
=
num_logprobs
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
'SamplingParams(n=
{
self
.
n
}
, '
return
(
f
"SamplingParams(n=
{
self
.
n
}
, "
f
'temperature=
{
self
.
temperature
}
, '
f
"temperature=
{
self
.
temperature
}
, "
f
'top_p=
{
self
.
top_p
}
, '
f
"top_p=
{
self
.
top_p
}
, "
f
'use_beam_search=
{
self
.
use_beam_search
}
, '
f
"top_k=
{
self
.
top_k
}
,"
f
'stop_token_ids=
{
self
.
stop_token_ids
}
, '
f
"use_beam_search=
{
self
.
use_beam_search
}
, "
f
'max_num_steps=
{
self
.
max_num_steps
}
, '
f
"stop_token_ids=
{
self
.
stop_token_ids
}
, "
f
'num_logprobs=
{
self
.
num_logprobs
}
'
)
f
"max_num_steps=
{
self
.
max_num_steps
}
, "
f
"num_logprobs=
{
self
.
num_logprobs
}
"
)
@
classmethod
@
classmethod
def
from_dict
(
cls
,
d
:
Dict
)
->
'
SamplingParams
'
:
def
from_dict
(
cls
,
d
:
Dict
)
->
"
SamplingParams
"
:
return
cls
(
return
cls
(
n
=
d
.
get
(
'n'
,
1
),
n
=
d
.
get
(
"n"
,
1
),
temperature
=
d
.
get
(
'temperature'
,
1.0
),
temperature
=
d
.
get
(
"temperature"
,
1.0
),
top_p
=
d
.
get
(
'top_p'
,
1.0
),
top_p
=
d
.
get
(
"top_p"
,
1.0
),
use_beam_search
=
d
.
get
(
'use_beam_search'
,
False
),
top_k
=
d
.
get
(
"top_k"
,
-
1
),
stop_token_ids
=
set
(
d
.
get
(
'stop_token_ids'
,
set
())),
use_beam_search
=
d
.
get
(
"use_beam_search"
,
False
),
max_num_steps
=
d
.
get
(
'max_num_steps'
,
16
),
stop_token_ids
=
set
(
d
.
get
(
"stop_token_ids"
,
set
())),
num_logprobs
=
d
.
get
(
'num_logprobs'
,
0
),
max_num_steps
=
d
.
get
(
"max_num_steps"
,
16
),
num_logprobs
=
d
.
get
(
"num_logprobs"
,
0
),
)
)
simple_server.py
View file @
9f88db35
...
@@ -11,8 +11,9 @@ def main(args: argparse.Namespace):
...
@@ -11,8 +11,9 @@ def main(args: argparse.Namespace):
# Test the following inputs.
# Test the following inputs.
test_inputs
=
[
test_inputs
=
[
(
"A robot may not injure a human being"
,
{}),
# Use default parameters.
(
"A robot may not injure a human being"
,
{}),
# Use default parameters.
(
"What is the meaning of life?"
,
{
"n"
:
3
,
"temperature"
:
0.8
,
"top_p"
:
0.99
}),
(
"To be or not to be,"
,
{
"temperature"
:
0.8
,
"top_k"
:
5
}),
(
"It is only with the heart that one can see rightly"
,
{
"n"
:
4
,
"use_beam_search"
:
True
,
"temperature"
:
0.0
}),
(
"What is the meaning of life?"
,
{
"n"
:
2
,
"temperature"
:
0.8
,
"top_p"
:
0.95
}),
(
"It is only with the heart that one can see rightly"
,
{
"n"
:
3
,
"use_beam_search"
:
True
,
"temperature"
:
0.0
}),
]
]
while
True
:
while
True
:
if
test_inputs
:
if
test_inputs
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment