Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
947b7941
Unverified
Commit
947b7941
authored
Sep 22, 2023
by
Zhuohan Li
Committed by
GitHub
Sep 22, 2023
Browse files
[Sampler] Vectorized sampling (simplified) (#1048)
Co-authored-by:
Antoni Baum
<
antoni.baum@protonmail.com
>
parent
8d926e91
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
481 additions
and
180 deletions
+481
-180
tests/samplers/test_sampler.py
tests/samplers/test_sampler.py
+184
-0
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+281
-180
vllm/sampling_params.py
vllm/sampling_params.py
+16
-0
No files found.
tests/samplers/test_sampler.py
0 → 100644
View file @
947b7941
import
pytest
import
random
from
typing
import
Tuple
from
unittest.mock
import
patch
import
torch
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.worker.worker
import
Worker
class
MockLogitsSampler
(
Sampler
):
def
__init__
(
self
,
vocab_size
:
int
,
fake_logits
:
torch
.
Tensor
):
super
().
__init__
(
vocab_size
=
vocab_size
)
self
.
fake_logits
=
fake_logits
def
forward
(
self
,
*
args
,
**
kwargs
):
with
patch
(
"vllm.model_executor.layers.sampler._prune_hidden_states"
,
lambda
x
,
y
:
x
):
with
patch
(
"vllm.model_executor.layers.sampler._get_logits"
,
lambda
*
args
,
**
kwargs
:
self
.
fake_logits
):
return
super
().
forward
(
*
args
,
**
kwargs
)
def
_prepare_test
(
batch_size
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
MockLogitsSampler
,
Worker
]:
vocab_size
=
32000
input_tensor
=
torch
.
rand
((
batch_size
,
1024
),
device
=
"cuda"
,
dtype
=
torch
.
float16
)
fake_logits
=
torch
.
full
((
batch_size
,
vocab_size
),
1e-2
,
device
=
input_tensor
.
device
,
dtype
=
input_tensor
.
dtype
)
sampler
=
MockLogitsSampler
(
32000
,
fake_logits
)
worker
=
Worker
(
None
,
None
,
None
)
worker
.
block_size
=
16
return
input_tensor
,
fake_logits
,
sampler
,
worker
RANDOM_SEEDS
=
list
(
range
(
128
))
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
def
test_sampler_all_greedy
(
seed
:
int
):
set_random_seed
(
seed
)
batch_size
=
random
.
randint
(
1
,
256
)
input_tensor
,
fake_logits
,
sampler
,
worker
=
_prepare_test
(
batch_size
)
seq_group_metadata_list
=
[]
for
i
in
range
(
batch_size
):
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
([
1
,
2
,
3
])},
sampling_params
=
SamplingParams
(
temperature
=
0
,
),
block_tables
=
{
0
:
[
1
]},
))
_
,
_
,
input_metadata
=
worker
.
_prepare_inputs
(
seq_group_metadata_list
)
sampler_output
=
sampler
(
embedding
=
None
,
hidden_states
=
input_tensor
,
input_metadata
=
input_metadata
)
expected
=
torch
.
argmax
(
fake_logits
,
dim
=-
1
)
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
nth_output
in
sequence_output
:
assert
nth_output
.
output_token
==
expected
[
i
].
item
()
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
def
test_sampler_all_random
(
seed
:
int
):
set_random_seed
(
seed
)
batch_size
=
random
.
randint
(
1
,
256
)
input_tensor
,
fake_logits
,
sampler
,
worker
=
_prepare_test
(
batch_size
)
for
i
in
range
(
batch_size
):
fake_logits
[
i
,
i
]
=
1e2
seq_group_metadata_list
=
[]
for
i
in
range
(
batch_size
):
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
([
1
,
2
,
3
])},
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
n
=
random
.
randint
(
1
,
10
),
),
block_tables
=
{
0
:
[
1
]},
))
_
,
_
,
input_metadata
=
worker
.
_prepare_inputs
(
seq_group_metadata_list
)
sampler_output
=
sampler
(
embedding
=
None
,
hidden_states
=
input_tensor
,
input_metadata
=
input_metadata
)
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
nth_output
in
sequence_output
:
assert
nth_output
.
output_token
==
i
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
def
test_sampler_all_beam
(
seed
:
int
):
set_random_seed
(
seed
)
batch_size
=
random
.
randint
(
1
,
256
)
input_tensor
,
fake_logits
,
sampler
,
worker
=
_prepare_test
(
batch_size
)
seq_group_metadata_list
=
[]
for
i
in
range
(
batch_size
):
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
([
1
,
2
,
3
])},
sampling_params
=
SamplingParams
(
temperature
=
0
,
best_of
=
2
,
use_beam_search
=
True
,
),
block_tables
=
{
0
:
[
1
]},
))
_
,
_
,
input_metadata
=
worker
.
_prepare_inputs
(
seq_group_metadata_list
)
sampler
(
embedding
=
None
,
hidden_states
=
input_tensor
,
input_metadata
=
input_metadata
)
# no assertion here as I am not sure how to determine whether
# the outputs are expected - in other words, this just tests
# whether there are no exceptions in the sampler
# when handling an all-beam search case.
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
def
test_sampler_mixed
(
seed
:
int
):
set_random_seed
(
seed
)
batch_size
=
random
.
randint
(
1
,
256
)
input_tensor
,
fake_logits
,
sampler
,
worker
=
_prepare_test
(
batch_size
)
seq_group_metadata_list
=
[]
expected_tokens
=
[]
for
i
in
range
(
batch_size
):
n
=
1
sampling_type
=
random
.
randint
(
0
,
2
)
if
sampling_type
==
0
:
sampling_params
=
SamplingParams
(
temperature
=
0
)
elif
sampling_type
==
1
:
n
=
random
.
randint
(
1
,
10
)
sampling_params
=
SamplingParams
(
temperature
=
random
.
random
()
+
0.1
,
top_p
=
min
(
random
.
random
()
+
0.1
,
1
),
top_k
=
random
.
randint
(
0
,
10
)
or
-
1
,
n
=
n
,
presence_penalty
=
random
.
randint
(
0
,
1
),
)
else
:
sampling_params
=
SamplingParams
(
temperature
=
0
,
use_beam_search
=
True
,
best_of
=
2
)
for
idx
in
range
(
n
):
fake_logits
[
i
,
i
+
idx
]
=
1e2
expected_tokens
.
append
(
i
+
idx
)
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
([
1
,
2
,
3
])},
sampling_params
=
sampling_params
,
block_tables
=
{
0
:
[
1
]},
))
_
,
_
,
input_metadata
=
worker
.
_prepare_inputs
(
seq_group_metadata_list
)
sampler_output
=
sampler
(
embedding
=
None
,
hidden_states
=
input_tensor
,
input_metadata
=
input_metadata
)
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
if
seq_group_metadata_list
[
i
].
sampling_params
.
use_beam_search
:
continue
for
nth_output
in
sequence_output
:
assert
nth_output
.
output_token
in
expected_tokens
vllm/model_executor/layers/sampler.py
View file @
947b7941
"""A layer that samples the next tokens from the model's outputs."""
"""A layer that samples the next tokens from the model's outputs."""
from
typing
import
Dict
,
List
,
Tuple
,
Optional
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
gather_from_tensor_model_parallel_region
)
gather_from_tensor_model_parallel_region
)
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
SamplerOutput
,
SequenceOutputs
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceOutputs
_SAMPLING_EPS
=
1e-5
_SAMPLING_EPS
=
1e-5
...
@@ -44,12 +43,8 @@ class Sampler(nn.Module):
...
@@ -44,12 +43,8 @@ class Sampler(nn.Module):
hidden_states
=
_prune_hidden_states
(
hidden_states
,
input_metadata
)
hidden_states
=
_prune_hidden_states
(
hidden_states
,
input_metadata
)
# Get the logits for the next tokens.
# Get the logits for the next tokens.
logits
=
torch
.
matmul
(
hidden_states
,
embedding
.
t
())
logits
=
_get_logits
(
hidden_states
,
embedding
,
embedding_bias
,
if
embedding_bias
is
not
None
:
self
.
vocab_size
)
logits
+=
embedding_bias
logits
=
gather_from_tensor_model_parallel_region
(
logits
)
# Remove paddings in vocab (if any).
logits
=
logits
[:,
:
self
.
vocab_size
]
# Apply presence and frequency penalties.
# Apply presence and frequency penalties.
output_tokens
=
_get_output_tokens
(
input_metadata
)
output_tokens
=
_get_output_tokens
(
input_metadata
)
...
@@ -59,7 +54,7 @@ class Sampler(nn.Module):
...
@@ -59,7 +54,7 @@ class Sampler(nn.Module):
assert
len
(
presence_penalties
)
==
logits
.
shape
[
0
]
assert
len
(
presence_penalties
)
==
logits
.
shape
[
0
]
assert
len
(
frequency_penalties
)
==
logits
.
shape
[
0
]
assert
len
(
frequency_penalties
)
==
logits
.
shape
[
0
]
logits
=
_apply_penalties
(
logits
,
output_tokens
,
presence_penalties
,
logits
=
_apply_penalties
(
logits
,
output_tokens
,
presence_penalties
,
frequency_penalties
,
self
.
vocab_size
)
frequency_penalties
)
# Apply temperature scaling.
# Apply temperature scaling.
temperatures
=
_get_temperatures
(
input_metadata
)
temperatures
=
_get_temperatures
(
input_metadata
)
...
@@ -90,19 +85,47 @@ class Sampler(nn.Module):
...
@@ -90,19 +85,47 @@ class Sampler(nn.Module):
return
_sample
(
probs
,
logprobs
,
input_metadata
)
return
_sample
(
probs
,
logprobs
,
input_metadata
)
def
_get_logits
(
hidden_states
:
torch
.
Tensor
,
embedding
:
torch
.
Tensor
,
embedding_bias
:
Optional
[
torch
.
Tensor
],
vocab_size
:
int
)
->
torch
.
Tensor
:
# Get the logits for the next tokens.
logits
=
torch
.
matmul
(
hidden_states
,
embedding
.
t
())
if
embedding_bias
is
not
None
:
logits
+=
embedding_bias
logits
=
gather_from_tensor_model_parallel_region
(
logits
)
# Remove paddings in vocab (if any).
logits
=
logits
[:,
:
vocab_size
]
return
logits
def
_prune_hidden_states
(
def
_prune_hidden_states
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
last_token_indices
=
{
t
:
[]
for
t
in
SamplingType
}
start_idx
=
0
start_idx
=
0
last_token_indicies
:
List
[
int
]
=
[]
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
for
prompt_len
in
input_metadata
.
prompt_lens
:
seq_ids
,
sampling_params
=
seq_group
last_token_indicies
.
append
(
start_idx
+
prompt_len
-
1
)
sampling_type
=
sampling_params
.
sampling_type
if
i
<
input_metadata
.
num_prompts
:
assert
len
(
seq_ids
)
==
1
,
"Prompt input should have only one seq."
prompt_len
=
input_metadata
.
prompt_lens
[
i
]
last_token_indices
[
sampling_type
].
append
(
start_idx
+
prompt_len
-
1
)
start_idx
+=
prompt_len
start_idx
+=
prompt_len
last_token_indicies
.
extend
(
else
:
range
(
start_idx
,
start_idx
+
input_metadata
.
num_generation_tokens
))
num_seqs
=
len
(
seq_ids
)
return
hidden_states
.
index_select
(
last_token_indices
[
sampling_type
].
extend
(
0
,
torch
.
tensor
(
last_token_indicies
,
device
=
hidden_states
.
device
))
range
(
start_idx
,
start_idx
+
num_seqs
))
start_idx
+=
num_seqs
all_last_token_indices
=
[]
for
sampling_type
in
SamplingType
:
all_last_token_indices
.
extend
(
last_token_indices
[
sampling_type
])
all_last_token_indices
=
torch
.
tensor
(
all_last_token_indices
,
dtype
=
torch
.
long
,
device
=
hidden_states
.
device
)
return
hidden_states
.
index_select
(
0
,
all_last_token_indices
)
def
_get_penalties
(
def
_get_penalties
(
...
@@ -149,11 +172,8 @@ def _apply_penalties(
...
@@ -149,11 +172,8 @@ def _apply_penalties(
output_tokens
:
List
[
List
[
int
]],
output_tokens
:
List
[
List
[
int
]],
presence_penalties
:
List
[
float
],
presence_penalties
:
List
[
float
],
frequency_penalties
:
List
[
float
],
frequency_penalties
:
List
[
float
],
vocab_size
:
int
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
num_seqs
=
logits
.
shape
[
0
]
num_seqs
,
vocab_size
=
logits
.
shape
# Collect the indices of sequences that have non-zero penalties.
indices
=
[]
for
i
in
range
(
num_seqs
):
for
i
in
range
(
num_seqs
):
if
not
output_tokens
[
i
]:
if
not
output_tokens
[
i
]:
continue
continue
...
@@ -161,33 +181,40 @@ def _apply_penalties(
...
@@ -161,33 +181,40 @@ def _apply_penalties(
f
=
frequency_penalties
[
i
]
f
=
frequency_penalties
[
i
]
if
abs
(
p
)
<
_SAMPLING_EPS
and
abs
(
f
)
<
_SAMPLING_EPS
:
if
abs
(
p
)
<
_SAMPLING_EPS
and
abs
(
f
)
<
_SAMPLING_EPS
:
continue
continue
indices
.
append
(
i
)
break
else
:
# Return early if all sequences have zero penalties.
# Return early if all sequences have zero penalties.
if
not
indices
:
return
logits
return
logits
bin_counts
=
[]
max_output_len
=
max
(
len
(
tokens
)
for
tokens
in
output_tokens
)
for
i
in
indices
:
padded_output_tokens
=
[
bin_counts
.
append
(
np
.
bincount
(
output_tokens
[
i
],
minlength
=
vocab_size
))
tokens
+
[
vocab_size
]
*
(
max_output_len
-
len
(
tokens
))
bin_counts
=
np
.
stack
(
bin_counts
,
axis
=
0
)
for
tokens
in
output_tokens
bin_counts
=
torch
.
from_numpy
(
bin_counts
).
to
(
dtype
=
logits
.
dtype
,
]
output_tokens_tensor
=
torch
.
tensor
(
padded_output_tokens
,
dtype
=
torch
.
long
,
device
=
logits
.
device
)
# Compute the bin counts for the output tokens.
# vocab_size + 1 for padding.
bin_counts
=
torch
.
zeros
((
num_seqs
,
vocab_size
+
1
),
dtype
=
torch
.
long
,
device
=
logits
.
device
)
device
=
logits
.
device
)
bin_counts
.
scatter_add_
(
1
,
output_tokens_tensor
,
torch
.
ones_like
(
output_tokens_tensor
))
bin_counts
=
bin_counts
[:,
:
vocab_size
]
# Remove the padding bin.
frequency_penalties
=
[
frequency_penalties
[
i
]
for
i
in
indices
]
frequency_penalties
=
torch
.
tensor
(
frequency_penalties
,
frequency_penalties
=
torch
.
tensor
(
frequency_penalties
,
dtype
=
logits
.
dtype
,
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
device
=
logits
.
device
)
presence_penalties
=
[
presence_penalties
[
i
]
for
i
in
indices
]
presence_penalties
=
torch
.
tensor
(
presence_penalties
,
presence_penalties
=
torch
.
tensor
(
presence_penalties
,
dtype
=
logits
.
dtype
,
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
device
=
logits
.
device
)
# We follow the definition in OpenAI API.
# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits
[
indices
]
-=
frequency_penalties
.
unsqueeze
(
dim
=
1
)
*
bin_counts
logits
-=
frequency_penalties
.
unsqueeze
(
dim
=
1
)
*
bin_counts
presence_mask
=
(
bin_counts
>
0.0
).
to
(
dtype
=
logits
.
dtype
)
logits
-=
presence_penalties
.
unsqueeze
(
dim
=
1
)
*
(
bin_counts
>
0
)
logits
[
indices
]
-=
presence_penalties
.
unsqueeze
(
dim
=
1
)
*
presence_mask
return
logits
return
logits
...
@@ -268,95 +295,154 @@ def _apply_top_p_top_k(
...
@@ -268,95 +295,154 @@ def _apply_top_p_top_k(
def
_get_topk_logprobs
(
def
_get_topk_logprobs
(
logprobs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
num_logprobs
:
Optional
[
int
],
num_logprobs
:
Optional
[
int
],
)
->
Dict
[
int
,
float
]:
)
->
List
[
Dict
[
int
,
float
]]:
num_seqs
=
logprobs
.
size
(
0
)
if
num_logprobs
is
None
or
num_logprobs
==
0
:
if
num_logprobs
is
None
or
num_logprobs
==
0
:
return
{}
return
[{}
for
_
in
range
(
num_seqs
)]
topk_logprobs
,
topk_ids
=
torch
.
topk
(
logprobs
,
num_logprobs
)
all_topk_logprobs
,
all_topk_ids
=
torch
.
topk
(
logprobs
,
if
num_logprobs
==
1
:
num_logprobs
,
topk_logprobs
=
[
topk_logprobs
.
item
()]
dim
=-
1
)
topk_ids
=
[
topk_ids
.
item
()]
all_topk_logprobs
=
all_topk_logprobs
.
cpu
()
else
:
all_topk_ids
=
all_topk_ids
.
cpu
()
topk_logprobs
=
topk_logprobs
.
tolist
()
all_token_to_logprob
=
[]
topk_ids
=
topk_ids
.
tolist
()
for
topk_logprobs
,
topk_ids
in
zip
(
all_topk_logprobs
,
all_topk_ids
):
token_to_logprob
:
Dict
[
int
,
float
]
=
{}
token_to_logprob
:
Dict
[
int
,
float
]
=
{}
for
token_id
,
logprob
in
zip
(
topk_ids
,
topk_logprobs
):
for
token_id
,
logprob
in
zip
(
topk_ids
,
topk_logprobs
):
token_to_logprob
[
token_id
]
=
logprob
token_to_logprob
[
token_id
.
item
()]
=
logprob
.
item
()
return
token_to_logprob
all_token_to_logprob
.
append
(
token_to_logprob
)
return
all_token_to_logprob
def
_build_sequence_outputs
(
parent_ids
:
List
[
int
],
next_token_ids
:
List
[
int
],
selected_token_logprobs
:
torch
.
Tensor
,
parent_seq_ids
:
List
[
int
],
parent_logprobs
:
torch
.
Tensor
,
num_output_logprobs
:
Optional
[
int
],
)
->
List
[
SequenceOutputs
]:
# Get top-k log probabilities for the next tokens.
next_logprobs
=
_get_topk_logprobs
(
parent_logprobs
,
num_output_logprobs
)
seq_outputs
:
List
[
SequenceOutputs
]
=
[]
for
parent_id
,
next_token_id
,
token_logprob
in
zip
(
parent_ids
,
next_token_ids
,
selected_token_logprobs
):
output_logprobs
=
next_logprobs
[
parent_id
].
copy
()
output_logprobs
[
next_token_id
]
=
token_logprob
seq_outputs
.
append
(
SequenceOutputs
(
parent_seq_ids
[
parent_id
],
next_token_id
,
output_logprobs
))
return
seq_outputs
def
_sample_from_prompt
(
def
_greedy_sample
(
prob
:
torch
.
Tensor
,
selected_seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]],
sampling_params
:
SamplingParams
,
logprobs
:
torch
.
Tensor
,
)
->
List
[
int
]:
)
->
List
[
Tuple
[
List
[
int
],
List
[
int
]]]:
if
sampling_params
.
use_beam_search
:
samples
=
torch
.
argmax
(
logprobs
,
dim
=-
1
).
cpu
()
# Beam search.
sample_idx
=
0
beam_width
=
sampling_params
.
best_of
results
=
[]
# Sample 2 * beam_width candidates to make sure that with high
for
seq_group
in
selected_seq_groups
:
seq_ids
,
_
=
seq_group
num_parent_seqs
=
len
(
seq_ids
)
assert
num_parent_seqs
==
1
,
(
"Greedy sampling should have only one seq."
)
parent_ids
=
list
(
range
(
num_parent_seqs
))
next_token_ids
=
[
samples
[
sample_idx
].
item
()]
results
.
append
((
next_token_ids
,
parent_ids
))
sample_idx
+=
num_parent_seqs
assert
sample_idx
==
logprobs
.
size
(
0
)
return
results
def
_random_sample
(
selected_seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]],
is_prompts
:
List
[
bool
],
probs
:
torch
.
Tensor
,
)
->
List
[
Tuple
[
List
[
int
],
List
[
int
]]]:
# Find the maximum best_of value of the prompt phase requests.
max_best_of
=
1
for
seq_group
,
is_prompt
in
zip
(
selected_seq_groups
,
is_prompts
):
if
is_prompt
:
seq_ids
,
sampling_params
=
seq_group
max_best_of
=
max
(
max_best_of
,
sampling_params
.
best_of
)
random_samples
=
torch
.
multinomial
(
probs
,
num_samples
=
max_best_of
,
replacement
=
True
).
cpu
()
sample_idx
=
0
results
=
[]
for
seq_group
,
is_prompt
in
zip
(
selected_seq_groups
,
is_prompts
):
seq_ids
,
sampling_params
=
seq_group
num_parent_seqs
=
len
(
seq_ids
)
if
is_prompt
:
# Prompt phase.
assert
num_parent_seqs
==
1
,
(
"Prompt input should have only one seq."
)
parent_ids
=
[
0
]
*
sampling_params
.
best_of
next_token_ids
=
random_samples
[
sample_idx
,
:
sampling_params
.
best_of
].
tolist
()
else
:
# Generation phase.
parent_ids
=
list
(
range
(
num_parent_seqs
))
next_token_ids
=
random_samples
[
sample_idx
:
sample_idx
+
num_parent_seqs
,
0
].
tolist
()
results
.
append
((
next_token_ids
,
parent_ids
))
sample_idx
+=
num_parent_seqs
assert
sample_idx
==
probs
.
size
(
0
)
return
results
def
_beam_search_sample
(
selected_seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]],
is_prompts
:
List
[
bool
],
seq_data
:
Dict
[
int
,
SequenceData
],
logprobs
:
torch
.
Tensor
,
)
->
List
[
Tuple
[
List
[
int
],
List
[
int
]]]:
# We sample 2 * beam_width candidates to make sure that with high
# probability we can get `beam_width` candidates in addition to
# probability we can get `beam_width` candidates in addition to
# the finished sequences for the next iteration. See
# the finished sequences for the next iteration. See
# https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
# https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
# for details. See also HF reference:
# for details. See also HF reference:
# https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
# https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
_
,
next_token_ids
=
torch
.
topk
(
prob
,
2
*
beam_width
)
#
# Note: Beam search is not vectorized, so its speed can be slower than
# other sampling methods.
sample_idx
=
0
results
=
[]
for
seq_group
,
is_prompt
in
zip
(
selected_seq_groups
,
is_prompts
):
seq_ids
,
sampling_params
=
seq_group
num_parent_seqs
=
len
(
seq_ids
)
beam_width
=
sampling_params
.
best_of
seq_group_logprobs
=
logprobs
[
sample_idx
:
sample_idx
+
num_parent_seqs
]
if
is_prompt
:
# Prompt phase.
assert
num_parent_seqs
==
1
,
(
"Prompt input should have only one seq."
)
parent_ids
=
[
0
]
*
(
2
*
beam_width
)
_
,
next_token_ids
=
torch
.
topk
(
seq_group_logprobs
[
0
],
2
*
beam_width
)
next_token_ids
=
next_token_ids
.
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
elif
sampling_params
.
temperature
<
_SAMPLING_EPS
:
# Greedy sampling.
assert
sampling_params
.
best_of
==
1
next_token_id
=
torch
.
argmax
(
prob
)
next_token_ids
=
[
next_token_id
.
item
()]
else
:
else
:
# Random sampling.
# Generation phase.
# Sample `best_of` tokens for the prompt.
cumulative_logprobs
=
[
num_seqs
=
sampling_params
.
best_of
seq_data
[
seq_id
].
cumulative_logprob
for
seq_id
in
seq_ids
next_token_ids
=
torch
.
multinomial
(
prob
,
]
num_samples
=
num_seqs
,
cumulative_logprobs
=
torch
.
tensor
(
replacement
=
True
)
cumulative_logprobs
,
next_token_ids
=
next_token_ids
.
tolist
()
return
next_token_ids
def
_sample_from_generation_tokens
(
seq_ids
:
List
[
int
],
probs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
seq_logprobs
:
List
[
float
],
sampling_params
:
SamplingParams
,
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
# NOTE(woosuk): sampling_params.best_of can be greater than
# len(seq_ids) because some sequences in the group might have
# been already terminated.
if
sampling_params
.
use_beam_search
:
# Beam search.
# Add cumulative logprobs for the sequences in the group.
seq_logprobs
=
torch
.
tensor
(
seq_logprobs
,
dtype
=
torch
.
float
,
dtype
=
torch
.
float
,
device
=
logprobs
.
device
)
device
=
seq_group_logprobs
.
device
)
logprobs
=
logprobs
+
seq_logprobs
.
unsqueeze
(
dim
=
1
)
seq_group_logprobs
=
(
seq_group_logprobs
+
cumulative_logprobs
.
unsqueeze
(
dim
=
1
))
vocab_size
=
logprobs
.
size
(
-
1
)
_
,
topk_ids
=
torch
.
topk
(
seq_group_logprobs
.
flatten
(),
beam_width
=
len
(
seq_ids
)
2
*
beam_width
)
_
,
topk_ids
=
torch
.
topk
(
logprobs
.
flatten
(),
2
*
beam_width
)
topk_ids
=
topk_ids
.
tolist
()
topk_ids
=
topk_ids
.
tolist
()
seq_idx
=
[
i
//
vocab_size
for
i
in
topk_ids
]
vocab_size
=
seq_group_logprobs
.
size
(
-
1
)
parent_
seq_
ids
=
[
seq_ids
[
i
]
for
i
in
seq
_id
x
]
parent_ids
=
[
i
//
vocab_size
for
i
in
topk
_id
s
]
next_token_ids
=
[
i
%
vocab_size
for
i
in
topk_ids
]
next_token_ids
=
[
i
%
vocab_size
for
i
in
topk_ids
]
elif
sampling_params
.
temperature
<
_SAMPLING_EPS
:
results
.
append
((
next_token_ids
,
parent_ids
))
# Greedy sampling.
sample_idx
+=
num_parent_seqs
assert
len
(
seq_ids
)
==
1
assert
sample_idx
==
logprobs
.
size
(
0
)
next_token_id
=
torch
.
argmax
(
probs
,
dim
=-
1
)
return
results
next_token_ids
=
[
int
(
next_token_id
.
item
())]
parent_seq_ids
=
seq_ids
else
:
# Random sampling.
# Sample 1 token for each sequence in the group.
next_token_ids
=
torch
.
multinomial
(
probs
,
num_samples
=
1
,
replacement
=
True
)
next_token_ids
=
next_token_ids
.
squeeze
(
dim
=-
1
).
tolist
()
parent_seq_ids
=
seq_ids
return
parent_seq_ids
,
next_token_ids
def
_sample
(
def
_sample
(
...
@@ -364,65 +450,80 @@ def _sample(
...
@@ -364,65 +450,80 @@ def _sample(
logprobs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
)
->
SamplerOutput
:
)
->
SamplerOutput
:
seq_outputs
:
SamplerOutput
=
[]
categorized_seq_group_ids
=
{
t
:
[]
for
t
in
SamplingType
}
category_num_tokens
=
{
t
:
0
for
t
in
SamplingType
}
# TODO(woosuk): Optimize.
idx
=
0
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
seq_group_outputs
:
List
[
SequenceOutputs
]
=
[]
seq_ids
,
sampling_params
=
seq_group
seq_ids
,
sampling_params
=
seq_group
if
i
<
input_metadata
.
num_prompts
:
sampling_type
=
sampling_params
.
sampling_type
# Generate the next tokens for a prompt input.
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
assert
len
(
seq_ids
)
==
1
,
"Prompt input should have only one seq."
num_seqs
=
len
(
seq_ids
)
parent_seq_id
=
seq_ids
[
0
]
category_num_tokens
[
sampling_type
]
+=
num_seqs
prob
=
probs
[
idx
]
logprob
=
logprobs
[
idx
]
seq_outputs_dict
:
Dict
[
int
,
List
[
SequenceOutputs
]]
=
{}
idx
+=
1
category_start_idx
=
0
for
sampling_type
in
SamplingType
:
# Sample the next tokens.
seq_group_ids
=
categorized_seq_group_ids
[
sampling_type
]
next_token_ids
=
_sample_from_prompt
(
prob
,
sampling_params
)
seq_groups
=
[
input_metadata
.
seq_groups
[
i
]
for
i
in
seq_group_ids
]
# Get top-k log probabilities for the next tokens.
is_prompts
=
[
i
<
input_metadata
.
num_prompts
for
i
in
seq_group_ids
]
next_logprobs
=
_get_topk_logprobs
(
logprob
,
num_tokens
=
category_num_tokens
[
sampling_type
]
sampling_params
.
logprobs
)
if
num_tokens
==
0
:
continue
# Build the output.
category_logprobs
=
logprobs
[
category_start_idx
:
category_start_idx
+
for
next_token_id
in
next_token_ids
:
num_tokens
]
output_logprobs
=
next_logprobs
.
copy
()
category_probs
=
probs
[
category_start_idx
:
category_start_idx
+
output_logprobs
[
next_token_id
]
=
logprob
[
next_token_id
].
item
()
num_tokens
]
seq_group_outputs
.
append
(
if
sampling_type
==
SamplingType
.
GREEDY
:
SequenceOutputs
(
parent_seq_id
,
next_token_id
,
sample_results
=
_greedy_sample
(
seq_groups
,
category_logprobs
)
output_logprobs
))
elif
sampling_type
==
SamplingType
.
RANDOM
:
sample_results
=
_random_sample
(
seq_groups
,
is_prompts
,
category_probs
)
elif
sampling_type
==
SamplingType
.
BEAM
:
sample_results
=
_beam_search_sample
(
seq_groups
,
is_prompts
,
input_metadata
.
seq_data
,
category_logprobs
)
else
:
else
:
# Generate the next tokens for generation tokens.
raise
ValueError
(
f
"Unsupported sampling type:
{
sampling_type
}
"
)
# Batched query for logprobs of selected token
batched_logprobs_query_seq_indices
:
List
[
int
]
=
[]
batched_logprobs_query_token_indices
:
List
[
int
]
=
[]
sample_idx
=
0
for
seq_group_id
,
seq_group
,
sample_result
in
zip
(
seq_group_ids
,
seq_groups
,
sample_results
):
seq_ids
,
sampling_params
=
seq_group
next_token_ids
,
parent_ids
=
sample_result
num_parent_seqs
=
len
(
seq_ids
)
num_parent_seqs
=
len
(
seq_ids
)
prob
=
probs
[
idx
:
idx
+
num_parent_seqs
]
batched_logprobs_query_seq_indices
.
extend
(
logprob
=
logprobs
[
idx
:
idx
+
num_parent_seqs
]
[
sample_idx
+
parent_id
for
parent_id
in
parent_ids
])
idx
+=
num_parent_seqs
batched_logprobs_query_token_indices
.
extend
(
next_token_ids
)
sample_idx
+=
num_parent_seqs
# Sample the next tokens.
assert
sample_idx
==
num_tokens
seq_logprobs
=
[
batched_logprobs_query_result
=
category_logprobs
[[
input_metadata
.
seq_data
[
seq_id
].
cumulative_logprob
batched_logprobs_query_seq_indices
,
for
seq_id
in
seq_ids
batched_logprobs_query_token_indices
]
]].
tolist
()
parent_seq_ids
,
next_token_ids
=
_sample_from_generation_tokens
(
seq_ids
,
prob
,
logprob
,
seq_logprobs
,
sampling_params
)
# Build the sequence outputs.
sample_idx
=
0
# Get top-k log probabilities for the next tokens.
result_idx
=
0
next_logprobs
:
Dict
[
int
,
Dict
[
int
,
float
]]
=
{}
for
seq_group_id
,
seq_group
,
sample_result
in
zip
(
for
j
,
seq_id
in
enumerate
(
seq_ids
):
seq_group_ids
,
seq_groups
,
sample_results
):
next_logprobs
[
seq_id
]
=
_get_topk_logprobs
(
seq_ids
,
sampling_params
=
seq_group
logprob
[
j
],
sampling_params
.
logprobs
)
next_token_ids
,
parent_ids
=
sample_result
num_results
=
len
(
next_token_ids
)
# Build the output.
num_parent_seqs
=
len
(
seq_ids
)
for
parent_seq_id
,
next_token_id
in
zip
(
parent_seq_ids
,
parent_logprobs
=
category_logprobs
[
sample_idx
:
sample_idx
+
next_token_ids
):
num_parent_seqs
]
j
=
seq_ids
.
index
(
parent_seq_id
)
selected_token_logprobs
=
batched_logprobs_query_result
[
output_logprobs
=
next_logprobs
[
parent_seq_id
].
copy
()
result_idx
:
result_idx
+
num_results
]
output_logprobs
[
next_token_id
]
=
logprob
[
j
,
seq_output
=
_build_sequence_outputs
(
parent_ids
,
next_token_ids
,
next_token_id
].
item
()
selected_token_logprobs
,
seq_group_outputs
.
append
(
seq_ids
,
parent_logprobs
,
SequenceOutputs
(
parent_seq_id
,
next_token_id
,
sampling_params
.
logprobs
)
output_logprobs
))
seq_outputs_dict
[
seq_group_id
]
=
seq_output
seq_outputs
.
append
(
seq_group_outputs
)
sample_idx
+=
num_parent_seqs
result_idx
+=
num_results
assert
sample_idx
==
num_tokens
category_start_idx
+=
num_tokens
return
seq_outputs
return
[
seq_outputs
_dict
[
i
]
for
i
in
range
(
len
(
input_metadata
.
seq_groups
))]
vllm/sampling_params.py
View file @
947b7941
"""Sampling parameters for text generation."""
"""Sampling parameters for text generation."""
from
enum
import
IntEnum
from
functools
import
cached_property
from
typing
import
List
,
Optional
,
Union
from
typing
import
List
,
Optional
,
Union
_SAMPLING_EPS
=
1e-5
_SAMPLING_EPS
=
1e-5
class
SamplingType
(
IntEnum
):
GREEDY
=
0
RANDOM
=
1
BEAM
=
2
class
SamplingParams
:
class
SamplingParams
:
"""Sampling parameters for text generation.
"""Sampling parameters for text generation.
...
@@ -166,6 +174,14 @@ class SamplingParams:
...
@@ -166,6 +174,14 @@ class SamplingParams:
if
self
.
top_k
!=
-
1
:
if
self
.
top_k
!=
-
1
:
raise
ValueError
(
"top_k must be -1 when using greedy sampling."
)
raise
ValueError
(
"top_k must be -1 when using greedy sampling."
)
@
cached_property
def
sampling_type
(
self
)
->
SamplingType
:
if
self
.
use_beam_search
:
return
SamplingType
.
BEAM
if
self
.
temperature
<
_SAMPLING_EPS
:
return
SamplingType
.
GREEDY
return
SamplingType
.
RANDOM
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"SamplingParams(n=
{
self
.
n
}
, "
return
(
f
"SamplingParams(n=
{
self
.
n
}
, "
f
"best_of=
{
self
.
best_of
}
, "
f
"best_of=
{
self
.
best_of
}
, "
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment