Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
80ca1e6a
Unverified
Commit
80ca1e6a
authored
Jul 01, 2024
by
sroy745
Committed by
GitHub
Jul 01, 2024
Browse files
[Speculative Decoding 2/2 ] Integrate typical acceptance sampler into Spec Decode Worker (#5348)
parent
614aa512
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
482 additions
and
210 deletions
+482
-210
tests/samplers/test_typical_acceptance_sampler.py
tests/samplers/test_typical_acceptance_sampler.py
+64
-32
tests/spec_decode/e2e/test_multistep_correctness.py
tests/spec_decode/e2e/test_multistep_correctness.py
+53
-1
tests/spec_decode/test_dynamic_spec_decode.py
tests/spec_decode/test_dynamic_spec_decode.py
+7
-5
tests/spec_decode/test_metrics.py
tests/spec_decode/test_metrics.py
+47
-47
tests/spec_decode/test_spec_decode_worker.py
tests/spec_decode/test_spec_decode_worker.py
+85
-69
tests/spec_decode/test_utils.py
tests/spec_decode/test_utils.py
+22
-0
vllm/config.py
vllm/config.py
+73
-2
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+41
-1
vllm/engine/metrics.py
vllm/engine/metrics.py
+1
-1
vllm/model_executor/layers/rejection_sampler.py
vllm/model_executor/layers/rejection_sampler.py
+9
-9
vllm/model_executor/layers/spec_decode_base_sampler.py
vllm/model_executor/layers/spec_decode_base_sampler.py
+14
-1
vllm/model_executor/layers/typical_acceptance_sampler.py
vllm/model_executor/layers/typical_acceptance_sampler.py
+10
-12
vllm/spec_decode/metrics.py
vllm/spec_decode/metrics.py
+13
-11
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+43
-19
No files found.
tests/samplers/test_typical_acceptance_sampler.py
View file @
80ca1e6a
...
...
@@ -52,6 +52,19 @@ def get_draft_token_ids(batch_size: int, k: int, vocab_size: int,
return
draft_token_ids
def
get_acceptance_sampler
(
posterior_threshold
:
float
=
0.03
,
posterior_alpha
:
float
=
0.9
,
disable_bonus_tokens
:
bool
=
False
,
strict_mode
:
bool
=
False
,
)
->
TypicalAcceptanceSampler
:
"""
Initializes and returns a TypicalAcceptanceSampler.
"""
return
TypicalAcceptanceSampler
(
posterior_threshold
,
posterior_alpha
,
disable_bonus_tokens
,
strict_mode
)
@
pytest
.
mark
.
parametrize
(
"k"
,
list
(
range
(
1
,
6
)))
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
30_000
,
50_000
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
list
(
range
(
1
,
32
)))
...
...
@@ -64,7 +77,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
different combinations of k, vocab_size, batch_size and num devices.
"""
torch
.
set_default_device
(
device
)
typical_acceptance_sampler
=
TypicalA
cceptance
S
ampler
()
typical_acceptance_sampler
=
get_a
cceptance
_s
ampler
()
typical_acceptance_sampler
.
init_gpu_tensors
(
rank
=
0
)
target_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
)
bonus_token_ids
=
torch
.
randint
(
low
=
0
,
...
...
@@ -76,7 +89,10 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
size
=
(
batch_size
,
k
),
dtype
=
torch
.
int64
)
# Verify that sampling succeeds for all cases.
typical_acceptance_sampler
(
target_probs
,
bonus_token_ids
,
draft_token_ids
)
typical_acceptance_sampler
(
target_probs
,
bonus_token_ids
,
draft_probs
=
None
,
draft_token_ids
=
draft_token_ids
)
@
pytest
.
mark
.
parametrize
(
"above_or_below_vocab_range"
,
[
"above"
,
"below"
])
...
...
@@ -94,7 +110,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
batch_size
=
5
vocab_size
=
30_000
torch
.
set_default_device
(
device
)
typical_acceptance_sampler
=
TypicalA
cceptance
S
ampler
(
strict_mode
=
True
)
typical_acceptance_sampler
=
get_a
cceptance
_s
ampler
(
strict_mode
=
True
)
typical_acceptance_sampler
.
init_gpu_tensors
(
rank
=
0
)
target_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
)
bonus_token_ids
=
torch
.
randint
(
low
=
0
,
...
...
@@ -125,8 +141,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
oob_token_ids
[
0
][
0
]
=
rogue_token_id
with
pytest
.
raises
(
AssertionError
):
typical_acceptance_sampler
(
target_probs
,
bonus_token_ids
,
draft_token_ids
)
typical_acceptance_sampler
(
target_probs
,
bonus_token_ids
,
draft_probs
=
None
,
draft_token_ids
=
draft_token_ids
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
10
)))
...
...
@@ -151,7 +169,7 @@ def test_uniform_target_distribution_accepts_all_tokens(
batch_size
=
5
vocab_size
=
30_000
torch
.
set_default_device
(
device
)
typical_acceptance_sampler
=
TypicalA
cceptance
S
ampler
(
typical_acceptance_sampler
=
get_a
cceptance
_s
ampler
(
strict_mode
=
True
,
disable_bonus_tokens
=
disable_bonus_tokens
)
typical_acceptance_sampler
.
init_gpu_tensors
(
rank
=
0
)
target_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
)
...
...
@@ -163,9 +181,11 @@ def test_uniform_target_distribution_accepts_all_tokens(
high
=
vocab_size
,
size
=
(
batch_size
,
1
),
dtype
=
torch
.
int64
)
output_token_ids
=
typical_acceptance_sampler
(
target_probs
,
bonus_token_ids
,
draft_token_ids
)
output_token_ids
=
typical_acceptance_sampler
(
target_probs
,
bonus_token_ids
,
draft_probs
=
None
,
draft_token_ids
=
draft_token_ids
)
# We are using a uniform target probability distribution.
# For a uniform distribution the entropy is very high and it
# should lead to all draft tokens being accepted. Verify that.
...
...
@@ -203,7 +223,7 @@ def test_temperature_zero_target_distribution(seed: int,
vocab_size
=
30_000
torch
.
set_default_device
(
device
)
typical_acceptance_sampler
=
TypicalA
cceptance
S
ampler
(
typical_acceptance_sampler
=
get_a
cceptance
_s
ampler
(
strict_mode
=
True
,
disable_bonus_tokens
=
disable_bonus_tokens
)
typical_acceptance_sampler
.
init_gpu_tensors
(
rank
=
0
)
# Simulate temperature 0 probability distribution for target probabilities
...
...
@@ -224,9 +244,11 @@ def test_temperature_zero_target_distribution(seed: int,
# 1.0 tokens in the target distribution we will reject all of them and
# fallback to the greedy sampling for selecting 1 token for each sequence.
# Verify the same.
output_token_ids
=
typical_acceptance_sampler
(
target_probs
,
bonus_token_ids
,
draft_token_ids
)
output_token_ids
=
typical_acceptance_sampler
(
target_probs
,
bonus_token_ids
,
draft_probs
=
None
,
draft_token_ids
=
draft_token_ids
)
assert
output_token_ids
.
shape
[
0
]
==
batch_size
assert
output_token_ids
.
shape
[
1
]
==
(
k
+
1
)
assert
torch
.
all
(
output_token_ids
[:,
-
1
]
==
-
1
)
...
...
@@ -261,7 +283,7 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
batch_size
=
4
vocab_size
=
30_000
torch
.
set_default_device
(
device
)
typical_acceptance_sampler
=
TypicalA
cceptance
S
ampler
(
typical_acceptance_sampler
=
get_a
cceptance
_s
ampler
(
strict_mode
=
True
,
disable_bonus_tokens
=
disable_bonus_tokens
)
typical_acceptance_sampler
.
init_gpu_tensors
(
rank
=
0
)
# For sequences 0 and 2 set the distribution to a temperature
...
...
@@ -277,9 +299,11 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
high
=
vocab_size
,
size
=
(
batch_size
,
1
),
dtype
=
torch
.
int64
)
output_token_ids
=
typical_acceptance_sampler
(
target_probs
,
bonus_token_ids
,
draft_token_ids
)
output_token_ids
=
typical_acceptance_sampler
(
target_probs
,
bonus_token_ids
,
draft_probs
=
None
,
draft_token_ids
=
draft_token_ids
)
# verify the shape of output_token_ids
assert
output_token_ids
.
shape
[
0
]
==
batch_size
assert
output_token_ids
.
shape
[
1
]
==
(
k
+
1
)
...
...
@@ -326,7 +350,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
batch_size
=
1
vocab_size
=
30_000
torch
.
set_default_device
(
device
)
typical_acceptance_sampler
=
TypicalA
cceptance
S
ampler
(
typical_acceptance_sampler
=
get_a
cceptance
_s
ampler
(
strict_mode
=
True
,
disable_bonus_tokens
=
disable_bonus_tokens
)
typical_acceptance_sampler
.
init_gpu_tensors
(
rank
=
0
)
# Create a temperature zero target probability distribution and ensure
...
...
@@ -339,9 +363,11 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
high
=
vocab_size
,
size
=
(
batch_size
,
1
),
dtype
=
torch
.
int64
)
output_token_ids
=
typical_acceptance_sampler
(
target_probs
,
bonus_token_ids
,
draft_token_ids
)
output_token_ids
=
typical_acceptance_sampler
(
target_probs
,
bonus_token_ids
,
draft_probs
=
None
,
draft_token_ids
=
draft_token_ids
)
assert
output_token_ids
.
shape
[
0
]
==
batch_size
assert
output_token_ids
.
shape
[
1
]
==
(
k
+
1
)
assert
torch
.
all
(
output_token_ids
[:,
0
:
-
1
]
==
draft_token_ids
)
...
...
@@ -357,9 +383,11 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
batch_size
,
k
,
vocab_size
,
zero_temperature_token_ids
)
draft_token_ids
=
torch
.
cat
(
(
draft_token_ids
[:,
:
2
],
draft_token_ids_to_replace
[:,
-
3
:]),
dim
=
1
)
output_token_ids
=
typical_acceptance_sampler
(
target_probs
,
bonus_token_ids
,
draft_token_ids
)
output_token_ids
=
typical_acceptance_sampler
(
target_probs
,
bonus_token_ids
,
draft_probs
=
None
,
draft_token_ids
=
draft_token_ids
)
assert
output_token_ids
.
shape
[
0
]
==
batch_size
assert
output_token_ids
.
shape
[
1
]
==
(
k
+
1
)
assert
torch
.
all
(
output_token_ids
[:,
:
2
]
==
draft_token_ids
[:,
:
2
])
...
...
@@ -384,7 +412,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
batch_size
=
1
vocab_size
=
30_000
torch
.
set_default_device
(
device
)
typical_acceptance_sampler
=
TypicalA
cceptance
S
ampler
(
typical_acceptance_sampler
=
get_a
cceptance
_s
ampler
(
strict_mode
=
True
,
disable_bonus_tokens
=
disable_bonus_tokens
)
typical_acceptance_sampler
.
init_gpu_tensors
(
rank
=
0
)
# Simulate temperature 0 probability distribution for target
...
...
@@ -402,9 +430,11 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
high
=
vocab_size
,
size
=
(
batch_size
,
1
),
dtype
=
torch
.
int64
)
output_token_ids
=
typical_acceptance_sampler
(
target_probs
,
bonus_token_ids
,
draft_token_ids
)
output_token_ids
=
typical_acceptance_sampler
(
target_probs
,
bonus_token_ids
,
draft_probs
=
None
,
draft_token_ids
=
draft_token_ids
)
assert
output_token_ids
.
shape
[
0
]
==
batch_size
assert
output_token_ids
.
shape
[
1
]
==
(
k
+
1
)
assert
torch
.
all
(
output_token_ids
[:,
1
:
-
1
]
==
-
1
)
...
...
@@ -418,9 +448,11 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
posterior_threshold
=
0.0
,
posterior_alpha
=
0.0
)
typical_acceptance_sampler
.
init_gpu_tensors
(
rank
=
0
)
output_token_ids
=
typical_acceptance_sampler
(
target_probs
,
bonus_token_ids
,
draft_token_ids
)
output_token_ids
=
typical_acceptance_sampler
(
target_probs
,
bonus_token_ids
,
draft_probs
=
None
,
draft_token_ids
=
draft_token_ids
)
assert
output_token_ids
.
shape
[
0
]
==
batch_size
assert
output_token_ids
.
shape
[
1
]
==
(
k
+
1
)
assert
torch
.
all
(
output_token_ids
[:,
0
:
-
1
]
==
draft_token_ids
)
...
...
@@ -451,7 +483,7 @@ def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool,
batch_size
=
5
vocab_size
=
30_000
torch
.
set_default_device
(
device
)
typical_acceptance_sampler
=
TypicalA
cceptance
S
ampler
(
typical_acceptance_sampler
=
get_a
cceptance
_s
ampler
(
strict_mode
=
True
,
disable_bonus_tokens
=
disable_bonus_tokens
)
typical_acceptance_sampler
.
init_gpu_tensors
(
rank
=
0
)
target_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
)
...
...
tests/spec_decode/e2e/test_multistep_correctness.py
View file @
80ca1e6a
...
...
@@ -11,9 +11,15 @@ distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality. This gives us good coverage of temp=0.
At temp=0, the TypicalAcceptanceSampler ensures that only the tokens with the
highest probability in the target distribution are accepted. Therefore, we can
expect greedy equality for the TypicalAcceptanceSampler at temp=0.
For temp>0, we rely on unit tests on the rejection sampler to verify that the
output distribution is the same with spec decode vs. no spec decode (this would
be prohibitively expensive to run with a real model).
be prohibitively expensive to run with a real model). Similarly, for the
TypicalAcceptance sampler also, we rely on unit tests to validate temp>0
test cases.
NOTE: Speculative decoding's distribution equality requires that the measured
distributions of the target model and proposal model be deterministic given the
...
...
@@ -611,3 +617,49 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-160m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
k
,
"spec_decoding_acceptance_method"
:
"typical_acceptance_sampler"
}
# Try a range of common k.
for
k
in
[
1
,
2
,
3
]
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_typical_acceptance_sampling
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify that speculative decoding produces exact equality to without spec
decode with TypicalAcceptanceSampler as the draft token acceptance
sampling method.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
tests/spec_decode/test_dynamic_spec_decode.py
View file @
80ca1e6a
...
...
@@ -3,33 +3,35 @@ from unittest.mock import MagicMock, patch
import
pytest
import
torch
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.spec_decode.metrics
import
AsyncMetricsCollector
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.spec_decode_worker
import
SpecDecodeWorker
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
.test_utils
import
mock_spec_decode_sampler
from
.utils
import
create_batch
,
mock_worker
@
pytest
.
mark
.
parametrize
(
'queue_size'
,
[
4
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
1
])
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"acceptance_sampler_method"
,
[
"rejection_sampler"
,
"typical_acceptance_sampler"
])
@
torch
.
inference_mode
()
def
test_disable_spec_tokens
(
queue_size
:
int
,
batch_size
:
int
,
k
:
int
):
def
test_disable_spec_tokens
(
queue_size
:
int
,
batch_size
:
int
,
k
:
int
,
acceptance_sampler_method
:
str
):
"""Verify that speculative tokens are disabled when the batch size
exceeds the threshold.
"""
disable_by_batch_size
=
3
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
)
target_worker
=
mock_worker
()
rejection_sampler
=
MagicMock
(
spec
=
RejectionSampler
)
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
worker
=
SpecDecodeWorker
(
proposer_worker
=
draft_worker
,
scorer_worker
=
target_worker
,
rejection_sampler
=
rejection_sampler
,
spec_decode_sampler
=
mock_spec_decode_sampler
(
acceptance_sampler_method
),
metrics_collector
=
metrics_collector
,
disable_by_batch_size
=
disable_by_batch_size
)
...
...
tests/spec_decode/test_metrics.py
View file @
80ca1e6a
...
...
@@ -10,16 +10,16 @@ from vllm.spec_decode.metrics import AsyncMetricsCollector
def
test_initial_call_returns_none
():
"""Expect first call to get metrics to return None.
"""
rej
_sampler
=
MagicMock
()
rej
_sampler
.
num_accepted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
rej
_sampler
.
num_emitted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
rej
_sampler
.
num_draft_tokens
=
0
collector
=
AsyncMetricsCollector
(
rej
_sampler
)
spec_decode
_sampler
=
MagicMock
()
spec_decode
_sampler
.
num_accepted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
spec_decode
_sampler
.
num_emitted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
spec_decode
_sampler
.
num_draft_tokens
=
0
collector
=
AsyncMetricsCollector
(
spec_decode
_sampler
)
collector
.
init_gpu_tensors
(
rank
=
0
)
maybe_metrics
=
collector
.
maybe_collect_rejsample_metrics
(
k
=
5
)
assert
maybe_metrics
is
None
...
...
@@ -28,14 +28,14 @@ def test_initial_call_returns_none():
def
test_second_call_returns_metrics
():
"""Expect second call to not return None.
"""
rej
_sampler
=
MagicMock
()
rej
_sampler
.
num_accepted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
rej
_sampler
.
num_emitted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
rej
_sampler
.
num_draft_tokens
=
0
spec_decode
_sampler
=
MagicMock
()
spec_decode
_sampler
.
num_accepted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
spec_decode
_sampler
.
num_emitted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
spec_decode
_sampler
.
num_draft_tokens
=
0
collect_interval_s
=
5.0
timer
=
MagicMock
()
...
...
@@ -43,7 +43,7 @@ def test_second_call_returns_metrics():
0.0
,
collect_interval_s
+
0.1
,
collect_interval_s
+
0.2
]
collector
=
AsyncMetricsCollector
(
rejection_sampler
=
rej
_sampler
,
collector
=
AsyncMetricsCollector
(
spec_decode_sampler
=
spec_decode
_sampler
,
timer
=
timer
,
collect_interval_s
=
collect_interval_s
)
collector
.
init_gpu_tensors
(
rank
=
0
)
...
...
@@ -56,16 +56,16 @@ def test_second_call_returns_metrics():
def
test_nonzero_rank_noop
(
rank
):
"""Verify nonzero ranks don't collect metrics.
"""
rej
_sampler
=
MagicMock
()
rej
_sampler
.
num_accepted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
rej
_sampler
.
num_emitted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
rej
_sampler
.
num_draft_tokens
=
0
collector
=
AsyncMetricsCollector
(
rej
_sampler
)
spec_decode
_sampler
=
MagicMock
()
spec_decode
_sampler
.
num_accepted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
spec_decode
_sampler
.
num_emitted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
spec_decode
_sampler
.
num_draft_tokens
=
0
collector
=
AsyncMetricsCollector
(
spec_decode
_sampler
)
collector
.
init_gpu_tensors
(
rank
=
rank
)
_
=
collector
.
maybe_collect_rejsample_metrics
(
k
=
5
)
metrics
=
collector
.
maybe_collect_rejsample_metrics
(
k
=
5
)
...
...
@@ -75,14 +75,14 @@ def test_nonzero_rank_noop(rank):
def
test_noop_until_time
():
"""Verify metrics aren't collected until enough time passes.
"""
rej
_sampler
=
MagicMock
()
rej
_sampler
.
num_accepted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
rej
_sampler
.
num_emitted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
rej
_sampler
.
num_draft_tokens
=
0
spec_decode
_sampler
=
MagicMock
()
spec_decode
_sampler
.
num_accepted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
spec_decode
_sampler
.
num_emitted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
spec_decode
_sampler
.
num_draft_tokens
=
0
collect_interval_s
=
5.0
timer
=
MagicMock
()
...
...
@@ -91,7 +91,7 @@ def test_noop_until_time():
collect_interval_s
+
0.1
,
collect_interval_s
+
0.1
]
collector
=
AsyncMetricsCollector
(
rejection_sampler
=
rej
_sampler
,
collector
=
AsyncMetricsCollector
(
spec_decode_sampler
=
spec_decode
_sampler
,
timer
=
timer
,
collect_interval_s
=
collect_interval_s
)
collector
.
init_gpu_tensors
(
rank
=
0
)
...
...
@@ -122,14 +122,14 @@ def test_initial_metrics_has_correct_values(has_data: bool):
max_num_emitted_tokens
=
AsyncMetricsCollector
.
get_max_num_emitted_tokens
(
num_draft_tokens
,
k
)
rej
_sampler
=
MagicMock
()
rej
_sampler
.
num_accepted_tokens
=
torch
.
tensor
(
num_accepted_tokens
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
rej
_sampler
.
num_emitted_tokens
=
torch
.
tensor
(
num_emitted_tokens
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
rej
_sampler
.
num_draft_tokens
=
num_draft_tokens
spec_decode
_sampler
=
MagicMock
()
spec_decode
_sampler
.
num_accepted_tokens
=
torch
.
tensor
(
num_accepted_tokens
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
spec_decode
_sampler
.
num_emitted_tokens
=
torch
.
tensor
(
num_emitted_tokens
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
spec_decode
_sampler
.
num_draft_tokens
=
num_draft_tokens
collect_interval_s
=
5.0
timer
=
MagicMock
()
...
...
@@ -137,7 +137,7 @@ def test_initial_metrics_has_correct_values(has_data: bool):
0.0
,
collect_interval_s
+
0.1
,
collect_interval_s
+
0.2
]
collector
=
AsyncMetricsCollector
(
rejection_sampler
=
rej
_sampler
,
collector
=
AsyncMetricsCollector
(
spec_decode_sampler
=
spec_decode
_sampler
,
timer
=
timer
,
collect_interval_s
=
collect_interval_s
)
collector
.
init_gpu_tensors
(
rank
=
0
)
...
...
tests/spec_decode/test_spec_decode_worker.py
View file @
80ca1e6a
...
...
@@ -6,7 +6,6 @@ from unittest.mock import MagicMock
import
pytest
import
torch
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
,
SequenceOutput
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
...
...
@@ -16,23 +15,26 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker
from
vllm.spec_decode.spec_decode_worker
import
(
SpecDecodeWorker
,
split_num_cache_blocks_evenly
)
from
.test_utils
import
mock_spec_decode_sampler
from
.utils
import
create_batch
,
create_sampler_output_list
,
mock_worker
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
,
6
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"acceptance_sampler_method"
,
[
"rejection_sampler"
,
"typical_acceptance_sampler"
])
@
torch
.
inference_mode
()
def
test_correctly_calls_draft_model
(
k
:
int
,
batch_size
:
int
):
def
test_correctly_calls_draft_model
(
k
:
int
,
batch_size
:
int
,
acceptance_sampler_method
:
str
):
"""Verify SpecDecodeWorker calls the draft worker with correct
inputs. Everything else is mocked out.
"""
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
)
target_worker
=
mock_worker
()
rejection_sampler
=
MagicMock
(
spec
=
RejectionSampler
)
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
rejection_sampler
,
metrics_collector
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
mock_spec_decode_sampler
(
acceptance_sampler_method
),
metrics_collector
)
exception_secret
=
'artificial stop'
draft_worker
.
get_spec_proposals
.
side_effect
=
ValueError
(
exception_secret
)
...
...
@@ -53,15 +55,16 @@ def test_correctly_calls_draft_model(k: int, batch_size: int):
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
,
6
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"acceptance_sampler_method"
,
[
"rejection_sampler"
,
"typical_acceptance_sampler"
])
@
torch
.
inference_mode
()
def
test_correctly_calls_target_model
(
k
:
int
,
batch_size
:
int
):
def
test_correctly_calls_target_model
(
k
:
int
,
batch_size
:
int
,
acceptance_sampler_method
:
str
):
"""Verify SpecDecodeWorker calls the target model with correct
inputs. Everything else is mocked out.
"""
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
,
use_spec
=
False
)
target_worker
=
mock_worker
(
use_spec
=
False
)
rejection_sampler
=
MagicMock
(
spec
=
RejectionSampler
)
rejection_sampler
.
token_id_dtype
=
torch
.
int64
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
draft_worker
.
device
=
'cuda'
...
...
@@ -69,8 +72,9 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
set_random_seed
(
1
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
rejection_sampler
,
metrics_collector
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
mock_spec_decode_sampler
(
acceptance_sampler_method
),
metrics_collector
)
worker
.
init_device
()
vocab_size
=
32_000
...
...
@@ -133,8 +137,11 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
,
6
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"acceptance_sampler_method"
,
[
"rejection_sampler"
,
"typical_acceptance_sampler"
])
@
torch
.
inference_mode
()
def
test_correctly_calls_rejection_sampler
(
k
:
int
,
batch_size
:
int
):
def
test_correctly_calls_spec_decode_sampler
(
k
:
int
,
batch_size
:
int
,
acceptance_sampler_method
:
str
):
"""Verify SpecDecodeWorker calls the rejection sampler with
correct inputs. Everything else is mocked out.
"""
...
...
@@ -144,15 +151,14 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
vocab_size
=
vocab_size
,
use_spec
=
False
)
target_worker
=
mock_worker
(
vocab_size
=
vocab_size
,
use_spec
=
False
)
rejection_sampler
=
MagicMock
(
spec
=
RejectionSampler
)
rejection_sampler
.
token_id_dtype
=
torch
.
int64
spec_decode_sampler
=
mock_spec_decode_sampler
(
acceptance_sampler_method
)
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
draft_worker
.
device
=
'cuda'
target_worker
.
device
=
'cuda'
set_random_seed
(
1
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
rejection
_sampler
,
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
spec_decode
_sampler
,
metrics_collector
)
worker
.
init_device
()
...
...
@@ -199,15 +205,16 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
target_worker
.
execute_model
.
return_value
=
[
target_output
[
0
]]
exception_secret
=
'artificial stop'
rejection_sampler
.
side_effect
=
ValueError
(
exception_secret
)
spec_decode_sampler
.
side_effect
=
ValueError
(
exception_secret
)
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
worker
.
execute_model
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
))
assert
len
(
rejection
_sampler
.
call_args_list
)
==
1
_
,
kwargs
=
rejection
_sampler
.
call_args_list
[
0
]
assert
len
(
spec_decode
_sampler
.
call_args_list
)
==
1
_
,
kwargs
=
spec_decode
_sampler
.
call_args_list
[
0
]
actual
=
SimpleNamespace
(
**
kwargs
)
assert
torch
.
equal
(
actual
.
bonus_token_ids
,
...
...
@@ -221,8 +228,11 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
,
6
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"acceptance_sampler_method"
,
[
"rejection_sampler"
,
"typical_acceptance_sampler"
])
@
torch
.
inference_mode
()
def
test_correctly_formats_output
(
k
:
int
,
batch_size
:
int
):
def
test_correctly_formats_output
(
k
:
int
,
batch_size
:
int
,
acceptance_sampler_method
:
str
):
"""Verify SpecDecodeWorker formats sampler output correctly.
Everything else is mocked out.
"""
...
...
@@ -232,15 +242,13 @@ def test_correctly_formats_output(k: int, batch_size: int):
vocab_size
=
vocab_size
,
use_spec
=
False
)
target_worker
=
mock_worker
(
vocab_size
=
vocab_size
,
use_spec
=
False
)
rejection_sampler
=
MagicMock
(
spec
=
RejectionSampler
)
rejection_sampler
.
token_id_dtype
=
torch
.
int64
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
draft_worker
.
device
=
'cuda'
target_worker
.
device
=
'cuda'
set_random_seed
(
1
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
rejection
_sampler
,
spec_decode_sampler
=
mock_spec_decode_sampler
(
acceptance_sampler_method
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
spec_decode
_sampler
,
metrics_collector
)
worker
.
init_device
()
...
...
@@ -286,24 +294,23 @@ def test_correctly_formats_output(k: int, batch_size: int):
target_worker
.
execute_model
.
return_value
=
[
target_output
[
0
]]
rejection
_sampler_output
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
k
+
1
),
dtype
=
torch
.
int64
,
device
=
'cuda'
)
spec_decode
_sampler_output
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
k
+
1
),
dtype
=
torch
.
int64
,
device
=
'cuda'
)
for
i
in
range
(
batch_size
):
minimum_accepted_tokens
=
1
rejection
_sampler_output
[
i
][
spec_decode
_sampler_output
[
i
][
-
random
.
randint
(
minimum_accepted_tokens
,
k
+
1
):]
=
-
1
rejection_sampler
.
return_value
=
rejection_sampler_output
spec_decode_sampler
.
return_value
=
spec_decode_sampler_output
output
=
worker
.
execute_model
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
))
expected_output
=
create_sampler_output_list
(
token_ids
=
rejection
_sampler_output
.
transpose
(
0
,
1
),
token_ids
=
spec_decode
_sampler_output
.
transpose
(
0
,
1
),
probs
=
[
None
for
_
in
range
(
k
+
1
)],
logprobs
=
[
None
for
_
in
range
(
k
+
1
)])
...
...
@@ -350,8 +357,11 @@ def test_correctly_formats_output(k: int, batch_size: int):
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
1
])
@
pytest
.
mark
.
parametrize
(
'returns_metrics'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"acceptance_sampler_method"
,
[
"rejection_sampler"
,
"typical_acceptance_sampler"
])
@
torch
.
inference_mode
()
def
test_collects_metrics
(
k
:
int
,
batch_size
:
int
,
returns_metrics
:
bool
):
def
test_collects_metrics
(
k
:
int
,
batch_size
:
int
,
returns_metrics
:
bool
,
acceptance_sampler_method
:
str
):
"""Verify SpecDecodeWorker collects metrics.
"""
vocab_size
=
32_000
...
...
@@ -360,15 +370,14 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
vocab_size
=
vocab_size
,
use_spec
=
False
)
target_worker
=
mock_worker
(
vocab_size
=
vocab_size
,
use_spec
=
False
)
rejection_sampler
=
MagicMock
(
spec
=
RejectionSampler
)
rejection_sampler
.
token_id_dtype
=
torch
.
int64
spec_decode_sampler
=
mock_spec_decode_sampler
(
acceptance_sampler_method
)
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
draft_worker
.
device
=
'cuda'
target_worker
.
device
=
'cuda'
set_random_seed
(
1
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
rejection
_sampler
,
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
spec_decode
_sampler
,
metrics_collector
)
worker
.
init_device
()
...
...
@@ -414,17 +423,16 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
target_worker
.
execute_model
.
return_value
=
[
target_output
[
0
]]
rejection
_sampler_output
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
k
+
1
),
dtype
=
torch
.
int64
,
device
=
'cuda'
)
spec_decode
_sampler_output
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
k
+
1
),
dtype
=
torch
.
int64
,
device
=
'cuda'
)
for
i
in
range
(
batch_size
):
minimum_accepted_tokens
=
1
rejection
_sampler_output
[
i
][
spec_decode
_sampler_output
[
i
][
-
random
.
randint
(
minimum_accepted_tokens
,
k
+
1
):]
=
-
1
rejection_sampler
.
return_value
=
rejection_sampler_output
spec_decode_sampler
.
return_value
=
spec_decode_sampler_output
mock_rejsample_metrics
=
MagicMock
(
spec
=
SpecDecodeWorkerMetrics
)
if
returns_metrics
else
None
...
...
@@ -445,15 +453,16 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
@
pytest
.
mark
.
parametrize
(
'k'
,
[
0
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"acceptance_sampler_method"
,
[
"rejection_sampler"
,
"typical_acceptance_sampler"
])
@
torch
.
inference_mode
()
def
test_k_equals_zero
(
k
:
int
,
batch_size
:
int
):
def
test_k_equals_zero
(
k
:
int
,
batch_size
:
int
,
acceptance_sampler_method
:
str
):
"""Verify that the SpecDecodeWorker calls the draft and target workers
when k is zero. This happens during prefill.
"""
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
)
target_worker
=
mock_worker
()
rejection_sampler
=
MagicMock
(
spec
=
RejectionSampler
)
rejection_sampler
.
token_id_dtype
=
torch
.
int64
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
sampler_output
=
MagicMock
(
spec
=
SamplerOutput
)
...
...
@@ -465,8 +474,9 @@ def test_k_equals_zero(k: int, batch_size: int):
set_random_seed
(
1
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
rejection_sampler
,
metrics_collector
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
mock_spec_decode_sampler
(
acceptance_sampler_method
),
metrics_collector
)
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
,
...
...
@@ -487,16 +497,17 @@ def test_k_equals_zero(k: int, batch_size: int):
@
pytest
.
mark
.
parametrize
(
'k'
,
[
0
,
5
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
0
])
@
pytest
.
mark
.
parametrize
(
"acceptance_sampler_method"
,
[
"rejection_sampler"
,
"typical_acceptance_sampler"
])
@
torch
.
inference_mode
()
def
test_empty_input_batch
(
k
:
int
,
batch_size
:
int
):
def
test_empty_input_batch
(
k
:
int
,
batch_size
:
int
,
acceptance_sampler_method
:
str
):
"""Verify that the SpecDecodeWorker calls the draft and target workers
when the input batch is empty. This can happen if the engine communicates
to the workers information without scheduling a batch.
"""
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
)
target_worker
=
mock_worker
()
rejection_sampler
=
MagicMock
(
spec
=
RejectionSampler
)
rejection_sampler
.
token_id_dtype
=
torch
.
int64
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
sampler_output
=
MagicMock
(
spec
=
SamplerOutput
)
...
...
@@ -508,8 +519,9 @@ def test_empty_input_batch(k: int, batch_size: int):
set_random_seed
(
1
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
rejection_sampler
,
metrics_collector
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
mock_spec_decode_sampler
(
acceptance_sampler_method
),
metrics_collector
)
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
,
...
...
@@ -528,18 +540,19 @@ def test_empty_input_batch(k: int, batch_size: int):
target_worker
.
execute_model
.
assert_called_once_with
(
execute_model_req
)
@
pytest
.
mark
.
parametrize
(
"acceptance_sampler_method"
,
[
"rejection_sampler"
,
"typical_acceptance_sampler"
])
@
pytest
.
mark
.
skip_global_cleanup
def
test_init_device
():
def
test_init_device
(
acceptance_sampler_method
:
str
):
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as
well as other GPU initialization.
"""
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
,
use_spec
=
False
)
target_worker
=
mock_worker
(
use_spec
=
False
)
rejection_sampler
=
MagicMock
(
spec
=
RejectionSampler
)
rejection_sampler
.
token_id_dtype
=
torch
.
int64
spec_decode_sampler
=
mock_spec_decode_sampler
(
acceptance_sampler_method
)
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
rejection
_sampler
,
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
spec_decode
_sampler
,
metrics_collector
)
worker
.
init_device
()
...
...
@@ -549,22 +562,23 @@ def test_init_device():
target_worker
.
init_device
.
assert_called_once
()
metrics_collector
.
init_gpu_tensors
.
assert_called_once
()
rejection
_sampler
.
init_gpu_tensors
.
assert_called_once
()
spec_decode
_sampler
.
init_gpu_tensors
.
assert_called_once
()
@
pytest
.
mark
.
parametrize
(
"acceptance_sampler_method"
,
[
"rejection_sampler"
,
"typical_acceptance_sampler"
])
@
torch
.
inference_mode
()
def
test_initialize_cache
():
def
test_initialize_cache
(
acceptance_sampler_method
):
"""Verify SpecDecodeWorker invokes initialize_cache on proposer/scorer
workers.
"""
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
)
target_worker
=
mock_worker
()
rejection_sampler
=
MagicMock
(
spec
=
RejectionSampler
)
rejection_sampler
.
token_id_dtype
=
torch
.
int64
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
rejection_sampler
,
metrics_collector
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
mock_spec_decode_sampler
(
acceptance_sampler_method
),
metrics_collector
)
kwargs
=
{
"num_gpu_blocks"
:
1024
,
"num_cpu_blocks"
:
1023
}
worker
.
initialize_cache
(
**
kwargs
)
...
...
@@ -577,19 +591,20 @@ def test_initialize_cache():
@
pytest
.
mark
.
parametrize
(
'available_cpu_blocks'
,
[
500
])
@
pytest
.
mark
.
parametrize
(
'target_cache_block_size_bytes'
,
[
2
*
2
*
4096
])
@
pytest
.
mark
.
parametrize
(
'draft_kv_size_bytes'
,
[
0
,
2
*
2
*
768
,
2
*
2
*
4096
])
@
pytest
.
mark
.
parametrize
(
"acceptance_sampler_method"
,
[
"rejection_sampler"
,
"typical_acceptance_sampler"
])
@
pytest
.
mark
.
skip_global_cleanup
def
test_determine_num_available_blocks
(
available_gpu_blocks
:
int
,
available_cpu_blocks
:
int
,
target_cache_block_size_bytes
:
int
,
draft_kv_size_bytes
:
int
):
draft_kv_size_bytes
:
int
,
acceptance_sampler_method
:
str
):
"""Verify SpecDecodeWorker correctly profiles num available GPU blocks.
Specifically, it should run profiling in the scorer worker, and then evenly
split the blocks between proposer and scorer worker.
"""
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
)
target_worker
=
mock_worker
()
rejection_sampler
=
MagicMock
(
spec
=
RejectionSampler
)
rejection_sampler
.
token_id_dtype
=
torch
.
int64
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
target_worker
.
determine_num_available_blocks
.
return_value
=
(
...
...
@@ -598,8 +613,9 @@ def test_determine_num_available_blocks(available_gpu_blocks: int,
target_cache_block_size_bytes
)
draft_worker
.
get_cache_block_size_bytes
.
return_value
=
draft_kv_size_bytes
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
rejection_sampler
,
metrics_collector
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
mock_spec_decode_sampler
(
acceptance_sampler_method
),
metrics_collector
)
num_gpu_blocks
,
num_cpu_blocks
=
worker
.
determine_num_available_blocks
()
...
...
tests/spec_decode/test_utils.py
View file @
80ca1e6a
from
unittest.mock
import
MagicMock
import
pytest
import
torch
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.model_executor.layers.typical_acceptance_sampler
import
(
TypicalAcceptanceSampler
)
from
vllm.sequence
import
SequenceGroupMetadata
,
get_all_seq_ids
from
vllm.spec_decode.util
import
split_batch_by_proposal_len
...
...
@@ -109,3 +113,21 @@ def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata):
assert
filtered_groups
==
[]
assert
indices
==
[]
def
mock_spec_decode_sampler
(
acceptance_sampler_method
):
"""
Returns either a RejectionSampler or TypicalAcceptanceSampler
object depending on whether acceptance_sampler_method is
'rejection_sampler' or 'typical_acceptance_sampler' respectively.
"""
if
acceptance_sampler_method
==
"rejection_sampler"
:
sampler
=
MagicMock
(
spec
=
RejectionSampler
)
sampler
.
token_id_dtype
=
torch
.
int64
return
sampler
elif
acceptance_sampler_method
==
"typical_acceptance_sampler"
:
sampler
=
MagicMock
(
spec
=
TypicalAcceptanceSampler
)
sampler
.
token_id_dtype
=
torch
.
int64
return
sampler
else
:
raise
ValueError
(
f
"Invalid sampler name
{
acceptance_sampler_method
}
"
)
vllm/config.py
View file @
80ca1e6a
...
...
@@ -753,7 +753,6 @@ class SchedulerConfig:
self
.
chunked_prefill_enabled
=
enable_chunked_prefill
self
.
embedding_mode
=
embedding_mode
self
.
preemption_mode
=
preemption_mode
self
.
_verify_args
()
def
_verify_args
(
self
)
->
None
:
...
...
@@ -834,6 +833,9 @@ class SpeculativeConfig:
speculative_disable_by_batch_size
:
Optional
[
int
],
ngram_prompt_lookup_max
:
Optional
[
int
],
ngram_prompt_lookup_min
:
Optional
[
int
],
draft_token_acceptance_method
:
str
,
typical_acceptance_sampler_posterior_threshold
:
Optional
[
float
],
typical_acceptance_sampler_posterior_alpha
:
Optional
[
float
],
)
->
Optional
[
"SpeculativeConfig"
]:
"""Create a SpeculativeConfig if possible, else return None.
...
...
@@ -870,7 +872,20 @@ class SpeculativeConfig:
window, if provided.
ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
window, if provided.
draft_token_acceptance_method (str): The method to use for
accepting draft tokens. This can take two possible
values 'rejection_sampler' and 'typical_acceptance_sampler'
for RejectionSampler and TypicalAcceptanceSampler
respectively.
typical_acceptance_sampler_posterior_threshold (Optional[float]):
A threshold value that sets a lower bound on the posterior
probability of a token in the target model for it to be
accepted. This threshold is used only when we use the
TypicalAcceptanceSampler for token acceptance.
typical_acceptance_sampler_posterior_alpha (Optional[float]):
A scaling factor for the entropy-based threshold in the
TypicalAcceptanceSampler.
Returns:
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
the necessary conditions are met, else None.
...
...
@@ -984,6 +999,11 @@ class SpeculativeConfig:
"speculative_model unless the draft model config contains an "
"n_predict parameter."
)
if
typical_acceptance_sampler_posterior_threshold
is
None
:
typical_acceptance_sampler_posterior_threshold
=
0.09
if
typical_acceptance_sampler_posterior_alpha
is
None
:
typical_acceptance_sampler_posterior_alpha
=
0.3
return
SpeculativeConfig
(
draft_model_config
,
draft_parallel_config
,
...
...
@@ -991,6 +1011,11 @@ class SpeculativeConfig:
speculative_disable_by_batch_size
,
ngram_prompt_lookup_max
,
ngram_prompt_lookup_min
,
draft_token_acceptance_method
=
draft_token_acceptance_method
,
typical_acceptance_sampler_posterior_threshold
=
\
typical_acceptance_sampler_posterior_threshold
,
typical_acceptance_sampler_posterior_alpha
=
\
typical_acceptance_sampler_posterior_alpha
,
)
@
staticmethod
...
...
@@ -1072,6 +1097,9 @@ class SpeculativeConfig:
speculative_disable_by_batch_size
:
Optional
[
int
],
ngram_prompt_lookup_max
:
Optional
[
int
],
ngram_prompt_lookup_min
:
Optional
[
int
],
draft_token_acceptance_method
:
str
,
typical_acceptance_sampler_posterior_threshold
:
float
,
typical_acceptance_sampler_posterior_alpha
:
float
,
):
"""Create a SpeculativeConfig object.
...
...
@@ -1085,6 +1113,19 @@ class SpeculativeConfig:
enqueue requests is larger than this value.
ngram_prompt_lookup_max: Max size of ngram token window.
ngram_prompt_lookup_min: Min size of ngram token window.
draft_token_acceptance_method (str): The method to use for
accepting draft tokens. This can take two possible
values 'rejection_sampler' and 'typical_acceptance_sampler'
for RejectionSampler and TypicalAcceptanceSampler
respectively.
typical_acceptance_sampler_posterior_threshold (Optional[float]):
A threshold value that sets a lower bound on the posterior
probability of a token in the target model for it to be
accepted. This threshold is used only when we use the
TypicalAcceptanceSampler for token acceptance.
typical_acceptance_sampler_posterior_alpha (Optional[float]):
A scaling factor for the entropy-based threshold in the
TypicalAcceptanceSampler.
"""
self
.
draft_model_config
=
draft_model_config
self
.
draft_parallel_config
=
draft_parallel_config
...
...
@@ -1093,6 +1134,11 @@ class SpeculativeConfig:
speculative_disable_by_batch_size
self
.
ngram_prompt_lookup_max
=
ngram_prompt_lookup_max
or
0
self
.
ngram_prompt_lookup_min
=
ngram_prompt_lookup_min
or
0
self
.
draft_token_acceptance_method
=
draft_token_acceptance_method
self
.
typical_acceptance_sampler_posterior_threshold
=
\
typical_acceptance_sampler_posterior_threshold
self
.
typical_acceptance_sampler_posterior_alpha
=
\
typical_acceptance_sampler_posterior_alpha
self
.
_verify_args
()
...
...
@@ -1104,6 +1150,31 @@ class SpeculativeConfig:
if
self
.
draft_model_config
:
self
.
draft_model_config
.
verify_with_parallel_config
(
self
.
draft_parallel_config
)
# Validate and set draft token acceptance related settings.
if
(
self
.
draft_token_acceptance_method
is
None
):
raise
ValueError
(
"draft_token_acceptance_method is not set. "
"Expected values are rejection_sampler or "
"typical_acceptance_sampler."
)
if
(
self
.
draft_token_acceptance_method
!=
'rejection_sampler'
and
self
.
draft_token_acceptance_method
!=
'typical_acceptance_sampler'
):
raise
ValueError
(
"Expected draft_token_acceptance_method to be either "
"rejection_sampler or typical_acceptance_sampler. Instead it "
f
"is
{
self
.
draft_token_acceptance_method
}
"
)
if
(
self
.
typical_acceptance_sampler_posterior_threshold
<
0
or
self
.
typical_acceptance_sampler_posterior_alpha
<
0
):
raise
ValueError
(
"Expected typical_acceptance_sampler_posterior_threshold "
"and typical_acceptance_sampler_posterior_alpha to be > 0. "
"Instead found "
f
"typical_acceptance_sampler_posterior_threshold = "
f
"
{
self
.
typical_acceptance_sampler_posterior_threshold
}
and "
f
"typical_acceptance_sampler_posterior_alpha = "
f
"
{
self
.
typical_acceptance_sampler_posterior_alpha
}
"
)
@
property
def
num_lookahead_slots
(
self
)
->
int
:
...
...
vllm/engine/arg_utils.py
View file @
80ca1e6a
...
...
@@ -100,7 +100,9 @@ class EngineArgs:
speculative_disable_by_batch_size
:
Optional
[
int
]
=
None
ngram_prompt_lookup_max
:
Optional
[
int
]
=
None
ngram_prompt_lookup_min
:
Optional
[
int
]
=
None
spec_decoding_acceptance_method
:
str
=
'rejection_sampler'
typical_acceptance_sampler_posterior_threshold
:
Optional
[
float
]
=
None
typical_acceptance_sampler_posterior_alpha
:
Optional
[
float
]
=
None
qlora_adapter_name_or_path
:
Optional
[
str
]
=
None
otlp_traces_endpoint
:
Optional
[
str
]
=
None
...
...
@@ -577,6 +579,38 @@ class EngineArgs:
help
=
'Min size of window for ngram prompt lookup in speculative '
'decoding.'
)
parser
.
add_argument
(
'--spec-decoding-acceptance-method'
,
type
=
str
,
default
=
EngineArgs
.
spec_decoding_acceptance_method
,
choices
=
[
'rejection_sampler'
,
'typical_acceptance_sampler'
],
help
=
'Specify the acceptance method to use during draft token '
'verification in speculative decoding. Two types of acceptance '
'routines are supported: '
'1) RejectionSampler which does not allow changing the '
'acceptance rate of draft tokens, '
'2) TypicalAcceptanceSampler which is configurable, allowing for '
'a higher acceptance rate at the cost of lower quality, '
'and vice versa.'
)
parser
.
add_argument
(
'--typical-acceptance-sampler-posterior-threshold'
,
type
=
float
,
default
=
EngineArgs
.
typical_acceptance_sampler_posterior_threshold
,
help
=
'Set the lower bound threshold for the posterior '
'probability of a token to be accepted. This threshold is '
'used by the TypicalAcceptanceSampler to make sampling decisions '
'during speculative decoding. Defaults to 0.09'
)
parser
.
add_argument
(
'--typical-acceptance-sampler-posterior-alpha'
,
type
=
float
,
default
=
EngineArgs
.
typical_acceptance_sampler_posterior_alpha
,
help
=
'A scaling factor for the entropy-based threshold for token '
'acceptance in the TypicalAcceptanceSampler. Typically defaults '
'to sqrt of --typical-acceptance-sampler-posterior-threshold '
'i.e. 0.3'
)
parser
.
add_argument
(
'--model-loader-extra-config'
,
type
=
nullable_str
,
default
=
EngineArgs
.
model_loader_extra_config
,
...
...
@@ -737,6 +771,12 @@ class EngineArgs:
use_v2_block_manager
=
self
.
use_v2_block_manager
,
ngram_prompt_lookup_max
=
self
.
ngram_prompt_lookup_max
,
ngram_prompt_lookup_min
=
self
.
ngram_prompt_lookup_min
,
draft_token_acceptance_method
=
\
self
.
spec_decoding_acceptance_method
,
typical_acceptance_sampler_posterior_threshold
=
self
.
typical_acceptance_sampler_posterior_threshold
,
typical_acceptance_sampler_posterior_alpha
=
self
.
typical_acceptance_sampler_posterior_alpha
,
)
scheduler_config
=
SchedulerConfig
(
...
...
vllm/engine/metrics.py
View file @
80ca1e6a
...
...
@@ -457,4 +457,4 @@ class PrometheusStatLogger(StatLoggerBase):
class
RayPrometheusStatLogger
(
PrometheusStatLogger
):
"""RayPrometheusStatLogger uses Ray metrics instead."""
_metrics_cls
=
RayMetrics
_metrics_cls
=
RayMetrics
\ No newline at end of file
vllm/model_executor/layers/rejection_sampler.py
View file @
80ca1e6a
...
...
@@ -3,13 +3,12 @@ from typing import Tuple
import
torch
import
torch.jit
import
torch.nn
as
nn
from
vllm.model_executor.layers.spec_decode_base_sampler
import
(
SpecDecodeBaseSampler
)
class
RejectionSampler
(
SpecDecodeBaseSampler
,
nn
.
Module
):
class
RejectionSampler
(
SpecDecodeBaseSampler
):
"""Apply modified rejection sampling as described in "Accelerating Large
Language Model Decoding with Speculative Sampling"
https://arxiv.org/pdf/2302.01318.pdf.
...
...
@@ -28,8 +27,8 @@ class RejectionSampler(SpecDecodeBaseSampler, nn.Module):
during sampling. This catches correctness issues but adds
nontrivial latency.
"""
SpecDecodeBaseSampler
.
__init__
(
self
,
disable_bonus_tokens
,
strict_mode
)
nn
.
Module
.
__init__
(
self
)
super
()
.
__init__
(
disable_bonus_tokens
=
disable_bonus_tokens
,
strict_mode
=
strict_mode
)
def
forward
(
self
,
...
...
@@ -78,11 +77,12 @@ class RejectionSampler(SpecDecodeBaseSampler, nn.Module):
self
.
_raise_if_incorrect_input
(
target_probs
,
bonus_token_ids
,
draft_probs
,
draft_token_ids
)
accepted
,
recovered_token_ids
=
self
.
_batch_modified_rejection_sampling
(
target_probs
,
draft_probs
,
draft_token_ids
,
)
accepted
,
recovered_token_ids
=
(
self
.
_batch_modified_rejection_sampling
(
target_probs
,
draft_probs
,
draft_token_ids
,
))
output_token_ids
=
self
.
_create_output
(
accepted
,
...
...
vllm/model_executor/layers/spec_decode_base_sampler.py
View file @
80ca1e6a
from
abc
import
abstractmethod
from
typing
import
Optional
import
torch
import
torch.jit
import
torch.nn
as
nn
class
SpecDecodeBaseSampler
():
class
SpecDecodeBaseSampler
(
nn
.
Module
):
"""Base class for samplers used for Speculative Decoding verification
step.
"""
...
...
@@ -51,6 +54,16 @@ class SpecDecodeBaseSampler():
def
token_id_dtype
(
self
):
return
torch
.
int64
@
abstractmethod
def
forward
(
self
,
target_probs
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
def
_create_output
(
self
,
accepted
:
torch
.
Tensor
,
# [batch_size, k]
...
...
vllm/model_executor/layers/typical_acceptance_sampler.py
View file @
80ca1e6a
import
torch
import
torch.jit
import
torch.nn
as
nn
from
vllm.model_executor.layers.spec_decode_base_sampler
import
(
SpecDecodeBaseSampler
)
class
TypicalAcceptanceSampler
(
SpecDecodeBaseSampler
,
nn
.
Module
):
class
TypicalAcceptanceSampler
(
SpecDecodeBaseSampler
):
"""Apply typical acceptance sampling as described in section 3.3.1 in
"MEDUSA: Simple LLM Inference Acceleration Framework with
Multiple Decoding Heads"
...
...
@@ -15,10 +14,10 @@ class TypicalAcceptanceSampler(SpecDecodeBaseSampler, nn.Module):
def
__init__
(
self
,
posterior_threshold
:
float
,
posterior_alpha
:
float
,
disable_bonus_tokens
:
bool
=
False
,
strict_mode
:
bool
=
False
,
posterior_threshold
:
float
=
0.09
,
posterior_alpha
:
float
=
0.3
,
):
"""Create a Typical Acceptance Sampler.
...
...
@@ -31,23 +30,20 @@ class TypicalAcceptanceSampler(SpecDecodeBaseSampler, nn.Module):
nontrivial latency.
posterior_threshold : A threshold value that sets a lower bound
on the posterior probability of a token in target model for it
to be accepted.
Default is 0.09
to be accepted.
posterior_alpha : A scaling factor for the entropy-based
threshold in typical acceptance sampling. Typically defaults to
sqrt of posterior_threshold and is set to 0.3.
threshold in typical acceptance sampling.
"""
SpecDecodeBaseSampler
.
__init__
(
self
,
disable_bonus_tokens
=
disable_bonus_tokens
,
strict_mode
=
strict_mode
)
nn
.
Module
.
__init__
(
self
)
self
.
_posterior_threshold
=
posterior_threshold
self
.
_posterior_alpha
=
posterior_alpha
super
().
__init__
(
disable_bonus_tokens
=
disable_bonus_tokens
,
strict_mode
=
strict_mode
)
def
forward
(
self
,
target_probs
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""Sample token ids using typical acceptance sampling. This accepts
...
...
@@ -69,6 +65,8 @@ class TypicalAcceptanceSampler(SpecDecodeBaseSampler, nn.Module):
speculative tokens in a sequence are accepted.
shape = [batch_size, num_bonus_tokens]
draft_probs: This parameter is unused by the acceptance sampler.
draft_token_ids: The token ids that were sampled from the draft
probabilities.
shape = [batch_size, num_speculative_tokens]
...
...
vllm/spec_decode/metrics.py
View file @
80ca1e6a
...
...
@@ -4,7 +4,8 @@ from typing import Callable, Optional
import
torch
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.model_executor.layers.spec_decode_base_sampler
import
(
SpecDecodeBaseSampler
)
from
vllm.utils
import
is_pin_memory_available
...
...
@@ -46,15 +47,15 @@ Timer = Callable[[], float]
class
AsyncMetricsCollector
:
"""Class which copies rejection
sampler metrics from the device to CPU on a
non-default Torch stream.
"""Class which copies rejection
/typical-acceptance sampler metrics
from the device to CPU on a
non-default Torch stream.
"""
def
__init__
(
self
,
rejection_sampler
:
Rejection
Sampler
,
spec_decode_sampler
:
SpecDecodeBase
Sampler
,
timer
:
Optional
[
Timer
]
=
None
,
collect_interval_s
:
float
=
5.0
):
self
.
_rejection_sampler
=
rejection
_sampler
self
.
spec_decode_sampler
=
spec_decode
_sampler
self
.
_timer
=
time
.
time
if
timer
is
None
else
timer
self
.
_rank
:
Optional
[
int
]
=
None
...
...
@@ -95,7 +96,7 @@ class AsyncMetricsCollector:
return
None
def
_should_collect_rejsample_metrics
(
self
,
now
:
float
)
->
bool
:
"""Return whether or not this iteration should print
rejection
sampling
"""Return whether or not this iteration should print sampling
metrics.
"""
if
self
.
_rank
!=
0
:
...
...
@@ -107,8 +108,8 @@ class AsyncMetricsCollector:
return
True
def
_copy_rejsample_metrics_async
(
self
)
->
torch
.
cuda
.
Event
:
"""Copy rejection
sampling metrics (number of accepted tokens, etc) to
CPU asynchronously.
"""Copy rejection
/typical-acceptance sampling metrics
(number of accepted tokens, etc) to
CPU asynchronously.
Returns a CUDA event recording when the copy is complete.
"""
...
...
@@ -117,13 +118,14 @@ class AsyncMetricsCollector:
with
torch
.
cuda
.
stream
(
self
.
_copy_stream
):
self
.
_aggregate_num_accepted_tokens
.
copy_
(
self
.
_rejection_sampler
.
num_accepted_tokens
,
non_blocking
=
True
)
self
.
spec_decode_sampler
.
num_accepted_tokens
,
non_blocking
=
True
)
self
.
_aggregate_num_emitted_tokens
.
copy_
(
self
.
_rejection
_sampler
.
num_emitted_tokens
,
non_blocking
=
True
)
self
.
spec_decode
_sampler
.
num_emitted_tokens
,
non_blocking
=
True
)
# Number of draft tokens is calculated on CPU, so no copy is
# required.
self
.
_aggregate_num_draft_tokens
=
(
self
.
_rejection
_sampler
.
num_draft_tokens
)
self
.
spec_decode
_sampler
.
num_draft_tokens
)
aggregate_metrics_ready
=
torch
.
cuda
.
Event
()
aggregate_metrics_ready
.
record
(
self
.
_copy_stream
)
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
80ca1e6a
...
...
@@ -7,6 +7,10 @@ from vllm.config import ParallelConfig, SpeculativeConfig
from
vllm.distributed.communication_op
import
broadcast_tensor_dict
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.model_executor.layers.spec_decode_base_sampler
import
(
SpecDecodeBaseSampler
)
from
vllm.model_executor.layers.typical_acceptance_sampler
import
(
TypicalAcceptanceSampler
)
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
ExecuteModelRequest
,
HiddenStates
,
SamplerOutput
,
SequenceGroupMetadata
,
get_all_seq_ids
)
...
...
@@ -56,7 +60,12 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
draft_worker_kwargs
=
draft_worker_kwargs
,
disable_by_batch_size
=
speculative_config
.
speculative_disable_by_batch_size
,
)
draft_token_acceptance_method
=
speculative_config
.
draft_token_acceptance_method
,
typical_acceptance_sampler_posterior_threshold
=
speculative_config
.
typical_acceptance_sampler_posterior_threshold
,
typical_acceptance_sampler_posterior_alpha
=
speculative_config
.
typical_acceptance_sampler_posterior_alpha
)
return
spec_decode_worker
...
...
@@ -78,8 +87,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
welcome!).
* Only top-1 proposal and scoring are implemented. Tree-attention is left as
future work.
* Only lossless rejection sampling is supported. Contributions adding lossy
verification routines are welcome (e.g. Medusa's typical acceptance).
* All sequences in a batch must have the same proposal length, or zero. This
can be improved by having per-sequence speculation in the future.
* The scoring forward pass is done without an MQA kernel, which is
...
...
@@ -95,6 +102,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
scorer_worker
:
Worker
,
draft_worker_kwargs
:
Dict
[
str
,
Any
],
disable_by_batch_size
:
Optional
[
int
],
draft_token_acceptance_method
:
str
,
typical_acceptance_sampler_posterior_threshold
:
float
,
typical_acceptance_sampler_posterior_alpha
:
float
,
)
->
"SpecDecodeWorker"
:
ngram_prompt_lookup_max
=
(
...
...
@@ -127,17 +137,30 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
logger
.
info
(
"Configuring SpecDecodeWorker with proposer=%s"
,
type
(
proposer_worker
))
spec_decode_sampler
:
SpecDecodeBaseSampler
=
None
if
draft_token_acceptance_method
==
"rejection_sampler"
:
spec_decode_sampler
=
RejectionSampler
(
disable_bonus_tokens
=
disable_bonus_tokens
,
)
elif
draft_token_acceptance_method
==
"typical_acceptance_sampler"
:
spec_decode_sampler
=
TypicalAcceptanceSampler
(
disable_bonus_tokens
=
disable_bonus_tokens
,
posterior_threshold
=
\
typical_acceptance_sampler_posterior_threshold
,
posterior_alpha
=
typical_acceptance_sampler_posterior_alpha
,
)
logger
.
info
(
"Configuring SpecDecodeWorker with sampler=%s"
,
type
(
spec_decode_sampler
))
return
SpecDecodeWorker
(
proposer_worker
,
scorer_worker
,
disable_by_batch_size
=
disable_by_batch_size
,
rejection_sampler
=
RejectionSampler
(
disable_bonus_tokens
=
disable_bonus_tokens
))
spec_decode_sampler
=
spec_decode_sampler
)
def
__init__
(
self
,
proposer_worker
:
ProposerWorkerBase
,
scorer_worker
:
WorkerBase
,
rejection_sampler
:
Rejection
Sampler
,
spec_decode_sampler
:
SpecDecodeBase
Sampler
,
metrics_collector
:
Optional
[
AsyncMetricsCollector
]
=
None
,
disable_by_batch_size
:
Optional
[
int
]
=
None
,
):
...
...
@@ -150,8 +173,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
scorer_worker: A worker that produces probabilities of speculative
tokens according to some base model. Typically a vanilla vLLM
Worker.
rejection_sampler: A Torch module used to perform modified rejection
sampling for speculative decoding.
spec_decode_sampler: A Torch module used to perform acceptance
sampling of the draft tokens in the verification step of
speculative decoding. Currently we support two different
types of sampler namely RejectionSampler and
TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
instance of RejectionSampler or TypicalAcceptanceSampler.
disable_by_batch_size: If the batch size is larger than this,
disable speculative decoding for new incoming requests.
metrics_collector: Helper class for collecting metrics; can be set
...
...
@@ -160,15 +187,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self
.
proposer_worker
=
proposer_worker
self
.
scorer_worker
=
scorer_worker
self
.
disable_by_batch_size
=
disable_by_batch_size
or
float
(
"inf"
)
self
.
rejection_sampler
=
rejection_sampler
self
.
spec_decode_sampler
=
spec_decode_sampler
self
.
_metrics
=
AsyncMetricsCollector
(
rejection
_sampler
self
.
spec_decode
_sampler
)
if
metrics_collector
is
None
else
metrics_collector
self
.
probs_dtype
=
self
.
rejection_sampler
.
probs_dtype
self
.
token_id_dtype
=
self
.
rejection_sampler
.
token_id_dtype
self
.
probs_dtype
=
self
.
spec_decode_sampler
.
probs_dtype
self
.
token_id_dtype
=
self
.
spec_decode_sampler
.
token_id_dtype
# Lazy initiazliation.
self
.
scorer
:
SpeculativeScorer
...
...
@@ -189,7 +213,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self
.
proposer_worker
.
load_model
()
self
.
_metrics
.
init_gpu_tensors
(
self
.
rank
)
self
.
rejection_sampler
.
init_gpu_tensors
(
self
.
rank
)
self
.
spec_decode_sampler
.
init_gpu_tensors
(
self
.
rank
)
self
.
scorer
=
BatchExpansionTop1Scorer
(
scorer_worker
=
self
.
scorer_worker
,
device
=
self
.
device
,
...
...
@@ -203,7 +228,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
def
_configure_model_sampler_for_spec_decode
(
self
):
"""Configure model sampler to emit GPU tensors. This allows spec decode
to keep data on device without transferring to CPU and serializing,
which significantly reduces overhead of
rejection sampling
.
which significantly reduces overhead of
sampling during verification
.
NOTE(cade): This breaks abstraction boundaries pretty badly. The better
design is to have the "move to CPU and serialize" sampling decision be
...
...
@@ -481,7 +506,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# Get proposed tokens.
proposal_token_ids
=
proposals
.
proposal_token_ids
[
spec_indices
]
accepted_token_ids
=
self
.
rejection
_sampler
(
accepted_token_ids
=
self
.
spec_decode
_sampler
(
target_probs
=
proposal_verifier_probs
,
bonus_token_ids
=
bonus_token_ids
,
draft_probs
=
proposal_probs
,
...
...
@@ -496,7 +521,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
accepted_token_ids
=
torch
.
cat
(
[
accepted_token_ids
,
non_spec_token_ids
])
logprobs
=
proposal_scores
.
logprobs
# Rearrange so that results are in the order of the original seq group
# metadata.
accepted_token_ids
[
original_indices
]
=
accepted_token_ids
.
clone
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment