Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c6bd70d7
"vscode:/vscode.git/clone" did not exist on "bf668b5bf56644db8e90cd0d385b62cc15a4657a"
Unverified
Commit
c6bd70d7
authored
Sep 22, 2024
by
Lily Liu
Committed by
GitHub
Sep 22, 2024
Browse files
[SpecDec][Misc] Cleanup, remove bonus token logic. (#8701)
parent
5b595327
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
33 additions
and
115 deletions
+33
-115
tests/samplers/test_rejection_sampler.py
tests/samplers/test_rejection_sampler.py
+7
-23
tests/samplers/test_typical_acceptance_sampler.py
tests/samplers/test_typical_acceptance_sampler.py
+20
-59
tests/spec_decode/e2e/test_medusa_correctness.py
tests/spec_decode/e2e/test_medusa_correctness.py
+1
-1
vllm/model_executor/layers/rejection_sampler.py
vllm/model_executor/layers/rejection_sampler.py
+1
-8
vllm/model_executor/layers/spec_decode_base_sampler.py
vllm/model_executor/layers/spec_decode_base_sampler.py
+1
-14
vllm/model_executor/layers/typical_acceptance_sampler.py
vllm/model_executor/layers/typical_acceptance_sampler.py
+2
-7
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+1
-3
No files found.
tests/samplers/test_rejection_sampler.py
View file @
c6bd70d7
...
@@ -42,18 +42,13 @@ def mock_causal_accepted_tensor(
...
@@ -42,18 +42,13 @@ def mock_causal_accepted_tensor(
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"which_tokens_accepted"
,
"which_tokens_accepted"
,
[
"all_tokens_accepted"
,
"no_tokens_accepted"
,
"some_tokens_accepted"
])
[
"all_tokens_accepted"
,
"no_tokens_accepted"
,
"some_tokens_accepted"
])
@
pytest
.
mark
.
parametrize
(
"disable_bonus_tokens"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"use_flashinfer"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_flashinfer"
,
[
True
,
False
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_correct_output_format
(
which_tokens_accepted
:
str
,
seed
:
int
,
def
test_correct_output_format
(
which_tokens_accepted
:
str
,
seed
:
int
,
disable_bonus_tokens
:
bool
,
device
:
str
,
device
:
str
,
use_flashinfer
:
bool
):
use_flashinfer
:
bool
):
"""Verify the output has correct format given predetermined accepted matrix.
"""Verify the output has correct format given predetermined accepted matrix.
"""
"""
if
use_flashinfer
and
disable_bonus_tokens
:
pytest
.
skip
(
"Flashinfer rejection sampler must enable bonus token."
)
set_random_seed
(
seed
)
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
...
@@ -88,9 +83,7 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
...
@@ -88,9 +83,7 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
size
=
(
batch_size
,
1
),
size
=
(
batch_size
,
1
),
dtype
=
torch
.
int64
)
dtype
=
torch
.
int64
)
rejection_sampler
=
RejectionSampler
(
rejection_sampler
=
RejectionSampler
(
use_flashinfer
=
use_flashinfer
)
disable_bonus_tokens
=
disable_bonus_tokens
,
use_flashinfer
=
use_flashinfer
)
rejection_sampler
.
init_gpu_tensors
(
device
=
device
)
rejection_sampler
.
init_gpu_tensors
(
device
=
device
)
output_token_ids
=
rejection_sampler
.
_create_output
(
# pylint: disable=protected-access
output_token_ids
=
rejection_sampler
.
_create_output
(
# pylint: disable=protected-access
accepted
,
accepted
,
...
@@ -100,10 +93,6 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
...
@@ -100,10 +93,6 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
)
)
expected_bonus_token_ids
=
bonus_token_ids
.
clone
()
expected_bonus_token_ids
=
bonus_token_ids
.
clone
()
# If bonus tokens disabled. Verify they are set to -1.
# See https://github.com/vllm-project/vllm/issues/4212
if
disable_bonus_tokens
:
expected_bonus_token_ids
=
expected_bonus_token_ids
*
0
-
1
if
which_tokens_accepted
==
"all_tokens_accepted"
:
if
which_tokens_accepted
==
"all_tokens_accepted"
:
# Expect all tokens to be equal to draft tokens.
# Expect all tokens to be equal to draft tokens.
...
@@ -143,8 +132,7 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
...
@@ -143,8 +132,7 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
def
test_no_crash_with_varying_dims
(
k
:
int
,
vocab_size
:
int
,
batch_size
:
int
,
def
test_no_crash_with_varying_dims
(
k
:
int
,
vocab_size
:
int
,
batch_size
:
int
,
device
:
str
,
use_flashinfer
:
bool
):
device
:
str
,
use_flashinfer
:
bool
):
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
rejection_sampler
=
RejectionSampler
(
disable_bonus_tokens
=
False
,
rejection_sampler
=
RejectionSampler
(
use_flashinfer
=
use_flashinfer
)
use_flashinfer
=
use_flashinfer
)
rejection_sampler
.
init_gpu_tensors
(
device
=
device
)
rejection_sampler
.
init_gpu_tensors
(
device
=
device
)
draft_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
)
draft_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
)
...
@@ -177,8 +165,7 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
...
@@ -177,8 +165,7 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
frac_seeded
:
float
,
n_rep
:
int
,
device
:
str
,
frac_seeded
:
float
,
n_rep
:
int
,
device
:
str
,
use_flashinfer
:
bool
):
use_flashinfer
:
bool
):
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
rejection_sampler
=
RejectionSampler
(
disable_bonus_tokens
=
False
,
rejection_sampler
=
RejectionSampler
(
use_flashinfer
=
use_flashinfer
)
use_flashinfer
=
use_flashinfer
)
rejection_sampler
.
init_gpu_tensors
(
device
=
device
)
rejection_sampler
.
init_gpu_tensors
(
device
=
device
)
draft_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
)
draft_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
)
...
@@ -251,8 +238,7 @@ def test_compare_nonflashinfer_backend(k: int, vocab_size: int,
...
@@ -251,8 +238,7 @@ def test_compare_nonflashinfer_backend(k: int, vocab_size: int,
}
}
for
use_flashinfer
in
[
True
,
False
]:
for
use_flashinfer
in
[
True
,
False
]:
rejection_sampler
=
RejectionSampler
(
disable_bonus_tokens
=
False
,
rejection_sampler
=
RejectionSampler
(
use_flashinfer
=
use_flashinfer
)
use_flashinfer
=
use_flashinfer
)
rejection_sampler
.
init_gpu_tensors
(
device
=
device
)
rejection_sampler
.
init_gpu_tensors
(
device
=
device
)
# We use seeded sequences to ensure the same tokens are accepted
# We use seeded sequences to ensure the same tokens are accepted
# for both flashinfer and nonflashinfer backends.
# for both flashinfer and nonflashinfer backends.
...
@@ -282,8 +268,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
...
@@ -282,8 +268,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
vocab_size
=
30_000
vocab_size
=
30_000
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
rejection_sampler
=
RejectionSampler
(
disable_bonus_tokens
=
False
,
rejection_sampler
=
RejectionSampler
(
use_flashinfer
=
use_flashinfer
,
use_flashinfer
=
use_flashinfer
,
strict_mode
=
True
)
strict_mode
=
True
)
rejection_sampler
.
init_gpu_tensors
(
device
=
device
)
rejection_sampler
.
init_gpu_tensors
(
device
=
device
)
...
@@ -359,8 +344,7 @@ def test_rejection_sampling_approximates_target_distribution(
...
@@ -359,8 +344,7 @@ def test_rejection_sampling_approximates_target_distribution(
set_random_seed
(
seed
)
set_random_seed
(
seed
)
helper
=
_CorrectnessTestHelper
(
helper
=
_CorrectnessTestHelper
(
vocab_size
=
10
,
vocab_size
=
10
,
rejection_sampler
=
RejectionSampler
(
disable_bonus_tokens
=
False
,
rejection_sampler
=
RejectionSampler
(
use_flashinfer
=
use_flashinfer
),
use_flashinfer
=
use_flashinfer
),
)
)
draft_probs
,
target_probs
,
reference_probs
=
helper
.
generate_probs_for_test
(
draft_probs
,
target_probs
,
reference_probs
=
helper
.
generate_probs_for_test
(
...
...
tests/samplers/test_typical_acceptance_sampler.py
View file @
c6bd70d7
...
@@ -55,14 +55,13 @@ def get_draft_token_ids(batch_size: int, k: int, vocab_size: int,
...
@@ -55,14 +55,13 @@ def get_draft_token_ids(batch_size: int, k: int, vocab_size: int,
def
get_acceptance_sampler
(
def
get_acceptance_sampler
(
posterior_threshold
:
float
=
0.03
,
posterior_threshold
:
float
=
0.03
,
posterior_alpha
:
float
=
0.9
,
posterior_alpha
:
float
=
0.9
,
disable_bonus_tokens
:
bool
=
False
,
strict_mode
:
bool
=
False
,
strict_mode
:
bool
=
False
,
)
->
TypicalAcceptanceSampler
:
)
->
TypicalAcceptanceSampler
:
"""
"""
Initializes and returns a TypicalAcceptanceSampler.
Initializes and returns a TypicalAcceptanceSampler.
"""
"""
return
TypicalAcceptanceSampler
(
posterior_threshold
,
posterior_alpha
,
return
TypicalAcceptanceSampler
(
posterior_threshold
,
posterior_alpha
,
disable_bonus_tokens
,
strict_mode
)
strict_mode
)
@
pytest
.
mark
.
parametrize
(
"k"
,
list
(
range
(
1
,
6
)))
@
pytest
.
mark
.
parametrize
(
"k"
,
list
(
range
(
1
,
6
)))
...
@@ -154,11 +153,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
...
@@ -154,11 +153,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
10
)))
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
10
)))
@
pytest
.
mark
.
parametrize
(
"disable_bonus_tokens"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_uniform_target_distribution_accepts_all_tokens
(
def
test_uniform_target_distribution_accepts_all_tokens
(
seed
:
int
,
disable_bonus_tokens
:
bool
,
device
:
str
):
seed
:
int
,
device
:
str
):
"""
"""
Test the TypicalAcceptanceSampler with a uniform target probability
Test the TypicalAcceptanceSampler with a uniform target probability
distribution.
distribution.
...
@@ -166,17 +164,14 @@ def test_uniform_target_distribution_accepts_all_tokens(
...
@@ -166,17 +164,14 @@ def test_uniform_target_distribution_accepts_all_tokens(
This test verifies that when provided with a uniform target probability
This test verifies that when provided with a uniform target probability
distribution, the TypicalAcceptanceSampler accepts all draft tokens. The
distribution, the TypicalAcceptanceSampler accepts all draft tokens. The
entropy of the uniform target distribution being high should lead to all
entropy of the uniform target distribution being high should lead to all
draft tokens being accepted. The test also ensures that the behavior
draft tokens being accepted.
regarding bonus tokens is consistent with the `disable_bonus_tokens`
flag.
"""
"""
set_random_seed
(
seed
)
set_random_seed
(
seed
)
k
=
3
k
=
3
batch_size
=
5
batch_size
=
5
vocab_size
=
30_000
vocab_size
=
30_000
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
typical_acceptance_sampler
=
get_acceptance_sampler
(
typical_acceptance_sampler
=
get_acceptance_sampler
(
strict_mode
=
True
)
strict_mode
=
True
,
disable_bonus_tokens
=
disable_bonus_tokens
)
typical_acceptance_sampler
.
init_gpu_tensors
(
device
=
device
)
typical_acceptance_sampler
.
init_gpu_tensors
(
device
=
device
)
target_with_bonus_probs
=
torch
.
rand
(
batch_size
,
target_with_bonus_probs
=
torch
.
rand
(
batch_size
,
k
+
1
,
k
+
1
,
...
@@ -200,21 +195,15 @@ def test_uniform_target_distribution_accepts_all_tokens(
...
@@ -200,21 +195,15 @@ def test_uniform_target_distribution_accepts_all_tokens(
# should lead to all draft tokens being accepted. Verify that.
# should lead to all draft tokens being accepted. Verify that.
assert
output_token_ids
.
shape
[
0
]
==
batch_size
assert
output_token_ids
.
shape
[
0
]
==
batch_size
assert
output_token_ids
.
shape
[
1
]
==
(
k
+
1
)
assert
output_token_ids
.
shape
[
1
]
==
(
k
+
1
)
if
disable_bonus_tokens
:
assert
torch
.
all
(
output_token_ids
[:,
-
1
]
==
-
1
)
else
:
assert
torch
.
all
(
output_token_ids
[:,
-
1
]
==
bonus_token_ids
.
squeeze
())
assert
torch
.
all
(
output_token_ids
[:,
-
1
]
==
bonus_token_ids
.
squeeze
())
assert
torch
.
all
(
output_token_ids
[:,
:
k
]
==
draft_token_ids
)
assert
torch
.
all
(
output_token_ids
[:,
:
k
]
==
draft_token_ids
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
10
)))
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
10
)))
@
pytest
.
mark
.
parametrize
(
"disable_bonus_tokens"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_temperature_zero_target_distribution
(
seed
:
int
,
def
test_temperature_zero_target_distribution
(
seed
:
int
,
device
:
str
):
disable_bonus_tokens
:
bool
,
device
:
str
):
"""
"""
Test the TypicalAcceptanceSampler with a zero-temperature target
Test the TypicalAcceptanceSampler with a zero-temperature target
probability distribution.
probability distribution.
...
@@ -232,8 +221,7 @@ def test_temperature_zero_target_distribution(seed: int,
...
@@ -232,8 +221,7 @@ def test_temperature_zero_target_distribution(seed: int,
vocab_size
=
30_000
vocab_size
=
30_000
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
typical_acceptance_sampler
=
get_acceptance_sampler
(
typical_acceptance_sampler
=
get_acceptance_sampler
(
strict_mode
=
True
)
strict_mode
=
True
,
disable_bonus_tokens
=
disable_bonus_tokens
)
typical_acceptance_sampler
.
init_gpu_tensors
(
device
=
device
)
typical_acceptance_sampler
.
init_gpu_tensors
(
device
=
device
)
# Simulate temperature 0 probability distribution for target probabilities
# Simulate temperature 0 probability distribution for target probabilities
# and create target probabilities such that only 1 token id has
# and create target probabilities such that only 1 token id has
...
@@ -267,11 +255,9 @@ def test_temperature_zero_target_distribution(seed: int,
...
@@ -267,11 +255,9 @@ def test_temperature_zero_target_distribution(seed: int,
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
10
)))
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
10
)))
@
pytest
.
mark
.
parametrize
(
"disable_bonus_tokens"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_mixed_target_distribution
(
seed
:
int
,
disable_bonus_tokens
:
bool
,
def
test_mixed_target_distribution
(
seed
:
int
,
device
:
str
):
device
:
str
):
"""
"""
Test the TypicalAcceptanceSampler with a mixed target probability
Test the TypicalAcceptanceSampler with a mixed target probability
distribution.
distribution.
...
@@ -285,16 +271,13 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
...
@@ -285,16 +271,13 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
with a probability of 1.0 is accepted, and all other tokens are rejected.
with a probability of 1.0 is accepted, and all other tokens are rejected.
- For sequences with a uniform distribution, all draft tokens are
- For sequences with a uniform distribution, all draft tokens are
accepted.
accepted.
- When `disable_bonus_tokens` is False, the bonus tokens are also accepted
for sequences with a uniform distribution.
"""
"""
set_random_seed
(
seed
)
set_random_seed
(
seed
)
k
=
3
k
=
3
batch_size
=
4
batch_size
=
4
vocab_size
=
30_000
vocab_size
=
30_000
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
typical_acceptance_sampler
=
get_acceptance_sampler
(
typical_acceptance_sampler
=
get_acceptance_sampler
(
strict_mode
=
True
)
strict_mode
=
True
,
disable_bonus_tokens
=
disable_bonus_tokens
)
typical_acceptance_sampler
.
init_gpu_tensors
(
device
=
device
)
typical_acceptance_sampler
.
init_gpu_tensors
(
device
=
device
)
# For sequences 0 and 2 set the distribution to a temperature
# For sequences 0 and 2 set the distribution to a temperature
# zero distribution. For sequences 1 and 3 set it to a uniform
# zero distribution. For sequences 1 and 3 set it to a uniform
...
@@ -328,21 +311,16 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
...
@@ -328,21 +311,16 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
0
]))
0
]))
# For sequences 1 and 3 verify that all tokens are accepted since the
# For sequences 1 and 3 verify that all tokens are accepted since the
# target probability distribution is uniform. In addition verify that
# target probability distribution is uniform. In addition verify that
#
if disable_bonus_tokens is false then
we also accept the bonus tokens.
# we also accept the bonus tokens.
assert
torch
.
all
(
assert
torch
.
all
(
output_token_ids
[[
1
,
3
],
:
-
1
]
==
draft_token_ids
[[
1
,
3
],
:])
output_token_ids
[[
1
,
3
],
:
-
1
]
==
draft_token_ids
[[
1
,
3
],
:])
if
disable_bonus_tokens
:
assert
torch
.
all
(
output_token_ids
[[
1
,
3
],
-
1
]
==
-
1
)
else
:
assert
torch
.
all
(
output_token_ids
[[
1
,
3
],
-
1
]
!=
-
1
)
assert
torch
.
all
(
output_token_ids
[[
1
,
3
],
-
1
]
!=
-
1
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
10
)))
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
10
)))
@
pytest
.
mark
.
parametrize
(
"disable_bonus_tokens"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_accept_tokens_partially
(
seed
:
int
,
disable_bonus_tokens
:
bool
,
def
test_accept_tokens_partially
(
seed
:
int
,
device
:
str
):
device
:
str
):
"""
"""
Test the TypicalAcceptanceSampler's behavior when only a subset of draft
Test the TypicalAcceptanceSampler's behavior when only a subset of draft
tokens should be accepted.
tokens should be accepted.
...
@@ -362,8 +340,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
...
@@ -362,8 +340,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
batch_size
=
1
batch_size
=
1
vocab_size
=
30_000
vocab_size
=
30_000
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
typical_acceptance_sampler
=
get_acceptance_sampler
(
typical_acceptance_sampler
=
get_acceptance_sampler
(
strict_mode
=
True
)
strict_mode
=
True
,
disable_bonus_tokens
=
disable_bonus_tokens
)
typical_acceptance_sampler
.
init_gpu_tensors
(
device
=
device
)
typical_acceptance_sampler
.
init_gpu_tensors
(
device
=
device
)
# Create a temperature zero target probability distribution and ensure
# Create a temperature zero target probability distribution and ensure
# all draft token ids correspond to the tokens with 1.0 probability.
# all draft token ids correspond to the tokens with 1.0 probability.
...
@@ -384,9 +361,6 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
...
@@ -384,9 +361,6 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
assert
output_token_ids
.
shape
[
0
]
==
batch_size
assert
output_token_ids
.
shape
[
0
]
==
batch_size
assert
output_token_ids
.
shape
[
1
]
==
(
k
+
1
)
assert
output_token_ids
.
shape
[
1
]
==
(
k
+
1
)
assert
torch
.
all
(
output_token_ids
[:,
0
:
-
1
]
==
draft_token_ids
)
assert
torch
.
all
(
output_token_ids
[:,
0
:
-
1
]
==
draft_token_ids
)
if
disable_bonus_tokens
:
assert
torch
.
all
(
output_token_ids
[:,
-
1
]
==
-
1
)
else
:
assert
torch
.
all
(
output_token_ids
[:,
-
1
]
==
bonus_token_ids
)
assert
torch
.
all
(
output_token_ids
[:,
-
1
]
==
bonus_token_ids
)
# Next only keep the first 2 draft tokens same as the zero temperature
# Next only keep the first 2 draft tokens same as the zero temperature
# tokens. For the remaining 3 choose some other tokens. In the
# tokens. For the remaining 3 choose some other tokens. In the
...
@@ -408,12 +382,9 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
...
@@ -408,12 +382,9 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
1
)))
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
1
)))
@
pytest
.
mark
.
parametrize
(
"disable_bonus_tokens"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_accept_tokens_set_non_default_posteriors
(
seed
:
int
,
def
test_accept_tokens_set_non_default_posteriors
(
seed
:
int
,
device
:
str
):
disable_bonus_tokens
:
bool
,
device
:
str
):
"""
"""
Test the TypicalAcceptanceSampler with custom posterior thresholds and
Test the TypicalAcceptanceSampler with custom posterior thresholds and
alpha values. This test verifies that by modifying the posterior
alpha values. This test verifies that by modifying the posterior
...
@@ -425,8 +396,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
...
@@ -425,8 +396,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
batch_size
=
1
batch_size
=
1
vocab_size
=
30_000
vocab_size
=
30_000
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
typical_acceptance_sampler
=
get_acceptance_sampler
(
typical_acceptance_sampler
=
get_acceptance_sampler
(
strict_mode
=
True
)
strict_mode
=
True
,
disable_bonus_tokens
=
disable_bonus_tokens
)
typical_acceptance_sampler
.
init_gpu_tensors
(
device
=
device
)
typical_acceptance_sampler
.
init_gpu_tensors
(
device
=
device
)
# Simulate temperature 0 probability distribution for target
# Simulate temperature 0 probability distribution for target
# probabilities and create target probabilities such that only 1 token
# probabilities and create target probabilities such that only 1 token
...
@@ -457,10 +427,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
...
@@ -457,10 +427,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
# now accept even draft tokens with very low probability in the
# now accept even draft tokens with very low probability in the
# target distribution. Simulate and verify the same.
# target distribution. Simulate and verify the same.
typical_acceptance_sampler
=
TypicalAcceptanceSampler
(
typical_acceptance_sampler
=
TypicalAcceptanceSampler
(
strict_mode
=
True
,
strict_mode
=
True
,
posterior_threshold
=
0.0
,
posterior_alpha
=
0.0
)
disable_bonus_tokens
=
disable_bonus_tokens
,
posterior_threshold
=
0.0
,
posterior_alpha
=
0.0
)
typical_acceptance_sampler
.
init_gpu_tensors
(
device
=
device
)
typical_acceptance_sampler
.
init_gpu_tensors
(
device
=
device
)
output_token_ids
=
typical_acceptance_sampler
(
output_token_ids
=
typical_acceptance_sampler
(
target_probs
,
target_probs
,
...
@@ -470,18 +437,13 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
...
@@ -470,18 +437,13 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
assert
output_token_ids
.
shape
[
0
]
==
batch_size
assert
output_token_ids
.
shape
[
0
]
==
batch_size
assert
output_token_ids
.
shape
[
1
]
==
(
k
+
1
)
assert
output_token_ids
.
shape
[
1
]
==
(
k
+
1
)
assert
torch
.
all
(
output_token_ids
[:,
0
:
-
1
]
==
draft_token_ids
)
assert
torch
.
all
(
output_token_ids
[:,
0
:
-
1
]
==
draft_token_ids
)
if
disable_bonus_tokens
:
assert
torch
.
all
(
output_token_ids
[:,
-
1
]
==
-
1
)
else
:
assert
torch
.
all
(
output_token_ids
[:,
-
1
]
==
bonus_token_ids
)
assert
torch
.
all
(
output_token_ids
[:,
-
1
]
==
bonus_token_ids
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
10
)))
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
10
)))
@
pytest
.
mark
.
parametrize
(
"disable_bonus_tokens"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_replacement_token_ids
(
seed
:
int
,
disable_bonus_tokens
:
bool
,
def
test_replacement_token_ids
(
seed
:
int
,
device
:
str
):
device
:
str
):
"""
"""
Test the TypicalAcceptanceSampler's method for generating
Test the TypicalAcceptanceSampler's method for generating
replacement token IDs.
replacement token IDs.
...
@@ -497,8 +459,7 @@ def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool,
...
@@ -497,8 +459,7 @@ def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool,
batch_size
=
5
batch_size
=
5
vocab_size
=
30_000
vocab_size
=
30_000
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
typical_acceptance_sampler
=
get_acceptance_sampler
(
typical_acceptance_sampler
=
get_acceptance_sampler
(
strict_mode
=
True
)
strict_mode
=
True
,
disable_bonus_tokens
=
disable_bonus_tokens
)
typical_acceptance_sampler
.
init_gpu_tensors
(
device
=
device
)
typical_acceptance_sampler
.
init_gpu_tensors
(
device
=
device
)
target_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
)
target_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
)
expected_replacement_tokens
=
-
torch
.
ones
(
expected_replacement_tokens
=
-
torch
.
ones
(
...
...
tests/spec_decode/e2e/test_medusa_correctness.py
View file @
c6bd70d7
...
@@ -31,7 +31,7 @@ MAIN_MODEL = "JackFram/llama-68m"
...
@@ -31,7 +31,7 @@ MAIN_MODEL = "JackFram/llama-68m"
# speculative model
# speculative model
SPEC_MODEL
=
"abhigoyal/vllm-medusa-llama-68m-random"
SPEC_MODEL
=
"abhigoyal/vllm-medusa-llama-68m-random"
# max
.
number of speculative tokens: this corresponds to
# max number of speculative tokens: this corresponds to
# num_heads in the config.json of the speculator model.
# num_heads in the config.json of the speculator model.
MAX_SPEC_TOKENS
=
5
MAX_SPEC_TOKENS
=
5
...
...
vllm/model_executor/layers/rejection_sampler.py
View file @
c6bd70d7
...
@@ -31,15 +31,11 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
...
@@ -31,15 +31,11 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
disable_bonus_tokens
:
bool
=
True
,
strict_mode
:
bool
=
False
,
strict_mode
:
bool
=
False
,
use_flashinfer
:
Optional
[
bool
]
=
None
):
use_flashinfer
:
Optional
[
bool
]
=
None
):
"""Create a rejection sampler.
"""Create a rejection sampler.
Args:
Args:
disable_bonus_tokens: Whether or not to disable the bonus token.
Require when bonus tokens will cause corrupt KV cache for
proposal methods that require KV cache.
strict_mode: Whether or not to perform shape/device/dtype checks
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
during sampling. This catches correctness issues but adds
nontrivial latency.
nontrivial latency.
...
@@ -48,8 +44,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
...
@@ -48,8 +44,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
None, we will use the default value from the environment variable.
None, we will use the default value from the environment variable.
This parameter is only used for testing purposes.
This parameter is only used for testing purposes.
"""
"""
super
().
__init__
(
disable_bonus_tokens
=
disable_bonus_tokens
,
super
().
__init__
(
strict_mode
=
strict_mode
)
strict_mode
=
strict_mode
)
if
use_flashinfer
is
None
:
if
use_flashinfer
is
None
:
self
.
use_flashinfer
=
envs
.
VLLM_USE_FLASHINFER_SAMPLER
and
(
self
.
use_flashinfer
=
envs
.
VLLM_USE_FLASHINFER_SAMPLER
and
(
chain_speculative_sampling
is
not
None
)
chain_speculative_sampling
is
not
None
)
...
@@ -57,8 +52,6 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
...
@@ -57,8 +52,6 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
self
.
use_flashinfer
=
use_flashinfer
self
.
use_flashinfer
=
use_flashinfer
if
self
.
use_flashinfer
:
if
self
.
use_flashinfer
:
assert
not
disable_bonus_tokens
,
\
"flashinfer will enable bonus token by default"
logger
.
info
(
"Use flashinfer for rejection sampling."
)
logger
.
info
(
"Use flashinfer for rejection sampling."
)
else
:
else
:
logger
.
info
(
"Use pytorch for rejection sampling."
)
logger
.
info
(
"Use pytorch for rejection sampling."
)
...
...
vllm/model_executor/layers/spec_decode_base_sampler.py
View file @
c6bd70d7
...
@@ -11,20 +11,14 @@ class SpecDecodeBaseSampler(nn.Module):
...
@@ -11,20 +11,14 @@ class SpecDecodeBaseSampler(nn.Module):
step.
step.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
strict_mode
:
bool
=
False
):
disable_bonus_tokens
:
bool
=
True
,
strict_mode
:
bool
=
False
):
"""Base class constructor.
"""Base class constructor.
Args:
Args:
disable_bonus_tokens: Whether or not to disable the bonus token.
Require when bonus tokens will cause corrupt KV cache for
proposal methods that require KV cache.
strict_mode: Whether or not to perform shape/device/dtype checks
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
during sampling. This catches correctness issues but adds
nontrivial latency.
nontrivial latency.
"""
"""
super
().
__init__
()
super
().
__init__
()
self
.
_disable_bonus_tokens
=
disable_bonus_tokens
self
.
_strict_mode
=
strict_mode
self
.
_strict_mode
=
strict_mode
# NOTE: A "bonus token" is accepted iff all proposal tokens are
# NOTE: A "bonus token" is accepted iff all proposal tokens are
...
@@ -111,13 +105,6 @@ class SpecDecodeBaseSampler(nn.Module):
...
@@ -111,13 +105,6 @@ class SpecDecodeBaseSampler(nn.Module):
output_with_bonus_tokens
[:,
-
1
]
=
torch
.
where
(
output
[:,
-
1
]
!=
-
1
,
output_with_bonus_tokens
[:,
-
1
]
=
torch
.
where
(
output
[:,
-
1
]
!=
-
1
,
bonus_token_ids
,
-
1
)
bonus_token_ids
,
-
1
)
# We disable bonus tokens because it causes corrupt KV cache for
# proposal methods that require KV cache. We can fix it by "prefilling"
# the bonus token in the proposer. The following issue tracks the fix.
# https://github.com/vllm-project/vllm/issues/4212
if
self
.
_disable_bonus_tokens
:
output_with_bonus_tokens
[:,
-
1
]
=
-
1
# Fill the recovered token ids.
# Fill the recovered token ids.
output
.
mul_
(
~
after_false_mask
).
add_
(
output
.
mul_
(
~
after_false_mask
).
add_
(
substitute_token_ids
.
mul
(
after_false_mask
))
substitute_token_ids
.
mul
(
after_false_mask
))
...
...
vllm/model_executor/layers/typical_acceptance_sampler.py
View file @
c6bd70d7
...
@@ -16,15 +16,11 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
...
@@ -16,15 +16,11 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
self
,
self
,
posterior_threshold
:
float
,
posterior_threshold
:
float
,
posterior_alpha
:
float
,
posterior_alpha
:
float
,
disable_bonus_tokens
:
bool
=
False
,
strict_mode
:
bool
=
False
,
strict_mode
:
bool
=
False
,
):
):
"""Create a Typical Acceptance Sampler.
"""Create a Typical Acceptance Sampler.
Args:
Args:
disable_bonus_tokens: Whether or not to disable the bonus token.
Require when bonus tokens will cause corrupt KV cache for
proposal methods that require KV cache.
strict_mode: Whether or not to perform shape/device/dtype checks
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
during sampling. This catches correctness issues but adds
nontrivial latency.
nontrivial latency.
...
@@ -36,8 +32,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
...
@@ -36,8 +32,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
"""
"""
self
.
_posterior_threshold
=
posterior_threshold
self
.
_posterior_threshold
=
posterior_threshold
self
.
_posterior_alpha
=
posterior_alpha
self
.
_posterior_alpha
=
posterior_alpha
super
().
__init__
(
disable_bonus_tokens
=
disable_bonus_tokens
,
super
().
__init__
(
strict_mode
=
strict_mode
)
strict_mode
=
strict_mode
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -54,7 +49,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
...
@@ -54,7 +49,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
one token will be emitted.
one token will be emitted.
In the case where all draft tokens are accepted, the bonus token will be
In the case where all draft tokens are accepted, the bonus token will be
accepted
conditioned on self._disable_bonus_tokens being false
.
accepted.
Args:
Args:
target_probs: The probability distribution over token ids given
target_probs: The probability distribution over token ids given
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
c6bd70d7
...
@@ -164,11 +164,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -164,11 +164,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
spec_decode_sampler
:
SpecDecodeBaseSampler
=
None
spec_decode_sampler
:
SpecDecodeBaseSampler
=
None
if
draft_token_acceptance_method
==
"rejection_sampler"
:
if
draft_token_acceptance_method
==
"rejection_sampler"
:
spec_decode_sampler
=
RejectionSampler
(
spec_decode_sampler
=
RejectionSampler
()
disable_bonus_tokens
=
False
,
)
elif
draft_token_acceptance_method
==
"typical_acceptance_sampler"
:
elif
draft_token_acceptance_method
==
"typical_acceptance_sampler"
:
spec_decode_sampler
=
TypicalAcceptanceSampler
(
spec_decode_sampler
=
TypicalAcceptanceSampler
(
disable_bonus_tokens
=
False
,
posterior_threshold
=
\
posterior_threshold
=
\
typical_acceptance_sampler_posterior_threshold
,
typical_acceptance_sampler_posterior_threshold
,
posterior_alpha
=
typical_acceptance_sampler_posterior_alpha
,
posterior_alpha
=
typical_acceptance_sampler_posterior_alpha
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment