Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
a7347d9a
"torchvision/vscode:/vscode.git/clone" did not exist on "4b2ad55f1b11d70cf2b31a903fbb685fc9f79e6a"
Unverified
Commit
a7347d9a
authored
Dec 17, 2023
by
Antoni Baum
Committed by
GitHub
Dec 17, 2023
Browse files
Make sampler less blocking (#1889)
parent
f8c688d7
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
310 additions
and
198 deletions
+310
-198
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+123
-198
vllm/model_executor/sampling_metadata.py
vllm/model_executor/sampling_metadata.py
+187
-0
No files found.
vllm/model_executor/layers/sampler.py
View file @
a7347d9a
...
@@ -6,13 +6,11 @@ import torch.nn as nn
...
@@ -6,13 +6,11 @@ import torch.nn as nn
from
vllm.model_executor.parallel_utils.communication_op
import
(
from
vllm.model_executor.parallel_utils.communication_op
import
(
tensor_model_parallel_all_gather
)
tensor_model_parallel_all_gather
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
,
SamplingTensors
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
(
PromptLogprobs
,
SampleLogprobs
,
SamplerOutput
,
from
vllm.sequence
import
(
PromptLogprobs
,
SampleLogprobs
,
SamplerOutput
,
SequenceData
,
SequenceGroupOutput
,
SequenceOutput
)
SequenceData
,
SequenceGroupOutput
,
SequenceOutput
)
_SAMPLING_EPS
=
1e-5
class
Sampler
(
nn
.
Module
):
class
Sampler
(
nn
.
Module
):
"""Samples the next tokens from the model's outputs.
"""Samples the next tokens from the model's outputs.
...
@@ -32,6 +30,7 @@ class Sampler(nn.Module):
...
@@ -32,6 +30,7 @@ class Sampler(nn.Module):
def
__init__
(
self
,
vocab_size
:
int
)
->
None
:
def
__init__
(
self
,
vocab_size
:
int
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
_copy_stream
:
torch
.
cuda
.
Stream
=
torch
.
cuda
.
Stream
()
def
forward
(
def
forward
(
self
,
self
,
...
@@ -47,40 +46,38 @@ class Sampler(nn.Module):
...
@@ -47,40 +46,38 @@ class Sampler(nn.Module):
logits
=
_get_logits
(
hidden_states
,
embedding
,
embedding_bias
,
logits
=
_get_logits
(
hidden_states
,
embedding
,
embedding_bias
,
self
.
vocab_size
)
self
.
vocab_size
)
_
,
vocab_size
=
logits
.
shape
# Apply logits processors (if any).
# Apply logits processors (if any).
logits
=
_apply_logits_processors
(
logits
,
sampling_metadata
)
logits
=
_apply_logits_processors
(
logits
,
sampling_metadata
)
# Prepare sampling tensors in another stream to overlap
# CPU<->GPU data transfer with GPU computation in forward pass.
with
torch
.
cuda
.
stream
(
self
.
_copy_stream
):
(
sampling_tensors
,
do_penalties
,
do_top_p_top_k
,
do_min_p
)
=
SamplingTensors
.
from_sampling_metadata
(
sampling_metadata
,
vocab_size
,
logits
.
device
,
logits
.
dtype
)
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_copy_stream
)
# Apply presence and frequency penalties.
# Apply presence and frequency penalties.
presence_penalties
,
frequency_penalties
,
repetition_penalties
=
(
if
do_penalties
:
_get_penalties
(
sampling_metadata
))
logits
=
_apply_penalties
(
logits
,
sampling_tensors
.
prompt_tokens
,
assert
len
(
presence_penalties
)
==
logits
.
shape
[
0
]
sampling_tensors
.
output_tokens
,
assert
len
(
frequency_penalties
)
==
logits
.
shape
[
0
]
sampling_tensors
.
presence_penalties
,
assert
len
(
repetition_penalties
)
==
logits
.
shape
[
0
]
sampling_tensors
.
frequency_penalties
,
logits
=
_apply_penalties
(
logits
,
sampling_metadata
,
sampling_tensors
.
repetition_penalties
)
presence_penalties
,
frequency_penalties
,
repetition_penalties
)
# Apply temperature scaling.
# Apply temperature scaling.
temperatures
=
_get_temperatures
(
sampling_metadata
)
assert
len
(
temperatures
)
==
logits
.
shape
[
0
]
if
any
(
t
!=
1.0
for
t
in
temperatures
):
t
=
torch
.
tensor
(
temperatures
,
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
# Use in-place division to avoid creating a new tensor.
# Use in-place division to avoid creating a new tensor.
logits
.
div_
(
t
.
unsqueeze
(
dim
=
1
))
logits
.
div_
(
sampling_tensors
.
temperatures
.
unsqueeze_
(
dim
=
1
))
# Apply top-p and top-k truncation.
if
do_top_p_top_k
:
top_ps
,
top_ks
,
min_ps
=
_get_top_p_top_k_min_p
(
logits
=
_apply_top_p_top_k
(
logits
,
sampling_tensors
.
top_ps
,
sampling_metadata
,
self
.
vocab_size
)
sampling_tensors
.
top_ks
)
assert
len
(
top_ps
)
==
len
(
top_ks
)
==
logits
.
shape
[
0
]
do_top_p
=
any
(
p
<
1.0
-
_SAMPLING_EPS
for
p
in
top_ps
)
do_top_k
=
any
(
k
!=
self
.
vocab_size
for
k
in
top_ks
)
if
do_top_p
or
do_top_k
:
logits
=
_apply_top_p_top_k
(
logits
,
top_ps
,
top_ks
)
do_min_p
=
any
(
mp
>
_SAMPLING_EPS
for
mp
in
min_ps
)
if
do_min_p
:
if
do_min_p
:
logits
=
_apply_min_p
(
logits
,
min_ps
)
logits
=
_apply_min_p
(
logits
,
sampling_tensors
.
min_ps
)
# We use float32 for probabilities and log probabilities.
# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
# Compute the probabilities.
...
@@ -120,32 +117,6 @@ def _prune_hidden_states(
...
@@ -120,32 +117,6 @@ def _prune_hidden_states(
sampling_metadata
.
selected_token_indices
)
sampling_metadata
.
selected_token_indices
)
def
_get_penalties
(
sampling_metadata
:
SamplingMetadata
)
->
Tuple
[
List
[
float
],
List
[
float
],
List
[
float
]]:
# Collect the presence and frequency penalties.
presence_penalties
:
List
[
float
]
=
[]
frequency_penalties
:
List
[
float
]
=
[]
repetition_penalties
:
List
[
float
]
=
[]
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
p
=
sampling_params
.
presence_penalty
f
=
sampling_params
.
frequency_penalty
r
=
sampling_params
.
repetition_penalty
if
(
i
<
sampling_metadata
.
num_prompts
and
sampling_params
.
prompt_logprobs
is
not
None
):
# NOTE: We do not apply presence and frequency penalties for the
# prompt token positions where we don't sample new tokens.
prompt_len
=
sampling_metadata
.
prompt_lens
[
i
]
presence_penalties
+=
[
0
]
*
(
prompt_len
-
1
)
frequency_penalties
+=
[
0
]
*
(
prompt_len
-
1
)
repetition_penalties
+=
[
1
]
*
(
prompt_len
-
1
)
presence_penalties
+=
[
p
]
*
len
(
seq_ids
)
frequency_penalties
+=
[
f
]
*
len
(
seq_ids
)
repetition_penalties
+=
[
r
]
*
len
(
seq_ids
)
return
presence_penalties
,
frequency_penalties
,
repetition_penalties
def
_get_prompt_and_output_tokens
(
def
_get_prompt_and_output_tokens
(
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Tuple
[
List
[
List
[
int
]],
List
[
List
[
int
]]]:
)
->
Tuple
[
List
[
List
[
int
]],
List
[
List
[
int
]]]:
...
@@ -168,25 +139,16 @@ def _get_prompt_and_output_tokens(
...
@@ -168,25 +139,16 @@ def _get_prompt_and_output_tokens(
def
_get_bin_counts_and_mask
(
def
_get_bin_counts_and_mask
(
logits
:
torch
.
Tensor
,
tokens
:
torch
.
Tensor
,
tokens
:
List
[
List
[
int
]],
vocab_size
:
int
,
vocab_size
:
int
,
num_seqs
:
int
,
num_seqs
:
int
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
max_len
=
max
(
len
(
tokens
)
for
tokens
in
tokens
)
padded_tokens
=
[
tokens
+
[
vocab_size
]
*
(
max_len
-
len
(
tokens
))
for
tokens
in
tokens
]
tokens_tensor
=
torch
.
tensor
(
padded_tokens
,
dtype
=
torch
.
long
,
device
=
logits
.
device
)
# Compute the bin counts for the tokens.
# Compute the bin counts for the tokens.
# vocab_size + 1 for padding.
# vocab_size + 1 for padding.
bin_counts
=
torch
.
zeros
((
num_seqs
,
vocab_size
+
1
),
bin_counts
=
torch
.
zeros
((
num_seqs
,
vocab_size
+
1
),
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
logit
s
.
device
)
device
=
token
s
.
device
)
bin_counts
.
scatter_add_
(
1
,
tokens
_tensor
,
torch
.
ones_like
(
tokens
_tensor
))
bin_counts
.
scatter_add_
(
1
,
tokens
,
torch
.
ones_like
(
tokens
))
bin_counts
=
bin_counts
[:,
:
vocab_size
]
bin_counts
=
bin_counts
[:,
:
vocab_size
]
mask
=
bin_counts
>
0
mask
=
bin_counts
>
0
...
@@ -217,45 +179,16 @@ def _apply_logits_processors(
...
@@ -217,45 +179,16 @@ def _apply_logits_processors(
return
logits
return
logits
def
_apply_penalties
(
def
_apply_penalties
(
logits
:
torch
.
Tensor
,
prompt_tokens_tensor
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
output_tokens_tensor
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
presence_penalties
:
torch
.
Tensor
,
presence_penalties
:
List
[
float
],
frequency_penalties
:
torch
.
Tensor
,
frequency_penalties
:
List
[
float
],
repetition_penalties
:
torch
.
Tensor
)
->
torch
.
Tensor
:
repetition_penalties
:
List
[
float
],
)
->
torch
.
Tensor
:
num_seqs
,
vocab_size
=
logits
.
shape
num_seqs
,
vocab_size
=
logits
.
shape
for
i
in
range
(
num_seqs
):
_
,
prompt_mask
=
_get_bin_counts_and_mask
(
prompt_tokens_tensor
,
vocab_size
,
p
=
presence_penalties
[
i
]
num_seqs
)
f
=
frequency_penalties
[
i
]
r
=
repetition_penalties
[
i
]
if
abs
(
p
)
<
_SAMPLING_EPS
and
abs
(
f
)
<
_SAMPLING_EPS
and
abs
(
r
-
1.0
)
<
_SAMPLING_EPS
:
continue
break
else
:
# Return early if all sequences have zero penalties.
return
logits
prompt_tokens
,
output_tokens
=
(
_get_prompt_and_output_tokens
(
sampling_metadata
))
assert
len
(
prompt_tokens
)
==
logits
.
shape
[
0
]
assert
len
(
output_tokens
)
==
logits
.
shape
[
0
]
prompt_bin_counts
,
prompt_mask
=
_get_bin_counts_and_mask
(
logits
,
prompt_tokens
,
vocab_size
,
num_seqs
)
output_bin_counts
,
output_mask
=
_get_bin_counts_and_mask
(
output_bin_counts
,
output_mask
=
_get_bin_counts_and_mask
(
logits
,
output_tokens
,
vocab_size
,
num_seqs
)
output_tokens_tensor
,
vocab_size
,
num_seqs
)
repetition_penalties
=
torch
.
tensor
(
repetition_penalties
,
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
frequency_penalties
=
torch
.
tensor
(
frequency_penalties
,
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
presence_penalties
=
torch
.
tensor
(
presence_penalties
,
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
repetition_penalties
=
repetition_penalties
[:,
None
].
repeat
(
1
,
vocab_size
)
repetition_penalties
=
repetition_penalties
[:,
None
].
repeat
(
1
,
vocab_size
)
repetition_penalties
[
~
(
prompt_mask
|
output_mask
)]
=
1.0
repetition_penalties
[
~
(
prompt_mask
|
output_mask
)]
=
1.0
...
@@ -264,109 +197,65 @@ def _apply_penalties(
...
@@ -264,109 +197,65 @@ def _apply_penalties(
# We follow the definition in OpenAI API.
# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits
-=
frequency_penalties
.
unsqueeze
(
dim
=
1
)
*
output_bin_counts
logits
-=
frequency_penalties
.
unsqueeze
_
(
dim
=
1
)
*
output_bin_counts
logits
-=
presence_penalties
.
unsqueeze
(
dim
=
1
)
*
output_mask
logits
-=
presence_penalties
.
unsqueeze
_
(
dim
=
1
)
*
output_mask
return
logits
return
logits
def
_get_temperatures
(
sampling_metadata
:
SamplingMetadata
)
->
List
[
float
]:
# Collect the temperatures for the logits.
temperatures
:
List
[
float
]
=
[]
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
temperature
=
sampling_params
.
temperature
if
temperature
<
_SAMPLING_EPS
:
# NOTE: Zero temperature means deterministic sampling
# (i.e., greedy sampling or beam search).
# Set the temperature to 1 to avoid division by zero.
temperature
=
1.0
if
(
i
<
sampling_metadata
.
num_prompts
and
sampling_params
.
prompt_logprobs
is
not
None
):
prompt_len
=
sampling_metadata
.
prompt_lens
[
i
]
temperatures
+=
[
temperature
]
*
(
prompt_len
-
1
)
temperatures
+=
[
temperature
]
*
len
(
seq_ids
)
return
temperatures
def
_get_top_p_top_k_min_p
(
sampling_metadata
:
SamplingMetadata
,
vocab_size
:
int
,
)
->
Tuple
[
List
[
float
],
List
[
int
],
List
[
float
]]:
top_ps
:
List
[
float
]
=
[]
top_ks
:
List
[
int
]
=
[]
min_ps
:
List
[
float
]
=
[]
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
top_p
=
sampling_params
.
top_p
min_p
=
sampling_params
.
min_p
# k should not be greater than the vocab size.
top_k
=
min
(
sampling_params
.
top_k
,
vocab_size
)
# k=-1 means no truncation.
top_k
=
vocab_size
if
top_k
==
-
1
else
top_k
if
(
i
<
sampling_metadata
.
num_prompts
and
sampling_params
.
prompt_logprobs
is
not
None
):
prompt_len
=
sampling_metadata
.
prompt_lens
[
i
]
top_ps
+=
[
top_p
]
*
(
prompt_len
-
1
)
top_ks
+=
[
top_k
]
*
(
prompt_len
-
1
)
min_ps
+=
[
min_p
]
*
(
prompt_len
-
1
)
top_ps
+=
[
top_p
]
*
len
(
seq_ids
)
top_ks
+=
[
top_k
]
*
len
(
seq_ids
)
min_ps
+=
[
min_p
]
*
len
(
seq_ids
)
return
top_ps
,
top_ks
,
min_ps
def
_apply_top_p_top_k
(
def
_apply_top_p_top_k
(
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
top_ps
:
List
[
float
]
,
p
:
torch
.
Tensor
,
top_ks
:
List
[
int
]
,
k
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
p
=
torch
.
tensor
(
top_ps
,
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
k
=
torch
.
tensor
(
top_ks
,
dtype
=
torch
.
int
,
device
=
logits
.
device
)
logits_sort
,
logits_idx
=
logits
.
sort
(
dim
=-
1
,
descending
=
True
)
logits_sort
,
logits_idx
=
logits
.
sort
(
dim
=-
1
,
descending
=
True
)
# Apply top-p.
# Apply top-p.
probs_sort
=
logits_sort
.
softmax
(
dim
=-
1
)
probs_sort
=
logits_sort
.
softmax
(
dim
=-
1
)
probs_sum
=
probs_sort
.
cumsum
(
dim
=-
1
)
probs_sum
=
probs_sort
.
cumsum
(
dim
=-
1
).
sub_
(
probs_sort
)
top_p_mask
=
(
probs_sum
-
probs_sort
)
>
p
.
unsqueeze
(
dim
=
1
)
top_p_mask
=
probs_sum
>
p
.
unsqueeze_
(
dim
=
1
)
logits_sort
[
top_p_mask
]
=
-
float
(
"inf"
)
# Apply top-k.
# Apply top-k.
# Create a mask for the top-k elements.
# Create a mask for the top-k elements.
top_k_mask
=
torch
.
arange
(
logits_idx
.
shape
[
-
1
],
device
=
logits_idx
.
device
)
top_k_mask
=
torch
.
arange
(
logits_idx
.
shape
[
-
1
],
device
=
logits_idx
.
device
)
top_k_mask
=
top_k_mask
.
expand
(
logits_idx
.
shape
[
0
],
-
1
)
top_k_mask
=
top_k_mask
.
expand
(
logits_idx
.
shape
[
0
],
-
1
)
top_k_mask
=
top_k_mask
>=
k
.
unsqueeze
(
dim
=
1
)
top_k_mask
=
top_k_mask
>=
k
.
unsqueeze_
(
dim
=
1
)
logits_sort
[
top_k_mask
]
=
-
float
(
"inf"
)
# Final mask.
mask
=
(
top_p_mask
|
top_k_mask
)
logits_sort
.
masked_fill_
(
mask
,
-
float
(
"inf"
))
# Re-sort the probabilities.
# Re-sort the probabilities.
logits
=
torch
.
gather
(
logits_sort
,
src
=
torch
.
arange
(
logits_idx
.
shape
[
-
1
],
dim
=-
1
,
device
=
logits_idx
.
device
).
expand_as
(
logits_idx
)
index
=
torch
.
argsort
(
logits_idx
,
dim
=-
1
))
logits_idx_inv
=
torch
.
empty_like
(
logits_idx
).
scatter_
(
dim
=-
1
,
index
=
logits_idx
,
src
=
src
)
logits
=
torch
.
gather
(
logits_sort
,
dim
=-
1
,
index
=
logits_idx_inv
)
return
logits
return
logits
def
_apply_min_p
(
def
_apply_min_p
(
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
min_p
s
:
List
[
float
]
,
min_p
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Adapted from
Adapted from
https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
"""
"""
min_p
=
torch
.
tensor
(
min_ps
,
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
top_probs
,
_
=
probs
.
max
(
dim
=-
1
,
keepdim
=
True
)
top_probs
,
_
=
probs
.
max
(
dim
=-
1
,
keepdim
=
True
)
scaled_min_p
=
min_p
.
unsqueeze
(
dim
=
1
)
*
top_probs
scaled_min_p
=
min_p
.
unsqueeze
_
(
dim
=
1
)
*
top_probs
tokens_to_remove
=
probs
<
scaled_min_p
tokens_to_remove
=
probs
<
scaled_min_p
logits
=
logits
.
masked_fill
(
tokens_to_remove
,
-
float
(
"inf"
))
logits
=
logits
.
masked_fill
_
(
tokens_to_remove
,
-
float
(
"inf"
))
return
logits
return
logits
def
_greedy_sample
(
def
_greedy_sample
(
selected_seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]],
selected_seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]],
logprob
s
:
torch
.
Tensor
,
sample
s
:
torch
.
Tensor
,
)
->
List
[
Tuple
[
List
[
int
],
List
[
int
]]]:
)
->
List
[
Tuple
[
List
[
int
],
List
[
int
]]]:
samples
=
torch
.
argmax
(
logprobs
,
dim
=-
1
).
cpu
()
samples
=
samples
.
tolist
()
sample_idx
=
0
sample_idx
=
0
results
=
[]
results
=
[]
for
seq_group
in
selected_seq_groups
:
for
seq_group
in
selected_seq_groups
:
...
@@ -375,27 +264,19 @@ def _greedy_sample(
...
@@ -375,27 +264,19 @@ def _greedy_sample(
assert
num_parent_seqs
==
1
,
(
assert
num_parent_seqs
==
1
,
(
"Greedy sampling should have only one seq."
)
"Greedy sampling should have only one seq."
)
parent_ids
=
list
(
range
(
num_parent_seqs
))
parent_ids
=
list
(
range
(
num_parent_seqs
))
next_token_ids
=
[
samples
[
sample_idx
]
.
item
()
]
next_token_ids
=
[
samples
[
sample_idx
]]
results
.
append
((
next_token_ids
,
parent_ids
))
results
.
append
((
next_token_ids
,
parent_ids
))
sample_idx
+=
num_parent_seqs
sample_idx
+=
num_parent_seqs
assert
sample_idx
==
logprobs
.
size
(
0
)
return
results
return
results
def
_random_sample
(
def
_random_sample
(
selected_seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]],
selected_seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]],
is_prompts
:
List
[
bool
],
is_prompts
:
List
[
bool
],
prob
s
:
torch
.
Tensor
,
random_sample
s
:
torch
.
Tensor
,
)
->
List
[
Tuple
[
List
[
int
],
List
[
int
]]]:
)
->
List
[
Tuple
[
List
[
int
],
List
[
int
]]]:
# Find the maximum best_of value of the prompt phase requests.
# Find the maximum best_of value of the prompt phase requests.
max_best_of
=
1
random_samples
=
random_samples
.
cpu
()
for
seq_group
,
is_prompt
in
zip
(
selected_seq_groups
,
is_prompts
):
if
is_prompt
:
seq_ids
,
sampling_params
=
seq_group
max_best_of
=
max
(
max_best_of
,
sampling_params
.
best_of
)
random_samples
=
torch
.
multinomial
(
probs
,
num_samples
=
max_best_of
,
replacement
=
True
).
cpu
()
sample_idx
=
0
sample_idx
=
0
results
=
[]
results
=
[]
for
seq_group
,
is_prompt
in
zip
(
selected_seq_groups
,
is_prompts
):
for
seq_group
,
is_prompt
in
zip
(
selected_seq_groups
,
is_prompts
):
...
@@ -403,8 +284,6 @@ def _random_sample(
...
@@ -403,8 +284,6 @@ def _random_sample(
num_parent_seqs
=
len
(
seq_ids
)
num_parent_seqs
=
len
(
seq_ids
)
if
is_prompt
:
if
is_prompt
:
# Prompt phase.
# Prompt phase.
assert
num_parent_seqs
==
1
,
(
"Prompt input should have only one seq."
)
parent_ids
=
[
0
]
*
sampling_params
.
best_of
parent_ids
=
[
0
]
*
sampling_params
.
best_of
next_token_ids
=
random_samples
[
next_token_ids
=
random_samples
[
sample_idx
,
:
sampling_params
.
best_of
].
tolist
()
sample_idx
,
:
sampling_params
.
best_of
].
tolist
()
...
@@ -415,7 +294,6 @@ def _random_sample(
...
@@ -415,7 +294,6 @@ def _random_sample(
num_parent_seqs
,
0
].
tolist
()
num_parent_seqs
,
0
].
tolist
()
results
.
append
((
next_token_ids
,
parent_ids
))
results
.
append
((
next_token_ids
,
parent_ids
))
sample_idx
+=
num_parent_seqs
sample_idx
+=
num_parent_seqs
assert
sample_idx
==
probs
.
size
(
0
)
return
results
return
results
...
@@ -472,6 +350,28 @@ def _beam_search_sample(
...
@@ -472,6 +350,28 @@ def _beam_search_sample(
return
results
return
results
# torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead.
# Note that we always sample with replacement.
# probs will be modified in place, but this is fine, as we pass
# in a copy already.
def
_multinomial
(
probs
:
torch
.
Tensor
,
num_samples
:
int
,
):
if
num_samples
>
1
:
# This is equivalent to torch.repeat_interleaved (which also
# forces a GPU<->CPU sync).
# This allows us to do sampling with replacement by creating
# num_samples copies of each row in the tensor, and then
# batch sampling the resulting tensor.
probs
=
probs
[:,
None
,
:].
expand
(
probs
.
shape
[
0
],
num_samples
,
probs
.
shape
[
1
]).
contiguous
().
view
(
-
1
,
probs
.
shape
[
1
])
q
=
torch
.
empty_like
(
probs
).
exponential_
(
1
)
return
probs
.
div_
(
q
).
argmax
(
dim
=
1
).
view
(
-
1
,
num_samples
)
def
_sample
(
def
_sample
(
probs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
...
@@ -485,28 +385,51 @@ def _sample(
...
@@ -485,28 +385,51 @@ def _sample(
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
sample_results_dict
:
Dict
[
int
,
Tuple
[
List
[
int
],
List
[
int
]]]
=
{}
sample_results_dict
:
Dict
[
int
,
Tuple
[
List
[
int
],
List
[
int
]]]
=
{}
sample_metadata
=
{}
# Counterintiutively, having two loops here is actually faster.
# The first loop can run without waiting on GPU<->CPU sync.
for
sampling_type
in
SamplingType
:
for
sampling_type
in
SamplingType
:
seq_group_ids
=
categorized_seq_group_ids
[
sampling_type
]
seq_groups
=
[
sampling_metadata
.
seq_groups
[
i
]
for
i
in
seq_group_ids
]
is_prompts
=
[
i
<
sampling_metadata
.
num_prompts
for
i
in
seq_group_ids
]
sample_indices
=
categorized_sample_indices
[
sampling_type
]
sample_indices
=
categorized_sample_indices
[
sampling_type
]
num_tokens
=
len
(
sample_indices
)
num_tokens
=
len
(
sample_indices
)
if
num_tokens
==
0
:
if
num_tokens
==
0
:
continue
continue
seq_group_ids
=
categorized_seq_group_ids
[
sampling_type
]
seq_groups
=
[
sampling_metadata
.
seq_groups
[
i
]
for
i
in
seq_group_ids
]
is_prompts
=
[
i
<
sampling_metadata
.
num_prompts
for
i
in
seq_group_ids
]
sample_metadata
[
sampling_type
]
=
(
seq_group_ids
,
seq_groups
,
is_prompts
,
sample_indices
)
if
sampling_type
==
SamplingType
.
GREEDY
:
if
sampling_type
==
SamplingType
.
GREEDY
:
category_logprobs
=
logprobs
[
sample_indices
]
greedy_samples
=
torch
.
argmax
(
logprobs
[
sample_indices
],
dim
=-
1
)
sample_results
=
_greedy_sample
(
seq_groups
,
category_logprobs
)
elif
sampling_type
==
SamplingType
.
RANDOM
:
max_best_of
=
1
for
seq_group
,
is_prompt
in
zip
(
seq_groups
,
is_prompts
):
if
is_prompt
:
_
,
sampling_params
=
seq_group
max_best_of
=
max
(
max_best_of
,
sampling_params
.
best_of
)
multinomial_samples
=
_multinomial
(
probs
[
sample_indices
],
max_best_of
)
elif
sampling_type
==
SamplingType
.
BEAM
:
beam_search_logprobs
=
logprobs
[
sample_indices
]
else
:
raise
ValueError
(
f
"Unsupported sampling type:
{
sampling_type
}
"
)
# GPU<->CPU sync happens in the loop below.
for
sampling_type
in
SamplingType
:
if
sampling_type
not
in
sample_metadata
:
continue
seq_group_ids
,
seq_groups
,
is_prompts
,
sample_indices
=
sample_metadata
[
sampling_type
]
if
sampling_type
==
SamplingType
.
GREEDY
:
sample_results
=
_greedy_sample
(
seq_groups
,
greedy_samples
)
elif
sampling_type
==
SamplingType
.
RANDOM
:
elif
sampling_type
==
SamplingType
.
RANDOM
:
category_probs
=
probs
[
sample_indices
]
sample_results
=
_random_sample
(
seq_groups
,
is_prompts
,
sample_results
=
_random_sample
(
seq_groups
,
is_prompts
,
category_prob
s
)
multinomial_sample
s
)
elif
sampling_type
==
SamplingType
.
BEAM
:
elif
sampling_type
==
SamplingType
.
BEAM
:
category_logprobs
=
logprobs
[
sample_indices
]
sample_results
=
_beam_search_sample
(
seq_groups
,
is_prompts
,
sample_results
=
_beam_search_sample
(
seq_groups
,
is_prompts
,
sampling_metadata
.
seq_data
,
sampling_metadata
.
seq_data
,
category_logprobs
)
beam_search_logprobs
)
else
:
raise
ValueError
(
f
"Unsupported sampling type:
{
sampling_type
}
"
)
sample_results_dict
.
update
(
zip
(
seq_group_ids
,
sample_results
))
sample_results_dict
.
update
(
zip
(
seq_group_ids
,
sample_results
))
sample_results
=
[
sample_results
=
[
...
@@ -557,7 +480,7 @@ def _get_logprobs(
...
@@ -557,7 +480,7 @@ def _get_logprobs(
batched_logprobs_query_result
=
logprobs
[[
batched_logprobs_query_result
=
logprobs
[[
batched_logprobs_query_seq_indices
,
batched_logprobs_query_seq_indices
,
batched_logprobs_query_token_indices
batched_logprobs_query_token_indices
]]
.
cpu
()
]]
# Batched query for logprobs of topk tokens
# Batched query for logprobs of topk tokens
if
largest_num_logprobs
>
0
:
if
largest_num_logprobs
>
0
:
...
@@ -569,6 +492,8 @@ def _get_logprobs(
...
@@ -569,6 +492,8 @@ def _get_logprobs(
else
:
else
:
top_logprobs
,
top_token_ids
=
None
,
None
top_logprobs
,
top_token_ids
=
None
,
None
batched_logprobs_query_result
=
batched_logprobs_query_result
.
cpu
()
# Gather results
# Gather results
result_prompt_logprobs
:
List
[
Optional
[
PromptLogprobs
]]
=
[]
result_prompt_logprobs
:
List
[
Optional
[
PromptLogprobs
]]
=
[]
result_sample_logprobs
:
List
[
SampleLogprobs
]
=
[]
result_sample_logprobs
:
List
[
SampleLogprobs
]
=
[]
...
...
vllm/model_executor/sampling_metadata.py
View file @
a7347d9a
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
,
Tuple
import
torch
import
torch
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
SequenceData
from
vllm.sequence
import
SequenceData
from
vllm.utils
import
in_wsl
_SAMPLING_EPS
=
1e-5
class
SamplingMetadata
:
class
SamplingMetadata
:
...
@@ -41,3 +45,186 @@ class SamplingMetadata:
...
@@ -41,3 +45,186 @@ class SamplingMetadata:
f
"prompt_lens=
{
self
.
prompt_lens
}
, "
f
"prompt_lens=
{
self
.
prompt_lens
}
, "
f
"selected_token_indices=
{
self
.
selected_token_indices
}
, "
f
"selected_token_indices=
{
self
.
selected_token_indices
}
, "
f
"categorized_sample_indices=
{
self
.
categorized_sample_indices
}
)"
)
f
"categorized_sample_indices=
{
self
.
categorized_sample_indices
}
)"
)
@
dataclass
class
SamplingTensors
:
"""Tensors for sampling."""
temperatures
:
torch
.
Tensor
top_ps
:
torch
.
Tensor
top_ks
:
torch
.
Tensor
min_ps
:
torch
.
Tensor
presence_penalties
:
torch
.
Tensor
frequency_penalties
:
torch
.
Tensor
repetition_penalties
:
torch
.
Tensor
prompt_tokens
:
torch
.
Tensor
output_tokens
:
torch
.
Tensor
@
classmethod
def
from_sampling_metadata
(
cls
,
sampling_metadata
:
"SamplingMetadata"
,
vocab_size
:
int
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
)
->
Tuple
[
"SamplingTensors"
,
bool
,
bool
,
bool
]:
prompt_tokens
:
List
[
List
[
int
]]
=
[]
output_tokens
:
List
[
List
[
int
]]
=
[]
top_ks
:
List
[
int
]
=
[]
temperatures
:
List
[
float
]
=
[]
top_ps
:
List
[
float
]
=
[]
min_ps
:
List
[
float
]
=
[]
presence_penalties
:
List
[
float
]
=
[]
frequency_penalties
:
List
[
float
]
=
[]
repetition_penalties
:
List
[
float
]
=
[]
do_penalties
=
False
do_top_p_top_k
=
False
do_min_p
=
False
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
temperature
=
sampling_params
.
temperature
p
=
sampling_params
.
presence_penalty
f
=
sampling_params
.
frequency_penalty
r
=
sampling_params
.
repetition_penalty
top_p
=
sampling_params
.
top_p
min_p
=
sampling_params
.
min_p
# k should not be greater than the vocab size.
top_k
=
min
(
sampling_params
.
top_k
,
vocab_size
)
top_k
=
vocab_size
if
top_k
==
-
1
else
top_k
if
temperature
<
_SAMPLING_EPS
:
# NOTE: Zero temperature means deterministic sampling
# (i.e., greedy sampling or beam search).
# Set the temperature to 1 to avoid division by zero.
temperature
=
1.0
if
not
do_top_p_top_k
and
(
top_p
<
1.0
-
_SAMPLING_EPS
or
top_k
!=
vocab_size
):
do_top_p_top_k
=
True
if
not
do_min_p
and
min_p
>
_SAMPLING_EPS
:
do_min_p
=
True
if
not
do_penalties
and
(
abs
(
p
)
>=
_SAMPLING_EPS
or
abs
(
f
)
>=
_SAMPLING_EPS
or
abs
(
r
-
1.0
)
>=
_SAMPLING_EPS
):
do_penalties
=
True
if
(
i
<
sampling_metadata
.
num_prompts
and
sampling_params
.
prompt_logprobs
is
not
None
):
# For tokens in the prompt that we only need to get their logprobs
prompt_len
=
sampling_metadata
.
prompt_lens
[
i
]
temperatures
+=
[
temperature
]
*
(
prompt_len
-
1
)
top_ps
+=
[
top_p
]
*
(
prompt_len
-
1
)
top_ks
+=
[
top_k
]
*
(
prompt_len
-
1
)
min_ps
+=
[
min_p
]
*
(
prompt_len
-
1
)
presence_penalties
+=
[
0
]
*
(
prompt_len
-
1
)
frequency_penalties
+=
[
0
]
*
(
prompt_len
-
1
)
repetition_penalties
+=
[
1
]
*
(
prompt_len
-
1
)
prompt_tokens
.
extend
([]
for
_
in
range
(
prompt_len
-
1
))
output_tokens
.
extend
([]
for
_
in
range
(
prompt_len
-
1
))
for
seq_id
in
seq_ids
:
seq_data
=
sampling_metadata
.
seq_data
[
seq_id
]
prompt_tokens
.
append
(
seq_data
.
prompt_token_ids
)
output_tokens
.
append
(
seq_data
.
output_token_ids
)
temperatures
+=
[
temperature
]
*
len
(
seq_ids
)
top_ps
+=
[
top_p
]
*
len
(
seq_ids
)
top_ks
+=
[
top_k
]
*
len
(
seq_ids
)
min_ps
+=
[
min_p
]
*
len
(
seq_ids
)
presence_penalties
+=
[
p
]
*
len
(
seq_ids
)
frequency_penalties
+=
[
f
]
*
len
(
seq_ids
)
repetition_penalties
+=
[
r
]
*
len
(
seq_ids
)
sampling_tensors
=
SamplingTensors
.
from_lists
(
temperatures
,
top_ps
,
top_ks
,
min_ps
,
presence_penalties
,
frequency_penalties
,
repetition_penalties
,
prompt_tokens
,
output_tokens
,
vocab_size
,
device
,
dtype
)
return
(
sampling_tensors
,
do_penalties
,
do_top_p_top_k
,
do_min_p
)
@
classmethod
def
from_lists
(
cls
,
temperatures
:
List
[
float
],
top_ps
:
List
[
float
],
top_ks
:
List
[
int
],
min_ps
:
List
[
float
],
presence_penalties
:
List
[
float
],
frequency_penalties
:
List
[
float
],
repetition_penalties
:
List
[
float
],
prompt_tokens
:
List
[
List
[
int
]],
output_tokens
:
List
[
List
[
int
]],
vocab_size
:
int
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
)
->
"SamplingTensors"
:
# Note that the performance will be very bad without
# pinned memory.
pin_memory
=
not
in_wsl
()
prompt_max_len
=
max
(
len
(
tokens
)
for
tokens
in
prompt_tokens
)
prompt_padded_tokens
=
[
tokens
+
[
vocab_size
]
*
(
prompt_max_len
-
len
(
tokens
))
for
tokens
in
prompt_tokens
]
output_max_len
=
max
(
len
(
tokens
)
for
tokens
in
output_tokens
)
output_padded_tokens
=
[
tokens
+
[
vocab_size
]
*
(
output_max_len
-
len
(
tokens
))
for
tokens
in
output_tokens
]
temperatures_t
=
torch
.
tensor
(
temperatures
,
device
=
"cpu"
,
dtype
=
dtype
,
pin_memory
=
pin_memory
,
)
top_ps_t
=
torch
.
tensor
(
top_ps
,
device
=
"cpu"
,
dtype
=
dtype
,
pin_memory
=
pin_memory
,
)
min_ps_t
=
torch
.
tensor
(
min_ps
,
device
=
"cpu"
,
dtype
=
dtype
,
pin_memory
=
pin_memory
,
)
presence_penalties_t
=
torch
.
tensor
(
presence_penalties
,
device
=
"cpu"
,
dtype
=
dtype
,
pin_memory
=
pin_memory
,
)
frequency_penalties_t
=
torch
.
tensor
(
frequency_penalties
,
device
=
"cpu"
,
dtype
=
dtype
,
pin_memory
=
pin_memory
,
)
repetition_penalties_t
=
torch
.
tensor
(
repetition_penalties
,
device
=
"cpu"
,
dtype
=
dtype
,
pin_memory
=
pin_memory
,
)
top_ks_t
=
torch
.
tensor
(
top_ks
,
device
=
"cpu"
,
dtype
=
torch
.
int
,
pin_memory
=
pin_memory
,
)
prompt_tensor
=
torch
.
tensor
(
prompt_padded_tokens
,
device
=
"cpu"
,
dtype
=
torch
.
long
,
pin_memory
=
pin_memory
,
)
output_tensor
=
torch
.
tensor
(
output_padded_tokens
,
device
=
"cpu"
,
dtype
=
torch
.
long
,
pin_memory
=
pin_memory
,
)
# Because the memory is pinned, we can do non-blocking
# transfer to device.
return
cls
(
temperatures
=
temperatures_t
.
to
(
device
=
device
,
non_blocking
=
True
),
top_ps
=
top_ps_t
.
to
(
device
=
device
,
non_blocking
=
True
),
top_ks
=
top_ks_t
.
to
(
device
=
device
,
non_blocking
=
True
),
min_ps
=
min_ps_t
.
to
(
device
=
device
,
non_blocking
=
True
),
presence_penalties
=
presence_penalties_t
.
to
(
device
=
device
,
non_blocking
=
True
),
frequency_penalties
=
frequency_penalties_t
.
to
(
device
=
device
,
non_blocking
=
True
),
repetition_penalties
=
repetition_penalties_t
.
to
(
device
=
device
,
non_blocking
=
True
),
prompt_tokens
=
prompt_tensor
.
to
(
device
=
device
,
non_blocking
=
True
),
output_tokens
=
output_tensor
.
to
(
device
=
device
,
non_blocking
=
True
),
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment