Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c6bd70d7
Unverified
Commit
c6bd70d7
authored
Sep 22, 2024
by
Lily Liu
Committed by
GitHub
Sep 22, 2024
Browse files
[SpecDec][Misc] Cleanup, remove bonus token logic. (#8701)
parent
5b595327
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
33 additions
and
115 deletions
+33
-115
tests/samplers/test_rejection_sampler.py
tests/samplers/test_rejection_sampler.py
+7
-23
tests/samplers/test_typical_acceptance_sampler.py
tests/samplers/test_typical_acceptance_sampler.py
+20
-59
tests/spec_decode/e2e/test_medusa_correctness.py
tests/spec_decode/e2e/test_medusa_correctness.py
+1
-1
vllm/model_executor/layers/rejection_sampler.py
vllm/model_executor/layers/rejection_sampler.py
+1
-8
vllm/model_executor/layers/spec_decode_base_sampler.py
vllm/model_executor/layers/spec_decode_base_sampler.py
+1
-14
vllm/model_executor/layers/typical_acceptance_sampler.py
vllm/model_executor/layers/typical_acceptance_sampler.py
+2
-7
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+1
-3
No files found.
tests/samplers/test_rejection_sampler.py
View file @
c6bd70d7
...
...
@@ -42,18 +42,13 @@ def mock_causal_accepted_tensor(
@
pytest
.
mark
.
parametrize
(
"which_tokens_accepted"
,
[
"all_tokens_accepted"
,
"no_tokens_accepted"
,
"some_tokens_accepted"
])
@
pytest
.
mark
.
parametrize
(
"disable_bonus_tokens"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"use_flashinfer"
,
[
True
,
False
])
@
torch
.
inference_mode
()
def
test_correct_output_format
(
which_tokens_accepted
:
str
,
seed
:
int
,
disable_bonus_tokens
:
bool
,
device
:
str
,
use_flashinfer
:
bool
):
device
:
str
,
use_flashinfer
:
bool
):
"""Verify the output has correct format given predetermined accepted matrix.
"""
if
use_flashinfer
and
disable_bonus_tokens
:
pytest
.
skip
(
"Flashinfer rejection sampler must enable bonus token."
)
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
...
...
@@ -88,9 +83,7 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
size
=
(
batch_size
,
1
),
dtype
=
torch
.
int64
)
rejection_sampler
=
RejectionSampler
(
disable_bonus_tokens
=
disable_bonus_tokens
,
use_flashinfer
=
use_flashinfer
)
rejection_sampler
=
RejectionSampler
(
use_flashinfer
=
use_flashinfer
)
rejection_sampler
.
init_gpu_tensors
(
device
=
device
)
output_token_ids
=
rejection_sampler
.
_create_output
(
# pylint: disable=protected-access
accepted
,
...
...
@@ -100,10 +93,6 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
)
expected_bonus_token_ids
=
bonus_token_ids
.
clone
()
# If bonus tokens disabled. Verify they are set to -1.
# See https://github.com/vllm-project/vllm/issues/4212
if
disable_bonus_tokens
:
expected_bonus_token_ids
=
expected_bonus_token_ids
*
0
-
1
if
which_tokens_accepted
==
"all_tokens_accepted"
:
# Expect all tokens to be equal to draft tokens.
...
...
@@ -143,8 +132,7 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
def
test_no_crash_with_varying_dims
(
k
:
int
,
vocab_size
:
int
,
batch_size
:
int
,
device
:
str
,
use_flashinfer
:
bool
):
torch
.
set_default_device
(
device
)
rejection_sampler
=
RejectionSampler
(
disable_bonus_tokens
=
False
,
use_flashinfer
=
use_flashinfer
)
rejection_sampler
=
RejectionSampler
(
use_flashinfer
=
use_flashinfer
)
rejection_sampler
.
init_gpu_tensors
(
device
=
device
)
draft_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
)
...
...
@@ -177,8 +165,7 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
frac_seeded
:
float
,
n_rep
:
int
,
device
:
str
,
use_flashinfer
:
bool
):
torch
.
set_default_device
(
device
)
rejection_sampler
=
RejectionSampler
(
disable_bonus_tokens
=
False
,
use_flashinfer
=
use_flashinfer
)
rejection_sampler
=
RejectionSampler
(
use_flashinfer
=
use_flashinfer
)
rejection_sampler
.
init_gpu_tensors
(
device
=
device
)
draft_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
)
...
...
@@ -251,8 +238,7 @@ def test_compare_nonflashinfer_backend(k: int, vocab_size: int,
}
for
use_flashinfer
in
[
True
,
False
]:
rejection_sampler
=
RejectionSampler
(
disable_bonus_tokens
=
False
,
use_flashinfer
=
use_flashinfer
)
rejection_sampler
=
RejectionSampler
(
use_flashinfer
=
use_flashinfer
)
rejection_sampler
.
init_gpu_tensors
(
device
=
device
)
# We use seeded sequences to ensure the same tokens are accepted
# for both flashinfer and nonflashinfer backends.
...
...
@@ -282,8 +268,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
vocab_size
=
30_000
torch
.
set_default_device
(
device
)
rejection_sampler
=
RejectionSampler
(
disable_bonus_tokens
=
False
,
use_flashinfer
=
use_flashinfer
,
rejection_sampler
=
RejectionSampler
(
use_flashinfer
=
use_flashinfer
,
strict_mode
=
True
)
rejection_sampler
.
init_gpu_tensors
(
device
=
device
)
...
...
@@ -359,8 +344,7 @@ def test_rejection_sampling_approximates_target_distribution(
set_random_seed
(
seed
)
helper
=
_CorrectnessTestHelper
(
vocab_size
=
10
,
rejection_sampler
=
RejectionSampler
(
disable_bonus_tokens
=
False
,
use_flashinfer
=
use_flashinfer
),
rejection_sampler
=
RejectionSampler
(
use_flashinfer
=
use_flashinfer
),
)
draft_probs
,
target_probs
,
reference_probs
=
helper
.
generate_probs_for_test
(
...
...
tests/samplers/test_typical_acceptance_sampler.py
View file @
c6bd70d7
...
...
@@ -55,14 +55,13 @@ def get_draft_token_ids(batch_size: int, k: int, vocab_size: int,
def
get_acceptance_sampler
(
posterior_threshold
:
float
=
0.03
,
posterior_alpha
:
float
=
0.9
,
disable_bonus_tokens
:
bool
=
False
,
strict_mode
:
bool
=
False
,
)
->
TypicalAcceptanceSampler
:
"""
Initializes and returns a TypicalAcceptanceSampler.
"""
return
TypicalAcceptanceSampler
(
posterior_threshold
,
posterior_alpha
,
disable_bonus_tokens
,
strict_mode
)
strict_mode
)
@
pytest
.
mark
.
parametrize
(
"k"
,
list
(
range
(
1
,
6
)))
...
...
@@ -154,11 +153,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
10
)))
@
pytest
.
mark
.
parametrize
(
"disable_bonus_tokens"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_uniform_target_distribution_accepts_all_tokens
(
seed
:
int
,
disable_bonus_tokens
:
bool
,
device
:
str
):
seed
:
int
,
device
:
str
):
"""
Test the TypicalAcceptanceSampler with a uniform target probability
distribution.
...
...
@@ -166,17 +164,14 @@ def test_uniform_target_distribution_accepts_all_tokens(
This test verifies that when provided with a uniform target probability
distribution, the TypicalAcceptanceSampler accepts all draft tokens. The
entropy of the uniform target distribution being high should lead to all
draft tokens being accepted. The test also ensures that the behavior
regarding bonus tokens is consistent with the `disable_bonus_tokens`
flag.
draft tokens being accepted.
"""
set_random_seed
(
seed
)
k
=
3
batch_size
=
5
vocab_size
=
30_000
torch
.
set_default_device
(
device
)
typical_acceptance_sampler
=
get_acceptance_sampler
(
strict_mode
=
True
,
disable_bonus_tokens
=
disable_bonus_tokens
)
typical_acceptance_sampler
=
get_acceptance_sampler
(
strict_mode
=
True
)
typical_acceptance_sampler
.
init_gpu_tensors
(
device
=
device
)
target_with_bonus_probs
=
torch
.
rand
(
batch_size
,
k
+
1
,
...
...
@@ -200,21 +195,15 @@ def test_uniform_target_distribution_accepts_all_tokens(
# should lead to all draft tokens being accepted. Verify that.
assert
output_token_ids
.
shape
[
0
]
==
batch_size
assert
output_token_ids
.
shape
[
1
]
==
(
k
+
1
)
if
disable_bonus_tokens
:
assert
torch
.
all
(
output_token_ids
[:,
-
1
]
==
-
1
)
else
:
assert
torch
.
all
(
output_token_ids
[:,
-
1
]
==
bonus_token_ids
.
squeeze
())
assert
torch
.
all
(
output_token_ids
[:,
-
1
]
==
bonus_token_ids
.
squeeze
())
assert
torch
.
all
(
output_token_ids
[:,
:
k
]
==
draft_token_ids
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
10
)))
@
pytest
.
mark
.
parametrize
(
"disable_bonus_tokens"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_temperature_zero_target_distribution
(
seed
:
int
,
disable_bonus_tokens
:
bool
,
device
:
str
):
def
test_temperature_zero_target_distribution
(
seed
:
int
,
device
:
str
):
"""
Test the TypicalAcceptanceSampler with a zero-temperature target
probability distribution.
...
...
@@ -232,8 +221,7 @@ def test_temperature_zero_target_distribution(seed: int,
vocab_size
=
30_000
torch
.
set_default_device
(
device
)
typical_acceptance_sampler
=
get_acceptance_sampler
(
strict_mode
=
True
,
disable_bonus_tokens
=
disable_bonus_tokens
)
typical_acceptance_sampler
=
get_acceptance_sampler
(
strict_mode
=
True
)
typical_acceptance_sampler
.
init_gpu_tensors
(
device
=
device
)
# Simulate temperature 0 probability distribution for target probabilities
# and create target probabilities such that only 1 token id has
...
...
@@ -267,11 +255,9 @@ def test_temperature_zero_target_distribution(seed: int,
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
10
)))
@
pytest
.
mark
.
parametrize
(
"disable_bonus_tokens"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_mixed_target_distribution
(
seed
:
int
,
disable_bonus_tokens
:
bool
,
device
:
str
):
def
test_mixed_target_distribution
(
seed
:
int
,
device
:
str
):
"""
Test the TypicalAcceptanceSampler with a mixed target probability
distribution.
...
...
@@ -285,16 +271,13 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
with a probability of 1.0 is accepted, and all other tokens are rejected.
- For sequences with a uniform distribution, all draft tokens are
accepted.
- When `disable_bonus_tokens` is False, the bonus tokens are also accepted
for sequences with a uniform distribution.
"""
set_random_seed
(
seed
)
k
=
3
batch_size
=
4
vocab_size
=
30_000
torch
.
set_default_device
(
device
)
typical_acceptance_sampler
=
get_acceptance_sampler
(
strict_mode
=
True
,
disable_bonus_tokens
=
disable_bonus_tokens
)
typical_acceptance_sampler
=
get_acceptance_sampler
(
strict_mode
=
True
)
typical_acceptance_sampler
.
init_gpu_tensors
(
device
=
device
)
# For sequences 0 and 2 set the distribution to a temperature
# zero distribution. For sequences 1 and 3 set it to a uniform
...
...
@@ -328,21 +311,16 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
0
]))
# For sequences 1 and 3 verify that all tokens are accepted since the
# target probability distribution is uniform. In addition verify that
#
if disable_bonus_tokens is false then
we also accept the bonus tokens.
# we also accept the bonus tokens.
assert
torch
.
all
(
output_token_ids
[[
1
,
3
],
:
-
1
]
==
draft_token_ids
[[
1
,
3
],
:])
if
disable_bonus_tokens
:
assert
torch
.
all
(
output_token_ids
[[
1
,
3
],
-
1
]
==
-
1
)
else
:
assert
torch
.
all
(
output_token_ids
[[
1
,
3
],
-
1
]
!=
-
1
)
assert
torch
.
all
(
output_token_ids
[[
1
,
3
],
-
1
]
!=
-
1
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
10
)))
@
pytest
.
mark
.
parametrize
(
"disable_bonus_tokens"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_accept_tokens_partially
(
seed
:
int
,
disable_bonus_tokens
:
bool
,
device
:
str
):
def
test_accept_tokens_partially
(
seed
:
int
,
device
:
str
):
"""
Test the TypicalAcceptanceSampler's behavior when only a subset of draft
tokens should be accepted.
...
...
@@ -362,8 +340,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
batch_size
=
1
vocab_size
=
30_000
torch
.
set_default_device
(
device
)
typical_acceptance_sampler
=
get_acceptance_sampler
(
strict_mode
=
True
,
disable_bonus_tokens
=
disable_bonus_tokens
)
typical_acceptance_sampler
=
get_acceptance_sampler
(
strict_mode
=
True
)
typical_acceptance_sampler
.
init_gpu_tensors
(
device
=
device
)
# Create a temperature zero target probability distribution and ensure
# all draft token ids correspond to the tokens with 1.0 probability.
...
...
@@ -384,10 +361,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
assert
output_token_ids
.
shape
[
0
]
==
batch_size
assert
output_token_ids
.
shape
[
1
]
==
(
k
+
1
)
assert
torch
.
all
(
output_token_ids
[:,
0
:
-
1
]
==
draft_token_ids
)
if
disable_bonus_tokens
:
assert
torch
.
all
(
output_token_ids
[:,
-
1
]
==
-
1
)
else
:
assert
torch
.
all
(
output_token_ids
[:,
-
1
]
==
bonus_token_ids
)
assert
torch
.
all
(
output_token_ids
[:,
-
1
]
==
bonus_token_ids
)
# Next only keep the first 2 draft tokens same as the zero temperature
# tokens. For the remaining 3 choose some other tokens. In the
# response we will expect the first 2 tokens to be the same as the
...
...
@@ -408,12 +382,9 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
1
)))
@
pytest
.
mark
.
parametrize
(
"disable_bonus_tokens"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_accept_tokens_set_non_default_posteriors
(
seed
:
int
,
disable_bonus_tokens
:
bool
,
device
:
str
):
def
test_accept_tokens_set_non_default_posteriors
(
seed
:
int
,
device
:
str
):
"""
Test the TypicalAcceptanceSampler with custom posterior thresholds and
alpha values. This test verifies that by modifying the posterior
...
...
@@ -425,8 +396,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
batch_size
=
1
vocab_size
=
30_000
torch
.
set_default_device
(
device
)
typical_acceptance_sampler
=
get_acceptance_sampler
(
strict_mode
=
True
,
disable_bonus_tokens
=
disable_bonus_tokens
)
typical_acceptance_sampler
=
get_acceptance_sampler
(
strict_mode
=
True
)
typical_acceptance_sampler
.
init_gpu_tensors
(
device
=
device
)
# Simulate temperature 0 probability distribution for target
# probabilities and create target probabilities such that only 1 token
...
...
@@ -457,10 +427,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
# now accept even draft tokens with very low probability in the
# target distribution. Simulate and verify the same.
typical_acceptance_sampler
=
TypicalAcceptanceSampler
(
strict_mode
=
True
,
disable_bonus_tokens
=
disable_bonus_tokens
,
posterior_threshold
=
0.0
,
posterior_alpha
=
0.0
)
strict_mode
=
True
,
posterior_threshold
=
0.0
,
posterior_alpha
=
0.0
)
typical_acceptance_sampler
.
init_gpu_tensors
(
device
=
device
)
output_token_ids
=
typical_acceptance_sampler
(
target_probs
,
...
...
@@ -470,18 +437,13 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
assert
output_token_ids
.
shape
[
0
]
==
batch_size
assert
output_token_ids
.
shape
[
1
]
==
(
k
+
1
)
assert
torch
.
all
(
output_token_ids
[:,
0
:
-
1
]
==
draft_token_ids
)
if
disable_bonus_tokens
:
assert
torch
.
all
(
output_token_ids
[:,
-
1
]
==
-
1
)
else
:
assert
torch
.
all
(
output_token_ids
[:,
-
1
]
==
bonus_token_ids
)
assert
torch
.
all
(
output_token_ids
[:,
-
1
]
==
bonus_token_ids
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
10
)))
@
pytest
.
mark
.
parametrize
(
"disable_bonus_tokens"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_replacement_token_ids
(
seed
:
int
,
disable_bonus_tokens
:
bool
,
device
:
str
):
def
test_replacement_token_ids
(
seed
:
int
,
device
:
str
):
"""
Test the TypicalAcceptanceSampler's method for generating
replacement token IDs.
...
...
@@ -497,8 +459,7 @@ def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool,
batch_size
=
5
vocab_size
=
30_000
torch
.
set_default_device
(
device
)
typical_acceptance_sampler
=
get_acceptance_sampler
(
strict_mode
=
True
,
disable_bonus_tokens
=
disable_bonus_tokens
)
typical_acceptance_sampler
=
get_acceptance_sampler
(
strict_mode
=
True
)
typical_acceptance_sampler
.
init_gpu_tensors
(
device
=
device
)
target_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
)
expected_replacement_tokens
=
-
torch
.
ones
(
...
...
tests/spec_decode/e2e/test_medusa_correctness.py
View file @
c6bd70d7
...
...
@@ -31,7 +31,7 @@ MAIN_MODEL = "JackFram/llama-68m"
# speculative model
SPEC_MODEL
=
"abhigoyal/vllm-medusa-llama-68m-random"
# max
.
number of speculative tokens: this corresponds to
# max number of speculative tokens: this corresponds to
# num_heads in the config.json of the speculator model.
MAX_SPEC_TOKENS
=
5
...
...
vllm/model_executor/layers/rejection_sampler.py
View file @
c6bd70d7
...
...
@@ -31,15 +31,11 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
"""
def
__init__
(
self
,
disable_bonus_tokens
:
bool
=
True
,
strict_mode
:
bool
=
False
,
use_flashinfer
:
Optional
[
bool
]
=
None
):
"""Create a rejection sampler.
Args:
disable_bonus_tokens: Whether or not to disable the bonus token.
Require when bonus tokens will cause corrupt KV cache for
proposal methods that require KV cache.
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
nontrivial latency.
...
...
@@ -48,8 +44,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
None, we will use the default value from the environment variable.
This parameter is only used for testing purposes.
"""
super
().
__init__
(
disable_bonus_tokens
=
disable_bonus_tokens
,
strict_mode
=
strict_mode
)
super
().
__init__
(
strict_mode
=
strict_mode
)
if
use_flashinfer
is
None
:
self
.
use_flashinfer
=
envs
.
VLLM_USE_FLASHINFER_SAMPLER
and
(
chain_speculative_sampling
is
not
None
)
...
...
@@ -57,8 +52,6 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
self
.
use_flashinfer
=
use_flashinfer
if
self
.
use_flashinfer
:
assert
not
disable_bonus_tokens
,
\
"flashinfer will enable bonus token by default"
logger
.
info
(
"Use flashinfer for rejection sampling."
)
else
:
logger
.
info
(
"Use pytorch for rejection sampling."
)
...
...
vllm/model_executor/layers/spec_decode_base_sampler.py
View file @
c6bd70d7
...
...
@@ -11,20 +11,14 @@ class SpecDecodeBaseSampler(nn.Module):
step.
"""
def
__init__
(
self
,
disable_bonus_tokens
:
bool
=
True
,
strict_mode
:
bool
=
False
):
def
__init__
(
self
,
strict_mode
:
bool
=
False
):
"""Base class constructor.
Args:
disable_bonus_tokens: Whether or not to disable the bonus token.
Require when bonus tokens will cause corrupt KV cache for
proposal methods that require KV cache.
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
nontrivial latency.
"""
super
().
__init__
()
self
.
_disable_bonus_tokens
=
disable_bonus_tokens
self
.
_strict_mode
=
strict_mode
# NOTE: A "bonus token" is accepted iff all proposal tokens are
...
...
@@ -111,13 +105,6 @@ class SpecDecodeBaseSampler(nn.Module):
output_with_bonus_tokens
[:,
-
1
]
=
torch
.
where
(
output
[:,
-
1
]
!=
-
1
,
bonus_token_ids
,
-
1
)
# We disable bonus tokens because it causes corrupt KV cache for
# proposal methods that require KV cache. We can fix it by "prefilling"
# the bonus token in the proposer. The following issue tracks the fix.
# https://github.com/vllm-project/vllm/issues/4212
if
self
.
_disable_bonus_tokens
:
output_with_bonus_tokens
[:,
-
1
]
=
-
1
# Fill the recovered token ids.
output
.
mul_
(
~
after_false_mask
).
add_
(
substitute_token_ids
.
mul
(
after_false_mask
))
...
...
vllm/model_executor/layers/typical_acceptance_sampler.py
View file @
c6bd70d7
...
...
@@ -16,15 +16,11 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
self
,
posterior_threshold
:
float
,
posterior_alpha
:
float
,
disable_bonus_tokens
:
bool
=
False
,
strict_mode
:
bool
=
False
,
):
"""Create a Typical Acceptance Sampler.
Args:
disable_bonus_tokens: Whether or not to disable the bonus token.
Require when bonus tokens will cause corrupt KV cache for
proposal methods that require KV cache.
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
nontrivial latency.
...
...
@@ -36,8 +32,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
"""
self
.
_posterior_threshold
=
posterior_threshold
self
.
_posterior_alpha
=
posterior_alpha
super
().
__init__
(
disable_bonus_tokens
=
disable_bonus_tokens
,
strict_mode
=
strict_mode
)
super
().
__init__
(
strict_mode
=
strict_mode
)
def
forward
(
self
,
...
...
@@ -54,7 +49,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
one token will be emitted.
In the case where all draft tokens are accepted, the bonus token will be
accepted
conditioned on self._disable_bonus_tokens being false
.
accepted.
Args:
target_probs: The probability distribution over token ids given
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
c6bd70d7
...
...
@@ -164,11 +164,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
spec_decode_sampler
:
SpecDecodeBaseSampler
=
None
if
draft_token_acceptance_method
==
"rejection_sampler"
:
spec_decode_sampler
=
RejectionSampler
(
disable_bonus_tokens
=
False
,
)
spec_decode_sampler
=
RejectionSampler
()
elif
draft_token_acceptance_method
==
"typical_acceptance_sampler"
:
spec_decode_sampler
=
TypicalAcceptanceSampler
(
disable_bonus_tokens
=
False
,
posterior_threshold
=
\
typical_acceptance_sampler_posterior_threshold
,
posterior_alpha
=
typical_acceptance_sampler_posterior_alpha
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment