Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
15702038
Unverified
Commit
15702038
authored
Oct 01, 2024
by
Lily Liu
Committed by
GitHub
Oct 01, 2024
Browse files
[Spec Decode] (1/2) Remove batch expansion (#8839)
parent
22f5851b
Changes
29
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
175 additions
and
63 deletions
+175
-63
vllm/engine/output_processor/multi_step.py
vllm/engine/output_processor/multi_step.py
+11
-8
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+1
-1
vllm/model_executor/sampling_metadata.py
vllm/model_executor/sampling_metadata.py
+12
-11
vllm/spec_decode/batch_expansion.py
vllm/spec_decode/batch_expansion.py
+0
-7
vllm/spec_decode/draft_model_runner.py
vllm/spec_decode/draft_model_runner.py
+0
-2
vllm/spec_decode/interfaces.py
vllm/spec_decode/interfaces.py
+7
-0
vllm/spec_decode/mqa_scorer.py
vllm/spec_decode/mqa_scorer.py
+80
-0
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+54
-7
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+10
-27
No files found.
vllm/engine/output_processor/multi_step.py
View file @
15702038
import
functools
import
functools
from
typing
import
Callable
,
List
from
typing
import
Callable
,
List
,
Optional
from
vllm.core.scheduler
import
Scheduler
from
vllm.core.scheduler
import
Scheduler
from
vllm.engine.output_processor.interfaces
import
(
from
vllm.engine.output_processor.interfaces
import
(
...
@@ -69,7 +69,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -69,7 +69,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
def
process_outputs
(
self
,
def
process_outputs
(
self
,
sequence_group
:
SequenceGroup
,
sequence_group
:
SequenceGroup
,
outputs
:
List
[
SequenceGroupOutput
],
outputs
:
List
[
SequenceGroupOutput
],
is_async
:
bool
=
False
)
->
None
:
is_async
:
bool
=
False
)
->
Optional
[
int
]
:
"""Append new tokens in the outputs to sequences in the sequence group.
"""Append new tokens in the outputs to sequences in the sequence group.
This only supports sequence groups of size 1. It supports greater than
This only supports sequence groups of size 1. It supports greater than
...
@@ -84,6 +84,10 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -84,6 +84,10 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
tokens from the previous step. If this is true, then
tokens from the previous step. If this is true, then
no tokens need to be appended since it is already done
no tokens need to be appended since it is already done
externally (before the next schedule() call)
externally (before the next schedule() call)
Returns:
The number of tokens appended to the sequence. This is optional
because only speculative decode uses this return value.
"""
"""
# Sequences can be in RUNNING or FINISHED_ABORTED state
# Sequences can be in RUNNING or FINISHED_ABORTED state
# once scheduled, as a sequence is moved to FINSIHED_ABORTED
# once scheduled, as a sequence is moved to FINSIHED_ABORTED
...
@@ -106,6 +110,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -106,6 +110,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
# was already appended, so we only need to do the rest of the
# was already appended, so we only need to do the rest of the
# postprocessor: Detokenization + stopping logic
# postprocessor: Detokenization + stopping logic
self
.
_process_decode_and_stop
(
seq
,
sequence_group
.
sampling_params
)
self
.
_process_decode_and_stop
(
seq
,
sequence_group
.
sampling_params
)
return
None
else
:
else
:
# Standard multi-step case
# Standard multi-step case
...
@@ -121,8 +126,8 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -121,8 +126,8 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
]
]
assert
valid_samples
assert
valid_samples
self
.
_process_seq_outputs
(
seq
,
valid_samples
,
return
self
.
_process_seq_outputs
(
seq
,
valid_samples
,
sequence_group
.
sampling_params
)
sequence_group
.
sampling_params
)
def
_process_decode_and_stop
(
self
,
seq
:
Sequence
,
def
_process_decode_and_stop
(
self
,
seq
:
Sequence
,
sampling_params
:
SamplingParams
)
->
None
:
sampling_params
:
SamplingParams
)
->
None
:
...
@@ -140,7 +145,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -140,7 +145,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
def
_process_seq_outputs
(
self
,
seq
:
Sequence
,
def
_process_seq_outputs
(
self
,
seq
:
Sequence
,
valid_samples
:
List
[
SequenceOutput
],
valid_samples
:
List
[
SequenceOutput
],
sampling_params
:
SamplingParams
)
->
None
:
sampling_params
:
SamplingParams
)
->
int
:
output_token_ids
=
[
sample
.
output_token
for
sample
in
valid_samples
]
output_token_ids
=
[
sample
.
output_token
for
sample
in
valid_samples
]
output_logprobs
=
[
sample
.
logprobs
for
sample
in
valid_samples
]
output_logprobs
=
[
sample
.
logprobs
for
sample
in
valid_samples
]
...
@@ -148,7 +153,6 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -148,7 +153,6 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
remaining_tokens
=
sampling_params
.
max_tokens
-
(
seq
.
get_output_len
()
+
remaining_tokens
=
sampling_params
.
max_tokens
-
(
seq
.
get_output_len
()
+
len
(
output_token_ids
))
len
(
output_token_ids
))
if
remaining_tokens
<
0
:
if
remaining_tokens
<
0
:
valid_samples
=
valid_samples
[:
remaining_tokens
]
output_token_ids
=
output_token_ids
[:
remaining_tokens
]
output_token_ids
=
output_token_ids
[:
remaining_tokens
]
# Truncate any tokens after EOS. This is required as spec decode
# Truncate any tokens after EOS. This is required as spec decode
...
@@ -162,7 +166,6 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -162,7 +166,6 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
for
i
in
range
(
len
(
output_token_ids
)):
for
i
in
range
(
len
(
output_token_ids
)):
if
output_token_ids
[
i
]
==
eos_token_id
:
if
output_token_ids
[
i
]
==
eos_token_id
:
output_token_ids
=
output_token_ids
[:
i
+
1
]
output_token_ids
=
output_token_ids
[:
i
+
1
]
valid_samples
=
valid_samples
[:
i
+
1
]
break
break
# Incrementally append tokens to the sequence, as if we had only one new
# Incrementally append tokens to the sequence, as if we had only one new
...
@@ -173,9 +176,9 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -173,9 +176,9 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
token_id
=
output_token_id
,
token_id
=
output_token_id
,
logprobs
=
output_logprob
,
logprobs
=
output_logprob
,
)
)
seq
.
data
.
update_num_computed_tokens
(
1
)
self
.
_process_decode_and_stop
(
seq
,
sampling_params
)
self
.
_process_decode_and_stop
(
seq
,
sampling_params
)
if
seq
.
is_finished
():
if
seq
.
is_finished
():
break
break
return
len
(
output_token_ids
)
vllm/model_executor/layers/sampler.py
View file @
15702038
...
@@ -912,7 +912,7 @@ def get_logprobs(
...
@@ -912,7 +912,7 @@ def get_logprobs(
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
sample_results
:
SampleResultType
,
sample_results
:
SampleResultType
,
)
->
Tuple
[
List
[
Optional
[
PromptLogprobs
]],
List
[
SampleLogprobs
]]:
)
->
Tuple
[
List
[
Optional
[
PromptLogprobs
]],
List
[
SampleLogprobs
]]:
"""Return sample lo
b
probs and prompt logprobs.
"""Return sample lo
g
probs and prompt logprobs.
The logic consists of 3 parts.
The logic consists of 3 parts.
- Select indices to compute logprob from, ranks of token ids, and
- Select indices to compute logprob from, ranks of token ids, and
...
...
vllm/model_executor/sampling_metadata.py
View file @
15702038
...
@@ -146,7 +146,7 @@ class SamplingMetadata:
...
@@ -146,7 +146,7 @@ class SamplingMetadata:
def
prepare
(
def
prepare
(
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_lens
:
List
[
int
],
seq_lens
:
List
[
int
],
query_lens
:
Optional
[
List
[
int
]
]
,
query_lens
:
List
[
int
],
device
:
str
,
device
:
str
,
pin_memory
:
bool
,
pin_memory
:
bool
,
generators
:
Optional
[
Dict
[
str
,
torch
.
Generator
]]
=
None
,
generators
:
Optional
[
Dict
[
str
,
torch
.
Generator
]]
=
None
,
...
@@ -194,7 +194,7 @@ class SamplingMetadata:
...
@@ -194,7 +194,7 @@ class SamplingMetadata:
def
_prepare_seq_groups
(
def
_prepare_seq_groups
(
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_lens
:
List
[
int
],
seq_lens
:
List
[
int
],
query_lens
:
Optional
[
List
[
int
]
]
,
query_lens
:
List
[
int
],
device
:
str
,
device
:
str
,
generators
:
Optional
[
Dict
[
str
,
torch
.
Generator
]]
=
None
,
generators
:
Optional
[
Dict
[
str
,
torch
.
Generator
]]
=
None
,
cache
:
Optional
[
SamplingMetadataCache
]
=
None
,
cache
:
Optional
[
SamplingMetadataCache
]
=
None
,
...
@@ -284,7 +284,8 @@ def _prepare_seq_groups(
...
@@ -284,7 +284,8 @@ def _prepare_seq_groups(
else
:
else
:
# Decode
# Decode
prompt_logprob_len
=
0
prompt_logprob_len
=
0
sample_len
=
len
(
seq_ids
)
if
do_sample
else
0
query_len
=
query_lens
[
i
]
if
query_lens
is
not
None
else
1
sample_len
=
len
(
seq_ids
)
*
query_len
if
do_sample
else
0
if
sampling_params
.
seed
is
not
None
and
generators
is
not
None
:
if
sampling_params
.
seed
is
not
None
and
generators
is
not
None
:
generator
=
generators
.
get
(
seq_group_metadata
.
request_id
)
generator
=
generators
.
get
(
seq_group_metadata
.
request_id
)
...
@@ -440,14 +441,14 @@ class SamplingTensors:
...
@@ -440,14 +441,14 @@ class SamplingTensors:
if
seq_group
.
do_sample
:
if
seq_group
.
do_sample
:
sample_lens
=
len
(
seq_group
.
sample_indices
)
sample_lens
=
len
(
seq_group
.
sample_indices
)
assert
sample_lens
=
=
len
(
seq_ids
)
assert
sample_lens
>
=
len
(
seq_ids
)
temperatures
+=
[
temperature
]
*
len
(
seq_ids
)
temperatures
+=
[
temperature
]
*
sample_lens
top_ps
+=
[
top_p
]
*
len
(
seq_ids
)
top_ps
+=
[
top_p
]
*
sample_lens
top_ks
+=
[
top_k
]
*
len
(
seq_ids
)
top_ks
+=
[
top_k
]
*
sample_lens
min_ps
+=
[
min_p
]
*
len
(
seq_ids
)
min_ps
+=
[
min_p
]
*
sample_lens
presence_penalties
+=
[
p
]
*
len
(
seq_ids
)
presence_penalties
+=
[
p
]
*
sample_lens
frequency_penalties
+=
[
f
]
*
len
(
seq_ids
)
frequency_penalties
+=
[
f
]
*
sample_lens
repetition_penalties
+=
[
r
]
*
len
(
seq_ids
)
repetition_penalties
+=
[
r
]
*
sample_lens
if
do_penalties
:
if
do_penalties
:
for
seq_group
in
sampling_metadata
.
seq_groups
:
for
seq_group
in
sampling_metadata
.
seq_groups
:
...
...
vllm/spec_decode/batch_expansion.py
View file @
15702038
...
@@ -12,7 +12,6 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID, VLLM_TOKEN_ID_ARRAY_TYPE,
...
@@ -12,7 +12,6 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID, VLLM_TOKEN_ID_ARRAY_TYPE,
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeScorer
,
SpeculativeScores
)
SpeculativeScorer
,
SpeculativeScores
)
from
vllm.spec_decode.util
import
nvtx_range
,
split_batch_by_proposal_len
from
vllm.spec_decode.util
import
nvtx_range
,
split_batch_by_proposal_len
from
vllm.worker.worker_base
import
WorkerBase
SeqId
=
int
SeqId
=
int
TargetSeqId
=
int
TargetSeqId
=
int
...
@@ -36,12 +35,6 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -36,12 +35,6 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
of topk/tree.
of topk/tree.
"""
"""
def
__init__
(
self
,
scorer_worker
:
WorkerBase
,
device
:
str
,
vocab_size
:
int
):
self
.
_scorer_worker
=
scorer_worker
self
.
_device
=
device
self
.
_vocab_size
=
vocab_size
@
nvtx_range
(
"BatchExpansionTop1Scorer.score_proposals"
)
@
nvtx_range
(
"BatchExpansionTop1Scorer.score_proposals"
)
def
score_proposals
(
def
score_proposals
(
self
,
self
,
...
...
vllm/spec_decode/draft_model_runner.py
View file @
15702038
...
@@ -94,8 +94,6 @@ class TP1DraftModelRunner(ModelRunner):
...
@@ -94,8 +94,6 @@ class TP1DraftModelRunner(ModelRunner):
assert
seq_group
.
is_prompt
is
False
# No prompt
assert
seq_group
.
is_prompt
is
False
# No prompt
assert
seq_group
.
prompt_logprob_indices
==
[]
# No prompt
assert
seq_group
.
prompt_logprob_indices
==
[]
# No prompt
assert
seq_group
.
sample_indices
==
[
i
]
# Simple
assert
seq_group
.
sample_indices
==
[
i
]
# Simple
assert
seq_group
.
seq_len
is
None
# Decode
assert
seq_group
.
query_len
is
None
# Decode
def
_gpu_advance_step
(
def
_gpu_advance_step
(
self
,
model_input
:
ModelInputForGPUWithSamplingMetadata
,
self
,
model_input
:
ModelInputForGPUWithSamplingMetadata
,
...
...
vllm/spec_decode/interfaces.py
View file @
15702038
...
@@ -5,6 +5,7 @@ from typing import Optional, Set
...
@@ -5,6 +5,7 @@ from typing import Optional, Set
import
torch
import
torch
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.worker.worker_base
import
WorkerBase
@
dataclass
@
dataclass
...
@@ -74,6 +75,12 @@ class SpeculativeProposer(ABC):
...
@@ -74,6 +75,12 @@ class SpeculativeProposer(ABC):
class
SpeculativeScorer
(
ABC
):
class
SpeculativeScorer
(
ABC
):
def
__init__
(
self
,
scorer_worker
:
WorkerBase
,
device
:
str
,
vocab_size
:
int
):
self
.
_scorer_worker
=
scorer_worker
self
.
_device
=
device
self
.
_vocab_size
=
vocab_size
@
abstractmethod
@
abstractmethod
def
score_proposals
(
def
score_proposals
(
self
,
self
,
...
...
vllm/spec_decode/mqa_scorer.py
0 → 100644
View file @
15702038
from
vllm.sequence
import
(
ExecuteModelRequest
,
SequenceData
,
SequenceGroupMetadata
,
get_all_seq_ids
)
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeScorer
,
SpeculativeScores
)
SeqId
=
int
TargetSeqId
=
int
class
MQAScorer
(
SpeculativeScorer
):
def
score_proposals
(
self
,
execute_model_req
:
ExecuteModelRequest
,
proposals
:
SpeculativeProposals
,
)
->
SpeculativeScores
:
target_seq_group_metadata_list
=
[]
target_seq_id_start
=
max
(
get_all_seq_ids
(
execute_model_req
.
seq_group_metadata_list
))
+
1
all_proposal_tokens
=
proposals
.
proposal_token_ids
.
tolist
()
for
i
,
seq_group_metadata
in
enumerate
(
execute_model_req
.
seq_group_metadata_list
):
seq_data_dict
=
seq_group_metadata
.
seq_data
assert
len
(
seq_data_dict
)
==
1
seq_id
=
next
(
iter
(
seq_data_dict
.
keys
()))
seq_data
:
SequenceData
=
seq_data_dict
[
seq_id
]
prompt_token_ids
=
seq_data
.
get_prompt_token_ids
()
output_token_ids
=
seq_data
.
get_output_token_ids
()
proposal_token_ids
=
all_proposal_tokens
[
i
]
new_output_token_ids
=
[
*
output_token_ids
,
*
proposal_token_ids
]
target_seq_id
=
target_seq_id_start
+
i
new_seq_data
=
SequenceData
.
from_seqs
(
prompt_token_ids
=
prompt_token_ids
,
output_token_ids
=
new_output_token_ids
,
)
new_seq_data
.
update_num_computed_tokens
(
len
(
prompt_token_ids
)
+
len
(
output_token_ids
)
-
1
)
# Ensure that the new sequence has at least one token
# because we only use mqa scorer in the decoding stage.
assert
len
(
output_token_ids
)
>=
1
new_seq_data_dict
=
{
target_seq_id
:
new_seq_data
}
new_seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
seq_group_metadata
.
request_id
,
is_prompt
=
seq_group_metadata
.
is_prompt
,
seq_data
=
new_seq_data_dict
,
sampling_params
=
seq_group_metadata
.
sampling_params
,
block_tables
=
{
target_seq_id
:
seq_group_metadata
.
block_tables
[
seq_id
],
},
lora_request
=
None
,
token_chunk_size
=
1
,
)
target_seq_group_metadata_list
.
append
(
new_seq_group_metadata
)
target_sampler_output
=
self
.
_scorer_worker
.
execute_model
(
execute_model_req
=
execute_model_req
.
clone
(
seq_group_metadata_list
=
target_seq_group_metadata_list
))
target_sampler_output
=
target_sampler_output
[
0
]
bs
,
k
=
proposals
.
proposal_token_ids
.
shape
all_tokens
=
target_sampler_output
.
sampled_token_ids
.
reshape
(
bs
,
k
+
1
)
all_probs
=
target_sampler_output
.
sampled_token_probs
.
reshape
(
bs
,
k
+
1
,
self
.
_vocab_size
)
all_logprobs
=
target_sampler_output
.
logprobs
.
reshape
(
bs
,
k
+
1
,
self
.
_vocab_size
)
hidden_states
=
None
if
target_sampler_output
.
hidden_states
is
not
None
:
hidden_states
=
target_sampler_output
.
hidden_states
.
reshape
(
bs
,
(
k
+
1
),
-
1
)
return
SpeculativeScores
(
probs
=
all_probs
,
token_ids
=
all_tokens
,
logprobs
=
all_logprobs
,
hidden_states
=
hidden_states
)
vllm/spec_decode/spec_decode_worker.py
View file @
15702038
from
collections
import
defaultdict
from
collections
import
defaultdict
from
functools
import
cached_property
from
functools
import
cached_property
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
import
torch
import
torch
...
@@ -24,6 +24,7 @@ from vllm.spec_decode.interfaces import (SpeculativeProposals,
...
@@ -24,6 +24,7 @@ from vllm.spec_decode.interfaces import (SpeculativeProposals,
from
vllm.spec_decode.medusa_worker
import
MedusaWorker
from
vllm.spec_decode.medusa_worker
import
MedusaWorker
from
vllm.spec_decode.metrics
import
AsyncMetricsCollector
from
vllm.spec_decode.metrics
import
AsyncMetricsCollector
from
vllm.spec_decode.mlp_speculator_worker
import
MLPSpeculatorWorker
from
vllm.spec_decode.mlp_speculator_worker
import
MLPSpeculatorWorker
from
vllm.spec_decode.mqa_scorer
import
MQAScorer
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.ngram_worker
import
NGramWorker
from
vllm.spec_decode.ngram_worker
import
NGramWorker
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
...
@@ -70,6 +71,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
...
@@ -70,6 +71,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
spec_decode_worker
=
SpecDecodeWorker
.
create_worker
(
spec_decode_worker
=
SpecDecodeWorker
.
create_worker
(
scorer_worker
=
target_worker
,
scorer_worker
=
target_worker
,
draft_worker_kwargs
=
draft_worker_kwargs
,
draft_worker_kwargs
=
draft_worker_kwargs
,
disable_mqa_scorer
=
speculative_config
.
speculative_disable_mqa_scorer
,
disable_by_batch_size
=
speculative_config
.
disable_by_batch_size
=
speculative_config
.
speculative_disable_by_batch_size
,
speculative_disable_by_batch_size
,
draft_token_acceptance_method
=
speculative_config
.
draft_token_acceptance_method
=
speculative_config
.
...
@@ -116,6 +118,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -116,6 +118,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
cls
,
cls
,
scorer_worker
:
Worker
,
scorer_worker
:
Worker
,
draft_worker_kwargs
:
Dict
[
str
,
Any
],
draft_worker_kwargs
:
Dict
[
str
,
Any
],
disable_mqa_scorer
:
bool
,
disable_by_batch_size
:
Optional
[
int
],
disable_by_batch_size
:
Optional
[
int
],
draft_token_acceptance_method
:
str
,
draft_token_acceptance_method
:
str
,
typical_acceptance_sampler_posterior_threshold
:
float
,
typical_acceptance_sampler_posterior_threshold
:
float
,
...
@@ -173,12 +176,43 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -173,12 +176,43 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
typical_acceptance_sampler_posterior_threshold
,
typical_acceptance_sampler_posterior_threshold
,
posterior_alpha
=
typical_acceptance_sampler_posterior_alpha
,
posterior_alpha
=
typical_acceptance_sampler_posterior_alpha
,
)
)
logger
.
info
(
"Configuring SpecDecodeWorker with sampler=%s"
,
logger
.
info
(
type
(
spec_decode_sampler
))
"[Speculative Decoding] Configuring"
" SpecDecodeWorker with sampler=%s"
,
type
(
spec_decode_sampler
))
if
not
disable_mqa_scorer
:
if
scorer_worker
.
model_runner
.
attn_backend
.
get_name
(
)
!=
"flash-attn"
:
disable_mqa_scorer
=
True
logger
.
info
(
"[Speculative Decoding] Disabling MQA scorer as the "
"MQA is only available with flash attn backend."
)
if
ngram_prompt_lookup_max
>
0
:
disable_mqa_scorer
=
True
logger
.
info
(
"[Speculative Decoding] Disabling MQA scorer as the "
"NGramWorker does not support MQA scorer."
)
if
"model_config"
in
draft_worker_kwargs
and
\
draft_worker_kwargs
[
"model_config"
].
max_model_len
<
\
scorer_worker
.
model_config
.
max_model_len
:
disable_mqa_scorer
=
True
logger
.
info
(
"[Speculative Decoding] Disabling MQA scorer as the "
"draft model max_model_len is smaller than the target "
"model max_model_len."
)
if
not
scorer_worker
.
model_runner
.
model_config
.
enforce_eager
:
disable_mqa_scorer
=
True
logger
.
info
(
"[Speculative Decoding] Disabling MQA scorer as the "
"target model is not running in eager mode."
)
return
SpecDecodeWorker
(
return
SpecDecodeWorker
(
proposer_worker
,
proposer_worker
,
scorer_worker
,
scorer_worker
,
disable_mqa_scorer
=
disable_mqa_scorer
,
disable_logprobs
=
disable_logprobs
,
disable_logprobs
=
disable_logprobs
,
disable_log_stats
=
disable_log_stats
,
disable_log_stats
=
disable_log_stats
,
disable_by_batch_size
=
disable_by_batch_size
,
disable_by_batch_size
=
disable_by_batch_size
,
...
@@ -190,6 +224,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -190,6 +224,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposer_worker
:
ProposerWorkerBase
,
proposer_worker
:
ProposerWorkerBase
,
scorer_worker
:
WorkerBase
,
scorer_worker
:
WorkerBase
,
spec_decode_sampler
:
SpecDecodeBaseSampler
,
spec_decode_sampler
:
SpecDecodeBaseSampler
,
disable_mqa_scorer
:
bool
=
False
,
disable_logprobs
:
bool
=
False
,
disable_logprobs
:
bool
=
False
,
disable_log_stats
:
bool
=
False
,
disable_log_stats
:
bool
=
False
,
metrics_collector
:
Optional
[
AsyncMetricsCollector
]
=
None
,
metrics_collector
:
Optional
[
AsyncMetricsCollector
]
=
None
,
...
@@ -211,6 +246,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -211,6 +246,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
types of sampler namely RejectionSampler and
types of sampler namely RejectionSampler and
TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
instance of RejectionSampler or TypicalAcceptanceSampler.
instance of RejectionSampler or TypicalAcceptanceSampler.
disable_mqa_scorer: If set to True, disable the MQA scorer and use
the BatchExpansionTop1Scorer instead.
disable_logprobs: If set to True, token log probabilities will
disable_logprobs: If set to True, token log probabilities will
not be output in both the draft worker and the target worker.
not be output in both the draft worker and the target worker.
If set to False, log probabilities will be output by both.
If set to False, log probabilities will be output by both.
...
@@ -248,6 +285,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -248,6 +285,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self
.
token_id_dtype
=
self
.
spec_decode_sampler
.
token_id_dtype
self
.
token_id_dtype
=
self
.
spec_decode_sampler
.
token_id_dtype
# Lazy initialization.
# Lazy initialization.
self
.
scorer
:
SpeculativeScorer
self
.
scorer
:
SpeculativeScorer
self
.
disable_mqa_scorer
=
disable_mqa_scorer
# Hidden states from target model to pass to proposer
# Hidden states from target model to pass to proposer
# in the subsequent step.
# in the subsequent step.
...
@@ -270,10 +308,19 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -270,10 +308,19 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self
.
_metrics
.
init_gpu_tensors
(
self
.
rank
)
self
.
_metrics
.
init_gpu_tensors
(
self
.
rank
)
self
.
spec_decode_sampler
.
init_gpu_tensors
(
self
.
rank
)
self
.
spec_decode_sampler
.
init_gpu_tensors
(
self
.
rank
)
self
.
scorer
=
BatchExpansionTop1Scorer
(
scorer_cls
:
Type
[
SpeculativeScorer
]
scorer_worker
=
self
.
scorer_worker
,
if
self
.
disable_mqa_scorer
:
device
=
self
.
device
,
scorer_cls
=
BatchExpansionTop1Scorer
vocab_size
=
self
.
_vocab_size
)
logger
.
info
(
"[Speculative Decoding] Use batch "
"expansion for scoring proposals."
)
else
:
scorer_cls
=
MQAScorer
logger
.
info
(
"[Speculative Decoding] Use MQA scorer for scoring proposals."
)
self
.
scorer
=
scorer_cls
(
scorer_worker
=
self
.
scorer_worker
,
device
=
self
.
device
,
vocab_size
=
self
.
_vocab_size
)
self
.
_configure_model_sampler_for_spec_decode
()
self
.
_configure_model_sampler_for_spec_decode
()
...
...
vllm/worker/model_runner.py
View file @
15702038
...
@@ -468,43 +468,26 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -468,43 +468,26 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Compute context length (the number of tokens that are
# Compute context length (the number of tokens that are
# already computed) and sequence length (total number of tokens).
# already computed) and sequence length (total number of tokens).
seq_len
=
seq_data
.
get_len
()
seq_len
=
seq_data
.
get_len
()
if
inter_data
.
is_prompt
:
if
inter_data
.
is_prompt
:
context_len
=
seq_data
.
get_num_computed_tokens
()
context_len
=
seq_data
.
get_num_computed_tokens
()
else
:
seq_len
=
min
(
seq_len
,
context_len
+
token_chunk_size
)
# get_num_computed_tokens is incorrect for spec decoding.
elif
self
.
runner
.
scheduler_config
.
is_multi_step
or
\
# So, we should have a special logic here.
self
.
runner
.
model_config
.
is_encoder_decoder_model
:
# TODO(sang): Fix it.
context_len
=
seq_len
-
1
context_len
=
seq_len
-
1
seq_len
=
min
(
seq_len
,
context_len
+
token_chunk_size
)
else
:
context_len
=
seq_data
.
get_num_computed_tokens
()
# Compute tokens.
# Compute tokens.
if
inter_data
.
is_prompt
:
tokens
=
seq_data
.
get_token_ids
()[
context_len
:
seq_len
]
tokens
=
seq_data
.
get_token_ids
()
if
context_len
!=
0
or
seq_len
<
len
(
tokens
):
tokens
=
tokens
[
context_len
:
seq_len
]
else
:
# Optimization. get_token_ids requires the entire copy of
# tokens.
tokens
=
seq_data
.
get_last_token_id
()
inter_data
.
seq_lens
[
seq_idx
]
=
seq_len
inter_data
.
seq_lens
[
seq_idx
]
=
seq_len
inter_data
.
orig_seq_lens
[
seq_idx
]
=
seq_len
inter_data
.
orig_seq_lens
[
seq_idx
]
=
seq_len
inter_data
.
context_lens
[
seq_idx
]
=
context_len
inter_data
.
context_lens
[
seq_idx
]
=
context_len
inter_data
.
input_tokens
[
seq_idx
].
extend
(
tokens
)
if
isinstance
(
tokens
,
list
):
inter_data
.
input_positions
[
seq_idx
].
extend
(
range
(
context_len
,
seq_len
))
inter_data
.
input_tokens
[
seq_idx
].
extend
(
tokens
)
inter_data
.
query_lens
[
seq_idx
]
=
seq_len
-
context_len
else
:
inter_data
.
input_tokens
[
seq_idx
].
append
(
tokens
)
if
(
seq_len
-
context_len
)
==
1
:
inter_data
.
input_positions
[
seq_idx
].
append
(
seq_len
-
1
)
else
:
inter_data
.
input_positions
[
seq_idx
].
extend
(
range
(
context_len
,
seq_len
))
inter_data
.
query_lens
[
seq_idx
]
=
seq_len
-
context_len
if
inter_data
.
is_prompt
else
1
if
seq_data
.
mrope_position_delta
is
not
None
:
if
seq_data
.
mrope_position_delta
is
not
None
:
if
inter_data
.
mrope_input_positions
is
None
:
if
inter_data
.
mrope_input_positions
is
None
:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment