Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f942efb5
Unverified
Commit
f942efb5
authored
May 08, 2024
by
Cody Yu
Committed by
GitHub
May 08, 2024
Browse files
[Dynamic Spec Decoding] Auto-disable by the running queue size (#4592)
Co-authored-by:
Cade Daniel
<
edacih@gmail.com
>
parent
89579a20
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
227 additions
and
39 deletions
+227
-39
tests/samplers/test_rejection_sampler.py
tests/samplers/test_rejection_sampler.py
+9
-4
tests/spec_decode/e2e/test_multistep_correctness.py
tests/spec_decode/e2e/test_multistep_correctness.py
+34
-0
tests/spec_decode/e2e/test_ngram_correctness.py
tests/spec_decode/e2e/test_ngram_correctness.py
+1
-1
tests/spec_decode/test_dynamic_spec_decode.py
tests/spec_decode/test_dynamic_spec_decode.py
+77
-0
vllm/config.py
vllm/config.py
+24
-5
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+10
-0
vllm/executor/gpu_executor.py
vllm/executor/gpu_executor.py
+2
-0
vllm/model_executor/layers/rejection_sampler.py
vllm/model_executor/layers/rejection_sampler.py
+9
-2
vllm/sequence.py
vllm/sequence.py
+6
-0
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+38
-21
vllm/spec_decode/top1_proposer.py
vllm/spec_decode/top1_proposer.py
+17
-6
No files found.
tests/samplers/test_rejection_sampler.py
View file @
f942efb5
...
...
@@ -42,9 +42,11 @@ def mock_causal_accepted_tensor(
@
pytest
.
mark
.
parametrize
(
"which_tokens_accepted"
,
[
"all_tokens_accepted"
,
"no_tokens_accepted"
,
"some_tokens_accepted"
])
@
pytest
.
mark
.
parametrize
(
"disable_bonus_tokens"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_correct_output_format
(
which_tokens_accepted
:
str
,
seed
:
int
,
def
test_correct_output_format
(
which_tokens_accepted
:
str
,
disable_bonus_tokens
:
bool
,
seed
:
int
,
device
:
str
):
"""Verify the output has correct format given predetermined accepted matrix.
"""
...
...
@@ -82,7 +84,8 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
size
=
(
batch_size
,
1
),
dtype
=
torch
.
int64
)
rejection_sampler
=
RejectionSampler
()
rejection_sampler
=
RejectionSampler
(
disable_bonus_tokens
=
disable_bonus_tokens
)
rejection_sampler
.
init_gpu_tensors
(
rank
=
0
)
output_token_ids
=
rejection_sampler
.
_create_output
(
# pylint: disable=protected-access
accepted
,
...
...
@@ -91,9 +94,11 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
bonus_token_ids
,
)
# Bonus tokens are currently disabled. Verify they're set to -1.
expected_bonus_token_ids
=
bonus_token_ids
.
clone
()
# If bonus tokens disabled. Verify they are set to -1.
# See https://github.com/vllm-project/vllm/issues/4212
expected_bonus_token_ids
=
bonus_token_ids
.
clone
()
*
0
-
1
if
disable_bonus_tokens
:
expected_bonus_token_ids
=
expected_bonus_token_ids
*
0
-
1
if
which_tokens_accepted
==
"all_tokens_accepted"
:
# Expect all tokens to be equal to draft tokens.
...
...
tests/spec_decode/e2e/test_multistep_correctness.py
View file @
f942efb5
...
...
@@ -536,6 +536,40 @@ def test_skip_speculation(baseline_llm_generator, test_llm_generator,
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-160m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"speculative_disable_by_batch_size"
:
2
,
},
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
10
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_disable_speculation
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify greedy equality when all sequences disable speculation.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
...
...
tests/spec_decode/e2e/test_ngram_correctness.py
View file @
f942efb5
...
...
@@ -57,7 +57,7 @@ from .conftest import run_greedy_equality_correctness_test
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
256
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
64
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_ngram_e2e_greedy_correctness
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
...
...
tests/spec_decode/test_dynamic_spec_decode.py
0 → 100644
View file @
f942efb5
from
unittest.mock
import
MagicMock
import
pytest
import
torch
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.spec_decode.metrics
import
AsyncMetricsCollector
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.spec_decode_worker
import
SpecDecodeWorker
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
.utils
import
create_batch
,
mock_worker
@
pytest
.
mark
.
parametrize
(
'queue_size'
,
[
2
,
4
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
1
,
2
,
3
,
6
])
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
,
5
,
7
,
10
])
@
torch
.
inference_mode
()
def
test_disable_spec_tokens
(
queue_size
:
int
,
batch_size
:
int
,
k
:
int
):
"""Verify that speculative tokens are disabled when the batch size
exceeds the threshold.
"""
disable_by_batch_size
=
3
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
)
target_worker
=
mock_worker
()
rejection_sampler
=
MagicMock
(
spec
=
RejectionSampler
)
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
worker
=
SpecDecodeWorker
(
proposer_worker
=
draft_worker
,
scorer_worker
=
target_worker
,
rejection_sampler
=
rejection_sampler
,
metrics_collector
=
metrics_collector
,
disable_by_batch_size
=
disable_by_batch_size
)
exception_secret
=
'artificial stop'
draft_worker
.
get_spec_proposals
.
side_effect
=
ValueError
(
exception_secret
)
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
)
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
,
running_queue_size
=
queue_size
)
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
# When the batch size is larger than the threshold,
# we expect no speculative tokens (0).
expected_num_spec_tokens
=
None
if
queue_size
<
disable_by_batch_size
else
0
assert
seq_group_metadata_list
[
0
].
num_speculative_tokens
==
expected_num_spec_tokens
draft_worker
.
sampler_output
.
side_effect
=
ValueError
(
exception_secret
)
proposer
=
Top1Proposer
(
worker
=
draft_worker
,
device
=
'cpu'
,
# not used
vocab_size
=
100
,
# not used
# Must be long enough to avoid being skipped due to length.
max_proposal_len
=
1024
,
)
if
queue_size
<
disable_by_batch_size
:
# Should raise exception when executing the mocked draft model.
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
proposer
.
get_proposals
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
),
)
else
:
# Should not execute the draft model because spec decode is disabled
# for all requests. Accordingly, the proposal length should be 0.
proposals
=
proposer
.
get_proposals
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
),
)
assert
proposals
.
proposal_lens
.
tolist
()
==
[
0
]
*
batch_size
vllm/config.py
View file @
f942efb5
...
...
@@ -692,6 +692,7 @@ class SpeculativeConfig:
speculative_max_model_len
:
Optional
[
int
],
enable_chunked_prefill
:
bool
,
use_v2_block_manager
:
bool
,
speculative_disable_by_batch_size
:
Optional
[
int
],
ngram_prompt_lookup_max
:
Optional
[
int
],
ngram_prompt_lookup_min
:
Optional
[
int
],
)
->
Optional
[
"SpeculativeConfig"
]:
...
...
@@ -720,6 +721,9 @@ class SpeculativeConfig:
use_v2_block_manager (bool): Whether vLLM is configured to use the
v2 block manager or not. Used for raising an error since the v2
block manager is required with spec decode.
speculative_disable_by_batch_size (Optional[int]): Disable
speculative decoding for new incoming requests when the number
of enqueue requests is larger than this value, if provided.
ngram_prompt_lookup_max (Optional[int]): Max size of ngram token
window, if provided.
ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
...
...
@@ -730,7 +734,7 @@ class SpeculativeConfig:
the necessary conditions are met, else None.
"""
if
(
speculative_model
is
None
and
num_speculative_tokens
is
None
)
:
if
speculative_model
is
None
and
num_speculative_tokens
is
None
:
return
None
if
speculative_model
is
not
None
and
num_speculative_tokens
is
None
:
...
...
@@ -739,6 +743,12 @@ class SpeculativeConfig:
"num_speculative_tokens to be provided, but found "
f
"
{
speculative_model
=
}
and
{
num_speculative_tokens
=
}
."
)
if
(
speculative_disable_by_batch_size
is
not
None
and
speculative_disable_by_batch_size
<
2
):
raise
ValueError
(
"Expect the batch size threshold of disabling "
"speculative decoding is > 1, but got "
f
"
{
speculative_disable_by_batch_size
=
}
"
)
assert
(
speculative_model
is
not
None
and
num_speculative_tokens
is
not
None
)
...
...
@@ -807,6 +817,7 @@ class SpeculativeConfig:
draft_model_config
,
draft_parallel_config
,
num_speculative_tokens
,
speculative_disable_by_batch_size
,
ngram_prompt_lookup_max
,
ngram_prompt_lookup_min
,
)
...
...
@@ -876,8 +887,9 @@ class SpeculativeConfig:
draft_model_config
:
ModelConfig
,
draft_parallel_config
:
ParallelConfig
,
num_speculative_tokens
:
int
,
ngram_prompt_lookup_max
:
int
,
ngram_prompt_lookup_min
:
int
,
speculative_disable_by_batch_size
:
Optional
[
int
],
ngram_prompt_lookup_max
:
Optional
[
int
],
ngram_prompt_lookup_min
:
Optional
[
int
],
):
"""Create a SpeculativeConfig object.
...
...
@@ -886,12 +898,19 @@ class SpeculativeConfig:
draft_parallel_config: ParallelConfig for the draft model.
num_speculative_tokens: The number of tokens to sample from the
draft model before scoring with the target model.
speculative_disable_by_batch_size: Disable speculative
decoding for new incoming requests when the number of
enqueue requests is larger than this value.
ngram_prompt_lookup_max: Max size of ngram token window.
ngram_prompt_lookup_min: Min size of ngram token window.
"""
self
.
draft_model_config
=
draft_model_config
self
.
draft_parallel_config
=
draft_parallel_config
self
.
num_speculative_tokens
=
num_speculative_tokens
self
.
ngram_prompt_lookup_max
=
ngram_prompt_lookup_max
self
.
ngram_prompt_lookup_min
=
ngram_prompt_lookup_min
self
.
speculative_disable_by_batch_size
=
\
speculative_disable_by_batch_size
self
.
ngram_prompt_lookup_max
=
ngram_prompt_lookup_max
or
0
self
.
ngram_prompt_lookup_min
=
ngram_prompt_lookup_min
or
0
self
.
_verify_args
()
...
...
vllm/engine/arg_utils.py
View file @
f942efb5
...
...
@@ -83,6 +83,7 @@ class EngineArgs:
speculative_model
:
Optional
[
str
]
=
None
num_speculative_tokens
:
Optional
[
int
]
=
None
speculative_max_model_len
:
Optional
[
int
]
=
None
speculative_disable_by_batch_size
:
Optional
[
int
]
=
None
ngram_prompt_lookup_max
:
Optional
[
int
]
=
None
ngram_prompt_lookup_min
:
Optional
[
int
]
=
None
...
...
@@ -467,6 +468,13 @@ class EngineArgs:
'draft model. Sequences over this length will skip '
'speculation.'
)
parser
.
add_argument
(
'--speculative-disable-by-batch-size'
,
type
=
int
,
default
=
EngineArgs
.
speculative_disable_by_batch_size
,
help
=
'Disable speculative decoding for new incoming requests '
'if the number of enqueue requests is larger than this value.'
)
parser
.
add_argument
(
'--ngram-prompt-lookup-max'
,
type
=
int
,
...
...
@@ -547,6 +555,8 @@ class EngineArgs:
target_dtype
=
self
.
dtype
,
speculative_model
=
self
.
speculative_model
,
num_speculative_tokens
=
self
.
num_speculative_tokens
,
speculative_disable_by_batch_size
=
self
.
speculative_disable_by_batch_size
,
speculative_max_model_len
=
self
.
speculative_max_model_len
,
enable_chunked_prefill
=
self
.
enable_chunked_prefill
,
use_v2_block_manager
=
self
.
use_v2_block_manager
,
...
...
vllm/executor/gpu_executor.py
View file @
f942efb5
...
...
@@ -93,6 +93,8 @@ class GPUExecutor(ExecutorBase):
spec_decode_worker
=
SpecDecodeWorker
.
create_worker
(
scorer_worker
=
target_worker
,
draft_worker_kwargs
=
draft_worker_kwargs
,
disable_by_batch_size
=
self
.
speculative_config
.
speculative_disable_by_batch_size
,
)
assert
self
.
parallel_config
.
world_size
==
1
,
(
...
...
vllm/model_executor/layers/rejection_sampler.py
View file @
f942efb5
...
...
@@ -12,15 +12,21 @@ class RejectionSampler(nn.Module):
https://arxiv.org/pdf/2302.01318.pdf.
"""
def
__init__
(
self
,
strict_mode
:
bool
=
False
):
def
__init__
(
self
,
disable_bonus_tokens
:
bool
=
True
,
strict_mode
:
bool
=
False
):
"""Create a rejection sampler.
Args:
disable_bonus_tokens: Whether or not to disable the bonus token.
Require when bonus tokens will cause corrupt KV cache for
proposal methods that require KV cache.
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
nontrivial latency.
"""
super
().
__init__
()
self
.
_disable_bonus_tokens
=
disable_bonus_tokens
self
.
_strict_mode
=
strict_mode
# NOTE: A "bonus token" is accepted iff all proposal tokens are
...
...
@@ -312,7 +318,8 @@ class RejectionSampler(nn.Module):
# proposal methods that require KV cache. We can fix it by "prefilling"
# the bonus token in the proposer. The following issue tracks the fix.
# https://github.com/vllm-project/vllm/issues/4212
output_with_bonus_tokens
[:,
-
1
]
=
-
1
if
self
.
_disable_bonus_tokens
:
output_with_bonus_tokens
[:,
-
1
]
=
-
1
# Fill the recovered token ids.
output
.
mul_
(
~
after_false_mask
).
add_
(
...
...
vllm/sequence.py
View file @
f942efb5
...
...
@@ -612,6 +612,12 @@ class SequenceGroupMetadata:
self
.
_token_chunk_size
=
token_chunk_size
self
.
do_sample
=
do_sample
# The number of speculative tokens adopted in this request.
# None means specuative decoding is not used.
# Zero means speculative decoding is disabled for some reasons.
# TODO: We should maintain this states out of the sequence group.
self
.
num_speculative_tokens
=
None
if
self
.
_token_chunk_size
is
None
:
if
is_prompt
:
self
.
_token_chunk_size
=
list
(
seq_data
.
values
())[
0
].
get_len
()
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
f942efb5
from
functools
import
cached_property
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
torch
...
...
@@ -54,7 +54,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
def
create_worker
(
cls
,
scorer_worker
:
WorkerBase
,
draft_worker_kwargs
,
draft_worker_kwargs
:
Dict
[
str
,
Any
],
disable_by_batch_size
:
Optional
[
int
],
)
->
"SpecDecodeWorker"
:
ngram_prompt_lookup_max
=
(
...
...
@@ -62,7 +63,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
ngram_prompt_lookup_min
=
(
draft_worker_kwargs
.
pop
(
"ngram_prompt_lookup_min"
))
disable_bonus_tokens
=
True
if
ngram_prompt_lookup_max
>
0
:
disable_bonus_tokens
=
False
proposer_worker
=
NGramWorker
(
**
draft_worker_kwargs
)
proposer_worker
.
set_ngram_window_size
(
ngram_prompt_lookup_min
,
ngram_prompt_lookup_max
)
...
...
@@ -75,9 +78,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
return
SpecDecodeWorker
(
proposer_worker
,
scorer_worker
,
# TODO(cade) disable strict mode for speedup.
rejection_sampler
=
RejectionSampler
(
strict_mode
=
True
),
)
disable_by_batch_size
=
disable_by_batch_size
,
rejection_sampler
=
RejectionSampler
(
disable_bonus_tokens
=
disable_bonus_tokens
,
)
)
def
__init__
(
self
,
...
...
@@ -85,6 +88,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
scorer_worker
:
WorkerBase
,
rejection_sampler
:
RejectionSampler
,
metrics_collector
:
Optional
[
AsyncMetricsCollector
]
=
None
,
disable_by_batch_size
:
Optional
[
int
]
=
None
,
):
"""
Create a SpecDecodeWorker.
...
...
@@ -97,11 +101,14 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
Worker.
rejection_sampler: A Torch module used to perform modified rejection
sampling for speculative decoding.
disable_by_batch_size: If the batch size is larger than this,
disable speculative decoding for new incoming requests.
metrics_collector: Helper class for collecting metrics; can be set
for testing purposes.
"""
self
.
proposer_worker
=
proposer_worker
self
.
scorer_worker
=
scorer_worker
self
.
disable_by_batch_size
=
disable_by_batch_size
or
float
(
"inf"
)
self
.
rejection_sampler
=
rejection_sampler
self
.
_metrics
=
AsyncMetricsCollector
(
...
...
@@ -199,27 +206,41 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"speculative decoding "
"requires non-None seq_group_metadata_list"
)
# When the batch size is too large, disable speculative decoding
# to stop trading off throughput for latency.
disable_all
=
(
execute_model_req
.
running_queue_size
>=
self
.
disable_by_batch_size
)
if
disable_all
:
for
seq_group_metadata
in
execute_model_req
.
seq_group_metadata_list
:
# Once num_speculative_tokens is set to 0, the spec decode
# of this request will be disabled forever.
# TODO(comaniac): We currently store spec decoding specific
# state in the global data structure, but we should maintain
# this state within spec decode worker.
seq_group_metadata
.
num_speculative_tokens
=
0
# If no spec tokens, call the proposer and scorer workers normally.
# Used for prefill.
# This happens for prefill, or when the spec decode is disabled
# for this batch.
if
execute_model_req
.
num_lookahead_slots
==
0
or
len
(
execute_model_req
.
seq_group_metadata_list
)
==
0
:
return
self
.
_run_no_spec
(
execute_model_req
)
return
self
.
_run_no_spec
(
execute_model_req
,
skip_proposer
=
disable_all
)
return
self
.
_run_speculative_decoding_step
(
execute_model_req
)
@
nvtx_range
(
"spec_decode_worker._run_no_spec"
)
def
_run_no_spec
(
self
,
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
"""Run a prefill step, without any speculation. The input is sent to the
proposer and scorer model so that the KV cache is consistent between the
two.
def
_run_no_spec
(
self
,
execute_model_req
:
ExecuteModelRequest
,
skip_proposer
:
bool
)
->
List
[
SamplerOutput
]:
"""Run a prefill step, without any speculation. The input is sent to
the proposer and scorer model so that the KV cache is consistent
between the two. When skip_proposer is True, the proposer model is
not called, meaning that the kv-cache in proposer for requests is not
updated, so they cannot enable spec decode in the rest decoding.
"""
#logger.info("run proposer worker no spec")
self
.
proposer_worker
.
execute_model
(
execute_model_req
)
if
not
skip_proposer
:
self
.
proposer_worker
.
execute_model
(
execute_model_req
)
#logger.info("run target worker no spec")
sampler_output
=
self
.
scorer_worker
.
execute_model
(
execute_model_req
)
assert
len
(
sampler_output
)
==
1
sampler_output
=
sampler_output
[
0
]
...
...
@@ -244,22 +265,18 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
sequence.
"""
#logger.info("get spec proposals")
# Generate proposals using draft worker.
proposals
=
self
.
proposer_worker
.
get_spec_proposals
(
execute_model_req
)
#logger.info("score proposals")
proposal_scores
=
self
.
scorer
.
score_proposals
(
execute_model_req
,
proposals
,
)
#logger.info("verify proposals")
accepted_token_ids
,
target_logprobs
=
self
.
_verify_tokens
(
execute_model_req
.
seq_group_metadata_list
,
proposal_scores
,
proposals
,
execute_model_req
.
num_lookahead_slots
)
#logger.info("create output list")
return
self
.
_create_output_sampler_list
(
execute_model_req
.
seq_group_metadata_list
,
accepted_token_ids
,
...
...
vllm/spec_decode/top1_proposer.py
View file @
f942efb5
...
...
@@ -56,7 +56,7 @@ class Top1Proposer(SpeculativeProposer):
proposal_lens
,
nonzero_proposal_len_seqs
,
nonzero_proposal_len_indices
,
)
=
self
.
_split_by_
max_mode
l_len
(
seq_group_metadata_list
,
proposal_len
)
)
=
self
.
_split_by_
proposa
l_len
(
seq_group_metadata_list
,
proposal_len
)
if
nonzero_proposal_len_seqs
:
# Speculate tokens using the draft worker for the speculative
...
...
@@ -97,17 +97,27 @@ class Top1Proposer(SpeculativeProposer):
return
proposals
def
_split_by_
max_mode
l_len
(
def
_split_by_
proposa
l_len
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
proposal_len
:
int
,
)
->
Tuple
[
List
[
int
],
List
[
SequenceGroupMetadata
],
List
[
int
]]:
"""Determine which sequences would exceed the max model length."""
"""Split sequences by two groups:
1. Sequences with non-zero proposal length.
2. Sequences with zero proposal length (due to disabled speculation
or exceed the maximum model length).
"""
proposal_lens
:
List
[
int
]
=
[]
nonzero_proposal_len_seqs
:
List
[
SequenceGroupMetadata
]
=
[]
nonzero_proposal_len_indices
:
List
[
int
]
=
[]
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
# The speculative decoding for this request has been disabled
# (e.g. due to high traffic).
if
seq_group_metadata
.
num_speculative_tokens
==
0
:
proposal_lens
.
append
(
0
)
continue
seq_data
=
next
(
iter
(
seq_group_metadata
.
seq_data
.
values
()))
seq_len
=
seq_data
.
get_len
()
...
...
@@ -115,13 +125,14 @@ class Top1Proposer(SpeculativeProposer):
# are supported.
# If max_proposal_len is defined, then we shall no exccess this
# quota for nonzero_proposal
new_k
=
0
if
(
self
.
max_proposal_len
is
None
or
seq_len
+
proposal_len
<
self
.
max_proposal_len
):
proposal_lens
.
append
(
proposal_len
)
new_k
=
proposal_len
nonzero_proposal_len_seqs
.
append
(
seq_group_metadata
)
nonzero_proposal_len_indices
.
append
(
i
)
else
:
proposal_lens
.
append
(
0
)
proposal_lens
.
append
(
new_k
)
seq_group_metadata
.
num_speculative_tokens
=
new_k
return
(
proposal_lens
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment