Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
a7347d9a
Unverified
Commit
a7347d9a
authored
Dec 17, 2023
by
Antoni Baum
Committed by
GitHub
Dec 17, 2023
Browse files
Make sampler less blocking (#1889)
parent
f8c688d7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
310 additions
and
198 deletions
+310
-198
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+123
-198
vllm/model_executor/sampling_metadata.py
vllm/model_executor/sampling_metadata.py
+187
-0
No files found.
vllm/model_executor/layers/sampler.py
View file @
a7347d9a
...
...
@@ -6,13 +6,11 @@ import torch.nn as nn
from
vllm.model_executor.parallel_utils.communication_op
import
(
tensor_model_parallel_all_gather
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
,
SamplingTensors
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
(
PromptLogprobs
,
SampleLogprobs
,
SamplerOutput
,
SequenceData
,
SequenceGroupOutput
,
SequenceOutput
)
_SAMPLING_EPS
=
1e-5
class
Sampler
(
nn
.
Module
):
"""Samples the next tokens from the model's outputs.
...
...
@@ -32,6 +30,7 @@ class Sampler(nn.Module):
def
__init__
(
self
,
vocab_size
:
int
)
->
None
:
super
().
__init__
()
self
.
vocab_size
=
vocab_size
self
.
_copy_stream
:
torch
.
cuda
.
Stream
=
torch
.
cuda
.
Stream
()
def
forward
(
self
,
...
...
@@ -47,40 +46,38 @@ class Sampler(nn.Module):
logits
=
_get_logits
(
hidden_states
,
embedding
,
embedding_bias
,
self
.
vocab_size
)
_
,
vocab_size
=
logits
.
shape
# Apply logits processors (if any).
logits
=
_apply_logits_processors
(
logits
,
sampling_metadata
)
# Prepare sampling tensors in another stream to overlap
# CPU<->GPU data transfer with GPU computation in forward pass.
with
torch
.
cuda
.
stream
(
self
.
_copy_stream
):
(
sampling_tensors
,
do_penalties
,
do_top_p_top_k
,
do_min_p
)
=
SamplingTensors
.
from_sampling_metadata
(
sampling_metadata
,
vocab_size
,
logits
.
device
,
logits
.
dtype
)
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_copy_stream
)
# Apply presence and frequency penalties.
presence_penalties
,
frequency_penalties
,
repetition_penalties
=
(
_get_penalties
(
sampling_metadata
))
assert
len
(
presence_penalties
)
==
logits
.
shape
[
0
]
assert
len
(
frequency_penalties
)
==
logits
.
shape
[
0
]
assert
len
(
repetition_penalties
)
==
logits
.
shape
[
0
]
logits
=
_apply_penalties
(
logits
,
sampling_metadata
,
presence_penalties
,
frequency_penalties
,
repetition_penalties
)
if
do_penalties
:
logits
=
_apply_penalties
(
logits
,
sampling_tensors
.
prompt_tokens
,
sampling_tensors
.
output_tokens
,
sampling_tensors
.
presence_penalties
,
sampling_tensors
.
frequency_penalties
,
sampling_tensors
.
repetition_penalties
)
# Apply temperature scaling.
temperatures
=
_get_temperatures
(
sampling_metadata
)
assert
len
(
temperatures
)
==
logits
.
shape
[
0
]
if
any
(
t
!=
1.0
for
t
in
temperatures
):
t
=
torch
.
tensor
(
temperatures
,
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
# Use in-place division to avoid creating a new tensor.
logits
.
div_
(
t
.
unsqueeze
(
dim
=
1
))
# Apply top-p and top-k truncation.
top_ps
,
top_ks
,
min_ps
=
_get_top_p_top_k_min_p
(
sampling_metadata
,
self
.
vocab_size
)
assert
len
(
top_ps
)
==
len
(
top_ks
)
==
logits
.
shape
[
0
]
do_top_p
=
any
(
p
<
1.0
-
_SAMPLING_EPS
for
p
in
top_ps
)
do_top_k
=
any
(
k
!=
self
.
vocab_size
for
k
in
top_ks
)
if
do_top_p
or
do_top_k
:
logits
=
_apply_top_p_top_k
(
logits
,
top_ps
,
top_ks
)
do_min_p
=
any
(
mp
>
_SAMPLING_EPS
for
mp
in
min_ps
)
# Use in-place division to avoid creating a new tensor.
logits
.
div_
(
sampling_tensors
.
temperatures
.
unsqueeze_
(
dim
=
1
))
if
do_top_p_top_k
:
logits
=
_apply_top_p_top_k
(
logits
,
sampling_tensors
.
top_ps
,
sampling_tensors
.
top_ks
)
if
do_min_p
:
logits
=
_apply_min_p
(
logits
,
min_ps
)
logits
=
_apply_min_p
(
logits
,
sampling_tensors
.
min_ps
)
# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
...
...
@@ -120,32 +117,6 @@ def _prune_hidden_states(
sampling_metadata
.
selected_token_indices
)
def
_get_penalties
(
sampling_metadata
:
SamplingMetadata
)
->
Tuple
[
List
[
float
],
List
[
float
],
List
[
float
]]:
# Collect the presence and frequency penalties.
presence_penalties
:
List
[
float
]
=
[]
frequency_penalties
:
List
[
float
]
=
[]
repetition_penalties
:
List
[
float
]
=
[]
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
p
=
sampling_params
.
presence_penalty
f
=
sampling_params
.
frequency_penalty
r
=
sampling_params
.
repetition_penalty
if
(
i
<
sampling_metadata
.
num_prompts
and
sampling_params
.
prompt_logprobs
is
not
None
):
# NOTE: We do not apply presence and frequency penalties for the
# prompt token positions where we don't sample new tokens.
prompt_len
=
sampling_metadata
.
prompt_lens
[
i
]
presence_penalties
+=
[
0
]
*
(
prompt_len
-
1
)
frequency_penalties
+=
[
0
]
*
(
prompt_len
-
1
)
repetition_penalties
+=
[
1
]
*
(
prompt_len
-
1
)
presence_penalties
+=
[
p
]
*
len
(
seq_ids
)
frequency_penalties
+=
[
f
]
*
len
(
seq_ids
)
repetition_penalties
+=
[
r
]
*
len
(
seq_ids
)
return
presence_penalties
,
frequency_penalties
,
repetition_penalties
def
_get_prompt_and_output_tokens
(
sampling_metadata
:
SamplingMetadata
,
)
->
Tuple
[
List
[
List
[
int
]],
List
[
List
[
int
]]]:
...
...
@@ -168,25 +139,16 @@ def _get_prompt_and_output_tokens(
def
_get_bin_counts_and_mask
(
logits
:
torch
.
Tensor
,
tokens
:
List
[
List
[
int
]],
tokens
:
torch
.
Tensor
,
vocab_size
:
int
,
num_seqs
:
int
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
max_len
=
max
(
len
(
tokens
)
for
tokens
in
tokens
)
padded_tokens
=
[
tokens
+
[
vocab_size
]
*
(
max_len
-
len
(
tokens
))
for
tokens
in
tokens
]
tokens_tensor
=
torch
.
tensor
(
padded_tokens
,
dtype
=
torch
.
long
,
device
=
logits
.
device
)
# Compute the bin counts for the tokens.
# vocab_size + 1 for padding.
bin_counts
=
torch
.
zeros
((
num_seqs
,
vocab_size
+
1
),
dtype
=
torch
.
long
,
device
=
logit
s
.
device
)
bin_counts
.
scatter_add_
(
1
,
tokens
_tensor
,
torch
.
ones_like
(
tokens
_tensor
))
device
=
token
s
.
device
)
bin_counts
.
scatter_add_
(
1
,
tokens
,
torch
.
ones_like
(
tokens
))
bin_counts
=
bin_counts
[:,
:
vocab_size
]
mask
=
bin_counts
>
0
...
...
@@ -217,45 +179,16 @@ def _apply_logits_processors(
return
logits
def
_apply_penalties
(
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
presence_penalties
:
List
[
float
],
frequency_penalties
:
List
[
float
],
repetition_penalties
:
List
[
float
],
)
->
torch
.
Tensor
:
def
_apply_penalties
(
logits
:
torch
.
Tensor
,
prompt_tokens_tensor
:
torch
.
Tensor
,
output_tokens_tensor
:
torch
.
Tensor
,
presence_penalties
:
torch
.
Tensor
,
frequency_penalties
:
torch
.
Tensor
,
repetition_penalties
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_seqs
,
vocab_size
=
logits
.
shape
for
i
in
range
(
num_seqs
):
p
=
presence_penalties
[
i
]
f
=
frequency_penalties
[
i
]
r
=
repetition_penalties
[
i
]
if
abs
(
p
)
<
_SAMPLING_EPS
and
abs
(
f
)
<
_SAMPLING_EPS
and
abs
(
r
-
1.0
)
<
_SAMPLING_EPS
:
continue
break
else
:
# Return early if all sequences have zero penalties.
return
logits
prompt_tokens
,
output_tokens
=
(
_get_prompt_and_output_tokens
(
sampling_metadata
))
assert
len
(
prompt_tokens
)
==
logits
.
shape
[
0
]
assert
len
(
output_tokens
)
==
logits
.
shape
[
0
]
prompt_bin_counts
,
prompt_mask
=
_get_bin_counts_and_mask
(
logits
,
prompt_tokens
,
vocab_size
,
num_seqs
)
_
,
prompt_mask
=
_get_bin_counts_and_mask
(
prompt_tokens_tensor
,
vocab_size
,
num_seqs
)
output_bin_counts
,
output_mask
=
_get_bin_counts_and_mask
(
logits
,
output_tokens
,
vocab_size
,
num_seqs
)
repetition_penalties
=
torch
.
tensor
(
repetition_penalties
,
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
frequency_penalties
=
torch
.
tensor
(
frequency_penalties
,
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
presence_penalties
=
torch
.
tensor
(
presence_penalties
,
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
output_tokens_tensor
,
vocab_size
,
num_seqs
)
repetition_penalties
=
repetition_penalties
[:,
None
].
repeat
(
1
,
vocab_size
)
repetition_penalties
[
~
(
prompt_mask
|
output_mask
)]
=
1.0
...
...
@@ -264,109 +197,65 @@ def _apply_penalties(
# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits
-=
frequency_penalties
.
unsqueeze
(
dim
=
1
)
*
output_bin_counts
logits
-=
presence_penalties
.
unsqueeze
(
dim
=
1
)
*
output_mask
logits
-=
frequency_penalties
.
unsqueeze
_
(
dim
=
1
)
*
output_bin_counts
logits
-=
presence_penalties
.
unsqueeze
_
(
dim
=
1
)
*
output_mask
return
logits
def
_get_temperatures
(
sampling_metadata
:
SamplingMetadata
)
->
List
[
float
]:
# Collect the temperatures for the logits.
temperatures
:
List
[
float
]
=
[]
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
temperature
=
sampling_params
.
temperature
if
temperature
<
_SAMPLING_EPS
:
# NOTE: Zero temperature means deterministic sampling
# (i.e., greedy sampling or beam search).
# Set the temperature to 1 to avoid division by zero.
temperature
=
1.0
if
(
i
<
sampling_metadata
.
num_prompts
and
sampling_params
.
prompt_logprobs
is
not
None
):
prompt_len
=
sampling_metadata
.
prompt_lens
[
i
]
temperatures
+=
[
temperature
]
*
(
prompt_len
-
1
)
temperatures
+=
[
temperature
]
*
len
(
seq_ids
)
return
temperatures
def
_get_top_p_top_k_min_p
(
sampling_metadata
:
SamplingMetadata
,
vocab_size
:
int
,
)
->
Tuple
[
List
[
float
],
List
[
int
],
List
[
float
]]:
top_ps
:
List
[
float
]
=
[]
top_ks
:
List
[
int
]
=
[]
min_ps
:
List
[
float
]
=
[]
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
top_p
=
sampling_params
.
top_p
min_p
=
sampling_params
.
min_p
# k should not be greater than the vocab size.
top_k
=
min
(
sampling_params
.
top_k
,
vocab_size
)
# k=-1 means no truncation.
top_k
=
vocab_size
if
top_k
==
-
1
else
top_k
if
(
i
<
sampling_metadata
.
num_prompts
and
sampling_params
.
prompt_logprobs
is
not
None
):
prompt_len
=
sampling_metadata
.
prompt_lens
[
i
]
top_ps
+=
[
top_p
]
*
(
prompt_len
-
1
)
top_ks
+=
[
top_k
]
*
(
prompt_len
-
1
)
min_ps
+=
[
min_p
]
*
(
prompt_len
-
1
)
top_ps
+=
[
top_p
]
*
len
(
seq_ids
)
top_ks
+=
[
top_k
]
*
len
(
seq_ids
)
min_ps
+=
[
min_p
]
*
len
(
seq_ids
)
return
top_ps
,
top_ks
,
min_ps
def
_apply_top_p_top_k
(
logits
:
torch
.
Tensor
,
top_ps
:
List
[
float
]
,
top_ks
:
List
[
int
]
,
p
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
p
=
torch
.
tensor
(
top_ps
,
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
k
=
torch
.
tensor
(
top_ks
,
dtype
=
torch
.
int
,
device
=
logits
.
device
)
logits_sort
,
logits_idx
=
logits
.
sort
(
dim
=-
1
,
descending
=
True
)
# Apply top-p.
probs_sort
=
logits_sort
.
softmax
(
dim
=-
1
)
probs_sum
=
probs_sort
.
cumsum
(
dim
=-
1
)
top_p_mask
=
(
probs_sum
-
probs_sort
)
>
p
.
unsqueeze
(
dim
=
1
)
logits_sort
[
top_p_mask
]
=
-
float
(
"inf"
)
probs_sum
=
probs_sort
.
cumsum
(
dim
=-
1
).
sub_
(
probs_sort
)
top_p_mask
=
probs_sum
>
p
.
unsqueeze_
(
dim
=
1
)
# Apply top-k.
# Create a mask for the top-k elements.
top_k_mask
=
torch
.
arange
(
logits_idx
.
shape
[
-
1
],
device
=
logits_idx
.
device
)
top_k_mask
=
top_k_mask
.
expand
(
logits_idx
.
shape
[
0
],
-
1
)
top_k_mask
=
top_k_mask
>=
k
.
unsqueeze
(
dim
=
1
)
logits_sort
[
top_k_mask
]
=
-
float
(
"inf"
)
top_k_mask
=
top_k_mask
>=
k
.
unsqueeze_
(
dim
=
1
)
# Final mask.
mask
=
(
top_p_mask
|
top_k_mask
)
logits_sort
.
masked_fill_
(
mask
,
-
float
(
"inf"
))
# Re-sort the probabilities.
logits
=
torch
.
gather
(
logits_sort
,
dim
=-
1
,
index
=
torch
.
argsort
(
logits_idx
,
dim
=-
1
))
src
=
torch
.
arange
(
logits_idx
.
shape
[
-
1
],
device
=
logits_idx
.
device
).
expand_as
(
logits_idx
)
logits_idx_inv
=
torch
.
empty_like
(
logits_idx
).
scatter_
(
dim
=-
1
,
index
=
logits_idx
,
src
=
src
)
logits
=
torch
.
gather
(
logits_sort
,
dim
=-
1
,
index
=
logits_idx_inv
)
return
logits
def
_apply_min_p
(
logits
:
torch
.
Tensor
,
min_p
s
:
List
[
float
]
,
min_p
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""
Adapted from
https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
"""
min_p
=
torch
.
tensor
(
min_ps
,
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
top_probs
,
_
=
probs
.
max
(
dim
=-
1
,
keepdim
=
True
)
scaled_min_p
=
min_p
.
unsqueeze
(
dim
=
1
)
*
top_probs
scaled_min_p
=
min_p
.
unsqueeze
_
(
dim
=
1
)
*
top_probs
tokens_to_remove
=
probs
<
scaled_min_p
logits
=
logits
.
masked_fill
(
tokens_to_remove
,
-
float
(
"inf"
))
logits
=
logits
.
masked_fill
_
(
tokens_to_remove
,
-
float
(
"inf"
))
return
logits
def
_greedy_sample
(
selected_seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]],
logprob
s
:
torch
.
Tensor
,
sample
s
:
torch
.
Tensor
,
)
->
List
[
Tuple
[
List
[
int
],
List
[
int
]]]:
samples
=
torch
.
argmax
(
logprobs
,
dim
=-
1
).
cpu
()
samples
=
samples
.
tolist
()
sample_idx
=
0
results
=
[]
for
seq_group
in
selected_seq_groups
:
...
...
@@ -375,27 +264,19 @@ def _greedy_sample(
assert
num_parent_seqs
==
1
,
(
"Greedy sampling should have only one seq."
)
parent_ids
=
list
(
range
(
num_parent_seqs
))
next_token_ids
=
[
samples
[
sample_idx
]
.
item
()
]
next_token_ids
=
[
samples
[
sample_idx
]]
results
.
append
((
next_token_ids
,
parent_ids
))
sample_idx
+=
num_parent_seqs
assert
sample_idx
==
logprobs
.
size
(
0
)
return
results
def
_random_sample
(
selected_seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]],
is_prompts
:
List
[
bool
],
prob
s
:
torch
.
Tensor
,
random_sample
s
:
torch
.
Tensor
,
)
->
List
[
Tuple
[
List
[
int
],
List
[
int
]]]:
# Find the maximum best_of value of the prompt phase requests.
max_best_of
=
1
for
seq_group
,
is_prompt
in
zip
(
selected_seq_groups
,
is_prompts
):
if
is_prompt
:
seq_ids
,
sampling_params
=
seq_group
max_best_of
=
max
(
max_best_of
,
sampling_params
.
best_of
)
random_samples
=
torch
.
multinomial
(
probs
,
num_samples
=
max_best_of
,
replacement
=
True
).
cpu
()
random_samples
=
random_samples
.
cpu
()
sample_idx
=
0
results
=
[]
for
seq_group
,
is_prompt
in
zip
(
selected_seq_groups
,
is_prompts
):
...
...
@@ -403,8 +284,6 @@ def _random_sample(
num_parent_seqs
=
len
(
seq_ids
)
if
is_prompt
:
# Prompt phase.
assert
num_parent_seqs
==
1
,
(
"Prompt input should have only one seq."
)
parent_ids
=
[
0
]
*
sampling_params
.
best_of
next_token_ids
=
random_samples
[
sample_idx
,
:
sampling_params
.
best_of
].
tolist
()
...
...
@@ -415,7 +294,6 @@ def _random_sample(
num_parent_seqs
,
0
].
tolist
()
results
.
append
((
next_token_ids
,
parent_ids
))
sample_idx
+=
num_parent_seqs
assert
sample_idx
==
probs
.
size
(
0
)
return
results
...
...
@@ -472,6 +350,28 @@ def _beam_search_sample(
return
results
# torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead.
# Note that we always sample with replacement.
# probs will be modified in place, but this is fine, as we pass
# in a copy already.
def
_multinomial
(
probs
:
torch
.
Tensor
,
num_samples
:
int
,
):
if
num_samples
>
1
:
# This is equivalent to torch.repeat_interleaved (which also
# forces a GPU<->CPU sync).
# This allows us to do sampling with replacement by creating
# num_samples copies of each row in the tensor, and then
# batch sampling the resulting tensor.
probs
=
probs
[:,
None
,
:].
expand
(
probs
.
shape
[
0
],
num_samples
,
probs
.
shape
[
1
]).
contiguous
().
view
(
-
1
,
probs
.
shape
[
1
])
q
=
torch
.
empty_like
(
probs
).
exponential_
(
1
)
return
probs
.
div_
(
q
).
argmax
(
dim
=
1
).
view
(
-
1
,
num_samples
)
def
_sample
(
probs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
...
...
@@ -485,28 +385,51 @@ def _sample(
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
sample_results_dict
:
Dict
[
int
,
Tuple
[
List
[
int
],
List
[
int
]]]
=
{}
sample_metadata
=
{}
# Counterintiutively, having two loops here is actually faster.
# The first loop can run without waiting on GPU<->CPU sync.
for
sampling_type
in
SamplingType
:
seq_group_ids
=
categorized_seq_group_ids
[
sampling_type
]
seq_groups
=
[
sampling_metadata
.
seq_groups
[
i
]
for
i
in
seq_group_ids
]
is_prompts
=
[
i
<
sampling_metadata
.
num_prompts
for
i
in
seq_group_ids
]
sample_indices
=
categorized_sample_indices
[
sampling_type
]
num_tokens
=
len
(
sample_indices
)
if
num_tokens
==
0
:
continue
seq_group_ids
=
categorized_seq_group_ids
[
sampling_type
]
seq_groups
=
[
sampling_metadata
.
seq_groups
[
i
]
for
i
in
seq_group_ids
]
is_prompts
=
[
i
<
sampling_metadata
.
num_prompts
for
i
in
seq_group_ids
]
sample_metadata
[
sampling_type
]
=
(
seq_group_ids
,
seq_groups
,
is_prompts
,
sample_indices
)
if
sampling_type
==
SamplingType
.
GREEDY
:
category_logprobs
=
logprobs
[
sample_indices
]
sample_results
=
_greedy_sample
(
seq_groups
,
category_logprobs
)
greedy_samples
=
torch
.
argmax
(
logprobs
[
sample_indices
],
dim
=-
1
)
elif
sampling_type
==
SamplingType
.
RANDOM
:
max_best_of
=
1
for
seq_group
,
is_prompt
in
zip
(
seq_groups
,
is_prompts
):
if
is_prompt
:
_
,
sampling_params
=
seq_group
max_best_of
=
max
(
max_best_of
,
sampling_params
.
best_of
)
multinomial_samples
=
_multinomial
(
probs
[
sample_indices
],
max_best_of
)
elif
sampling_type
==
SamplingType
.
BEAM
:
beam_search_logprobs
=
logprobs
[
sample_indices
]
else
:
raise
ValueError
(
f
"Unsupported sampling type:
{
sampling_type
}
"
)
# GPU<->CPU sync happens in the loop below.
for
sampling_type
in
SamplingType
:
if
sampling_type
not
in
sample_metadata
:
continue
seq_group_ids
,
seq_groups
,
is_prompts
,
sample_indices
=
sample_metadata
[
sampling_type
]
if
sampling_type
==
SamplingType
.
GREEDY
:
sample_results
=
_greedy_sample
(
seq_groups
,
greedy_samples
)
elif
sampling_type
==
SamplingType
.
RANDOM
:
category_probs
=
probs
[
sample_indices
]
sample_results
=
_random_sample
(
seq_groups
,
is_prompts
,
category_prob
s
)
multinomial_sample
s
)
elif
sampling_type
==
SamplingType
.
BEAM
:
category_logprobs
=
logprobs
[
sample_indices
]
sample_results
=
_beam_search_sample
(
seq_groups
,
is_prompts
,
sampling_metadata
.
seq_data
,
category_logprobs
)
else
:
raise
ValueError
(
f
"Unsupported sampling type:
{
sampling_type
}
"
)
beam_search_logprobs
)
sample_results_dict
.
update
(
zip
(
seq_group_ids
,
sample_results
))
sample_results
=
[
...
...
@@ -557,7 +480,7 @@ def _get_logprobs(
batched_logprobs_query_result
=
logprobs
[[
batched_logprobs_query_seq_indices
,
batched_logprobs_query_token_indices
]]
.
cpu
()
]]
# Batched query for logprobs of topk tokens
if
largest_num_logprobs
>
0
:
...
...
@@ -569,6 +492,8 @@ def _get_logprobs(
else
:
top_logprobs
,
top_token_ids
=
None
,
None
batched_logprobs_query_result
=
batched_logprobs_query_result
.
cpu
()
# Gather results
result_prompt_logprobs
:
List
[
Optional
[
PromptLogprobs
]]
=
[]
result_sample_logprobs
:
List
[
SampleLogprobs
]
=
[]
...
...
vllm/model_executor/sampling_metadata.py
View file @
a7347d9a
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Tuple
import
torch
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
SequenceData
from
vllm.utils
import
in_wsl
_SAMPLING_EPS
=
1e-5
class
SamplingMetadata
:
...
...
@@ -41,3 +45,186 @@ class SamplingMetadata:
f
"prompt_lens=
{
self
.
prompt_lens
}
, "
f
"selected_token_indices=
{
self
.
selected_token_indices
}
, "
f
"categorized_sample_indices=
{
self
.
categorized_sample_indices
}
)"
)
@
dataclass
class
SamplingTensors
:
"""Tensors for sampling."""
temperatures
:
torch
.
Tensor
top_ps
:
torch
.
Tensor
top_ks
:
torch
.
Tensor
min_ps
:
torch
.
Tensor
presence_penalties
:
torch
.
Tensor
frequency_penalties
:
torch
.
Tensor
repetition_penalties
:
torch
.
Tensor
prompt_tokens
:
torch
.
Tensor
output_tokens
:
torch
.
Tensor
@
classmethod
def
from_sampling_metadata
(
cls
,
sampling_metadata
:
"SamplingMetadata"
,
vocab_size
:
int
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
)
->
Tuple
[
"SamplingTensors"
,
bool
,
bool
,
bool
]:
prompt_tokens
:
List
[
List
[
int
]]
=
[]
output_tokens
:
List
[
List
[
int
]]
=
[]
top_ks
:
List
[
int
]
=
[]
temperatures
:
List
[
float
]
=
[]
top_ps
:
List
[
float
]
=
[]
min_ps
:
List
[
float
]
=
[]
presence_penalties
:
List
[
float
]
=
[]
frequency_penalties
:
List
[
float
]
=
[]
repetition_penalties
:
List
[
float
]
=
[]
do_penalties
=
False
do_top_p_top_k
=
False
do_min_p
=
False
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
temperature
=
sampling_params
.
temperature
p
=
sampling_params
.
presence_penalty
f
=
sampling_params
.
frequency_penalty
r
=
sampling_params
.
repetition_penalty
top_p
=
sampling_params
.
top_p
min_p
=
sampling_params
.
min_p
# k should not be greater than the vocab size.
top_k
=
min
(
sampling_params
.
top_k
,
vocab_size
)
top_k
=
vocab_size
if
top_k
==
-
1
else
top_k
if
temperature
<
_SAMPLING_EPS
:
# NOTE: Zero temperature means deterministic sampling
# (i.e., greedy sampling or beam search).
# Set the temperature to 1 to avoid division by zero.
temperature
=
1.0
if
not
do_top_p_top_k
and
(
top_p
<
1.0
-
_SAMPLING_EPS
or
top_k
!=
vocab_size
):
do_top_p_top_k
=
True
if
not
do_min_p
and
min_p
>
_SAMPLING_EPS
:
do_min_p
=
True
if
not
do_penalties
and
(
abs
(
p
)
>=
_SAMPLING_EPS
or
abs
(
f
)
>=
_SAMPLING_EPS
or
abs
(
r
-
1.0
)
>=
_SAMPLING_EPS
):
do_penalties
=
True
if
(
i
<
sampling_metadata
.
num_prompts
and
sampling_params
.
prompt_logprobs
is
not
None
):
# For tokens in the prompt that we only need to get their logprobs
prompt_len
=
sampling_metadata
.
prompt_lens
[
i
]
temperatures
+=
[
temperature
]
*
(
prompt_len
-
1
)
top_ps
+=
[
top_p
]
*
(
prompt_len
-
1
)
top_ks
+=
[
top_k
]
*
(
prompt_len
-
1
)
min_ps
+=
[
min_p
]
*
(
prompt_len
-
1
)
presence_penalties
+=
[
0
]
*
(
prompt_len
-
1
)
frequency_penalties
+=
[
0
]
*
(
prompt_len
-
1
)
repetition_penalties
+=
[
1
]
*
(
prompt_len
-
1
)
prompt_tokens
.
extend
([]
for
_
in
range
(
prompt_len
-
1
))
output_tokens
.
extend
([]
for
_
in
range
(
prompt_len
-
1
))
for
seq_id
in
seq_ids
:
seq_data
=
sampling_metadata
.
seq_data
[
seq_id
]
prompt_tokens
.
append
(
seq_data
.
prompt_token_ids
)
output_tokens
.
append
(
seq_data
.
output_token_ids
)
temperatures
+=
[
temperature
]
*
len
(
seq_ids
)
top_ps
+=
[
top_p
]
*
len
(
seq_ids
)
top_ks
+=
[
top_k
]
*
len
(
seq_ids
)
min_ps
+=
[
min_p
]
*
len
(
seq_ids
)
presence_penalties
+=
[
p
]
*
len
(
seq_ids
)
frequency_penalties
+=
[
f
]
*
len
(
seq_ids
)
repetition_penalties
+=
[
r
]
*
len
(
seq_ids
)
sampling_tensors
=
SamplingTensors
.
from_lists
(
temperatures
,
top_ps
,
top_ks
,
min_ps
,
presence_penalties
,
frequency_penalties
,
repetition_penalties
,
prompt_tokens
,
output_tokens
,
vocab_size
,
device
,
dtype
)
return
(
sampling_tensors
,
do_penalties
,
do_top_p_top_k
,
do_min_p
)
@
classmethod
def
from_lists
(
cls
,
temperatures
:
List
[
float
],
top_ps
:
List
[
float
],
top_ks
:
List
[
int
],
min_ps
:
List
[
float
],
presence_penalties
:
List
[
float
],
frequency_penalties
:
List
[
float
],
repetition_penalties
:
List
[
float
],
prompt_tokens
:
List
[
List
[
int
]],
output_tokens
:
List
[
List
[
int
]],
vocab_size
:
int
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
)
->
"SamplingTensors"
:
# Note that the performance will be very bad without
# pinned memory.
pin_memory
=
not
in_wsl
()
prompt_max_len
=
max
(
len
(
tokens
)
for
tokens
in
prompt_tokens
)
prompt_padded_tokens
=
[
tokens
+
[
vocab_size
]
*
(
prompt_max_len
-
len
(
tokens
))
for
tokens
in
prompt_tokens
]
output_max_len
=
max
(
len
(
tokens
)
for
tokens
in
output_tokens
)
output_padded_tokens
=
[
tokens
+
[
vocab_size
]
*
(
output_max_len
-
len
(
tokens
))
for
tokens
in
output_tokens
]
temperatures_t
=
torch
.
tensor
(
temperatures
,
device
=
"cpu"
,
dtype
=
dtype
,
pin_memory
=
pin_memory
,
)
top_ps_t
=
torch
.
tensor
(
top_ps
,
device
=
"cpu"
,
dtype
=
dtype
,
pin_memory
=
pin_memory
,
)
min_ps_t
=
torch
.
tensor
(
min_ps
,
device
=
"cpu"
,
dtype
=
dtype
,
pin_memory
=
pin_memory
,
)
presence_penalties_t
=
torch
.
tensor
(
presence_penalties
,
device
=
"cpu"
,
dtype
=
dtype
,
pin_memory
=
pin_memory
,
)
frequency_penalties_t
=
torch
.
tensor
(
frequency_penalties
,
device
=
"cpu"
,
dtype
=
dtype
,
pin_memory
=
pin_memory
,
)
repetition_penalties_t
=
torch
.
tensor
(
repetition_penalties
,
device
=
"cpu"
,
dtype
=
dtype
,
pin_memory
=
pin_memory
,
)
top_ks_t
=
torch
.
tensor
(
top_ks
,
device
=
"cpu"
,
dtype
=
torch
.
int
,
pin_memory
=
pin_memory
,
)
prompt_tensor
=
torch
.
tensor
(
prompt_padded_tokens
,
device
=
"cpu"
,
dtype
=
torch
.
long
,
pin_memory
=
pin_memory
,
)
output_tensor
=
torch
.
tensor
(
output_padded_tokens
,
device
=
"cpu"
,
dtype
=
torch
.
long
,
pin_memory
=
pin_memory
,
)
# Because the memory is pinned, we can do non-blocking
# transfer to device.
return
cls
(
temperatures
=
temperatures_t
.
to
(
device
=
device
,
non_blocking
=
True
),
top_ps
=
top_ps_t
.
to
(
device
=
device
,
non_blocking
=
True
),
top_ks
=
top_ks_t
.
to
(
device
=
device
,
non_blocking
=
True
),
min_ps
=
min_ps_t
.
to
(
device
=
device
,
non_blocking
=
True
),
presence_penalties
=
presence_penalties_t
.
to
(
device
=
device
,
non_blocking
=
True
),
frequency_penalties
=
frequency_penalties_t
.
to
(
device
=
device
,
non_blocking
=
True
),
repetition_penalties
=
repetition_penalties_t
.
to
(
device
=
device
,
non_blocking
=
True
),
prompt_tokens
=
prompt_tensor
.
to
(
device
=
device
,
non_blocking
=
True
),
output_tokens
=
output_tensor
.
to
(
device
=
device
,
non_blocking
=
True
),
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment