Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
69e1d23e
Unverified
Commit
69e1d23e
authored
Feb 16, 2025
by
Woosuk Kwon
Committed by
GitHub
Feb 16, 2025
Browse files
[V1][BugFix] Clean up rejection sampler & Fix warning msg (#13362)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
d67cc21b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
69 additions
and
40 deletions
+69
-40
vllm/v1/sample/rejection_sampler.py
vllm/v1/sample/rejection_sampler.py
+69
-40
No files found.
vllm/v1/sample/rejection_sampler.py
View file @
69e1d23e
...
@@ -3,7 +3,9 @@ import torch
...
@@ -3,7 +3,9 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch.nn.utils.rnn
import
pad_sequence
from
torch.nn.utils.rnn
import
pad_sequence
from
vllm
import
envs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.v1.outputs
import
SamplerOutput
from
vllm.v1.outputs
import
SamplerOutput
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
...
@@ -19,27 +21,50 @@ INVALID_TOKEN_ID = -1
...
@@ -19,27 +21,50 @@ INVALID_TOKEN_ID = -1
class
RejectionSampler
(
nn
.
Module
):
class
RejectionSampler
(
nn
.
Module
):
def
forward
(
self
,
logits
:
torch
.
Tensor
,
def
__init__
(
self
):
sampling_metadata
:
SamplingMetadata
)
->
SamplerOutput
:
super
().
__init__
()
if
not
sampling_metadata
.
all_greedy
:
if
current_platform
.
is_cuda
:
raise
NotImplementedError
(
"Only greedy sampling is supported by rejection sampler."
)
if
is_flashinfer_available
:
if
is_flashinfer_available
:
logger
.
info
(
"User FlashInfer for rejection sampling."
)
if
envs
.
VLLM_USE_FLASHINFER_SAMPLER
is
not
False
:
return
RejectionSampler
.
flashinfer_sample
(
logits
,
# NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
sampling_metadata
)
# sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by
# default it is unused). For backward compatibility, we set
# `VLLM_USE_FLASHINFER_SAMPLER` as None by default and
# interpret it differently in V0 and V1 samplers: In V0,
# None means False, while in V1, None means True. This is
# why we use the condition
# `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here.
logger
.
info
(
"Using FlashInfer for rejection sampling."
)
self
.
forward_method
=
self
.
flashinfer_sample
else
:
logger
.
warning
(
"FlashInfer is available, but it is not enabled. "
"Falling back to the PyTorch-native implementation of "
"rejection sampling. For the best performance, "
"please set VLLM_USE_FLASHINFER_SAMPLER=1."
)
self
.
forward_method
=
self
.
forward_native
else
:
else
:
logger
.
warning
(
logger
.
warning
(
"FlashInfer is not available. Falling back to the PyTorch-"
"FlashInfer is not available. Falling back to the PyTorch-"
"native implementation of rejection sampling."
)
"native implementation of rejection sampling. For the "
return
RejectionSampler
.
greedy_sample_native
(
"best performance, please install FlashInfer."
)
logits
,
sampling_metadata
)
self
.
forward_method
=
self
.
forward_native
else
:
self
.
forward_method
=
self
.
forward_native
def
forward
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
SamplerOutput
:
if
not
sampling_metadata
.
all_greedy
:
raise
NotImplementedError
(
"Currently, only greedy sampling is supported by "
"rejection sampler."
)
return
self
.
forward_method
(
logits
,
sampling_metadata
)
@
staticmethod
def
flashinfer_sample
(
def
flashinfer_sample
(
self
,
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
SamplerOutput
:
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
# NOTE: The following input preparationg can be moved
# NOTE: The following input preparationg can be moved
# to the model runner with a persistent manner for better
# to the model runner with a persistent manner for better
# performance.
# performance.
...
@@ -71,10 +96,10 @@ class RejectionSampler(nn.Module):
...
@@ -71,10 +96,10 @@ class RejectionSampler(nn.Module):
vocab_size
=
logits
.
size
(
-
1
)
vocab_size
=
logits
.
size
(
-
1
)
# NOTE: CPU <-> GPU synchronization happens here.
# NOTE: CPU <-> GPU synchronization happens here.
draft_token_ids
=
draft_token_ids
.
to
(
logits
.
device
)
draft_token_ids
=
draft_token_ids
.
to
(
logits
.
device
)
draft_probs
=
RejectionSampler
.
_create_greedy_token_probs
(
draft_probs
=
_create_greedy_token_probs
(
draft_token_ids
,
vocab_size
,
draft_token_ids
,
vocab_size
,
logits
.
device
)
logits
.
device
)
target_probs
=
RejectionSampler
.
_create_greedy_token_probs
(
target_probs
=
_create_greedy_token_probs
(
target_token_ids
,
vocab_size
,
target_token_ids
,
vocab_size
,
logits
.
device
)
logits
.
device
)
uniform_samples
=
torch
.
zeros
(
batch_size
,
uniform_samples
=
torch
.
zeros
(
batch_size
,
max_spec_len
+
1
,
max_spec_len
+
1
,
device
=
logits
.
device
)
device
=
logits
.
device
)
...
@@ -89,10 +114,11 @@ class RejectionSampler(nn.Module):
...
@@ -89,10 +114,11 @@ class RejectionSampler(nn.Module):
logprobs_tensors
=
None
)
logprobs_tensors
=
None
)
# TODO: The following method can be optimized for better performance.
# TODO: The following method can be optimized for better performance.
@
staticmethod
def
forward_native
(
def
greedy_sample_native
(
self
,
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
SamplerOutput
:
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
spec_lens
=
[
len
(
x
)
for
x
in
sampling_metadata
.
spec_token_ids
]
spec_lens
=
[
len
(
x
)
for
x
in
sampling_metadata
.
spec_token_ids
]
# Add 1 to include the 'bonus' token.
# Add 1 to include the 'bonus' token.
sample_lens
=
[
x
+
1
for
x
in
spec_lens
]
sample_lens
=
[
x
+
1
for
x
in
spec_lens
]
...
@@ -137,9 +163,12 @@ class RejectionSampler(nn.Module):
...
@@ -137,9 +163,12 @@ class RejectionSampler(nn.Module):
return
SamplerOutput
(
sampled_token_ids
=
output_token_ids
,
return
SamplerOutput
(
sampled_token_ids
=
output_token_ids
,
logprobs_tensors
=
None
)
logprobs_tensors
=
None
)
@
staticmethod
def
_create_greedy_token_probs
(
token_ids
:
torch
.
Tensor
,
vocab_size
:
int
,
def
_create_greedy_token_probs
(
out_device
:
torch
.
device
)
->
torch
.
Tensor
:
token_ids
:
torch
.
Tensor
,
vocab_size
:
int
,
out_device
:
torch
.
device
,
)
->
torch
.
Tensor
:
batch_size
,
num_tokens
=
token_ids
.
shape
batch_size
,
num_tokens
=
token_ids
.
shape
token_probs
=
torch
.
zeros
(
batch_size
,
token_probs
=
torch
.
zeros
(
batch_size
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment