Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f942efb5
Unverified
Commit
f942efb5
authored
May 08, 2024
by
Cody Yu
Committed by
GitHub
May 08, 2024
Browse files
[Dynamic Spec Decoding] Auto-disable by the running queue size (#4592)
Co-authored-by:
Cade Daniel
<
edacih@gmail.com
>
parent
89579a20
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
227 additions
and
39 deletions
+227
-39
tests/samplers/test_rejection_sampler.py
tests/samplers/test_rejection_sampler.py
+9
-4
tests/spec_decode/e2e/test_multistep_correctness.py
tests/spec_decode/e2e/test_multistep_correctness.py
+34
-0
tests/spec_decode/e2e/test_ngram_correctness.py
tests/spec_decode/e2e/test_ngram_correctness.py
+1
-1
tests/spec_decode/test_dynamic_spec_decode.py
tests/spec_decode/test_dynamic_spec_decode.py
+77
-0
vllm/config.py
vllm/config.py
+24
-5
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+10
-0
vllm/executor/gpu_executor.py
vllm/executor/gpu_executor.py
+2
-0
vllm/model_executor/layers/rejection_sampler.py
vllm/model_executor/layers/rejection_sampler.py
+9
-2
vllm/sequence.py
vllm/sequence.py
+6
-0
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+38
-21
vllm/spec_decode/top1_proposer.py
vllm/spec_decode/top1_proposer.py
+17
-6
No files found.
tests/samplers/test_rejection_sampler.py
View file @
f942efb5
...
@@ -42,9 +42,11 @@ def mock_causal_accepted_tensor(
...
@@ -42,9 +42,11 @@ def mock_causal_accepted_tensor(
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"which_tokens_accepted"
,
"which_tokens_accepted"
,
[
"all_tokens_accepted"
,
"no_tokens_accepted"
,
"some_tokens_accepted"
])
[
"all_tokens_accepted"
,
"no_tokens_accepted"
,
"some_tokens_accepted"
])
@
pytest
.
mark
.
parametrize
(
"disable_bonus_tokens"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_correct_output_format
(
which_tokens_accepted
:
str
,
seed
:
int
,
def
test_correct_output_format
(
which_tokens_accepted
:
str
,
disable_bonus_tokens
:
bool
,
seed
:
int
,
device
:
str
):
device
:
str
):
"""Verify the output has correct format given predetermined accepted matrix.
"""Verify the output has correct format given predetermined accepted matrix.
"""
"""
...
@@ -82,7 +84,8 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
...
@@ -82,7 +84,8 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
size
=
(
batch_size
,
1
),
size
=
(
batch_size
,
1
),
dtype
=
torch
.
int64
)
dtype
=
torch
.
int64
)
rejection_sampler
=
RejectionSampler
()
rejection_sampler
=
RejectionSampler
(
disable_bonus_tokens
=
disable_bonus_tokens
)
rejection_sampler
.
init_gpu_tensors
(
rank
=
0
)
rejection_sampler
.
init_gpu_tensors
(
rank
=
0
)
output_token_ids
=
rejection_sampler
.
_create_output
(
# pylint: disable=protected-access
output_token_ids
=
rejection_sampler
.
_create_output
(
# pylint: disable=protected-access
accepted
,
accepted
,
...
@@ -91,9 +94,11 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
...
@@ -91,9 +94,11 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
bonus_token_ids
,
bonus_token_ids
,
)
)
# Bonus tokens are currently disabled. Verify they're set to -1.
expected_bonus_token_ids
=
bonus_token_ids
.
clone
()
# If bonus tokens disabled. Verify they are set to -1.
# See https://github.com/vllm-project/vllm/issues/4212
# See https://github.com/vllm-project/vllm/issues/4212
expected_bonus_token_ids
=
bonus_token_ids
.
clone
()
*
0
-
1
if
disable_bonus_tokens
:
expected_bonus_token_ids
=
expected_bonus_token_ids
*
0
-
1
if
which_tokens_accepted
==
"all_tokens_accepted"
:
if
which_tokens_accepted
==
"all_tokens_accepted"
:
# Expect all tokens to be equal to draft tokens.
# Expect all tokens to be equal to draft tokens.
...
...
tests/spec_decode/e2e/test_multistep_correctness.py
View file @
f942efb5
...
@@ -536,6 +536,40 @@ def test_skip_speculation(baseline_llm_generator, test_llm_generator,
...
@@ -536,6 +536,40 @@ def test_skip_speculation(baseline_llm_generator, test_llm_generator,
force_output_len
=
True
)
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-160m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"speculative_disable_by_batch_size"
:
2
,
},
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
10
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_disable_speculation
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify greedy equality when all sequences disable speculation.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
"common_llm_kwargs"
,
[{
[{
...
...
tests/spec_decode/e2e/test_ngram_correctness.py
View file @
f942efb5
...
@@ -57,7 +57,7 @@ from .conftest import run_greedy_equality_correctness_test
...
@@ -57,7 +57,7 @@ from .conftest import run_greedy_equality_correctness_test
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
256
,
256
,
])
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
64
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_ngram_e2e_greedy_correctness
(
baseline_llm_generator
,
def
test_ngram_e2e_greedy_correctness
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
test_llm_generator
,
batch_size
:
int
,
...
...
tests/spec_decode/test_dynamic_spec_decode.py
0 → 100644
View file @
f942efb5
from
unittest.mock
import
MagicMock
import
pytest
import
torch
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.spec_decode.metrics
import
AsyncMetricsCollector
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.spec_decode_worker
import
SpecDecodeWorker
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
.utils
import
create_batch
,
mock_worker
@
pytest
.
mark
.
parametrize
(
'queue_size'
,
[
2
,
4
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
1
,
2
,
3
,
6
])
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
,
5
,
7
,
10
])
@
torch
.
inference_mode
()
def
test_disable_spec_tokens
(
queue_size
:
int
,
batch_size
:
int
,
k
:
int
):
"""Verify that speculative tokens are disabled when the batch size
exceeds the threshold.
"""
disable_by_batch_size
=
3
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
)
target_worker
=
mock_worker
()
rejection_sampler
=
MagicMock
(
spec
=
RejectionSampler
)
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
worker
=
SpecDecodeWorker
(
proposer_worker
=
draft_worker
,
scorer_worker
=
target_worker
,
rejection_sampler
=
rejection_sampler
,
metrics_collector
=
metrics_collector
,
disable_by_batch_size
=
disable_by_batch_size
)
exception_secret
=
'artificial stop'
draft_worker
.
get_spec_proposals
.
side_effect
=
ValueError
(
exception_secret
)
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
)
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
,
running_queue_size
=
queue_size
)
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
# When the batch size is larger than the threshold,
# we expect no speculative tokens (0).
expected_num_spec_tokens
=
None
if
queue_size
<
disable_by_batch_size
else
0
assert
seq_group_metadata_list
[
0
].
num_speculative_tokens
==
expected_num_spec_tokens
draft_worker
.
sampler_output
.
side_effect
=
ValueError
(
exception_secret
)
proposer
=
Top1Proposer
(
worker
=
draft_worker
,
device
=
'cpu'
,
# not used
vocab_size
=
100
,
# not used
# Must be long enough to avoid being skipped due to length.
max_proposal_len
=
1024
,
)
if
queue_size
<
disable_by_batch_size
:
# Should raise exception when executing the mocked draft model.
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
proposer
.
get_proposals
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
),
)
else
:
# Should not execute the draft model because spec decode is disabled
# for all requests. Accordingly, the proposal length should be 0.
proposals
=
proposer
.
get_proposals
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
),
)
assert
proposals
.
proposal_lens
.
tolist
()
==
[
0
]
*
batch_size
vllm/config.py
View file @
f942efb5
...
@@ -692,6 +692,7 @@ class SpeculativeConfig:
...
@@ -692,6 +692,7 @@ class SpeculativeConfig:
speculative_max_model_len
:
Optional
[
int
],
speculative_max_model_len
:
Optional
[
int
],
enable_chunked_prefill
:
bool
,
enable_chunked_prefill
:
bool
,
use_v2_block_manager
:
bool
,
use_v2_block_manager
:
bool
,
speculative_disable_by_batch_size
:
Optional
[
int
],
ngram_prompt_lookup_max
:
Optional
[
int
],
ngram_prompt_lookup_max
:
Optional
[
int
],
ngram_prompt_lookup_min
:
Optional
[
int
],
ngram_prompt_lookup_min
:
Optional
[
int
],
)
->
Optional
[
"SpeculativeConfig"
]:
)
->
Optional
[
"SpeculativeConfig"
]:
...
@@ -720,6 +721,9 @@ class SpeculativeConfig:
...
@@ -720,6 +721,9 @@ class SpeculativeConfig:
use_v2_block_manager (bool): Whether vLLM is configured to use the
use_v2_block_manager (bool): Whether vLLM is configured to use the
v2 block manager or not. Used for raising an error since the v2
v2 block manager or not. Used for raising an error since the v2
block manager is required with spec decode.
block manager is required with spec decode.
speculative_disable_by_batch_size (Optional[int]): Disable
speculative decoding for new incoming requests when the number
of enqueue requests is larger than this value, if provided.
ngram_prompt_lookup_max (Optional[int]): Max size of ngram token
ngram_prompt_lookup_max (Optional[int]): Max size of ngram token
window, if provided.
window, if provided.
ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
...
@@ -730,7 +734,7 @@ class SpeculativeConfig:
...
@@ -730,7 +734,7 @@ class SpeculativeConfig:
the necessary conditions are met, else None.
the necessary conditions are met, else None.
"""
"""
if
(
speculative_model
is
None
and
num_speculative_tokens
is
None
)
:
if
speculative_model
is
None
and
num_speculative_tokens
is
None
:
return
None
return
None
if
speculative_model
is
not
None
and
num_speculative_tokens
is
None
:
if
speculative_model
is
not
None
and
num_speculative_tokens
is
None
:
...
@@ -739,6 +743,12 @@ class SpeculativeConfig:
...
@@ -739,6 +743,12 @@ class SpeculativeConfig:
"num_speculative_tokens to be provided, but found "
"num_speculative_tokens to be provided, but found "
f
"
{
speculative_model
=
}
and
{
num_speculative_tokens
=
}
."
)
f
"
{
speculative_model
=
}
and
{
num_speculative_tokens
=
}
."
)
if
(
speculative_disable_by_batch_size
is
not
None
and
speculative_disable_by_batch_size
<
2
):
raise
ValueError
(
"Expect the batch size threshold of disabling "
"speculative decoding is > 1, but got "
f
"
{
speculative_disable_by_batch_size
=
}
"
)
assert
(
speculative_model
is
not
None
assert
(
speculative_model
is
not
None
and
num_speculative_tokens
is
not
None
)
and
num_speculative_tokens
is
not
None
)
...
@@ -807,6 +817,7 @@ class SpeculativeConfig:
...
@@ -807,6 +817,7 @@ class SpeculativeConfig:
draft_model_config
,
draft_model_config
,
draft_parallel_config
,
draft_parallel_config
,
num_speculative_tokens
,
num_speculative_tokens
,
speculative_disable_by_batch_size
,
ngram_prompt_lookup_max
,
ngram_prompt_lookup_max
,
ngram_prompt_lookup_min
,
ngram_prompt_lookup_min
,
)
)
...
@@ -876,8 +887,9 @@ class SpeculativeConfig:
...
@@ -876,8 +887,9 @@ class SpeculativeConfig:
draft_model_config
:
ModelConfig
,
draft_model_config
:
ModelConfig
,
draft_parallel_config
:
ParallelConfig
,
draft_parallel_config
:
ParallelConfig
,
num_speculative_tokens
:
int
,
num_speculative_tokens
:
int
,
ngram_prompt_lookup_max
:
int
,
speculative_disable_by_batch_size
:
Optional
[
int
],
ngram_prompt_lookup_min
:
int
,
ngram_prompt_lookup_max
:
Optional
[
int
],
ngram_prompt_lookup_min
:
Optional
[
int
],
):
):
"""Create a SpeculativeConfig object.
"""Create a SpeculativeConfig object.
...
@@ -886,12 +898,19 @@ class SpeculativeConfig:
...
@@ -886,12 +898,19 @@ class SpeculativeConfig:
draft_parallel_config: ParallelConfig for the draft model.
draft_parallel_config: ParallelConfig for the draft model.
num_speculative_tokens: The number of tokens to sample from the
num_speculative_tokens: The number of tokens to sample from the
draft model before scoring with the target model.
draft model before scoring with the target model.
speculative_disable_by_batch_size: Disable speculative
decoding for new incoming requests when the number of
enqueue requests is larger than this value.
ngram_prompt_lookup_max: Max size of ngram token window.
ngram_prompt_lookup_min: Min size of ngram token window.
"""
"""
self
.
draft_model_config
=
draft_model_config
self
.
draft_model_config
=
draft_model_config
self
.
draft_parallel_config
=
draft_parallel_config
self
.
draft_parallel_config
=
draft_parallel_config
self
.
num_speculative_tokens
=
num_speculative_tokens
self
.
num_speculative_tokens
=
num_speculative_tokens
self
.
ngram_prompt_lookup_max
=
ngram_prompt_lookup_max
self
.
speculative_disable_by_batch_size
=
\
self
.
ngram_prompt_lookup_min
=
ngram_prompt_lookup_min
speculative_disable_by_batch_size
self
.
ngram_prompt_lookup_max
=
ngram_prompt_lookup_max
or
0
self
.
ngram_prompt_lookup_min
=
ngram_prompt_lookup_min
or
0
self
.
_verify_args
()
self
.
_verify_args
()
...
...
vllm/engine/arg_utils.py
View file @
f942efb5
...
@@ -83,6 +83,7 @@ class EngineArgs:
...
@@ -83,6 +83,7 @@ class EngineArgs:
speculative_model
:
Optional
[
str
]
=
None
speculative_model
:
Optional
[
str
]
=
None
num_speculative_tokens
:
Optional
[
int
]
=
None
num_speculative_tokens
:
Optional
[
int
]
=
None
speculative_max_model_len
:
Optional
[
int
]
=
None
speculative_max_model_len
:
Optional
[
int
]
=
None
speculative_disable_by_batch_size
:
Optional
[
int
]
=
None
ngram_prompt_lookup_max
:
Optional
[
int
]
=
None
ngram_prompt_lookup_max
:
Optional
[
int
]
=
None
ngram_prompt_lookup_min
:
Optional
[
int
]
=
None
ngram_prompt_lookup_min
:
Optional
[
int
]
=
None
...
@@ -467,6 +468,13 @@ class EngineArgs:
...
@@ -467,6 +468,13 @@ class EngineArgs:
'draft model. Sequences over this length will skip '
'draft model. Sequences over this length will skip '
'speculation.'
)
'speculation.'
)
parser
.
add_argument
(
'--speculative-disable-by-batch-size'
,
type
=
int
,
default
=
EngineArgs
.
speculative_disable_by_batch_size
,
help
=
'Disable speculative decoding for new incoming requests '
'if the number of enqueue requests is larger than this value.'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--ngram-prompt-lookup-max'
,
'--ngram-prompt-lookup-max'
,
type
=
int
,
type
=
int
,
...
@@ -547,6 +555,8 @@ class EngineArgs:
...
@@ -547,6 +555,8 @@ class EngineArgs:
target_dtype
=
self
.
dtype
,
target_dtype
=
self
.
dtype
,
speculative_model
=
self
.
speculative_model
,
speculative_model
=
self
.
speculative_model
,
num_speculative_tokens
=
self
.
num_speculative_tokens
,
num_speculative_tokens
=
self
.
num_speculative_tokens
,
speculative_disable_by_batch_size
=
self
.
speculative_disable_by_batch_size
,
speculative_max_model_len
=
self
.
speculative_max_model_len
,
speculative_max_model_len
=
self
.
speculative_max_model_len
,
enable_chunked_prefill
=
self
.
enable_chunked_prefill
,
enable_chunked_prefill
=
self
.
enable_chunked_prefill
,
use_v2_block_manager
=
self
.
use_v2_block_manager
,
use_v2_block_manager
=
self
.
use_v2_block_manager
,
...
...
vllm/executor/gpu_executor.py
View file @
f942efb5
...
@@ -93,6 +93,8 @@ class GPUExecutor(ExecutorBase):
...
@@ -93,6 +93,8 @@ class GPUExecutor(ExecutorBase):
spec_decode_worker
=
SpecDecodeWorker
.
create_worker
(
spec_decode_worker
=
SpecDecodeWorker
.
create_worker
(
scorer_worker
=
target_worker
,
scorer_worker
=
target_worker
,
draft_worker_kwargs
=
draft_worker_kwargs
,
draft_worker_kwargs
=
draft_worker_kwargs
,
disable_by_batch_size
=
self
.
speculative_config
.
speculative_disable_by_batch_size
,
)
)
assert
self
.
parallel_config
.
world_size
==
1
,
(
assert
self
.
parallel_config
.
world_size
==
1
,
(
...
...
vllm/model_executor/layers/rejection_sampler.py
View file @
f942efb5
...
@@ -12,15 +12,21 @@ class RejectionSampler(nn.Module):
...
@@ -12,15 +12,21 @@ class RejectionSampler(nn.Module):
https://arxiv.org/pdf/2302.01318.pdf.
https://arxiv.org/pdf/2302.01318.pdf.
"""
"""
def
__init__
(
self
,
strict_mode
:
bool
=
False
):
def
__init__
(
self
,
disable_bonus_tokens
:
bool
=
True
,
strict_mode
:
bool
=
False
):
"""Create a rejection sampler.
"""Create a rejection sampler.
Args:
Args:
disable_bonus_tokens: Whether or not to disable the bonus token.
Require when bonus tokens will cause corrupt KV cache for
proposal methods that require KV cache.
strict_mode: Whether or not to perform shape/device/dtype checks
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
during sampling. This catches correctness issues but adds
nontrivial latency.
nontrivial latency.
"""
"""
super
().
__init__
()
super
().
__init__
()
self
.
_disable_bonus_tokens
=
disable_bonus_tokens
self
.
_strict_mode
=
strict_mode
self
.
_strict_mode
=
strict_mode
# NOTE: A "bonus token" is accepted iff all proposal tokens are
# NOTE: A "bonus token" is accepted iff all proposal tokens are
...
@@ -312,6 +318,7 @@ class RejectionSampler(nn.Module):
...
@@ -312,6 +318,7 @@ class RejectionSampler(nn.Module):
# proposal methods that require KV cache. We can fix it by "prefilling"
# proposal methods that require KV cache. We can fix it by "prefilling"
# the bonus token in the proposer. The following issue tracks the fix.
# the bonus token in the proposer. The following issue tracks the fix.
# https://github.com/vllm-project/vllm/issues/4212
# https://github.com/vllm-project/vllm/issues/4212
if
self
.
_disable_bonus_tokens
:
output_with_bonus_tokens
[:,
-
1
]
=
-
1
output_with_bonus_tokens
[:,
-
1
]
=
-
1
# Fill the recovered token ids.
# Fill the recovered token ids.
...
...
vllm/sequence.py
View file @
f942efb5
...
@@ -612,6 +612,12 @@ class SequenceGroupMetadata:
...
@@ -612,6 +612,12 @@ class SequenceGroupMetadata:
self
.
_token_chunk_size
=
token_chunk_size
self
.
_token_chunk_size
=
token_chunk_size
self
.
do_sample
=
do_sample
self
.
do_sample
=
do_sample
# The number of speculative tokens adopted in this request.
# None means specuative decoding is not used.
# Zero means speculative decoding is disabled for some reasons.
# TODO: We should maintain this states out of the sequence group.
self
.
num_speculative_tokens
=
None
if
self
.
_token_chunk_size
is
None
:
if
self
.
_token_chunk_size
is
None
:
if
is_prompt
:
if
is_prompt
:
self
.
_token_chunk_size
=
list
(
seq_data
.
values
())[
0
].
get_len
()
self
.
_token_chunk_size
=
list
(
seq_data
.
values
())[
0
].
get_len
()
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
f942efb5
from
functools
import
cached_property
from
functools
import
cached_property
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
...
@@ -54,7 +54,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -54,7 +54,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
def
create_worker
(
def
create_worker
(
cls
,
cls
,
scorer_worker
:
WorkerBase
,
scorer_worker
:
WorkerBase
,
draft_worker_kwargs
,
draft_worker_kwargs
:
Dict
[
str
,
Any
],
disable_by_batch_size
:
Optional
[
int
],
)
->
"SpecDecodeWorker"
:
)
->
"SpecDecodeWorker"
:
ngram_prompt_lookup_max
=
(
ngram_prompt_lookup_max
=
(
...
@@ -62,7 +63,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -62,7 +63,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
ngram_prompt_lookup_min
=
(
ngram_prompt_lookup_min
=
(
draft_worker_kwargs
.
pop
(
"ngram_prompt_lookup_min"
))
draft_worker_kwargs
.
pop
(
"ngram_prompt_lookup_min"
))
disable_bonus_tokens
=
True
if
ngram_prompt_lookup_max
>
0
:
if
ngram_prompt_lookup_max
>
0
:
disable_bonus_tokens
=
False
proposer_worker
=
NGramWorker
(
**
draft_worker_kwargs
)
proposer_worker
=
NGramWorker
(
**
draft_worker_kwargs
)
proposer_worker
.
set_ngram_window_size
(
ngram_prompt_lookup_min
,
proposer_worker
.
set_ngram_window_size
(
ngram_prompt_lookup_min
,
ngram_prompt_lookup_max
)
ngram_prompt_lookup_max
)
...
@@ -75,9 +78,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -75,9 +78,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
return
SpecDecodeWorker
(
return
SpecDecodeWorker
(
proposer_worker
,
proposer_worker
,
scorer_worker
,
scorer_worker
,
# TODO(cade) disable strict mode for speedup.
disable_by_batch_size
=
disable_by_batch_size
,
rejection_sampler
=
RejectionSampler
(
strict_mode
=
True
),
rejection_sampler
=
RejectionSampler
(
)
disable_bonus_tokens
=
disable_bonus_tokens
,
)
)
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -85,6 +88,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -85,6 +88,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
scorer_worker
:
WorkerBase
,
scorer_worker
:
WorkerBase
,
rejection_sampler
:
RejectionSampler
,
rejection_sampler
:
RejectionSampler
,
metrics_collector
:
Optional
[
AsyncMetricsCollector
]
=
None
,
metrics_collector
:
Optional
[
AsyncMetricsCollector
]
=
None
,
disable_by_batch_size
:
Optional
[
int
]
=
None
,
):
):
"""
"""
Create a SpecDecodeWorker.
Create a SpecDecodeWorker.
...
@@ -97,11 +101,14 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -97,11 +101,14 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
Worker.
Worker.
rejection_sampler: A Torch module used to perform modified rejection
rejection_sampler: A Torch module used to perform modified rejection
sampling for speculative decoding.
sampling for speculative decoding.
disable_by_batch_size: If the batch size is larger than this,
disable speculative decoding for new incoming requests.
metrics_collector: Helper class for collecting metrics; can be set
metrics_collector: Helper class for collecting metrics; can be set
for testing purposes.
for testing purposes.
"""
"""
self
.
proposer_worker
=
proposer_worker
self
.
proposer_worker
=
proposer_worker
self
.
scorer_worker
=
scorer_worker
self
.
scorer_worker
=
scorer_worker
self
.
disable_by_batch_size
=
disable_by_batch_size
or
float
(
"inf"
)
self
.
rejection_sampler
=
rejection_sampler
self
.
rejection_sampler
=
rejection_sampler
self
.
_metrics
=
AsyncMetricsCollector
(
self
.
_metrics
=
AsyncMetricsCollector
(
...
@@ -199,27 +206,41 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -199,27 +206,41 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"speculative decoding "
"speculative decoding "
"requires non-None seq_group_metadata_list"
)
"requires non-None seq_group_metadata_list"
)
# When the batch size is too large, disable speculative decoding
# to stop trading off throughput for latency.
disable_all
=
(
execute_model_req
.
running_queue_size
>=
self
.
disable_by_batch_size
)
if
disable_all
:
for
seq_group_metadata
in
execute_model_req
.
seq_group_metadata_list
:
# Once num_speculative_tokens is set to 0, the spec decode
# of this request will be disabled forever.
# TODO(comaniac): We currently store spec decoding specific
# state in the global data structure, but we should maintain
# this state within spec decode worker.
seq_group_metadata
.
num_speculative_tokens
=
0
# If no spec tokens, call the proposer and scorer workers normally.
# If no spec tokens, call the proposer and scorer workers normally.
# Used for prefill.
# This happens for prefill, or when the spec decode is disabled
# for this batch.
if
execute_model_req
.
num_lookahead_slots
==
0
or
len
(
if
execute_model_req
.
num_lookahead_slots
==
0
or
len
(
execute_model_req
.
seq_group_metadata_list
)
==
0
:
execute_model_req
.
seq_group_metadata_list
)
==
0
:
return
self
.
_run_no_spec
(
execute_model_req
)
return
self
.
_run_no_spec
(
execute_model_req
,
skip_proposer
=
disable_all
)
return
self
.
_run_speculative_decoding_step
(
execute_model_req
)
return
self
.
_run_speculative_decoding_step
(
execute_model_req
)
@
nvtx_range
(
"spec_decode_worker._run_no_spec"
)
@
nvtx_range
(
"spec_decode_worker._run_no_spec"
)
def
_run_no_spec
(
def
_run_no_spec
(
self
,
execute_model_req
:
ExecuteModelRequest
,
self
,
skip_proposer
:
bool
)
->
List
[
SamplerOutput
]:
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
"""Run a prefill step, without any speculation. The input is sent to
"""Run a prefill step, without any speculation. The input is sent to the
the proposer and scorer model so that the KV cache is consistent
proposer and scorer model so that the KV cache is consistent between the
between the two. When skip_proposer is True, the proposer model is
two.
not called, meaning that the kv-cache in proposer for requests is not
updated, so they cannot enable spec decode in the rest decoding.
"""
"""
#logger.info("run proposer worker no spec")
if
not
skip_proposer
:
self
.
proposer_worker
.
execute_model
(
execute_model_req
)
self
.
proposer_worker
.
execute_model
(
execute_model_req
)
#logger.info("run target worker no spec")
sampler_output
=
self
.
scorer_worker
.
execute_model
(
execute_model_req
)
sampler_output
=
self
.
scorer_worker
.
execute_model
(
execute_model_req
)
assert
len
(
sampler_output
)
==
1
assert
len
(
sampler_output
)
==
1
sampler_output
=
sampler_output
[
0
]
sampler_output
=
sampler_output
[
0
]
...
@@ -244,22 +265,18 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -244,22 +265,18 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
sequence.
sequence.
"""
"""
#logger.info("get spec proposals")
# Generate proposals using draft worker.
# Generate proposals using draft worker.
proposals
=
self
.
proposer_worker
.
get_spec_proposals
(
execute_model_req
)
proposals
=
self
.
proposer_worker
.
get_spec_proposals
(
execute_model_req
)
#logger.info("score proposals")
proposal_scores
=
self
.
scorer
.
score_proposals
(
proposal_scores
=
self
.
scorer
.
score_proposals
(
execute_model_req
,
execute_model_req
,
proposals
,
proposals
,
)
)
#logger.info("verify proposals")
accepted_token_ids
,
target_logprobs
=
self
.
_verify_tokens
(
accepted_token_ids
,
target_logprobs
=
self
.
_verify_tokens
(
execute_model_req
.
seq_group_metadata_list
,
proposal_scores
,
execute_model_req
.
seq_group_metadata_list
,
proposal_scores
,
proposals
,
execute_model_req
.
num_lookahead_slots
)
proposals
,
execute_model_req
.
num_lookahead_slots
)
#logger.info("create output list")
return
self
.
_create_output_sampler_list
(
return
self
.
_create_output_sampler_list
(
execute_model_req
.
seq_group_metadata_list
,
execute_model_req
.
seq_group_metadata_list
,
accepted_token_ids
,
accepted_token_ids
,
...
...
vllm/spec_decode/top1_proposer.py
View file @
f942efb5
...
@@ -56,7 +56,7 @@ class Top1Proposer(SpeculativeProposer):
...
@@ -56,7 +56,7 @@ class Top1Proposer(SpeculativeProposer):
proposal_lens
,
proposal_lens
,
nonzero_proposal_len_seqs
,
nonzero_proposal_len_seqs
,
nonzero_proposal_len_indices
,
nonzero_proposal_len_indices
,
)
=
self
.
_split_by_
max_mode
l_len
(
seq_group_metadata_list
,
proposal_len
)
)
=
self
.
_split_by_
proposa
l_len
(
seq_group_metadata_list
,
proposal_len
)
if
nonzero_proposal_len_seqs
:
if
nonzero_proposal_len_seqs
:
# Speculate tokens using the draft worker for the speculative
# Speculate tokens using the draft worker for the speculative
...
@@ -97,17 +97,27 @@ class Top1Proposer(SpeculativeProposer):
...
@@ -97,17 +97,27 @@ class Top1Proposer(SpeculativeProposer):
return
proposals
return
proposals
def
_split_by_
max_mode
l_len
(
def
_split_by_
proposa
l_len
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
proposal_len
:
int
,
proposal_len
:
int
,
)
->
Tuple
[
List
[
int
],
List
[
SequenceGroupMetadata
],
List
[
int
]]:
)
->
Tuple
[
List
[
int
],
List
[
SequenceGroupMetadata
],
List
[
int
]]:
"""Determine which sequences would exceed the max model length."""
"""Split sequences by two groups:
1. Sequences with non-zero proposal length.
2. Sequences with zero proposal length (due to disabled speculation
or exceed the maximum model length).
"""
proposal_lens
:
List
[
int
]
=
[]
proposal_lens
:
List
[
int
]
=
[]
nonzero_proposal_len_seqs
:
List
[
SequenceGroupMetadata
]
=
[]
nonzero_proposal_len_seqs
:
List
[
SequenceGroupMetadata
]
=
[]
nonzero_proposal_len_indices
:
List
[
int
]
=
[]
nonzero_proposal_len_indices
:
List
[
int
]
=
[]
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
# The speculative decoding for this request has been disabled
# (e.g. due to high traffic).
if
seq_group_metadata
.
num_speculative_tokens
==
0
:
proposal_lens
.
append
(
0
)
continue
seq_data
=
next
(
iter
(
seq_group_metadata
.
seq_data
.
values
()))
seq_data
=
next
(
iter
(
seq_group_metadata
.
seq_data
.
values
()))
seq_len
=
seq_data
.
get_len
()
seq_len
=
seq_data
.
get_len
()
...
@@ -115,13 +125,14 @@ class Top1Proposer(SpeculativeProposer):
...
@@ -115,13 +125,14 @@ class Top1Proposer(SpeculativeProposer):
# are supported.
# are supported.
# If max_proposal_len is defined, then we shall no exccess this
# If max_proposal_len is defined, then we shall no exccess this
# quota for nonzero_proposal
# quota for nonzero_proposal
new_k
=
0
if
(
self
.
max_proposal_len
is
None
if
(
self
.
max_proposal_len
is
None
or
seq_len
+
proposal_len
<
self
.
max_proposal_len
):
or
seq_len
+
proposal_len
<
self
.
max_proposal_len
):
proposal_lens
.
append
(
proposal_len
)
new_k
=
proposal_len
nonzero_proposal_len_seqs
.
append
(
seq_group_metadata
)
nonzero_proposal_len_seqs
.
append
(
seq_group_metadata
)
nonzero_proposal_len_indices
.
append
(
i
)
nonzero_proposal_len_indices
.
append
(
i
)
else
:
proposal_lens
.
append
(
new_k
)
proposal_lens
.
append
(
0
)
seq_group_metadata
.
num_speculative_tokens
=
new_k
return
(
return
(
proposal_lens
,
proposal_lens
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment