Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
371d04d3
Unverified
Commit
371d04d3
authored
Dec 27, 2024
by
Woosuk Kwon
Committed by
GitHub
Dec 27, 2024
Browse files
[V1] Use FlashInfer Sampling Kernel for Top-P & Top-K Sampling (#11394)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
0c0c2015
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
355 additions
and
190 deletions
+355
-190
tests/v1/sample/test_sampler.py
tests/v1/sample/test_sampler.py
+22
-32
vllm/envs.py
vllm/envs.py
+3
-2
vllm/v1/sample/ops/__init__.py
vllm/v1/sample/ops/__init__.py
+0
-0
vllm/v1/sample/ops/penalties.py
vllm/v1/sample/ops/penalties.py
+57
-0
vllm/v1/sample/ops/topk_topp_sampler.py
vllm/v1/sample/ops/topk_topp_sampler.py
+201
-0
vllm/v1/sample/sampler.py
vllm/v1/sample/sampler.py
+72
-156
No files found.
tests/v1/sample/test_sampler.py
View file @
371d04d3
...
...
@@ -68,7 +68,7 @@ def _create_default_sampling_metadata(
no_top_p
=
True
,
no_top_k
=
True
,
generators
=
{},
max_num_logprobs
=
VOCAB_SIZE
,
max_num_logprobs
=
0
,
prompt_token_ids
=
_create_prompt_tokens_tensor
(
prompt_token_ids
,
vocab_size
,
device
),
output_token_ids
=
output_token_ids
,
...
...
@@ -169,20 +169,14 @@ def test_sampler_min_tokens_penalty(device: str, batch_size: int):
sampling_metadata
.
min_tokens
=
min_tokens
sampling_metadata
.
stop_token_ids
=
stop_token_ids
sampler
=
Sampler
()
sampler_output
=
sampler
(
fake_logits
,
sampling_metadata
)
logits
=
sampler
.
apply_penalties
(
fake_logits
,
sampling_metadata
)
logits
=
logits
.
cpu
()
for
batch_idx
in
range
(
batch_size
):
for
vocab
in
range
(
VOCAB_SIZE
):
# Verify that the logprobs for stop token ids is set
# to -inf.
logprob_index
=
torch
.
where
(
sampler_output
.
logprob_token_ids
[
batch_idx
]
==
vocab
)[
0
].
item
()
if
vocab
in
stop_token_ids
[
batch_idx
]:
assert
sampler_output
.
logprobs
[
batch_idx
][
logprob_index
]
==
-
float
(
"inf"
)
for
token_id
in
range
(
VOCAB_SIZE
):
if
token_id
in
stop_token_ids
[
batch_idx
]:
assert
logits
[
batch_idx
][
token_id
]
==
-
float
(
"inf"
)
else
:
assert
sampler_output
.
logprobs
[
batch_idx
][
logprob_index
]
!=
-
float
(
"inf"
)
assert
logits
[
batch_idx
][
token_id
]
!=
-
float
(
"inf"
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
...
...
@@ -205,18 +199,14 @@ def test_sampler_presence_penalty(device: str, batch_size: int,
batch_size
,
presence_penalty
,
torch
.
device
(
device
))
sampling_metadata
.
no_penalties
=
False
sampler
=
Sampler
()
sampler_output
=
sampler
(
fake_logits
,
sampling_metadata
)
logits
=
sampler
.
apply_penalties
(
fake_logits
,
sampling_metadata
)
logits
=
logits
.
cpu
()
for
batch_idx
in
range
(
batch_size
):
# The logprobs in the SamplerOutput are arranged in descending order.
# Since all tokens initially have the same logprobs, the non-penalized
# tokens will appear at the beginning, while the penalized tokens
# will appear at the end of the list.
penalized_token_id
=
sampler_output
.
logprob_token_ids
[
batch_idx
][
VOCAB_SIZE
-
1
]
penalized_log_prod
=
sampler_output
.
logprobs
[
batch_idx
][
VOCAB_SIZE
-
1
]
non_penalized_token_id
=
sampler_output
.
logprob_token_ids
[
batch_idx
][
0
]
non_penalized_log_prod
=
sampler_output
.
logprobs
[
batch_idx
][
0
]
assert
non_penalized_log_prod
>
penalized_log_prod
# Since all tokens initially have the same logits, the non-penalized
# token ID will be the one with the highest logit value, while the
# penalized token ID will be the one with the lowest logit value.
non_penalized_token_id
=
logits
[
batch_idx
].
argmax
().
item
()
penalized_token_id
=
logits
[
batch_idx
].
argmin
().
item
()
if
presence_penalty
>
0
:
# If `presence_penalty` is set to a value greater than 0, it
# indicates a preference for new tokens over those already
...
...
@@ -256,11 +246,11 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
sampling_metadata
.
output_token_ids
=
output_token_ids
sampling_metadata
.
no_penalties
=
False
sampler
=
Sampler
()
sampler_output
=
sampler
(
fake_logits
,
sampling_metadata
)
logits
=
sampler
.
apply_penalties
(
fake_logits
,
sampling_metadata
)
logits
=
logits
.
cpu
()
for
batch_idx
in
range
(
batch_size
):
logprobs_token_ids
=
sampler_output
.
logprob_token_ids
[
batch_idx
]
non_penalized_token_id
=
logprobs_token_ids
[
0
]
penalized_token_id
=
logprobs_token_ids
[
VOCAB_SIZE
-
1
]
non_penalized_token_id
=
logits
[
batch_idx
].
argmax
().
item
()
penalized_token_id
=
logits
[
batch_idx
].
argmin
().
item
()
distinct_sorted_token_ids_in_output
=
\
sorted_token_ids_in_output
[
batch_idx
]
most_frequent_token_id
=
distinct_sorted_token_ids_in_output
[
...
...
@@ -305,11 +295,11 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
batch_size
,
repetition_penalty
,
torch
.
device
(
device
))
sampling_metadata
.
no_penalties
=
False
sampler
=
Sampler
()
sampler_output
=
sampler
(
fake_logits
,
sampling_metadata
)
logits
=
sampler
.
apply_penalties
(
fake_logits
,
sampling_metadata
)
logits
=
logits
.
cpu
()
for
batch_idx
in
range
(
batch_size
):
logprobs_token_ids
=
sampler_output
.
logprob_token_ids
[
batch_idx
]
non_penalized_token_id
=
logprobs_token_ids
[
0
]
penalized_token_id
=
logprobs_token_ids
[
VOCAB_SIZE
-
1
]
non_penalized_token_id
=
logits
[
batch_idx
].
argmax
().
item
()
penalized_token_id
=
logits
[
batch_idx
].
argmin
().
item
()
prompt_tokens
=
sampling_metadata
.
prompt_token_ids
[
batch_idx
][:].
tolist
()
output_tokens
=
sampling_metadata
.
output_token_ids
[
batch_idx
]
...
...
vllm/envs.py
View file @
371d04d3
...
...
@@ -30,7 +30,7 @@ if TYPE_CHECKING:
VLLM_LOGGING_CONFIG_PATH
:
Optional
[
str
]
=
None
VLLM_TRACE_FUNCTION
:
int
=
0
VLLM_ATTENTION_BACKEND
:
Optional
[
str
]
=
None
VLLM_USE_FLASHINFER_SAMPLER
:
bool
=
Fals
e
VLLM_USE_FLASHINFER_SAMPLER
:
Optional
[
bool
]
=
Non
e
VLLM_USE_FLASHINFER_REJECTION_SAMPLER
:
bool
=
False
VLLM_FLASHINFER_FORCE_TENSOR_CORES
:
bool
=
False
VLLM_PP_LAYER_PARTITION
:
Optional
[
str
]
=
None
...
...
@@ -277,7 +277,8 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# If set, vllm will use flashinfer sampler
"VLLM_USE_FLASHINFER_SAMPLER"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_FLASHINFER_SAMPLER"
,
"0"
))),
lambda
:
bool
(
int
(
os
.
environ
[
"VLLM_USE_FLASHINFER_SAMPLER"
]))
if
"VLLM_USE_FLASHINFER_SAMPLER"
in
os
.
environ
else
None
,
# If set, vllm will force flashinfer to use tensor cores;
# otherwise will use heuristic based on model architecture.
...
...
vllm/v1/sample/ops/__init__.py
0 → 100644
View file @
371d04d3
vllm/v1/sample/ops/penalties.py
0 → 100644
View file @
371d04d3
from
typing
import
List
,
Set
,
Tuple
import
torch
from
vllm.model_executor.layers.utils
import
(
apply_penalties
as
_apply_penalties
)
from
vllm.utils
import
is_pin_memory_available
,
make_tensor_with_pad
def
apply_min_token_penalties
(
logits
:
torch
.
Tensor
,
output_token_ids
:
List
[
List
[
int
]],
stop_token_ids
:
List
[
Set
[
int
]],
min_tokens
:
List
[
int
])
->
None
:
"""
Applies minimum token penalty by setting the logits of the stop tokens
to -inf.
"""
min_tokens_logits_to_penalize
:
List
[
Tuple
[
int
,
int
]]
=
[]
for
index
,
min_token
in
enumerate
(
min_tokens
):
if
(
len
(
output_token_ids
[
index
])
<
min_token
):
for
stop_token_id
in
stop_token_ids
[
index
]:
min_tokens_logits_to_penalize
.
append
((
index
,
stop_token_id
))
if
min_tokens_logits_to_penalize
:
logits
[
tuple
(
zip
(
*
min_tokens_logits_to_penalize
))]
=
-
float
(
"inf"
)
def
apply_penalties
(
logits
:
torch
.
Tensor
,
prompt_token_ids
:
torch
.
Tensor
,
presence_penalties
:
torch
.
Tensor
,
frequency_penalties
:
torch
.
Tensor
,
repetition_penalties
:
torch
.
Tensor
,
output_token_ids
:
List
[
List
[
int
]])
->
torch
.
Tensor
:
"""
Applies presence, frequency and repetition penalties to the logits.
"""
_
,
vocab_size
=
logits
.
shape
output_tokens_t
=
_convert_to_tensors
(
output_token_ids
,
vocab_size
,
logits
.
device
)
return
_apply_penalties
(
logits
,
prompt_token_ids
,
output_tokens_t
,
presence_penalties
,
frequency_penalties
,
repetition_penalties
)
def
_convert_to_tensors
(
output_token_ids
:
List
[
List
[
int
]],
vocab_size
:
int
,
device
:
torch
.
device
)
->
torch
.
Tensor
:
"""
Convert the different list data structures to tensors.
"""
output_tokens_tensor
=
make_tensor_with_pad
(
output_token_ids
,
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
pad
=
vocab_size
,
device
=
"cpu"
,
dtype
=
torch
.
int64
,
pin_memory
=
is_pin_memory_available
(),
)
return
output_tokens_tensor
.
to
(
device
,
non_blocking
=
True
)
vllm/v1/sample/ops/topk_topp_sampler.py
0 → 100644
View file @
371d04d3
from
typing
import
Dict
import
torch
import
torch.nn
as
nn
from
vllm
import
envs
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
logger
=
init_logger
(
__name__
)
try
:
import
flashinfer.sampling
is_flashinfer_available
=
True
except
ImportError
:
is_flashinfer_available
=
False
class
TopKTopPSampler
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
if
current_platform
.
is_cuda
:
if
is_flashinfer_available
:
if
envs
.
VLLM_USE_FLASHINFER_SAMPLER
is
not
False
:
# NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
# sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by
# default it is unused). For backward compatibility, we set
# `VLLM_USE_FLASHINFER_SAMPLER` as None by default and
# interpret it differently in V0 and V1 samplers: In V0,
# None means False, while in V1, None means True. This is
# why we use the condition
# `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here.
logger
.
info
(
"Using FlashInfer for top-p & top-k sampling."
)
self
.
forward
=
self
.
forward_cuda
else
:
logger
.
warning
(
"FlashInfer is available, but it is not enabled. "
"Falling back to the PyTorch-native implementation of "
"top-p & top-k sampling. For the best performance, "
"please set VLLM_USE_FLASHINFER_SAMPLER=1."
)
self
.
forward
=
self
.
forward_native
else
:
logger
.
warning
(
"FlashInfer is not available. Falling back to the PyTorch-"
"native implementation of top-p & top-k sampling. For the "
"best performance, please install FalshInfer."
)
self
.
forward
=
self
.
forward_native
else
:
self
.
forward
=
self
.
forward_native
def
forward_native
(
self
,
logits
:
torch
.
Tensor
,
generators
:
Dict
[
int
,
torch
.
Generator
],
no_top_k
:
bool
,
k
:
torch
.
Tensor
,
no_top_p
:
bool
,
p
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""PyTorch-native implementation of top-k and top-p sampling."""
logits
=
apply_top_k_top_p
(
logits
,
no_top_k
,
k
,
no_top_p
,
p
)
probs
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
return
random_sample
(
probs
,
generators
)
def
forward_cuda
(
self
,
logits
:
torch
.
Tensor
,
generators
:
Dict
[
int
,
torch
.
Generator
],
no_top_k
:
bool
,
k
:
torch
.
Tensor
,
no_top_p
:
bool
,
p
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""More optimized implementation for top-k and top-p sampling."""
probs
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
if
no_top_k
and
no_top_p
:
# We prefer `random_sample` over `flashinfer_sample` when sorting is
# not needed. This is because `random_sample` does not require
# CPU-GPU synchronization while `flashinfer_sample` does.
return
random_sample
(
probs
,
generators
)
return
flashinfer_sample
(
probs
,
no_top_k
,
k
,
no_top_p
,
p
,
generators
)
def
apply_top_k_top_p
(
logits
:
torch
.
Tensor
,
no_top_k
:
bool
,
k
:
torch
.
Tensor
,
no_top_p
:
bool
,
p
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""Apply top-k and top-p masks to the logits.
This function sorts the logits tensor, which can be slow for large batches.
"""
if
no_top_k
and
no_top_p
:
return
logits
logits_sort
,
logits_idx
=
logits
.
sort
(
dim
=-
1
,
descending
=
False
)
if
not
no_top_k
:
# Apply top-k.
top_k_mask
=
logits_sort
.
size
(
1
)
-
k
.
to
(
torch
.
long
)
# Get all the top_k values.
top_k_mask
=
logits_sort
.
gather
(
1
,
top_k_mask
.
unsqueeze
(
dim
=
1
))
top_k_mask
=
logits_sort
<
top_k_mask
logits_sort
.
masked_fill_
(
top_k_mask
,
-
float
(
"inf"
))
if
not
no_top_p
:
# Apply top-p.
probs_sort
=
logits_sort
.
softmax
(
dim
=-
1
)
probs_sum
=
probs_sort
.
cumsum
(
dim
=-
1
)
top_p_mask
=
probs_sum
<=
1
-
p
.
unsqueeze
(
dim
=
1
)
# at least one
top_p_mask
[:,
-
1
]
=
False
logits_sort
.
masked_fill_
(
top_p_mask
,
-
float
(
"inf"
))
# Re-sort the probabilities.
logits
=
logits_sort
.
scatter
(
dim
=-
1
,
index
=
logits_idx
,
src
=
logits_sort
)
return
logits
def
random_sample
(
probs
:
torch
.
Tensor
,
generators
:
Dict
[
int
,
torch
.
Generator
],
)
->
torch
.
Tensor
:
"""Randomly sample from the probabilities.
We use this function instead of torch.multinomial because torch.multinomial
causes CPU-GPU synchronization.
"""
q
=
torch
.
empty_like
(
probs
)
# NOTE(woosuk): To batch-process the requests without their own seeds,
# which is the common case, we first assume that every request does
# not have its own seed. Then, we overwrite the values for the requests
# that have their own seeds.
if
len
(
generators
)
!=
probs
.
shape
[
0
]:
q
.
exponential_
()
if
generators
:
# TODO(woosuk): This can be slow because we handle each request
# one by one. Optimize this.
for
i
,
generator
in
generators
.
items
():
q
[
i
].
exponential_
(
generator
=
generator
)
return
probs
.
div_
(
q
).
argmax
(
dim
=-
1
).
view
(
-
1
)
def
flashinfer_sample
(
probs
:
torch
.
Tensor
,
no_top_k
:
bool
,
k
:
torch
.
Tensor
,
no_top_p
:
bool
,
p
:
torch
.
Tensor
,
generators
:
Dict
[
int
,
torch
.
Generator
],
)
->
torch
.
Tensor
:
"""Sample from the probabilities using FlashInfer.
Statistically, this function is equivalent to the `random_sample` function.
However, this function is faster because it avoids sorting the logits tensor
via rejection sampling.
NOTE: The outputs of this function do not necessarily match the outputs of
the `random_sample` function. It only guarantees that the outputs are
statistically equivalent.
NOTE: This function includes CPU-GPU synchronization, while `random_sample`
does not. Call this function at the end of the forward pass to minimize
the synchronization overhead.
"""
assert
not
(
no_top_k
and
no_top_p
)
max_top_k_round
=
32
batch_size
=
probs
.
shape
[
0
]
uniform_samples
=
torch
.
empty
((
max_top_k_round
,
batch_size
),
device
=
probs
.
device
)
if
len
(
generators
)
!=
batch_size
:
uniform_samples
.
uniform_
()
if
generators
:
for
i
,
generator
in
generators
.
items
():
uniform_samples
[:,
i
].
uniform_
(
generator
=
generator
)
if
no_top_k
:
# Top-p only.
next_token_ids
,
success
=
flashinfer
.
sampling
.
top_p_sampling_from_probs
(
probs
,
uniform_samples
,
p
,
deterministic
=
True
)
elif
no_top_p
:
# Top-k only.
next_token_ids
,
success
=
flashinfer
.
sampling
.
top_k_sampling_from_probs
(
probs
,
uniform_samples
,
k
,
deterministic
=
True
)
else
:
# Both top-k and top-p.
next_token_ids
,
success
=
(
flashinfer
.
sampling
.
top_k_top_p_sampling_from_probs
(
probs
,
uniform_samples
,
k
,
p
,
deterministic
=
True
))
# NOTE: CPU-GPU synchronization happens here.
if
not
success
.
all
():
if
not
no_top_k
:
probs
=
flashinfer
.
sampling
.
top_k_renorm_prob
(
probs
,
k
)
if
not
no_top_p
:
probs
=
flashinfer
.
sampling
.
top_p_renorm_prob
(
probs
,
p
)
next_token_ids
=
flashinfer
.
sampling
.
sampling_from_probs
(
probs
,
uniform_samples
[
0
],
deterministic
=
True
)
return
next_token_ids
.
view
(
-
1
)
vllm/v1/sample/sampler.py
View file @
371d04d3
"""A layer that samples the next tokens from the model's outputs."""
from
typing
import
Dict
,
List
,
Set
,
Tuple
from
typing
import
Tuple
import
torch
import
torch.nn
as
nn
from
vllm.model_executor.layers.utils
import
apply_penalties
from
vllm.utils
import
is_pin_memory_available
,
make_tensor_with_pad
from
vllm.v1.outputs
import
SamplerOutput
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.ops.penalties
import
(
apply_min_token_penalties
,
apply_penalties
)
from
vllm.v1.sample.ops.topk_topp_sampler
import
TopKTopPSampler
_SAMPLING_EPS
=
1e-5
class
Sampler
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
topk_topp_sampler
=
TopKTopPSampler
()
def
forward
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
_apply_min_token_penalties
(
logits
,
sampling_metadata
.
output_token_ids
,
sampling_metadata
.
stop_token_ids
,
sampling_metadata
.
min_tokens
)
if
not
sampling_metadata
.
no_penalties
:
assert
sampling_metadata
.
prompt_token_ids
is
not
None
_apply_penalties
(
logits
,
sampling_metadata
.
prompt_token_ids
,
sampling_metadata
.
presence_penalties
,
sampling_metadata
.
frequency_penalties
,
sampling_metadata
.
repetition_penalties
,
sampling_metadata
.
output_token_ids
)
logits
=
self
.
apply_temperature
(
logits
,
sampling_metadata
.
temperature
)
logits
=
self
.
apply_top_k_top_p
(
logits
,
sampling_metadata
)
probs
=
self
.
get_probs
(
logits
)
sampled
=
self
.
sample
(
probs
,
sampling_metadata
)
# Use int32 to reduce the tensor size.
sampled
=
sampled
.
to
(
torch
.
int32
)
if
sampling_metadata
.
max_num_logprobs
>
0
:
logprobs
=
self
.
get_logprobs
(
logits
)
# FIXME: Mask the sampled token_id, get topk logprobs,
# and concatenate the topk with the sampled token_id.
topk_logprobs
,
topk_indices
=
torch
.
topk
(
logprobs
,
sampling_metadata
.
max_num_logprobs
,
dim
=-
1
)
# Use int32 to reduce the tensor size.
topk_indices
=
topk_indices
.
to
(
torch
.
int32
)
needs_logprobs
=
sampling_metadata
.
max_num_logprobs
>
0
if
needs_logprobs
:
# NOTE(woosuk): Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs.
# This is different from the V0 sampler, which uses the logits that
# is used for sampling (after penalties and temperature scaling).
# NOTE: We compute logprobs first because the below ops may
# modify the logits tensor in-place (and we don't want to clone
# the logits tensor for memory efficiency).
topk_logprobs
,
topk_indices
=
self
.
get_topk_logprobs
(
logits
,
sampling_metadata
)
else
:
topk_logprobs
=
None
topk_indices
=
None
# Use float32 for the logits.
logits
=
logits
.
to
(
torch
.
float32
)
# Apply penalties (e.g., min_tokens, freq_penalties).
logits
=
self
.
apply_penalties
(
logits
,
sampling_metadata
)
# Apply temperature.
logits
=
self
.
apply_temperature
(
logits
,
sampling_metadata
.
temperature
)
# Sample the next token.
sampled
=
self
.
sample
(
logits
,
sampling_metadata
)
# Use int32 to reduce the tensor size.
sampled
=
sampled
.
to
(
torch
.
int32
)
# NOTE: CPU-GPU synchronization happens here.
sampler_output
=
SamplerOutput
(
sampled_token_ids
=
sampled
.
tolist
(),
...
...
@@ -63,71 +65,37 @@ class Sampler(nn.Module):
logits
:
torch
.
Tensor
,
temp
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# Use float32 to apply temperature scaling.
logits
=
logits
.
to
(
torch
.
float32
)
# Avoid division by zero.
temp
=
torch
.
where
(
temp
<
_SAMPLING_EPS
,
1.0
,
temp
)
# Use in-place division to avoid creating a new tensor.
logits
.
div_
(
temp
.
unsqueeze
(
dim
=
1
))
return
logits
def
apply_top_k_top_p
(
def
greedy_sample
(
self
,
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
logits
.
argmax
(
dim
=-
1
).
view
(
-
1
)
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
torch
.
Tensor
:
return
_apply_top_k_top_p
(
assert
not
(
sampling_metadata
.
all_greedy
and
sampling_metadata
.
all_random
)
if
sampling_metadata
.
all_greedy
:
return
self
.
greedy_sample
(
logits
)
random_sampled
=
self
.
topk_topp_sampler
(
logits
,
sampling_metadata
.
generators
,
sampling_metadata
.
no_top_k
,
sampling_metadata
.
top_k
,
sampling_metadata
.
no_top_p
,
sampling_metadata
.
top_p
,
)
def
get_probs
(
self
,
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
def
get_logprobs
(
self
,
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
log_softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
def
greedy_sample
(
self
,
probs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
probs
.
argmax
(
dim
=-
1
).
view
(
-
1
)
def
random_sample
(
self
,
probs
:
torch
.
Tensor
,
generators
:
Dict
[
int
,
torch
.
Generator
],
)
->
torch
.
Tensor
:
q
=
torch
.
empty_like
(
probs
)
# NOTE(woosuk): To batch-process the requests without their own seeds,
# which is the common case, we first assume that every request does
# not have its own seed. Then, we overwrite the values for the requests
# that have their own seeds.
if
len
(
generators
)
!=
probs
.
shape
[
0
]:
# This might still be done here unnecessarily if there are greedies
q
.
exponential_
()
if
generators
:
# TODO(woosuk): This can be slow because we handle each request
# one by one. Optimize this.
for
i
,
generator
in
generators
.
items
():
q
[
i
].
exponential_
(
generator
=
generator
)
return
probs
.
div_
(
q
).
argmax
(
dim
=-
1
).
view
(
-
1
)
def
sample
(
self
,
probs
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
torch
.
Tensor
:
assert
not
(
sampling_metadata
.
all_greedy
and
sampling_metadata
.
all_random
)
if
sampling_metadata
.
all_greedy
:
return
self
.
greedy_sample
(
probs
)
if
sampling_metadata
.
all_random
:
return
self
.
random_sample
(
probs
,
sampling_metadata
.
generators
)
return
random_sample
d
greedy_sampled
=
self
.
greedy_sample
(
probs
)
random_sampled
=
self
.
random_sample
(
probs
,
sampling_metadata
.
generators
)
greedy_sampled
=
self
.
greedy_sample
(
logits
)
sampled
=
torch
.
where
(
sampling_metadata
.
temperature
<
_SAMPLING_EPS
,
greedy_sampled
,
...
...
@@ -135,86 +103,34 @@ class Sampler(nn.Module):
)
return
sampled
def
get_topk_logprobs
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
logprobs
=
logits
.
log_softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
# FIXME: Mask the sampled token_id, get topk logprobs,
# and concatenate the topk with the sampled token_id.
topk_logprobs
,
topk_indices
=
torch
.
topk
(
logprobs
,
sampling_metadata
.
max_num_logprobs
,
dim
=-
1
)
# Use int32 to reduce the tensor size.
topk_indices
=
topk_indices
.
to
(
torch
.
int32
)
return
topk_logprobs
,
topk_indices
# TODO(woosuk): Optimize this with a custom kernel.
def
_apply_top_k_top_p
(
logits
:
torch
.
Tensor
,
no_top_k
:
bool
,
k
:
torch
.
Tensor
,
no_top_p
:
bool
,
p
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
if
no_top_k
and
no_top_p
:
def
apply_penalties
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
torch
.
Tensor
:
apply_min_token_penalties
(
logits
,
sampling_metadata
.
output_token_ids
,
sampling_metadata
.
stop_token_ids
,
sampling_metadata
.
min_tokens
)
if
not
sampling_metadata
.
no_penalties
:
assert
sampling_metadata
.
prompt_token_ids
is
not
None
logits
=
apply_penalties
(
logits
,
sampling_metadata
.
prompt_token_ids
,
sampling_metadata
.
presence_penalties
,
sampling_metadata
.
frequency_penalties
,
sampling_metadata
.
repetition_penalties
,
sampling_metadata
.
output_token_ids
)
return
logits
logits_sort
,
logits_idx
=
logits
.
sort
(
dim
=-
1
,
descending
=
False
)
if
not
no_top_k
:
# Apply top-k.
top_k_mask
=
logits_sort
.
size
(
1
)
-
k
.
to
(
torch
.
long
)
# Get all the top_k values.
top_k_mask
=
logits_sort
.
gather
(
1
,
top_k_mask
.
unsqueeze
(
dim
=
1
))
top_k_mask
=
logits_sort
<
top_k_mask
logits_sort
.
masked_fill_
(
top_k_mask
,
-
float
(
"inf"
))
if
not
no_top_p
:
# Apply top-p.
probs_sort
=
logits_sort
.
softmax
(
dim
=-
1
)
probs_sum
=
probs_sort
.
cumsum
(
dim
=-
1
)
top_p_mask
=
probs_sum
<=
1
-
p
.
unsqueeze
(
dim
=
1
)
# at least one
top_p_mask
[:,
-
1
]
=
False
logits_sort
.
masked_fill_
(
top_p_mask
,
-
float
(
"inf"
))
# Re-sort the probabilities.
logits
=
logits_sort
.
scatter
(
dim
=-
1
,
index
=
logits_idx
,
src
=
logits_sort
)
return
logits
def
_apply_min_token_penalties
(
logits
:
torch
.
Tensor
,
output_token_ids
:
List
[
List
[
int
]],
stop_token_ids
:
List
[
Set
[
int
]],
min_tokens
:
List
[
int
]):
"""
Applies minimum token penalty by setting the logits of the stop tokens
to -inf.
"""
min_tokens_logits_to_penalize
:
List
[
Tuple
[
int
,
int
]]
=
[]
for
index
,
min_token
in
enumerate
(
min_tokens
):
if
(
len
(
output_token_ids
[
index
])
<
min_token
):
for
stop_token_id
in
stop_token_ids
[
index
]:
min_tokens_logits_to_penalize
.
append
((
index
,
stop_token_id
))
if
min_tokens_logits_to_penalize
:
logits
[
tuple
(
zip
(
*
min_tokens_logits_to_penalize
))]
=
-
float
(
"inf"
)
def
_apply_penalties
(
logits
:
torch
.
Tensor
,
prompt_token_ids
:
torch
.
Tensor
,
presence_penalties
:
torch
.
Tensor
,
frequency_penalties
:
torch
.
Tensor
,
repetition_penalties
:
torch
.
Tensor
,
output_token_ids
:
List
[
List
[
int
]]):
"""
Applies presence, frequency and repetition penalties to the logits.
"""
_
,
vocab_size
=
logits
.
shape
output_tokens_t
=
_convert_to_tensors
(
output_token_ids
,
vocab_size
,
logits
.
device
)
return
apply_penalties
(
logits
,
prompt_token_ids
,
output_tokens_t
,
presence_penalties
,
frequency_penalties
,
repetition_penalties
)
def
_convert_to_tensors
(
output_token_ids
:
List
[
List
[
int
]],
vocab_size
:
int
,
device
:
torch
.
device
)
->
torch
.
Tensor
:
"""
Convert the different list data structures to tensors.
"""
output_tokens_tensor
=
make_tensor_with_pad
(
output_token_ids
,
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
pad
=
vocab_size
,
device
=
"cpu"
,
dtype
=
torch
.
int64
,
pin_memory
=
is_pin_memory_available
(),
)
return
output_tokens_tensor
.
to
(
device
,
non_blocking
=
True
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment