Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
dd572c0a
Unverified
Commit
dd572c0a
authored
Jul 18, 2025
by
Woosuk Kwon
Committed by
GitHub
Jul 18, 2025
Browse files
[V0 Deprecation] Remove V0 Spec Decode workers (#21152)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
9ffe905a
Changes
73
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
28 additions
and
2390 deletions
+28
-2390
vllm/config.py
vllm/config.py
+7
-54
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+6
-22
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+0
-8
vllm/engine/metrics.py
vllm/engine/metrics.py
+0
-66
vllm/engine/metrics_types.py
vllm/engine/metrics_types.py
+1
-11
vllm/engine/output_processor/multi_step.py
vllm/engine/output_processor/multi_step.py
+0
-5
vllm/model_executor/layers/rejection_sampler.py
vllm/model_executor/layers/rejection_sampler.py
+0
-406
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+3
-9
vllm/model_executor/layers/spec_decode_base_sampler.py
vllm/model_executor/layers/spec_decode_base_sampler.py
+0
-259
vllm/model_executor/layers/typical_acceptance_sampler.py
vllm/model_executor/layers/typical_acceptance_sampler.py
+0
-166
vllm/model_executor/models/eagle.py
vllm/model_executor/models/eagle.py
+0
-261
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+3
-2
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+4
-8
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+3
-8
vllm/sequence.py
vllm/sequence.py
+1
-13
vllm/spec_decode/__init__.py
vllm/spec_decode/__init__.py
+0
-0
vllm/spec_decode/batch_expansion.py
vllm/spec_decode/batch_expansion.py
+0
-506
vllm/spec_decode/draft_model_runner.py
vllm/spec_decode/draft_model_runner.py
+0
-349
vllm/spec_decode/interfaces.py
vllm/spec_decode/interfaces.py
+0
-99
vllm/spec_decode/medusa_worker.py
vllm/spec_decode/medusa_worker.py
+0
-138
No files found.
vllm/config.py
View file @
dd572c0a
...
@@ -2536,8 +2536,6 @@ class DeviceConfig:
...
@@ -2536,8 +2536,6 @@ class DeviceConfig:
SpeculativeMethod
=
Literal
[
"ngram"
,
"eagle"
,
"eagle3"
,
"medusa"
,
SpeculativeMethod
=
Literal
[
"ngram"
,
"eagle"
,
"eagle3"
,
"medusa"
,
"mlp_speculator"
,
"draft_model"
,
"deepseek_mtp"
]
"mlp_speculator"
,
"draft_model"
,
"deepseek_mtp"
]
SpeculativeAcceptanceMethod
=
Literal
[
"rejection_sampler"
,
"typical_acceptance_sampler"
]
@
config
@
config
...
@@ -2560,13 +2558,6 @@ class SpeculativeConfig:
...
@@ -2560,13 +2558,6 @@ class SpeculativeConfig:
If using `ngram` method, the related configuration `prompt_lookup_max` and
If using `ngram` method, the related configuration `prompt_lookup_max` and
`prompt_lookup_min` should be considered."""
`prompt_lookup_min` should be considered."""
acceptance_method
:
SpeculativeAcceptanceMethod
=
"rejection_sampler"
"""The method to use for accepting draft tokens:
\n
- "rejection_sampler" maps to `RejectionSampler`.
\n
- "typical_acceptance_sampler" maps to `TypicalAcceptanceSampler`.
If using `typical_acceptance_sampler`, the related configuration
`posterior_threshold` and `posterior_alpha` should be considered."""
draft_tensor_parallel_size
:
Optional
[
int
]
=
None
draft_tensor_parallel_size
:
Optional
[
int
]
=
None
"""The degree of the tensor parallelism for the draft model. Can only be 1
"""The degree of the tensor parallelism for the draft model. Can only be 1
or the same as the target model's tensor parallel size."""
or the same as the target model's tensor parallel size."""
...
@@ -2593,9 +2584,6 @@ class SpeculativeConfig:
...
@@ -2593,9 +2584,6 @@ class SpeculativeConfig:
will use the default version."""
will use the default version."""
# Advanced control
# Advanced control
disable_mqa_scorer
:
bool
=
False
"""Disable the MQA scorer and fall back to batch expansion for scoring
proposals."""
disable_by_batch_size
:
Optional
[
int
]
=
None
disable_by_batch_size
:
Optional
[
int
]
=
None
"""Disable speculative decoding for new incoming requests when the number
"""Disable speculative decoding for new incoming requests when the number
of enqueued requests is larger than this value, if provided."""
of enqueued requests is larger than this value, if provided."""
...
@@ -2608,16 +2596,6 @@ class SpeculativeConfig:
...
@@ -2608,16 +2596,6 @@ class SpeculativeConfig:
"""Minimum size of ngram token window when using Ngram proposer, if
"""Minimum size of ngram token window when using Ngram proposer, if
provided. Defaults to 1."""
provided. Defaults to 1."""
# Typical acceptance sampler configuration
posterior_threshold
:
Optional
[
float
]
=
None
"""A threshold value that sets a lower bound on the posterior probability
of a token in the target model for it to be accepted. This threshold is
used only when we use the `TypicalAcceptanceSampler` for token acceptance.
"""
posterior_alpha
:
Optional
[
float
]
=
None
"""Scaling factor for entropy-based threshold, applied when using
`TypicalAcceptanceSampler`."""
speculative_token_tree
:
Optional
[
str
]
=
None
speculative_token_tree
:
Optional
[
str
]
=
None
"""Specifies the tree structure for speculative token generation.
"""Specifies the tree structure for speculative token generation.
"""
"""
...
@@ -2795,8 +2773,8 @@ class SpeculativeConfig:
...
@@ -2795,8 +2773,8 @@ class SpeculativeConfig:
elif
(
self
.
draft_model_config
.
hf_config
.
model_type
==
elif
(
self
.
draft_model_config
.
hf_config
.
model_type
==
"mlp_speculator"
):
"mlp_speculator"
):
self
.
method
=
"mlp_speculator"
self
.
method
=
"mlp_speculator"
elif
(
self
.
draft_model_config
.
hf_config
.
model_type
==
elif
(
self
.
draft_model_config
.
hf_config
.
model_type
"deepseek_mtp"
):
in
(
"deepseek_mtp"
,
"mimo_mtp"
)
):
self
.
method
=
"deepseek_mtp"
self
.
method
=
"deepseek_mtp"
if
self
.
num_speculative_tokens
>
1
:
if
self
.
num_speculative_tokens
>
1
:
logger
.
warning
(
logger
.
warning
(
...
@@ -2806,6 +2784,11 @@ class SpeculativeConfig:
...
@@ -2806,6 +2784,11 @@ class SpeculativeConfig:
)
)
else
:
else
:
self
.
method
=
"draft_model"
self
.
method
=
"draft_model"
raise
NotImplementedError
(
"Speculative decoding with draft model is not "
"supported yet. Please consider using other "
"speculative decoding methods such as ngram, medusa, "
"eagle, or deepseek_mtp."
)
# Replace hf_config for EAGLE draft_model
# Replace hf_config for EAGLE draft_model
if
self
.
method
in
(
"eagle"
,
"eagle3"
):
if
self
.
method
in
(
"eagle"
,
"eagle3"
):
...
@@ -2864,12 +2847,6 @@ class SpeculativeConfig:
...
@@ -2864,12 +2847,6 @@ class SpeculativeConfig:
self
.
target_parallel_config
,
self
.
target_parallel_config
,
self
.
draft_tensor_parallel_size
))
self
.
draft_tensor_parallel_size
))
if
self
.
acceptance_method
==
"typical_acceptance_sampler"
:
if
self
.
posterior_threshold
is
None
:
self
.
posterior_threshold
=
0.09
if
self
.
posterior_alpha
is
None
:
self
.
posterior_alpha
=
0.3
@
staticmethod
@
staticmethod
def
_maybe_override_draft_max_model_len
(
def
_maybe_override_draft_max_model_len
(
speculative_max_model_len
:
Optional
[
int
],
speculative_max_model_len
:
Optional
[
int
],
...
@@ -2975,30 +2952,6 @@ class SpeculativeConfig:
...
@@ -2975,30 +2952,6 @@ class SpeculativeConfig:
if
self
.
draft_model_config
:
if
self
.
draft_model_config
:
self
.
draft_model_config
.
verify_with_parallel_config
(
self
.
draft_model_config
.
verify_with_parallel_config
(
self
.
draft_parallel_config
)
self
.
draft_parallel_config
)
# Validate and set draft token acceptance related settings.
if
self
.
acceptance_method
is
None
:
raise
ValueError
(
"acceptance_method is not set. "
"Expected values are rejection_sampler or "
"typical_acceptance_sampler."
)
if
(
self
.
acceptance_method
!=
'rejection_sampler'
and
self
.
acceptance_method
!=
'typical_acceptance_sampler'
):
raise
ValueError
(
"Expected acceptance_method to be either "
"rejection_sampler or typical_acceptance_sampler. Instead it "
f
"is
{
self
.
acceptance_method
}
"
)
if
self
.
acceptance_method
==
"typical_acceptance_sampler"
and
(
(
self
.
posterior_threshold
is
not
None
and
self
.
posterior_threshold
<
0
)
or
(
self
.
posterior_alpha
is
not
None
and
self
.
posterior_alpha
<
0
)):
raise
ValueError
(
"Expected the posterior_threshold and posterior_alpha of "
"typical_acceptance_sampler to be > 0. "
"Instead found posterior_threshold = "
f
"
{
self
.
posterior_threshold
}
and posterior_alpha = "
f
"
{
self
.
posterior_alpha
}
"
)
if
(
self
.
disable_by_batch_size
is
not
None
if
(
self
.
disable_by_batch_size
is
not
None
and
self
.
disable_by_batch_size
<
2
):
and
self
.
disable_by_batch_size
<
2
):
...
...
vllm/engine/arg_utils.py
View file @
dd572c0a
...
@@ -1417,28 +1417,12 @@ class EngineArgs:
...
@@ -1417,28 +1417,12 @@ class EngineArgs:
return
False
return
False
# V1 supports N-gram, Medusa, and Eagle speculative decoding.
# V1 supports N-gram, Medusa, and Eagle speculative decoding.
is_ngram_enabled
=
False
if
(
self
.
speculative_config
is
not
None
is_eagle_enabled
=
False
and
self
.
speculative_config
.
get
(
"method"
)
==
"draft_model"
):
is_medusa_enabled
=
False
raise
NotImplementedError
(
if
self
.
speculative_config
is
not
None
:
"Speculative decoding with draft model is not supported yet. "
# This is supported but experimental (handled below).
"Please consider using other speculative decoding methods "
speculative_method
=
self
.
speculative_config
.
get
(
"method"
)
"such as ngram, medusa, eagle, or deepseek_mtp."
)
if
speculative_method
:
if
speculative_method
in
(
"ngram"
,
"[ngram]"
):
is_ngram_enabled
=
True
elif
speculative_method
==
"medusa"
:
is_medusa_enabled
=
True
elif
speculative_method
in
(
"eagle"
,
"eagle3"
,
"deepseek_mtp"
):
is_eagle_enabled
=
True
else
:
speculative_model
=
self
.
speculative_config
.
get
(
"model"
)
if
speculative_model
in
(
"ngram"
,
"[ngram]"
):
is_ngram_enabled
=
True
if
not
(
is_ngram_enabled
or
is_eagle_enabled
or
is_medusa_enabled
):
# Other speculative decoding methods are not supported yet.
_raise_or_fallback
(
feature_name
=
"Speculative Decoding"
,
recommend_to_remove
=
False
)
return
False
# No XFormers so far.
# No XFormers so far.
V1_BACKENDS
=
[
V1_BACKENDS
=
[
...
...
vllm/engine/llm_engine.py
View file @
dd572c0a
...
@@ -1780,13 +1780,6 @@ class LLMEngine:
...
@@ -1780,13 +1780,6 @@ class LLMEngine:
num_generation_tokens_from_prefill_groups
)
num_generation_tokens_from_prefill_groups
)
num_tokens_iter
=
(
num_generation_tokens_iter
+
num_tokens_iter
=
(
num_generation_tokens_iter
+
num_prompt_tokens_iter
)
num_prompt_tokens_iter
)
# Spec decode, if enabled, emits specialized metrics from the worker in
# sampler output.
if
model_output
and
isinstance
(
model_output
[
0
],
SamplerOutput
)
and
(
model_output
[
0
].
spec_decode_worker_metrics
is
not
None
):
spec_decode_metrics
=
model_output
[
0
].
spec_decode_worker_metrics
else
:
spec_decode_metrics
=
None
return
Stats
(
return
Stats
(
now
=
now
,
now
=
now
,
...
@@ -1808,7 +1801,6 @@ class LLMEngine:
...
@@ -1808,7 +1801,6 @@ class LLMEngine:
num_tokens_iter
=
num_tokens_iter
,
num_tokens_iter
=
num_tokens_iter
,
time_to_first_tokens_iter
=
time_to_first_tokens_iter
,
time_to_first_tokens_iter
=
time_to_first_tokens_iter
,
time_per_output_tokens_iter
=
time_per_output_tokens_iter
,
time_per_output_tokens_iter
=
time_per_output_tokens_iter
,
spec_decode_metrics
=
spec_decode_metrics
,
num_preemption_iter
=
num_preemption_iter
,
num_preemption_iter
=
num_preemption_iter
,
# Request stats
# Request stats
...
...
vllm/engine/metrics.py
View file @
dd572c0a
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
time
import
time
from
typing
import
TYPE_CHECKING
from
typing
import
Counter
as
CollectionsCounter
from
typing
import
Counter
as
CollectionsCounter
from
typing
import
Dict
,
List
,
Optional
,
Type
,
Union
,
cast
from
typing
import
Dict
,
List
,
Optional
,
Type
,
Union
,
cast
...
@@ -19,9 +18,6 @@ if ray is not None:
...
@@ -19,9 +18,6 @@ if ray is not None:
else
:
else
:
ray_metrics
=
None
ray_metrics
=
None
if
TYPE_CHECKING
:
from
vllm.spec_decode.metrics
import
SpecDecodeWorkerMetrics
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
prometheus_client
.
disable_created_metrics
()
prometheus_client
.
disable_created_metrics
()
...
@@ -199,30 +195,6 @@ class Metrics:
...
@@ -199,30 +195,6 @@ class Metrics:
documentation
=
"Count of successfully processed requests."
,
documentation
=
"Count of successfully processed requests."
,
labelnames
=
labelnames
+
[
Metrics
.
labelname_finish_reason
])
labelnames
=
labelnames
+
[
Metrics
.
labelname_finish_reason
])
# Speculative decoding stats
self
.
gauge_spec_decode_draft_acceptance_rate
=
self
.
_gauge_cls
(
name
=
"vllm:spec_decode_draft_acceptance_rate"
,
documentation
=
"Speulative token acceptance rate."
,
labelnames
=
labelnames
,
multiprocess_mode
=
"sum"
)
self
.
gauge_spec_decode_efficiency
=
self
.
_gauge_cls
(
name
=
"vllm:spec_decode_efficiency"
,
documentation
=
"Speculative decoding system efficiency."
,
labelnames
=
labelnames
,
multiprocess_mode
=
"sum"
)
self
.
counter_spec_decode_num_accepted_tokens
=
(
self
.
_counter_cls
(
name
=
"vllm:spec_decode_num_accepted_tokens_total"
,
documentation
=
"Number of accepted tokens."
,
labelnames
=
labelnames
))
self
.
counter_spec_decode_num_draft_tokens
=
self
.
_counter_cls
(
name
=
"vllm:spec_decode_num_draft_tokens_total"
,
documentation
=
"Number of draft tokens."
,
labelnames
=
labelnames
)
self
.
counter_spec_decode_num_emitted_tokens
=
(
self
.
_counter_cls
(
name
=
"vllm:spec_decode_num_emitted_tokens_total"
,
documentation
=
"Number of emitted tokens."
,
labelnames
=
labelnames
))
# --8<-- [end:metrics-definitions]
# --8<-- [end:metrics-definitions]
...
@@ -391,9 +363,6 @@ class LoggingStatLogger(StatLoggerBase):
...
@@ -391,9 +363,6 @@ class LoggingStatLogger(StatLoggerBase):
self
.
num_prompt_tokens
.
append
(
stats
.
num_prompt_tokens_iter
)
self
.
num_prompt_tokens
.
append
(
stats
.
num_prompt_tokens_iter
)
self
.
num_generation_tokens
.
append
(
stats
.
num_generation_tokens_iter
)
self
.
num_generation_tokens
.
append
(
stats
.
num_generation_tokens_iter
)
# Update spec decode metrics
self
.
maybe_update_spec_decode_metrics
(
stats
)
# Log locally every local_interval seconds.
# Log locally every local_interval seconds.
if
local_interval_elapsed
(
stats
.
now
,
self
.
last_local_log
,
if
local_interval_elapsed
(
stats
.
now
,
self
.
last_local_log
,
self
.
local_interval
):
self
.
local_interval
):
...
@@ -435,10 +404,6 @@ class LoggingStatLogger(StatLoggerBase):
...
@@ -435,10 +404,6 @@ class LoggingStatLogger(StatLoggerBase):
stats
.
gpu_prefix_cache_hit_rate
*
100
,
stats
.
gpu_prefix_cache_hit_rate
*
100
,
stats
.
cpu_prefix_cache_hit_rate
*
100
,
stats
.
cpu_prefix_cache_hit_rate
*
100
,
)
)
if
self
.
spec_decode_metrics
is
not
None
:
log_fn
(
self
.
_format_spec_decode_metrics_str
(
self
.
spec_decode_metrics
))
self
.
_reset
(
stats
,
prompt_throughput
,
generation_throughput
)
self
.
_reset
(
stats
,
prompt_throughput
,
generation_throughput
)
...
@@ -447,21 +412,9 @@ class LoggingStatLogger(StatLoggerBase):
...
@@ -447,21 +412,9 @@ class LoggingStatLogger(StatLoggerBase):
self
.
num_prompt_tokens
=
[]
self
.
num_prompt_tokens
=
[]
self
.
num_generation_tokens
=
[]
self
.
num_generation_tokens
=
[]
self
.
last_local_log
=
stats
.
now
self
.
last_local_log
=
stats
.
now
self
.
spec_decode_metrics
=
None
self
.
last_prompt_throughput
=
prompt_throughput
self
.
last_prompt_throughput
=
prompt_throughput
self
.
last_generation_throughput
=
generation_throughput
self
.
last_generation_throughput
=
generation_throughput
def
_format_spec_decode_metrics_str
(
self
,
metrics
:
"SpecDecodeWorkerMetrics"
)
->
str
:
return
(
"Speculative metrics: "
f
"Draft acceptance rate:
{
metrics
.
draft_acceptance_rate
:.
3
f
}
, "
f
"System efficiency:
{
metrics
.
system_efficiency
:.
3
f
}
, "
f
"Number of speculative tokens:
{
metrics
.
num_spec_tokens
}
, "
f
"Number of accepted tokens:
{
metrics
.
accepted_tokens
}
, "
f
"Number of draft tokens:
{
metrics
.
draft_tokens
}
, "
f
"Number of emitted tokens:
{
metrics
.
emitted_tokens
}
."
)
def
info
(
self
,
type
:
str
,
obj
:
SupportsMetricsInfo
)
->
None
:
def
info
(
self
,
type
:
str
,
obj
:
SupportsMetricsInfo
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -579,33 +532,14 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -579,33 +532,14 @@ class PrometheusStatLogger(StatLoggerBase):
self
.
num_prompt_tokens
.
append
(
stats
.
num_prompt_tokens_iter
)
self
.
num_prompt_tokens
.
append
(
stats
.
num_prompt_tokens_iter
)
self
.
num_generation_tokens
.
append
(
stats
.
num_generation_tokens_iter
)
self
.
num_generation_tokens
.
append
(
stats
.
num_generation_tokens_iter
)
# Update spec decode metrics
self
.
maybe_update_spec_decode_metrics
(
stats
)
# Log locally every local_interval seconds.
# Log locally every local_interval seconds.
if
local_interval_elapsed
(
stats
.
now
,
self
.
last_local_log
,
if
local_interval_elapsed
(
stats
.
now
,
self
.
last_local_log
,
self
.
local_interval
):
self
.
local_interval
):
if
self
.
spec_decode_metrics
is
not
None
:
self
.
_log_gauge
(
self
.
metrics
.
gauge_spec_decode_draft_acceptance_rate
,
self
.
spec_decode_metrics
.
draft_acceptance_rate
)
self
.
_log_gauge
(
self
.
metrics
.
gauge_spec_decode_efficiency
,
self
.
spec_decode_metrics
.
system_efficiency
)
self
.
_log_counter
(
self
.
metrics
.
counter_spec_decode_num_accepted_tokens
,
self
.
spec_decode_metrics
.
accepted_tokens
)
self
.
_log_counter
(
self
.
metrics
.
counter_spec_decode_num_draft_tokens
,
self
.
spec_decode_metrics
.
draft_tokens
)
self
.
_log_counter
(
self
.
metrics
.
counter_spec_decode_num_emitted_tokens
,
self
.
spec_decode_metrics
.
emitted_tokens
)
# Reset tracked stats for next interval.
# Reset tracked stats for next interval.
self
.
num_prompt_tokens
=
[]
self
.
num_prompt_tokens
=
[]
self
.
num_generation_tokens
=
[]
self
.
num_generation_tokens
=
[]
self
.
last_local_log
=
stats
.
now
self
.
last_local_log
=
stats
.
now
self
.
spec_decode_metrics
=
None
def
info
(
self
,
type
:
str
,
obj
:
SupportsMetricsInfo
)
->
None
:
def
info
(
self
,
type
:
str
,
obj
:
SupportsMetricsInfo
)
->
None
:
# Info type metrics are syntactic sugar for a gauge permanently set to 1
# Info type metrics are syntactic sugar for a gauge permanently set to 1
...
...
vllm/engine/metrics_types.py
View file @
dd572c0a
...
@@ -16,10 +16,9 @@ do this in Python code and lazily import prometheus_client.
...
@@ -16,10 +16,9 @@ do this in Python code and lazily import prometheus_client.
import
time
import
time
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
from
typing
import
List
from
vllm.config
import
SupportsMetricsInfo
,
VllmConfig
from
vllm.config
import
SupportsMetricsInfo
,
VllmConfig
from
vllm.spec_decode.metrics
import
SpecDecodeWorkerMetrics
@
dataclass
@
dataclass
...
@@ -65,8 +64,6 @@ class Stats:
...
@@ -65,8 +64,6 @@ class Stats:
running_lora_adapters
:
List
[
str
]
running_lora_adapters
:
List
[
str
]
max_lora
:
str
max_lora
:
str
spec_decode_metrics
:
Optional
[
"SpecDecodeWorkerMetrics"
]
=
None
class
StatLoggerBase
(
ABC
):
class
StatLoggerBase
(
ABC
):
"""Base class for StatLogger."""
"""Base class for StatLogger."""
...
@@ -77,7 +74,6 @@ class StatLoggerBase(ABC):
...
@@ -77,7 +74,6 @@ class StatLoggerBase(ABC):
self
.
num_generation_tokens
:
List
[
int
]
=
[]
self
.
num_generation_tokens
:
List
[
int
]
=
[]
self
.
last_local_log
=
time
.
time
()
self
.
last_local_log
=
time
.
time
()
self
.
local_interval
=
local_interval
self
.
local_interval
=
local_interval
self
.
spec_decode_metrics
:
Optional
[
SpecDecodeWorkerMetrics
]
=
None
@
abstractmethod
@
abstractmethod
def
log
(
self
,
stats
:
Stats
)
->
None
:
def
log
(
self
,
stats
:
Stats
)
->
None
:
...
@@ -86,9 +82,3 @@ class StatLoggerBase(ABC):
...
@@ -86,9 +82,3 @@ class StatLoggerBase(ABC):
@
abstractmethod
@
abstractmethod
def
info
(
self
,
type
:
str
,
obj
:
SupportsMetricsInfo
)
->
None
:
def
info
(
self
,
type
:
str
,
obj
:
SupportsMetricsInfo
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
def
maybe_update_spec_decode_metrics
(
self
,
stats
:
Stats
):
"""Save spec decode metrics (since they are unlikely
to be emitted at same time as log interval)."""
if
stats
.
spec_decode_metrics
is
not
None
:
self
.
spec_decode_metrics
=
stats
.
spec_decode_metrics
vllm/engine/output_processor/multi_step.py
View file @
dd572c0a
...
@@ -104,11 +104,6 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -104,11 +104,6 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
seqs
=
sequence_group
.
get_seqs
(
seqs
=
sequence_group
.
get_seqs
(
status
=
SequenceStatus
.
FINISHED_ABORTED
)
status
=
SequenceStatus
.
FINISHED_ABORTED
)
for
output
in
outputs
:
if
output
.
samples
[
0
].
output_token
!=
VLLM_INVALID_TOKEN_ID
:
sequence_group
.
metrics
.
spec_token_acceptance_counts
[
output
.
step_index
]
+=
1
assert
seqs
,
"Expected RUNNING or FINISHED_ABORTED sequences"
assert
seqs
,
"Expected RUNNING or FINISHED_ABORTED sequences"
assert
len
(
seqs
)
==
1
,
(
assert
len
(
seqs
)
==
1
,
(
"Beam search not supported in multi-step decoding."
)
"Beam search not supported in multi-step decoding."
)
...
...
vllm/model_executor/layers/rejection_sampler.py
deleted
100644 → 0
View file @
9ffe905a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
functools
import
cached_property
from
importlib.util
import
find_spec
from
typing
import
Optional
import
torch
import
torch.jit
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.spec_decode_base_sampler
import
(
SpecDecodeStochasticBaseSampler
)
from
vllm.platforms
import
current_platform
logger
=
init_logger
(
__name__
)
if
find_spec
(
"flashinfer"
):
"""
Consider utilizing the FlashInfer rejection sampling kernel initially,
as it employs a dedicated kernel rather than relying on
Torch tensor operations. This design choice helps to fuse operations,
reduce memory I/O, and consequently enhances performance.
"""
from
flashinfer.sampling
import
chain_speculative_sampling
else
:
chain_speculative_sampling
=
None
class
RejectionSampler
(
SpecDecodeStochasticBaseSampler
):
"""Apply modified rejection sampling as described in "Accelerating Large
Language Model Decoding with Speculative Sampling"
https://arxiv.org/pdf/2302.01318.pdf.
"""
def
__init__
(
self
,
strict_mode
:
bool
=
False
,
use_flashinfer
:
Optional
[
bool
]
=
None
):
"""Create a rejection sampler.
Args:
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
nontrivial latency.
use_flashinfer: We will use this parameter to determine whether
to use the FlashInfer rejection sampling kernel or not. If it's
None, we will use the default value from the environment variable.
This parameter is only used for testing purposes.
"""
super
().
__init__
(
strict_mode
=
strict_mode
)
if
use_flashinfer
is
None
:
self
.
use_flashinfer
=
envs
.
VLLM_USE_FLASHINFER_SAMPLER
and
(
chain_speculative_sampling
is
not
None
)
else
:
self
.
use_flashinfer
=
use_flashinfer
if
self
.
use_flashinfer
:
logger
.
info
(
"Use flashinfer for rejection sampling."
)
else
:
logger
.
info
(
"Use pytorch for rejection sampling."
)
def
forward
(
self
,
target_with_bonus_probs
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
seeded_seqs
:
Optional
[
dict
[
int
,
torch
.
Generator
]]
=
None
,
)
->
torch
.
Tensor
:
"""Sample token ids using rejection sampling. This accepts or rejects
tokens proposed by the draft model using the probability of each token
according to the draft and target models.
In the worst case where all draft tokens are rejected, it is guaranteed
one correct token will be emitted.
In the case where all draft tokens are accepted, a bonus token will be
accepted as its cheap to have the target model score this speculative
sequence.
Args:
target_with_bonus_probs: The probability distribution
over token ids given context according to the target model.
shape = [batch_size, num_speculative_tokens + 1, vocab_size]
bonus_token_ids: The "bonus" token ids that are accepted iff all
speculative tokens in a sequence are accepted.
shape = [batch_size, num_bonus_tokens]
draft_probs: The probability distribution over token ids given
context according to the draft model.
shape = [batch_size, num_speculative_tokens, vocab_size]
draft_token_ids: The token ids that were sampled from the draft
probabilities.
shape = [batch_size, num_speculative_tokens]
seeded_seqs: Dict of batch row index to torch generator, for
sequences using seeded generation.
Returns:
output_token_ids: The token ids sampled via rejection sampling,
or -1 if unable to sample a token because the previous token
was rejected.
shape = [batch_size, num_speculative_tokens + num_bonus_tokens]
"""
# Only perform shape/dtype/device checking in strict mode, as it adds
# overhead.
if
self
.
_strict_mode
:
self
.
_raise_if_incorrect_input
(
target_with_bonus_probs
,
draft_token_ids
,
bonus_token_ids
,
draft_probs
)
batch_size
,
k
,
_
=
draft_probs
.
shape
# batch_size = 0 when all requests in the batch are
# non_spec requests. In this case, output_token_ids is
# just an empty tensor.
if
batch_size
==
0
:
return
torch
.
empty
(
0
,
k
+
1
,
device
=
draft_probs
.
device
,
dtype
=
int
)
# If use Flashinfer chain_speculative_sampling kernel
# for rejection sampling
if
self
.
use_flashinfer
and
chain_speculative_sampling
is
not
None
:
batch_size
,
k
,
_
=
draft_probs
.
shape
(
output_token_ids
,
accepted_token_num
,
emitted_token_num
)
=
chain_speculative_sampling
(
draft_probs
,
draft_token_ids
,
target_with_bonus_probs
,
)
# num_emitted_tokens returned by flashinfer
# does not include the bonus token
# Flashinfer stops at the first token that violates
# the condition p >= q and does not include recovery/bonus token.
# Therefore, we need to add batch_size here.
self
.
num_accepted_tokens
+=
accepted_token_num
.
sum
()
self
.
num_emitted_tokens
+=
emitted_token_num
.
sum
()
+
batch_size
self
.
num_draft_tokens
+=
batch_size
*
k
else
:
accepted
,
recovered_token_ids
=
(
self
.
_batch_modified_rejection_sampling
(
target_with_bonus_probs
[:,
:
-
1
],
draft_probs
,
draft_token_ids
,
seeded_seqs
,
))
output_token_ids
=
self
.
_create_output
(
accepted
,
recovered_token_ids
,
draft_token_ids
,
bonus_token_ids
,
)
return
output_token_ids
def
_batch_modified_rejection_sampling
(
self
,
target_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_token_ids
:
torch
.
Tensor
,
# [batch_size, k]
seeded_seqs
:
Optional
[
dict
[
int
,
torch
.
Generator
]],
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Perform modified rejection sampling on each sequence.
Returns:
A tuple of two tensors:
0: A bool tensor of which tokens in each sequence is accepted.
shape = [batch_size, k]
1: Token ids sampled from a recovered distribution, to be used
when a token is rejected.
shape = [batch_size, k]
"""
batch_size
,
k
,
vocab_size
=
draft_probs
.
shape
# shape [batch_size, k]
accepted
=
self
.
_get_accepted
(
target_probs
,
draft_probs
,
draft_token_ids
,
seeded_seqs
)
recovered_probs
=
self
.
_get_recovered_probs
(
target_probs
,
draft_probs
).
reshape
(
batch_size
*
k
,
vocab_size
)
# NOTE: the recovered_probs are overwritten by this method.
recovered_token_ids
=
_multinomial
(
recovered_probs
,
num_samples
=
1
,
k
=
k
,
seeded_seqs
=
seeded_seqs
or
{},
).
reshape
(
batch_size
,
k
)
return
accepted
,
recovered_token_ids
def
_create_uniform_samples
(
self
,
seeded_seqs
:
Optional
[
dict
[
int
,
torch
.
Generator
]],
batch_size
:
int
,
k
:
int
,
device
:
torch
.
device
)
->
torch
.
Tensor
:
"""
Generates a batch of uniform random samples, with optional seeding
for specific sequences.
This method creates a tensor of shape `(batch_size, k + 1)` filled
with uniform random values in the range [0, 1). If `seeded_seqs`
is provided, the sequences corresponding to specific indices
will be generated using the provided `torch.Generator` for
reproducibility. The other sequences will be generated without
a seed.
Args:
seeded_seqs : Optional[dict[int, torch.Generator]]
A dictionary mapping indices in the batch to
`torch.Generator` objects. If `None`, all samples are
generated without a seed.
batch_size : int
The number of sequences to generate.
k : int
The number of random samples per sequence.
device : torch.device
The device on which to allocate the tensor.
Returns:
uniform_rand : torch.Tensor
A tensor of shape `(batch_size, k + 1)` containing uniform
random values in the range [0, 1).
"""
if
not
seeded_seqs
:
return
torch
.
rand
(
batch_size
,
k
+
1
,
device
=
device
)
uniform_rand
=
torch
.
empty
(
batch_size
,
k
+
1
,
device
=
device
)
non_seeded_indices
=
[]
for
idx
in
range
(
batch_size
):
generator
=
seeded_seqs
.
get
(
idx
)
if
generator
is
None
:
non_seeded_indices
.
append
(
idx
)
else
:
uniform_rand
[
idx
,
:]
=
torch
.
rand
(
1
,
k
+
1
,
dtype
=
self
.
probs_dtype
,
device
=
device
,
generator
=
generator
)
if
non_seeded_indices
:
uniform_rand
[
non_seeded_indices
,
:]
=
torch
.
rand
(
len
(
non_seeded_indices
),
k
+
1
,
dtype
=
self
.
probs_dtype
,
device
=
device
)
return
uniform_rand
def
_get_accepted
(
self
,
target_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_token_ids
:
torch
.
Tensor
,
# [batch_size, k]
seeded_seqs
:
Optional
[
dict
[
int
,
torch
.
Generator
]],
)
->
torch
.
Tensor
:
r
"""Create bool matrix over the proposed draft tokens. If
True, then a token can be accepted, else it should be
rejected.
Given $q(\hat{x}_{n+1}|x_1, \dots, x_n)$, the probability of
$\hat{x}_{n+1}$ given context $x_1, \dots, x_n$ according
to the target model, and $p(\hat{x}_{n+1}|x_1, \dots, x_n)$, the
same conditional probability according to the draft model, the token
is accepted with probability:
$$
\min\left(1, \frac{q(\hat{x}_{n+1}|x_1, \dots, x_n)}
{p(\hat{x}_{n+1}|x_1, \dots, x_n)}\right)
$$
This implementation does not apply causality. When using the output,
if a token is rejected, subsequent tokens should not be used.
Returns a bool tensor of shape [batch_size, k] specifying which tokens
are accepted.
"""
batch_size
,
k
,
_
=
draft_probs
.
shape
batch_indices
=
torch
.
arange
(
batch_size
,
device
=
target_probs
.
device
)[:,
None
]
probs_indices
=
torch
.
arange
(
k
,
device
=
target_probs
.
device
)
# shape [batch_size, k]
selected_draft_probs
=
draft_probs
[
batch_indices
,
probs_indices
,
draft_token_ids
]
# shape [batch_size, k]
selected_target_probs
=
target_probs
[
batch_indices
,
probs_indices
,
draft_token_ids
]
uniform_rand
=
self
.
_create_uniform_samples
(
seeded_seqs
,
batch_size
,
k
-
1
,
target_probs
.
device
)
capped_ratio
=
torch
.
minimum
(
selected_target_probs
/
selected_draft_probs
,
torch
.
full
((
1
,
),
1
,
device
=
target_probs
.
device
))
accepted
=
uniform_rand
<
capped_ratio
return
accepted
def
_get_recovered_probs
(
self
,
target_probs
:
torch
.
Tensor
,
# [k, vocab_size]
draft_probs
:
torch
.
Tensor
,
# [k, vocab_size]
)
->
torch
.
Tensor
:
r
"""Create a probability distribution for each proposed token which can
be sampled if the proposed token is rejected.
When this routine is applied sequentially, the true distribution of the
target model is recovered (within hardware numerics).
The probability distribution used in this rejection case is constructed
as follows. Given $q(x|x_1, \dots, x_n)$, the probability of
$x$ given context $x_1, \dots, x_n$ according to the target
model and $p(x|x_1, \dots, x_n)$, the same conditional probability
according to the draft model:
$$
x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+
$$
where $(f(x))_+$ is defined as:
$$
(f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))}
$$
See https://github.com/vllm-project/vllm/pull/2336 for a visualization
of the draft, target, and recovered probability distributions.
Returns a tensor of shape [batch_size, k, vocab_size].
Note:
This batches operations on GPU and thus constructs the recovered
distribution for all tokens, even if they are accepted. This causes
division-by-zero errors, so we use self._smallest_positive_value to
avoid that. This introduces some drift to the distribution.
"""
_
,
k
,
_
=
draft_probs
.
shape
# shape [batch_size, k, vocab_size]
difference
=
target_probs
-
draft_probs
# TODO(cade): Can we use logprobs instead of probs, and avoid the
# division-by-zero errors without introducing distribution drift?
# shape [batch_size, k, vocab_size]
f
=
torch
.
clamp
(
difference
,
min
=
self
.
_smallest_positive_value
)
# shape [batch_size, k, vocab_size]
recovered_probs
=
f
/
torch
.
sum
(
f
,
dim
=-
1
).
reshape
(
-
1
,
k
,
1
)
return
recovered_probs
@
cached_property
def
_smallest_positive_value
(
self
)
->
float
:
"""Return the smallest positive value representable by the probs dtype.
This value is used when constructing a distribution from which to sample
recovered tokens in the first rejection case.
See _get_recovered_probs for more details
Note that this isn't actually the smallest positive value representable
by float32, but the smallest positive normal value.
See https://en.wikipedia.org/wiki/Subnormal_number for more information.
"""
return
torch
.
finfo
(
self
.
probs_dtype
).
tiny
# torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead that skips the sync.
# Note that we always sample with replacement.
# probs will be modified in place, but this is fine, as we pass
# in a copy already.
@
torch
.
compile
(
dynamic
=
True
,
backend
=
current_platform
.
simple_compile_backend
)
def
_multinomial
(
probs
:
torch
.
Tensor
,
num_samples
:
int
,
k
:
int
,
seeded_seqs
:
dict
[
int
,
torch
.
Generator
],
)
->
torch
.
Tensor
:
if
num_samples
>
1
:
# This is equivalent to torch.repeat_interleaved (which also
# forces a GPU<->CPU sync).
probs
=
probs
[:,
None
,
:].
expand
(
probs
.
shape
[
0
],
num_samples
,
probs
.
shape
[
1
]).
contiguous
().
view
(
-
1
,
probs
.
shape
[
1
])
q
=
torch
.
empty_like
(
probs
)
if
not
seeded_seqs
:
q
.
exponential_
(
1.0
)
else
:
start
=
0
for
idx
in
range
(
len
(
q
)
//
k
):
end
=
start
+
k
generator
=
seeded_seqs
.
get
(
idx
)
# Note: generator might be None for non seeded
q
[
start
:
end
].
exponential_
(
1.0
,
generator
=
generator
)
start
=
end
return
probs
.
div_
(
q
).
argmax
(
dim
=
1
).
view
(
-
1
,
num_samples
)
vllm/model_executor/layers/sampler.py
View file @
dd572c0a
...
@@ -21,7 +21,6 @@ from vllm.sampling_params import SamplingType
...
@@ -21,7 +21,6 @@ from vllm.sampling_params import SamplingType
from
vllm.sequence
import
(
VLLM_INVALID_TOKEN_ID
,
from
vllm.sequence
import
(
VLLM_INVALID_TOKEN_ID
,
CompletionSequenceGroupOutput
,
Logprob
,
CompletionSequenceGroupOutput
,
Logprob
,
PromptLogprobs
,
SampleLogprobs
,
SequenceOutput
)
PromptLogprobs
,
SampleLogprobs
,
SequenceOutput
)
from
vllm.spec_decode.metrics
import
SpecDecodeWorkerMetrics
if
envs
.
VLLM_USE_FLASHINFER_SAMPLER
and
find_spec
(
"flashinfer"
):
if
envs
.
VLLM_USE_FLASHINFER_SAMPLER
and
find_spec
(
"flashinfer"
):
# yapf: disable
# yapf: disable
...
@@ -119,9 +118,6 @@ class SamplerOutput(
...
@@ -119,9 +118,6 @@ class SamplerOutput(
# specified in lieu of prompt token ids or text.
# specified in lieu of prompt token ids or text.
sampled_token_embeds
:
Optional
[
torch
.
Tensor
]
=
None
sampled_token_embeds
:
Optional
[
torch
.
Tensor
]
=
None
# Spec decode metrics populated by workers.
spec_decode_worker_metrics
:
Optional
[
SpecDecodeWorkerMetrics
]
=
None
# Optional last hidden states from the model.
# Optional last hidden states from the model.
hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
...
@@ -159,11 +155,9 @@ class SamplerOutput(
...
@@ -159,11 +155,9 @@ class SamplerOutput(
else
self
.
sampled_token_probs
.
shape
)
else
self
.
sampled_token_probs
.
shape
)
sampled_token_ids_repr
=
(
"None"
if
self
.
sampled_token_ids
is
None
else
sampled_token_ids_repr
=
(
"None"
if
self
.
sampled_token_ids
is
None
else
self
.
sampled_token_ids
.
shape
)
self
.
sampled_token_ids
.
shape
)
return
(
return
(
f
"SamplerOutput(outputs=
{
self
.
outputs
}
, "
f
"SamplerOutput(outputs=
{
self
.
outputs
}
, "
f
"sampled_token_probs=
{
sampled_token_probs_repr
}
, "
f
"sampled_token_probs=
{
sampled_token_probs_repr
}
, "
f
"sampled_token_ids=
{
sampled_token_ids_repr
}
)"
)
f
"sampled_token_ids=
{
sampled_token_ids_repr
}
, "
f
"spec_decode_worker_metrics=
{
self
.
spec_decode_worker_metrics
}
)"
)
class
Sampler
(
nn
.
Module
):
class
Sampler
(
nn
.
Module
):
...
...
vllm/model_executor/layers/spec_decode_base_sampler.py
deleted
100644 → 0
View file @
9ffe905a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
abstractmethod
from
typing
import
Optional
,
Union
import
torch
import
torch.jit
import
torch.nn
as
nn
from
vllm.platforms
import
current_platform
class
SpecDecodeBaseSampler
(
nn
.
Module
):
"""Base class for samplers used for Speculative Decoding verification
step.
"""
def
__init__
(
self
,
strict_mode
:
bool
=
False
):
"""Base class constructor.
Args:
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
nontrivial latency.
"""
super
().
__init__
()
self
.
_strict_mode
=
strict_mode
# NOTE: A "bonus token" is accepted iff all proposal tokens are
# accepted. There is always only one possible bonus token. We store this
# value in a variable for readability.
self
.
_num_bonus_tokens
=
1
self
.
num_accepted_tokens
:
Optional
[
torch
.
Tensor
]
=
None
self
.
num_emitted_tokens
:
Optional
[
torch
.
Tensor
]
=
None
self
.
num_draft_tokens
:
int
=
0
def
init_gpu_tensors
(
self
,
device
:
Union
[
int
,
str
])
->
None
:
assert
self
.
num_accepted_tokens
is
None
if
isinstance
(
device
,
int
):
device
=
f
"
{
current_platform
.
device_type
}
:
{
device
}
"
elif
not
isinstance
(
device
,
str
):
raise
ValueError
(
f
"Device must be int or str, get
{
type
(
device
)
}
"
)
self
.
num_accepted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
device
)
self
.
num_emitted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
device
)
def
init_tensors
(
self
,
device
:
Union
[
int
,
str
],
device_type
:
Union
[
torch
.
device
,
str
]
=
'cuda'
)
->
None
:
assert
self
.
num_accepted_tokens
is
None
if
isinstance
(
device_type
,
torch
.
device
):
device_type
=
device_type
.
type
if
isinstance
(
device
,
int
):
device
=
f
"
{
device_type
}
:
{
device
}
"
self
.
num_accepted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
device
)
self
.
num_emitted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
device
)
@
property
def
probs_dtype
(
self
):
return
torch
.
float32
@
property
def
token_id_dtype
(
self
):
return
torch
.
int64
def
_create_output
(
self
,
accepted
:
torch
.
Tensor
,
# [batch_size, k]
substitute_token_ids
:
torch
.
Tensor
,
# [batch_size, k]
draft_token_ids
:
torch
.
Tensor
,
# [batch_size, k]
bonus_token_ids
:
torch
.
Tensor
,
# [batch_size]
)
->
torch
.
Tensor
:
"""Format output. Returns a matrix of token ids. When
a token is rejected via sampling, all subsequent token ids are
set to -1 for the sequence.
Args:
accepted: A boolean tensor indicating if the corresponding
draft token in draft_token_ids should be accepted or not.
substitute_token_ids: A tensor of token_ids that can be used
as substitutes for the draft token ids if the proposed token
is rejected.
draft_token_ids: A tensor of token ids speculated by the
draft model.
bonus_token_ids: Token ids to use as the bonus token if
all the draft tokens are accepted.
Returns:
A tensor containing the accepted token ids. The shape of the
tensor is [batch_size, k + num_bonus_tokens]
"""
batch_size
,
k
=
substitute_token_ids
.
shape
bonus_token_ids
=
bonus_token_ids
.
squeeze
(
-
1
)
# Determine the index of the first False value for each row.
limits
=
(
accepted
==
0
).
max
(
1
).
indices
limits
[
~
(
accepted
==
0
).
any
(
1
)]
=
k
# Create masks using the indices.
indices
=
torch
.
arange
(
k
,
device
=
accepted
.
device
).
unsqueeze
(
0
)
accepted_mask
=
indices
<
limits
.
unsqueeze
(
1
)
after_false_mask
=
indices
==
limits
.
unsqueeze
(
1
)
# Create an extended output tensor
output_with_bonus_tokens
=
-
torch
.
ones
(
(
batch_size
,
k
+
self
.
_num_bonus_tokens
),
dtype
=
self
.
token_id_dtype
,
device
=
accepted
.
device
)
output
=
output_with_bonus_tokens
[:,
:
k
]
# Fill in the first k columns of the output tensor using masks and data
# tensors.
output
[:,
:
k
]
=
torch
.
where
(
accepted_mask
,
draft_token_ids
,
-
torch
.
ones_like
(
draft_token_ids
))
# Fill the last column.
# We check output directly as accepted may have True values inconsistent
# with causal acceptance.
output_with_bonus_tokens
[:,
-
1
]
=
torch
.
where
(
output
[:,
-
1
]
!=
-
1
,
bonus_token_ids
,
-
1
)
# Fill the recovered token ids.
output
.
mul_
(
~
after_false_mask
).
add_
(
substitute_token_ids
.
mul
(
after_false_mask
))
self
.
num_accepted_tokens
+=
accepted
.
sum
()
self
.
num_emitted_tokens
+=
(
output_with_bonus_tokens
!=
-
1
).
sum
()
self
.
num_draft_tokens
+=
batch_size
*
k
return
output_with_bonus_tokens
def
_raise_if_incorrect_input
(
self
,
target_with_bonus_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
self
.
_raise_if_incorrect_shape
(
target_with_bonus_probs
,
draft_token_ids
,
bonus_token_ids
,
draft_probs
)
self
.
_raise_if_incorrect_dtype
(
target_with_bonus_probs
,
draft_token_ids
,
bonus_token_ids
,
draft_probs
)
self
.
_raise_if_inconsistent_device
(
target_with_bonus_probs
,
draft_token_ids
,
bonus_token_ids
,
draft_probs
)
self
.
_raise_if_out_of_bounds_vocab
(
target_with_bonus_probs
.
shape
[
-
1
],
draft_token_ids
,
bonus_token_ids
)
def
_raise_if_incorrect_shape
(
self
,
target_with_bonus_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
(
target_batch_size
,
num_target_probs
,
target_vocab_size
)
=
target_with_bonus_probs
.
shape
# Does not count the extra token
num_target_probs
-=
1
# validate the shape of draft token ids.
draft_token_ids_batch_size
,
num_draft_token_ids
=
draft_token_ids
.
shape
assert
draft_token_ids_batch_size
==
target_batch_size
assert
num_draft_token_ids
==
num_target_probs
# validate the shape of bonus token ids
bonus_batch_size
,
num_bonus_tokens
=
bonus_token_ids
.
shape
assert
bonus_batch_size
==
target_batch_size
assert
num_bonus_tokens
==
self
.
_num_bonus_tokens
# validate the shape of draft probs if it is set
if
draft_probs
is
not
None
:
(
draft_batch_size
,
num_draft_probs
,
draft_vocab_size
)
=
draft_probs
.
shape
assert
draft_batch_size
==
target_batch_size
assert
num_draft_probs
==
num_target_probs
assert
(
draft_vocab_size
==
target_vocab_size
),
f
"
{
draft_vocab_size
=
}
{
target_vocab_size
=
}
"
def
_raise_if_incorrect_dtype
(
self
,
target_with_bonus_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
assert
target_with_bonus_probs
.
dtype
==
self
.
probs_dtype
assert
draft_token_ids
.
dtype
==
self
.
token_id_dtype
assert
bonus_token_ids
.
dtype
==
self
.
token_id_dtype
if
draft_probs
is
not
None
:
assert
draft_probs
.
dtype
==
self
.
probs_dtype
def
_raise_if_inconsistent_device
(
self
,
target_with_bonus_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
devices
=
[
t
.
device
for
t
in
[
target_with_bonus_probs
,
bonus_token_ids
,
draft_probs
,
draft_token_ids
]
if
t
is
not
None
]
assert
all
([
devices
[
0
]
==
device
for
device
in
devices
])
def
_raise_if_out_of_bounds_vocab
(
self
,
vocab_size
:
int
,
draft_token_ids
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
)
->
None
:
assert
torch
.
all
(
bonus_token_ids
<
vocab_size
)
assert
torch
.
all
(
bonus_token_ids
>=
0
)
assert
torch
.
all
(
draft_token_ids
<
vocab_size
)
assert
torch
.
all
(
draft_token_ids
>=
0
)
class
SpecDecodeDeterministicBaseSampler
(
SpecDecodeBaseSampler
):
"""Base class for samplers used for Speculative Decoding verification
step which are deterministic.
"""
@
abstractmethod
def
forward
(
self
,
target_with_bonus_probs
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
class
SpecDecodeStochasticBaseSampler
(
SpecDecodeBaseSampler
):
"""Base class for samplers used for Speculative Decoding verification
step which are stochastic
"""
@
abstractmethod
def
forward
(
self
,
target_with_bonus_probs
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
seeded_seqs
:
Optional
[
dict
[
int
,
torch
.
Generator
]]
=
None
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
vllm/model_executor/layers/typical_acceptance_sampler.py
deleted
100644 → 0
View file @
9ffe905a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
import
torch.jit
from
vllm.model_executor.layers.spec_decode_base_sampler
import
(
SpecDecodeDeterministicBaseSampler
)
class
TypicalAcceptanceSampler
(
SpecDecodeDeterministicBaseSampler
):
"""Apply typical acceptance sampling as described in section 3.3.1 in
"MEDUSA: Simple LLM Inference Acceleration Framework with
Multiple Decoding Heads"
https://arxiv.org/pdf/2401.10774
"""
def
__init__
(
self
,
posterior_threshold
:
float
,
posterior_alpha
:
float
,
strict_mode
:
bool
=
False
,
):
"""Create a Typical Acceptance Sampler.
Args:
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
nontrivial latency.
posterior_threshold : A threshold value that sets a lower bound
on the posterior probability of a token in target model for it
to be accepted.
posterior_alpha : A scaling factor for the entropy-based
threshold in typical acceptance sampling.
"""
self
.
_posterior_threshold
=
posterior_threshold
self
.
_posterior_alpha
=
posterior_alpha
super
().
__init__
(
strict_mode
=
strict_mode
)
def
forward
(
self
,
target_with_bonus_probs
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""Sample token ids using typical acceptance sampling. This accepts
or rejects tokens proposed by the draft model using the probability
of each token according to the draft and target models.
In the worst case where all draft tokens are rejected, it is guaranteed
one token will be emitted.
In the case where all draft tokens are accepted, the bonus token will be
accepted.
Args:
target_probs: The probability distribution over token ids given
context according to the target model.
shape = [batch_size, num_speculative_tokens, vocab_size]
bonus_token_ids: The "bonus" token ids that are accepted iff all
speculative tokens in a sequence are accepted.
shape = [batch_size, num_bonus_tokens]
draft_probs: This parameter is unused by the acceptance sampler.
draft_token_ids: The token ids that were sampled from the draft
probabilities.
shape = [batch_size, num_speculative_tokens]
Returns:
output_token_ids: The token ids sampled via rejection sampling,
or -1 if unable to sample a token because the previous token
was rejected.
shape = [batch_size, num_speculative_tokens + num_bonus_tokens]
"""
# Only perform shape/dtype/device checking in strict mode, as it adds
# overhead.
if
self
.
_strict_mode
:
self
.
_raise_if_incorrect_input
(
target_with_bonus_probs
,
draft_token_ids
,
bonus_token_ids
)
target_probs
=
target_with_bonus_probs
[:,
:
-
1
]
accepted
=
self
.
_evaluate_accepted_tokens
(
target_probs
,
draft_token_ids
)
recovered_token_ids
=
self
.
_get_recovered_token_ids
(
target_probs
)
output_token_ids
=
self
.
_create_output
(
accepted
,
recovered_token_ids
,
draft_token_ids
,
bonus_token_ids
)
return
output_token_ids
def
_evaluate_accepted_tokens
(
self
,
target_probs
,
draft_token_ids
):
r
"""
Evaluates and returns a mask of accepted tokens based on the
posterior probabilities.
Args:
target_probs (torch.Tensor): A tensor of shape
(batch_size, k, vocab_size) representing the probabilities of
each token in the vocabulary for each position in the proposed
sequence. This is the distribution generated by the target
model.
draft_token_ids (torch.Tensor): A tensor of shape (batch_size, k)
representing the proposed token ids.
A draft token_id x_{n+k} is accepted if it satisfies the
following condition
$$
p_{\text{original}}(x_{n+k} | x_1, x_2, \dots, x_{n+k-1}) >
\min \left( \epsilon, \delta * \exp \left(
-H(p_{\text{original}}(
\cdot | x_1, x_2, \ldots, x_{n+k-1})) \right) \right)
$$
where $p_{\text{original}}$ corresponds to target_probs
and $\epsilon$ and $\delta$ correspond to hyperparameters
specified using self._posterior_threshold and self._posterior_alpha
This method computes the posterior probabilities for the given
draft token ids based on the provided target probabilities. It
calculates the entropy of the posterior distribution and determines
a dynamic threshold for each token position using the provided
posterior_threshold and posterior_alpha values. The method then
returns a boolean mask indicating which tokens can be accepted.
Returns:
torch.Tensor: A boolean tensor of shape (batch_size, k) where each
element indicates whether the corresponding draft token has
been accepted or rejected. True indicates acceptance and false
indicates rejection.
"""
device
=
target_probs
.
device
candidates_prob
=
torch
.
gather
(
target_probs
,
dim
=-
1
,
index
=
draft_token_ids
.
unsqueeze
(
-
1
)).
squeeze
(
-
1
)
# A small constant added to prevent computing the logarithm of zero,
# which can lead to undefined values.
epsilon
=
1e-5
posterior_entropy
=
-
torch
.
sum
(
target_probs
*
torch
.
log
(
target_probs
+
epsilon
),
dim
=-
1
)
threshold
=
torch
.
minimum
(
torch
.
ones_like
(
posterior_entropy
,
device
=
device
)
*
self
.
_posterior_threshold
,
torch
.
exp
(
-
posterior_entropy
)
*
self
.
_posterior_alpha
,
)
accepted_mask
=
candidates_prob
>
threshold
return
accepted_mask
def
_get_recovered_token_ids
(
self
,
target_probs
):
"""
The recovered token ids will fill the first unmatched token
by the target token.
Args:
target_probs (torch.Tensor): A tensor of shape
(batch_size, k, vocab_size) containing the target probability
distribution.
Returns:
torch.Tensor: A tensor of shape (batch_size, k) with the recovered
token ids which are selected from target probs.
"""
max_indices
=
torch
.
argmax
(
target_probs
,
dim
=-
1
)
return
max_indices
vllm/model_executor/models/eagle.py
deleted
100644 → 0
View file @
9ffe905a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Iterable
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.utils
import
maybe_prefix
logger
=
init_logger
(
__name__
)
class
DummyInputLayerNorm
(
nn
.
Module
):
def
__init__
(
self
,
weight
=
None
,
bias
=
None
):
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
weight
)
if
weight
is
not
None
else
None
self
.
bias
=
nn
.
Parameter
(
bias
)
if
bias
is
not
None
else
None
def
forward
(
self
,
x
):
return
x
class
DummyOutputNorm
(
nn
.
Module
):
def
forward
(
self
,
x
,
residual
):
if
residual
is
None
:
return
x
else
:
return
x
+
residual
,
None
class
EAGLE
(
nn
.
Module
):
"""This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077
Reference implementation: https://github.com/SafeAILab/EAGLE
Differences from reference implementation:
1. In reference, LlamaDecoderLayer implementation doesn't have
input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427).
Following this approach, our implementation also disables
the input_layernorm for the first decoder layer.
2. We allow any decoder layer to be used in EAGLE whereas in reference
decoder layer is fixed to be LlamaDecoderLayer.
3. We have an optional token_map which reduces draft vocab to most
frequently used tokens to give some additional speed-up by reducing
sampling overhead. This is disabled unless the checkpoint file has
explicit token_map tensor and config has an optional attribute
truncated_vocab_size < vocab_size. To use this technique, one has to find
the top-k most frequent tokens in target dataset and add that as a tensor
in the draft checkpoint (using key token_map). Also, the draft config
needs to have truncated_vocab_size (=k) as an attribute.
4. We allow an enhanced EAGLE architecture similar to the DeepSeek MTP
module with regards to the use of additional RMS norms. The original
EAGLE architecture 1) skips the pre-attention norm in its first
transformer block, and 2) skips the final output norm, both of which we
found to be suboptimal. We also add the support for separate norms
applying to both the token embedding and hidden states before projection
as in DeepSeek MTP, which we found to improve performance as well.
"""
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
self
.
dtype
=
vllm_config
.
model_config
.
dtype
self
.
config
=
config
architectures
=
getattr
(
self
.
config
.
model
,
"architectures"
,
[])
model_cls
,
_
=
ModelRegistry
.
resolve_model_cls
(
architectures
)
self
.
model
=
model_cls
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
fc
=
nn
.
Linear
(
config
.
model
.
hidden_size
*
2
,
config
.
model
.
hidden_size
,
bias
=
getattr
(
self
.
config
,
"eagle_fc_bias"
,
False
))
# Modify layer normalization and residual connections as suggested
# in the EAGLE framework: https://github.com/SafeAILab/EAGLE
# While weights and biases are generally not needed,
# they are retained here to support certain unit tests
# (e.g., spec_decode/e2e/test_eagle_correctness.py).
if
not
hasattr
(
self
.
config
.
model
,
"skip_prenorm"
)
or
self
.
config
.
model
.
skip_prenorm
:
self
.
model
.
model
.
layers
[
0
].
input_layernorm
=
DummyInputLayerNorm
(
weight
=
self
.
model
.
model
.
layers
[
0
].
input_layernorm
.
weight
)
if
not
hasattr
(
self
.
config
.
model
,
"skip_output_norm"
)
or
self
.
config
.
model
.
skip_output_norm
:
self
.
model
.
model
.
norm
=
DummyOutputNorm
()
self
.
add_para_norm
=
False
if
hasattr
(
self
.
config
.
model
,
"add_para_norm"
)
and
self
.
config
.
model
.
add_para_norm
:
self
.
enorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
hnorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
add_para_norm
=
True
self
.
orig_vocab_size
=
config
.
vocab_size
self
.
truncated_vocab_size
=
config
.
truncated_vocab_size
self
.
unpadded_vocab_size
=
self
.
truncated_vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
self
.
truncated_vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
,
)
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
self
.
truncated_vocab_size
,
logit_scale
)
# Token map is a idx to token mapping to reduce the vocab size for
# the draft model. Using smaller vocab size for draft, containing
# only most frequent tokens reduces the speculation overhead. This
# doesn't affect the acceptance rate much and thus gives more speed
# -up. By default, this is disabled and is only used if the EAGLE
# checkpoint file has token_map tensor.
self
.
token_map
=
None
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
previous_hidden_states
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
get_input_embeddings
(
input_ids
)
# Handle both empty previous_hidden_states
# and mismatched batch size
batch_size
=
inputs_embeds
.
size
(
0
)
if
previous_hidden_states
.
size
(
0
)
==
0
or
\
previous_hidden_states
.
size
(
0
)
!=
batch_size
:
hidden_dim
=
self
.
config
.
model
.
hidden_size
device
=
inputs_embeds
.
device
# Create zero tensor with matching batch size
previous_hidden_states
=
\
torch
.
zeros
(
batch_size
,
hidden_dim
,
device
=
device
)
if
self
.
add_para_norm
:
inputs_embeds
=
torch
.
cat
([
self
.
enorm
(
inputs_embeds
),
self
.
hnorm
(
previous_hidden_states
)
],
dim
=-
1
)
else
:
inputs_embeds
=
torch
.
cat
([
inputs_embeds
,
previous_hidden_states
],
dim
=-
1
)
inputs_embeds
=
self
.
fc
(
inputs_embeds
)
inputs_embeds
[
positions
==
0
]
=
0
# masking inputs at position=0
hidden_states
=
self
.
model
.
model
(
input_ids
=
None
,
inputs_embeds
=
inputs_embeds
,
positions
=
positions
,
intermediate_tensors
=
intermediate_tensors
,
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
if
self
.
token_map
is
not
None
:
_logits
=
logits
logits
=
-
torch
.
inf
*
torch
.
ones
(
size
=
(
*
_logits
.
shape
[:
-
1
],
self
.
orig_vocab_size
),
device
=
_logits
.
device
,
dtype
=
_logits
.
dtype
)
logits
[...,
self
.
token_map
]
=
_logits
return
logits
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
# This implementation is incompatible with https://huggingface.co/yuhuili/EAGLE-LLaMA3-Instruct-8B
# due to missing lm_head weights and its config being that of a
# Llama model. Here's a compatible version with the same weights:
# https://huggingface.co/abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm
# Also, here's an example script for converting trained EAGLE
# checkpoint to vLLM compatible version: https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d
model_weights
=
{}
for
name
,
loaded_weight
in
weights
:
if
name
==
"token_map"
:
if
self
.
config
.
truncated_vocab_size
<
self
.
config
.
vocab_size
:
self
.
token_map
=
nn
.
Parameter
(
loaded_weight
,
requires_grad
=
False
)
elif
name
.
startswith
(
"fc.weight"
):
weight_loader
=
getattr
(
self
.
fc
.
weight
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
self
.
fc
.
weight
,
loaded_weight
)
elif
name
.
startswith
(
"fc.bias"
):
if
self
.
fc
.
bias
is
not
None
:
weight_loader
=
getattr
(
self
.
fc
.
bias
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
self
.
fc
.
bias
,
loaded_weight
)
else
:
logger
.
warning_once
(
"Found bias in the loaded weights but "
"the model config doesn't have bias."
)
elif
name
.
startswith
(
"enorm.weight"
):
weight_loader
=
getattr
(
self
.
enorm
.
weight
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
self
.
enorm
.
weight
,
loaded_weight
)
elif
name
.
startswith
(
"hnorm.weight"
):
weight_loader
=
getattr
(
self
.
hnorm
.
weight
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
self
.
hnorm
.
weight
,
loaded_weight
)
elif
name
.
startswith
(
"model.lm_head."
)
or
name
.
startswith
(
"model.model."
):
model_weights
[
name
.
split
(
"model."
,
1
)[
-
1
]]
=
loaded_weight
elif
name
.
startswith
(
"lm_head."
)
or
name
.
startswith
(
"model."
):
model_weights
[
name
]
=
loaded_weight
else
:
model_weights
[
f
"model.
{
name
}
"
]
=
loaded_weight
if
"lm_head.weight"
in
model_weights
:
lm_head_weight
=
model_weights
.
pop
(
"lm_head.weight"
)
if
self
.
token_map
is
not
None
and
\
lm_head_weight
.
shape
[
0
]
>
self
.
token_map
.
shape
[
0
]:
lm_head_weight
=
lm_head_weight
[
self
.
token_map
]
else
:
# NOTE(Shangming): initialize the placeholder for lm_head weight.
lm_head_weight
=
torch
.
zeros
(
self
.
lm_head
.
org_vocab_size
,
self
.
lm_head
.
embedding_dim
,
dtype
=
self
.
dtype
,
)
weight_loader
=
getattr
(
self
.
lm_head
.
weight
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
self
.
lm_head
.
weight
,
lm_head_weight
)
self
.
model
.
load_weights
(
model_weights
.
items
())
vllm/model_executor/models/registry.py
View file @
dd572c0a
...
@@ -239,14 +239,15 @@ _MULTIMODAL_MODELS = {
...
@@ -239,14 +239,15 @@ _MULTIMODAL_MODELS = {
_SPECULATIVE_DECODING_MODELS
=
{
_SPECULATIVE_DECODING_MODELS
=
{
"MiMoMTPModel"
:
(
"mimo_mtp"
,
"MiMoMTP"
),
"MiMoMTPModel"
:
(
"mimo_mtp"
,
"MiMoMTP"
),
"EAGLEModel"
:
(
"eagle"
,
"EAGLE"
),
"EagleLlamaForCausalLM"
:
(
"llama_eagle"
,
"EagleLlamaForCausalLM"
),
"EagleLlamaForCausalLM"
:
(
"llama_eagle"
,
"EagleLlamaForCausalLM"
),
"EagleLlama4ForCausalLM"
:
(
"llama4_eagle"
,
"EagleLlama4ForCausalLM"
),
"EagleLlama4ForCausalLM"
:
(
"llama4_eagle"
,
"EagleLlama4ForCausalLM"
),
"EagleMiniCPMForCausalLM"
:
(
"minicpm_eagle"
,
"EagleMiniCPMForCausalLM"
),
"EagleMiniCPMForCausalLM"
:
(
"minicpm_eagle"
,
"EagleMiniCPMForCausalLM"
),
"Eagle3LlamaForCausalLM"
:
(
"llama_eagle3"
,
"Eagle3LlamaForCausalLM"
),
"Eagle3LlamaForCausalLM"
:
(
"llama_eagle3"
,
"Eagle3LlamaForCausalLM"
),
"DeepSeekMTPModel"
:
(
"deepseek_mtp"
,
"DeepSeekMTP"
),
"DeepSeekMTPModel"
:
(
"deepseek_mtp"
,
"DeepSeekMTP"
),
"MedusaModel"
:
(
"medusa"
,
"Medusa"
),
"MedusaModel"
:
(
"medusa"
,
"Medusa"
),
"MLPSpeculatorPreTrainedModel"
:
(
"mlp_speculator"
,
"MLPSpeculator"
),
# Temporarily disabled.
# # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
# "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
}
}
_TRANSFORMERS_MODELS
=
{
_TRANSFORMERS_MODELS
=
{
...
...
vllm/platforms/cuda.py
View file @
dd572c0a
...
@@ -132,14 +132,10 @@ class CudaPlatformBase(Platform):
...
@@ -132,14 +132,10 @@ class CudaPlatformBase(Platform):
parallel_config
.
worker_cls
=
\
parallel_config
.
worker_cls
=
\
"vllm.worker.multi_step_worker.MultiStepWorker"
"vllm.worker.multi_step_worker.MultiStepWorker"
elif
vllm_config
.
speculative_config
:
elif
vllm_config
.
speculative_config
:
if
envs
.
VLLM_USE_V1
:
if
not
envs
.
VLLM_USE_V1
:
parallel_config
.
worker_cls
=
\
raise
NotImplementedError
(
"vllm.v1.worker.gpu_worker.Worker"
"Speculative decoding is not supported on vLLM V0."
)
else
:
parallel_config
.
worker_cls
=
"vllm.v1.worker.gpu_worker.Worker"
parallel_config
.
worker_cls
=
\
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
parallel_config
.
sd_worker_cls
=
\
"vllm.worker.worker.Worker"
else
:
else
:
if
envs
.
VLLM_USE_V1
:
if
envs
.
VLLM_USE_V1
:
parallel_config
.
worker_cls
=
\
parallel_config
.
worker_cls
=
\
...
...
vllm/platforms/rocm.py
View file @
dd572c0a
...
@@ -326,15 +326,10 @@ class RocmPlatform(Platform):
...
@@ -326,15 +326,10 @@ class RocmPlatform(Platform):
parallel_config
.
worker_cls
=
\
parallel_config
.
worker_cls
=
\
"vllm.worker.multi_step_worker.MultiStepWorker"
"vllm.worker.multi_step_worker.MultiStepWorker"
elif
vllm_config
.
speculative_config
:
elif
vllm_config
.
speculative_config
:
if
envs
.
VLLM_USE_V1
:
if
not
envs
.
VLLM_USE_V1
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"Speculative decoding is not yet supported on vLLM V1."
"Speculative decoding is not supported on vLLM V0."
)
)
parallel_config
.
worker_cls
=
"vllm.v1.worker.gpu_worker.Worker"
else
:
parallel_config
.
worker_cls
=
\
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
parallel_config
.
sd_worker_cls
=
\
"vllm.worker.worker.Worker"
else
:
else
:
if
envs
.
VLLM_USE_V1
:
if
envs
.
VLLM_USE_V1
:
parallel_config
.
worker_cls
=
\
parallel_config
.
worker_cls
=
\
...
...
vllm/sequence.py
View file @
dd572c0a
...
@@ -112,13 +112,6 @@ class RequestMetrics:
...
@@ -112,13 +112,6 @@ class RequestMetrics:
model_execute_time: The time spent in the model execute function. This
model_execute_time: The time spent in the model execute function. This
will include model forward, block/sync across
will include model forward, block/sync across
workers, cpu-gpu sync time and sampling time.
workers, cpu-gpu sync time and sampling time.
spec_token_acceptance_counts: number of accepted speculative tokens at
each position; the first token is from
the target model and is always accepted;
e.g., when it's [10, 8, 4, 2] for a req,
it means there were 10 forward passes in
total, and there were 8, 4, 2 accepted
tokens at 1st, 2nd, 3rd speculation step.
"""
"""
arrival_time
:
float
arrival_time
:
float
last_token_time
:
float
last_token_time
:
float
...
@@ -129,7 +122,6 @@ class RequestMetrics:
...
@@ -129,7 +122,6 @@ class RequestMetrics:
scheduler_time
:
Optional
[
float
]
=
None
scheduler_time
:
Optional
[
float
]
=
None
model_forward_time
:
Optional
[
float
]
=
None
model_forward_time
:
Optional
[
float
]
=
None
model_execute_time
:
Optional
[
float
]
=
None
model_execute_time
:
Optional
[
float
]
=
None
spec_token_acceptance_counts
:
Optional
[
list
[
int
]]
=
None
class
SequenceDataDelta
(
class
SequenceDataDelta
(
...
@@ -748,9 +740,7 @@ class SequenceGroup:
...
@@ -748,9 +740,7 @@ class SequenceGroup:
last_token_time
=
arrival_time
,
last_token_time
=
arrival_time
,
first_scheduled_time
=
None
,
first_scheduled_time
=
None
,
first_token_time
=
None
,
first_token_time
=
None
,
time_in_queue
=
None
,
time_in_queue
=
None
)
spec_token_acceptance_counts
=
[
0
]
*
draft_size
)
self
.
last_token_latency
=
0.0
self
.
last_token_latency
=
0.0
self
.
lora_request
=
lora_request
self
.
lora_request
=
lora_request
self
.
prompt_logprobs
:
Optional
[
PromptLogprobs
]
=
None
self
.
prompt_logprobs
:
Optional
[
PromptLogprobs
]
=
None
...
@@ -1390,8 +1380,6 @@ class ExecuteModelRequest(
...
@@ -1390,8 +1380,6 @@ class ExecuteModelRequest(
previous_hidden_states
:
Optional
[
HiddenStates
]
=
None
previous_hidden_states
:
Optional
[
HiddenStates
]
=
None
# The number of forward steps to run.
# The number of forward steps to run.
num_steps
:
int
=
1
num_steps
:
int
=
1
# The step index for spec model input.
spec_step_idx
:
Optional
[
int
]
=
None
# Finished request ids since last step.
# Finished request ids since last step.
finished_requests_ids
:
list
[
str
]
=
msgspec
.
field
(
default_factory
=
list
)
finished_requests_ids
:
list
[
str
]
=
msgspec
.
field
(
default_factory
=
list
)
# The last sampled token ids for multi step decoding.
# The last sampled token ids for multi step decoding.
...
...
vllm/spec_decode/__init__.py
deleted
100644 → 0
View file @
9ffe905a
vllm/spec_decode/batch_expansion.py
deleted
100644 → 0
View file @
9ffe905a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
array
import
array
from
itertools
import
chain
,
count
from
typing
import
Iterator
,
List
,
Optional
,
Tuple
import
torch
from
vllm
import
SamplingParams
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
(
VLLM_INVALID_TOKEN_ID
,
VLLM_TOKEN_ID_ARRAY_TYPE
,
ExecuteModelRequest
,
SequenceData
,
SequenceGroupMetadata
,
get_all_seq_ids
)
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeScorer
,
SpeculativeScores
)
from
vllm.spec_decode.util
import
nvtx_range
,
split_batch_by_proposal_len
SeqId
=
int
TargetSeqId
=
int
TokenId
=
int
DEFAULT_SIMPLE_SAMPLING_PARAMS
=
SamplingParams
()
class
BatchExpansionTop1Scorer
(
SpeculativeScorer
):
"""Implements a speculative scorer that uses batch expansion to get
probabilities of speculative tokens according to the scoring model.
Batch expansion converts a list of sequences and multiple query positions
to a new batch of sequences, each with a single query position. This allows
for MQA-like scoring in speculative decoding without requiring an MQA
kernel.
It is strictly less efficient than MQA scoring.
It only supports scoring the top1 proposal tokens of the proposer, instead
of topk/tree.
"""
@
nvtx_range
(
"BatchExpansionTop1Scorer.score_proposals"
)
def
score_proposals
(
self
,
execute_model_req
:
ExecuteModelRequest
,
proposals
:
SpeculativeProposals
,
)
->
SpeculativeScores
:
"""Score the proposed tokens via the scorer model.
This converts each input sequence to a set of k+1 target sequences. The
target sequences have the unique continuations to be scored and a
unique sequence ID that is different from all input sequence ids.
If a speculative sequence length would exceed the max model length, then
no speculation is produced for that sequence.
Args:
execute_model_req: The execution request.
proposals: The speculative proposals to score.
Returns:
SpeculativeScores: The scores of each speculative token, along with
which sequences were ignored during scoring.
"""
# TODO(cade) perform this on GPU to remove blocking call.
proposal_lens_list
=
proposals
.
proposal_lens
.
tolist
()
proposal_token_ids_list
=
proposals
.
proposal_token_ids
.
tolist
()
# Filter the list to ignore invalid proposals.
proposal_token_ids_list_without_skips
=
[
proposals
for
proposals
in
proposal_token_ids_list
if
VLLM_INVALID_TOKEN_ID
not
in
proposals
]
(
spec_indices
,
non_spec_indices
,
target_seq_group_metadata_list
,
num_scoring_tokens
)
=
self
.
_expand_batch
(
seq_group_metadata_list
=
execute_model_req
.
seq_group_metadata_list
,
proposal_token_ids_list
=
proposal_token_ids_list_without_skips
,
proposal_lens_list
=
proposal_lens_list
,
)
target_sampler_output
=
self
.
_scorer_worker
.
execute_model
(
execute_model_req
=
execute_model_req
.
clone
(
seq_group_metadata_list
=
target_seq_group_metadata_list
))
assert
len
(
target_sampler_output
)
==
1
,
"expected single-step output"
target_sampler_output
=
target_sampler_output
[
0
]
if
not
non_spec_indices
:
# All sequence groups in batch have spec decoding enabled
return
self
.
_contract_batch_all_spec
(
target_sampler_output
=
target_sampler_output
,
proposals
=
proposals
,
)
else
:
# Batch has a mix of spec decode enabled and disabled seq groups
return
self
.
_contract_batch
(
execute_model_req
.
seq_group_metadata_list
,
target_sampler_output
=
target_sampler_output
,
proposals
=
proposals
,
num_scoring_tokens
=
num_scoring_tokens
,
non_spec_indices
=
non_spec_indices
,
spec_indices
=
spec_indices
,
k
=
execute_model_req
.
num_lookahead_slots
,
)
def
_expand_batch
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
proposal_token_ids_list
:
List
[
List
[
TokenId
]],
proposal_lens_list
:
List
[
int
],
)
->
Tuple
[
List
[
int
],
List
[
int
],
List
[
SequenceGroupMetadata
],
int
]:
"""Given the input sequences and potentially multiple corresponding
proposal tokens, create a new batch where each sequence has a single
query token.
"""
# vLLM currently only supports proposal lens equal to zero or the batch
# proposal len. This adds some complexity (splitting the batch into spec
# and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens.
(
spec_seqs
,
spec_indices
),
(
non_spec_seqs
,
non_spec_indices
)
=
\
split_batch_by_proposal_len
(
seq_group_metadata_list
,
proposal_lens_list
)
spec_expanded_seqs
=
self
.
_create_scoring_model_input
(
seq_group_metadata_list
=
spec_seqs
,
proposal_token_ids
=
proposal_token_ids_list
,
# NOTE: We determine the seq ids in the expanded batch using the
# full seq_group_metadata_list, instead of only spec_seqs.
target_seq_ids_iter
=
self
.
_create_target_seq_id_iterator
(
seq_ids
=
get_all_seq_ids
(
seq_group_metadata_list
)),
)
num_scoring_tokens
=
len
(
spec_expanded_seqs
)
# Batch speculative and non-speculative (e.g. chunked prefill) requests
# but make sure order is prefill|decode due to backend requirement.
target_seq_group_metadata_list
=
non_spec_seqs
+
spec_expanded_seqs
return
(
spec_indices
,
non_spec_indices
,
target_seq_group_metadata_list
,
num_scoring_tokens
)
def
_contract_non_speculative
(
self
,
scores
:
SpeculativeScores
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
non_spec_indices
:
List
[
int
],
non_spec_outputs
:
SpeculativeScores
,
has_prompt_log
:
bool
)
->
SpeculativeScores
:
"""
Augment input `scores` with non-speculative requests outputs.
This includes decode requests with speculation turned off, as well
as prefill requests when `enable_chunked_prefill` is set.
For the latter, prefills are further separated into terminal and
non-terminal chunks (from which no token is sampled).
"""
if
not
non_spec_indices
:
return
scores
if
has_prompt_log
:
# When prompt_logprobs is enabled, prefills yield output token
# (and respective prob) in the last entry (prompt|out):
# [.|.|.|prefill0_out|.|prefill1_out|decode0_out|..].
# With chunked prefill, non-terminal chunks have -1 on each
# position: they're still picked, but they're discarded later.
seq_meta
=
seq_group_metadata_list
nospec_sizes
=
torch
.
tensor
([
seq_meta
[
i
].
token_chunk_size
if
seq_meta
[
i
].
is_prompt
else
1
for
i
in
non_spec_indices
])
nospec_sampled_token_idxs
=
torch
.
cumsum
(
nospec_sizes
,
0
).
add_
(
-
1
)
else
:
# In this case only sampled tokens are returned, select all.
nospec_sampled_token_idxs
=
list
(
range
(
len
(
non_spec_outputs
.
token_ids
)))
scores
.
token_ids
[
non_spec_indices
,
:
1
]
=
\
non_spec_outputs
.
token_ids
[
nospec_sampled_token_idxs
].
unsqueeze
(
1
)
scores
.
probs
[
non_spec_indices
,
:
1
,
:]
=
\
non_spec_outputs
.
probs
[
nospec_sampled_token_idxs
].
unsqueeze
(
1
)
scores
.
logprobs
[
non_spec_indices
,
:
1
,
:]
=
\
non_spec_outputs
.
logprobs
[
nospec_sampled_token_idxs
].
unsqueeze
(
1
)
if
scores
.
hidden_states
is
not
None
:
assert
non_spec_outputs
.
hidden_states
is
not
None
scores
.
hidden_states
[
non_spec_indices
,
:
1
,
:]
=
\
non_spec_outputs
.
hidden_states
[
nospec_sampled_token_idxs
].
unsqueeze
(
1
)
return
scores
def
_contract_batch
(
self
,
contracted_seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
target_sampler_output
:
SamplerOutput
,
proposals
:
SpeculativeProposals
,
num_scoring_tokens
:
int
,
non_spec_indices
:
List
[
int
],
spec_indices
:
List
[
int
],
k
:
int
)
->
SpeculativeScores
:
"""Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original
sequences.
contracted_bs is the original batch size, and the batch size that the
target_sampler_output will be contracted to.
"""
contracted_bs
=
len
(
contracted_seq_group_metadata_list
)
(
target_token_ids
,
target_probs
,
target_logprobs
,
target_hidden_states
,
non_spec_target_token_ids
,
non_spec_target_probs
,
non_spec_target_logprobs
,
non_spec_target_hidden_states
)
=
self
.
_split_scoring_output
(
target_sampler_output
,
num_scoring_tokens
)
# Map distinct sequences used to score each token
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
expanded_batch_size
,
k
=
proposals
.
proposal_token_ids
.
shape
# The number of tokens in the expanded batch used for speculation is
# equal to the total expanded batch size minus the number of samples for
# non-speculative sequences, prefill chunks with no out tokens included
non_spec_expanded_bs
=
len
(
non_spec_indices
)
spec_expanded_bs
=
expanded_batch_size
-
non_spec_expanded_bs
target_token_ids
=
target_token_ids
.
reshape
(
spec_expanded_bs
,
k
+
1
)
target_probs
=
target_probs
.
reshape
(
*
target_token_ids
.
shape
,
self
.
_vocab_size
)
target_logprobs
=
target_logprobs
.
reshape
(
target_probs
.
shape
)
if
target_hidden_states
is
not
None
:
target_hidden_states
=
target_hidden_states
.
reshape
(
*
target_token_ids
.
shape
,
target_hidden_states
.
shape
[
-
1
])
all_tokens
=
target_token_ids
.
new_full
(
size
=
(
contracted_bs
,
k
+
1
),
fill_value
=-
1
)
all_probs
=
target_probs
.
new_zeros
(
*
all_tokens
.
shape
,
self
.
_vocab_size
)
all_logprobs
=
target_logprobs
.
new_full
(
size
=
all_probs
.
shape
,
fill_value
=-
float
(
"inf"
))
if
target_sampler_output
.
hidden_states
is
not
None
:
all_hidden_states
=
target_hidden_states
.
new_zeros
(
size
=
(
contracted_bs
,
k
+
1
,
target_hidden_states
.
shape
[
-
1
]))
else
:
all_hidden_states
=
None
has_prompt_log
=
any
((
sg
.
sampling_params
.
prompt_logprobs
and
sg
.
sampling_params
.
prompt_logprobs
>
0
)
for
sg
in
contracted_seq_group_metadata_list
)
# When prompt logprobs is enabled, lens of returned tensors go from
# n_sampled (requests with do_sample=True) to n_prompt+n_prefills.
# We adjust stride accordingly to get the generated tokens and
# their probs, but pass on prompt_logprobs as is.
prompt_logprobs
=
None
if
(
not
self
.
_scorer_worker
.
model_runner
.
disable_logprobs
\
and
has_prompt_log
):
prompt_logprobs
=
[
o
.
prompt_logprobs
for
o
in
target_sampler_output
.
outputs
]
elif
not
has_prompt_log
:
# When prompt logprobs are not to be returned,
# we can ignore non-terminal chunks (no out token).
non_spec_indices
=
[
idx
for
idx
in
non_spec_indices
if
contracted_seq_group_metadata_list
[
idx
].
do_sample
]
# "Contract" speculative.
if
spec_indices
:
all_tokens
[
spec_indices
]
=
target_token_ids
all_probs
[
spec_indices
]
=
target_probs
all_logprobs
[
spec_indices
]
=
target_logprobs
if
all_hidden_states
is
not
None
:
all_hidden_states
[
spec_indices
]
=
target_hidden_states
spec_scores
=
SpeculativeScores
(
probs
=
all_probs
,
token_ids
=
all_tokens
,
logprobs
=
all_logprobs
,
hidden_states
=
all_hidden_states
,
prompt_logprobs
=
prompt_logprobs
)
non_spec_outputs
=
SpeculativeScores
(
probs
=
non_spec_target_probs
,
token_ids
=
non_spec_target_token_ids
,
logprobs
=
non_spec_target_logprobs
,
hidden_states
=
non_spec_target_hidden_states
)
# Contract remaining nonspec entries based on non_spec_indices, if any.
return
self
.
_contract_non_speculative
(
spec_scores
,
contracted_seq_group_metadata_list
,
non_spec_indices
,
non_spec_outputs
,
has_prompt_log
)
def
_contract_batch_all_spec
(
self
,
target_sampler_output
:
SamplerOutput
,
proposals
:
SpeculativeProposals
,
)
->
SpeculativeScores
:
"""Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original
sequences.
It assumes all sequences in the batch were previously expanded.
"""
# Map distinct sequences used to score each token
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
contracted_bs
,
k
=
proposals
.
proposal_token_ids
.
shape
# Reshape tensors to original batch size
target_token_ids
=
target_sampler_output
.
sampled_token_ids
.
reshape
(
contracted_bs
,
k
+
1
)
target_probs
=
target_sampler_output
.
sampled_token_probs
.
reshape
(
*
target_token_ids
.
shape
,
self
.
_vocab_size
)
target_logprobs
=
target_sampler_output
.
logprobs
.
reshape
(
target_probs
.
shape
)
target_hidden_states
=
target_sampler_output
.
hidden_states
if
target_hidden_states
is
not
None
:
target_hidden_states
=
target_hidden_states
.
reshape
(
*
target_token_ids
.
shape
,
target_hidden_states
.
shape
[
-
1
])
return
SpeculativeScores
(
probs
=
target_probs
,
token_ids
=
target_token_ids
,
logprobs
=
target_logprobs
,
hidden_states
=
target_hidden_states
,
prompt_logprobs
=
None
)
def
_create_scoring_model_input
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
proposal_token_ids
:
List
[
List
[
TokenId
]],
# shape: [batch_size, k]
target_seq_ids_iter
:
Iterator
[
TargetSeqId
],
)
->
List
[
SequenceGroupMetadata
]:
"""Given the original input sequences and proposed tokens from the draft
model, create a list of target sequences that can be used for scoring.
target_seq_ids_iter provides sequence ids for the expanded batch,
fulfilling the requirement that no seq id in the expanded batch is equal
to the seq id in the original batch.
"""
if
not
seq_group_metadata_list
:
return
[]
target_seq_group_metadata
=
list
(
chain
.
from_iterable
(
self
.
_create_target_seq_group_metadata
(
seq_group_metadata
,
proposal_token_ids
,
i
,
target_seq_ids_iter
,
)
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
)))
return
target_seq_group_metadata
def
_create_target_seq_group_metadata
(
self
,
input_seq_group_metadata
:
SequenceGroupMetadata
,
proposal_token_ids
:
List
[
List
[
TokenId
]],
# shape: [batch_size, k]
batch_index
:
int
,
target_seq_ids_iter
:
Iterator
[
TargetSeqId
],
)
->
List
[
SequenceGroupMetadata
]:
"""Given an input sequence group metadata and a list of draft tokens,
create a list of target SequenceGroupMetadata, one for each
token id that needs to be scored.
Naive speculative decoding requires K target model scores, one for each
draft model token. However one can add a bonus token such that if each
token is accepted, then a final token may be sampled from the model.
This function creates K+1 target SequenceGroupMetadata to take
advantage of the bonus token.
"""
assert
len
(
input_seq_group_metadata
.
seq_data
)
==
1
,
(
"Beam search "
"not supported in speculative decoding"
)
input_seq_id
=
next
(
iter
(
input_seq_group_metadata
.
seq_data
.
keys
()))
token_ids_to_score
=
self
.
_get_token_ids_to_score
(
proposal_token_ids
[
batch_index
])
sampling_params
=
input_seq_group_metadata
.
sampling_params
target_seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
for
i
,
token_ids
in
enumerate
(
token_ids_to_score
):
target_seq_group_metadata_list
.
append
(
self
.
_create_single_target_seq_group_metadata
(
input_seq_group_metadata
,
input_seq_id
,
next
(
target_seq_ids_iter
),
token_ids
,
sampling_params
=
sampling_params
,
))
return
target_seq_group_metadata_list
@
staticmethod
def
_create_single_target_seq_group_metadata
(
seq_group_metadata
:
SequenceGroupMetadata
,
seq_id
:
SeqId
,
target_seq_id
:
TargetSeqId
,
token_ids
:
List
[
TokenId
],
sampling_params
:
SamplingParams
,
)
->
SequenceGroupMetadata
:
"""Create a single target SequenceGroupMetadata.
Args:
seq_group_metadata: The metadata for the input sequence.
seq_id: The input sequence ID.
target_seq_id: The corresponding target sequence ID.
token_ids: The list of token ids that are to be appended to the
input sequence.
"""
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
prompt_token_ids
=
seq_data
.
prompt_token_ids_array
new_output_token_ids
=
[
*
seq_data
.
get_output_token_ids
(),
*
token_ids
]
mrope_position_delta
=
seq_data
.
mrope_position_delta
new_seq_data_dict
=
{
target_seq_id
:
SequenceData
(
prompt_token_ids
,
_output_token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
new_output_token_ids
),
),
}
# This is a hack. Technically, spec decoding should compute
# num_lookahead slots at one shot, but instead, it expands the batch
# and evaluate one by one right now. context_len is seq_len - 1 because
# the kv cache is filled by a previous batch in the batch expansion.
for
data
in
new_seq_data_dict
.
values
():
data
.
update_num_computed_tokens
(
data
.
get_len
()
-
1
)
data
.
mrope_position_delta
=
mrope_position_delta
return
SequenceGroupMetadata
(
request_id
=
seq_group_metadata
.
request_id
,
is_prompt
=
seq_group_metadata
.
is_prompt
,
seq_data
=
new_seq_data_dict
,
sampling_params
=
sampling_params
,
block_tables
=
{
target_seq_id
:
seq_group_metadata
.
block_tables
[
seq_id
],
},
lora_request
=
None
,
token_chunk_size
=
1
,
)
@
staticmethod
def
_split_scoring_output
(
sampler_output
:
SamplerOutput
,
num_scoring_tokens
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""Split the target model output into speculative and non-speculative
output.
"""
# vLLM currently only supports proposal lens equal to zero or the batch
# proposal len. This adds some complexity (splitting the batch into spec
# and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens.
#
# First samples are non-speculative, latter samples are from speculative
# scoring (prefill|decode order).
split_sizes
=
(
sampler_output
.
sampled_token_ids
.
numel
()
-
num_scoring_tokens
,
num_scoring_tokens
)
(
non_spec_probs
,
spec_probs
)
=
sampler_output
.
sampled_token_probs
.
split
(
split_sizes
)
(
non_spec_sampled_tokens
,
spec_sampled_tokens
)
=
sampler_output
.
sampled_token_ids
.
flatten
().
split
(
split_sizes
)
(
non_spec_logprobs
,
spec_logprobs
)
=
sampler_output
.
logprobs
.
split
(
split_sizes
)
if
sampler_output
.
hidden_states
is
not
None
:
(
non_spec_hidden_states
,
spec_hidden_states
)
=
sampler_output
.
hidden_states
.
split
(
split_sizes
)
else
:
non_spec_hidden_states
,
spec_hidden_states
=
None
,
None
return
(
spec_sampled_tokens
,
spec_probs
,
spec_logprobs
,
spec_hidden_states
,
non_spec_sampled_tokens
,
non_spec_probs
,
non_spec_logprobs
,
non_spec_hidden_states
)
@
staticmethod
def
_create_target_seq_id_iterator
(
seq_ids
:
List
[
SeqId
])
->
Iterator
[
TargetSeqId
]:
"""Create an iterator for creating target sequence ids.
Target sequence ids are distinct from sequence ids because we create a
distinct target sequence id for each proposal token to be scored.
This implementation increments a counter starting at 1 + max of all
provided input sequence ids.
"""
return
count
(
start
=
max
(
seq_ids
)
+
1
)
@
staticmethod
def
_get_token_ids_to_score
(
full_spec_token_ids
:
List
[
TokenId
]
# shape: [k]
)
->
List
[
List
[
TokenId
]]:
"""Given an int tensor of proposal token ids, return a list of
token ids that should be scored.
Returns k+1 output lists. The additional one is used for generating the
bonus token.
Example:
Input: [0, 1, 2, 3] (k=4)
Output: (k+1 lists)
[]
[0]
[0, 1]
[0, 1, 2]
[0, 1, 2, 3]
"""
empty_token_ids
:
List
[
TokenId
]
=
[]
token_ids_to_score
=
[
empty_token_ids
]
token_ids_to_score
.
extend
(
full_spec_token_ids
[:
i
+
1
]
for
i
in
range
(
len
(
full_spec_token_ids
)))
return
token_ids_to_score
vllm/spec_decode/draft_model_runner.py
deleted
100644 → 0
View file @
9ffe905a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
List
,
Optional
import
torch
from
vllm.forward_context
import
set_forward_context
from
vllm.model_executor.layers.sampler
import
SamplerOutput
try
:
try
:
from
vllm.attention.backends.flash_attn
import
FlashAttentionMetadata
except
(
ModuleNotFoundError
,
ImportError
):
# vllm_flash_attn is not installed, try the ROCm FA metadata
from
vllm.attention.backends.rocm_flash_attn
import
(
ROCmFlashAttentionMetadata
as
FlashAttentionMetadata
)
except
(
ModuleNotFoundError
,
ImportError
)
as
err
:
raise
RuntimeError
(
"Draft model speculative decoding currently only supports "
"CUDA and ROCm flash attention backend."
)
from
err
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MultiModalKwargs
from
vllm.sequence
import
ExecuteModelRequest
,
IntermediateTensors
from
vllm.worker.model_runner_base
import
(
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerWrapperBase
)
logger
=
init_logger
(
__name__
)
# A flag to enable debug prints for the updated input tensors
# before each step.
debug_advance_input
=
False
# A flag to allow GPU advance step for draft model runner.
# Set to False for debugging.
allow_gpu_advance_step
=
True
class
TP1DraftModelRunner
(
ModelRunnerWrapperBase
):
"""Specialized model runner for speculative decoding draft model.
Since the draft model always execute k forward passes consecutively to
generate k speculative tokens in a single speculative decoding step,
we could get rid of most CPU-GPU synchronization and data transfer
overheads by keeping model input and output tensors on GPU all the time.
TODOs:
1. Currently supports only flash-attn, add support for other attn_backends.
2. Support TP > 1 (this requires some designs because we do not expect
any broadcasting inside execute_model).
"""
def
__init__
(
self
,
model_runner
:
ModelRunnerBase
):
super
().
__init__
(
model_runner
)
self
.
indices_of_seq_with_bonus_tokens
=
None
def
_update_sampling_metadata
(
self
,
sampling_metadata
,
num_seqs
,
num_queries
):
assert
sampling_metadata
.
num_prompts
==
0
assert
len
(
sampling_metadata
.
seq_groups
)
==
num_queries
assert
sampling_metadata
.
selected_token_indices
.
shape
==
(
num_queries
,
)
# assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501
# Verify that all sequences are decodes
for
i
in
range
(
num_queries
):
seq_group
=
sampling_metadata
.
seq_groups
[
i
]
assert
seq_group
.
is_prompt
is
False
# No prompt
assert
seq_group
.
prompt_logprob_indices
==
[]
# No prompt
assert
seq_group
.
sample_indices
==
[
i
]
# Simple
def
_gpu_advance_step
(
self
,
model_input
:
ModelRunnerInputBase
,
last_output
:
SamplerOutput
)
->
ModelRunnerInputBase
:
# Currently, we expect "decode mode" only
assert
not
model_input
.
is_prompt
# Get num_seqs
num_seqs
=
len
(
model_input
.
seq_lens
)
num_queries
=
len
(
model_input
.
query_lens
)
# Get output tokens GPU tensor
sampled_token_ids
=
last_output
.
sampled_token_ids
assert
sampled_token_ids
is
not
None
# Update attn_metadata
attn_metadata
=
model_input
.
attn_metadata
assert
isinstance
(
attn_metadata
,
FlashAttentionMetadata
)
attn_metadata
.
advance_step
(
model_input
,
sampled_token_ids
,
self
.
block_size
,
num_seqs
,
num_queries
)
# Update sampling_metadata
sampling_metadata
=
model_input
.
sampling_metadata
self
.
_update_sampling_metadata
(
sampling_metadata
,
num_seqs
,
num_queries
)
# Create new input
new_model_input
=
self
.
_model_input_cls
(
input_tokens
=
model_input
.
input_tokens
,
input_positions
=
model_input
.
input_positions
,
attn_metadata
=
attn_metadata
,
seq_lens
=
attn_metadata
.
seq_lens
,
query_lens
=
model_input
.
query_lens
,
lora_mapping
=
model_input
.
lora_mapping
,
lora_requests
=
model_input
.
lora_requests
,
multi_modal_kwargs
=
model_input
.
multi_modal_kwargs
,
sampling_metadata
=
model_input
.
sampling_metadata
,
is_prompt
=
False
,
)
# Ensure we skip CPU samples
assert
new_model_input
.
sampling_metadata
.
skip_sampler_cpu_output
is
True
# We can reuse sampling tensors since every decode iteration is the same
new_model_input
.
sampling_metadata
.
reuse_sampling_tensors
=
True
if
debug_advance_input
:
logger
.
debug
(
"NEW INPUT: "
)
logger
.
debug
(
" input_tokens = %s"
,
new_model_input
.
input_tokens
)
logger
.
debug
(
" input_positions = %s"
,
new_model_input
.
input_positions
)
logger
.
debug
(
" seq_lens = %d"
,
new_model_input
.
seq_lens
)
logger
.
debug
(
" query_lens = %d"
,
new_model_input
.
query_lens
)
logger
.
debug
(
" attn_metadata:"
)
logger
.
debug
(
" seq_lens_tensor: %s"
,
attn_metadata
.
seq_lens_tensor
)
logger
.
debug
(
" slot_mapping: %s"
,
attn_metadata
.
slot_mapping
)
logger
.
debug
(
" block_tables: %s"
,
attn_metadata
.
block_tables
)
return
new_model_input
def
supports_gpu_multi_step
(
self
,
execute_model_req
:
ExecuteModelRequest
):
"""Determines if draft_model_runner GPU multi-step can be used.
Currently required conditions are:
1. Only decodes
2. Only flash-attn
3. No LORA
4. No prompt_adapter_config
"""
if
not
allow_gpu_advance_step
:
return
False
# We allow multi-step GPU only in decode mode
for
seq_group
in
execute_model_req
.
seq_group_metadata_list
:
if
seq_group
.
is_prompt
:
return
False
# TODO: Add support for other attn backends
if
self
.
attn_backend
.
get_name
()
not
in
(
"FLASH_ATTN"
,
):
return
False
# TODO: Add support for LORA
if
self
.
lora_config
:
return
False
# TODO: Add soft-tuning prompt adapter support
return
not
self
.
prompt_adapter_config
def
set_indices_of_seq_with_bonus_tokens
(
self
,
indices_of_seq_with_bonus_tokens
):
self
.
indices_of_seq_with_bonus_tokens
=
indices_of_seq_with_bonus_tokens
@
torch
.
inference_mode
()
def
execute_model
(
self
,
model_input
:
ModelRunnerInputBase
,
kv_caches
:
List
[
torch
.
Tensor
],
previous_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
num_steps
:
int
=
1
,
**
kwargs
,
)
->
Optional
[
List
[
SamplerOutput
]]:
"""Executes num_steps forward passes with advacement of input tensors
on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions.
Optimizations used:
1. Input tensors are updated on the GPU directly
2. Skips GPU=>CPU serialization of sampler outputs (we don't need
them since we do batch expansion later that uses GPU outputs)
3. Reuses sampling tensors (since we run only decodes and they have
a repeating sampling logic)
"""
# When num_steps == 1, we execute the fallback here for the GPU
# advance_step, which runs prepare_inputs on CPU and for each spec
# iteration invokes this function only once
# (Look at multi-step-worker code)
is_fallback
=
num_steps
==
1
if
not
is_fallback
:
# Since we do not broadcast data inside execute_model anymore,
# we need to figure out the best way to support TP > 1 in this
# case, because we will at least need to broadcast the sampled
# tokens to all workers.
if
not
self
.
is_driver_worker
:
raise
ValueError
(
"TP1DraftModelRunner only supports TP=1."
)
# Sanity
if
self
.
lora_config
is
not
None
:
raise
ValueError
(
"TP1DraftModelRunner has no support for LORA"
)
if
self
.
prompt_adapter_config
is
not
None
:
raise
ValueError
(
"TP1DraftModelRunner has no support for "
"prompt_adapter_config"
)
if
model_input
.
inputs_embeds
is
not
None
:
raise
ValueError
(
"TP1DraftModelRunner has no support for "
"inputs_embeds"
)
if
model_input
.
multi_modal_kwargs
:
raise
ValueError
(
"TP1DraftModelRunner has no support for multi_modal_kwargs"
)
else
:
if
self
.
lora_config
:
assert
model_input
.
lora_requests
is
not
None
assert
model_input
.
lora_mapping
is
not
None
self
.
set_active_loras
(
model_input
.
lora_requests
,
model_input
.
lora_mapping
)
if
self
.
prompt_adapter_config
:
assert
model_input
.
prompt_adapter_requests
is
not
None
assert
model_input
.
prompt_adapter_mapping
is
not
None
self
.
set_active_prompt_adapters
(
model_input
.
prompt_adapter_requests
,
model_input
.
prompt_adapter_mapping
)
self
.
attn_state
.
begin_forward
(
model_input
)
# Detect exec mode
assert
model_input
.
attn_metadata
is
not
None
use_cuda_graph
=
False
if
model_input
.
attn_metadata
.
num_prefills
>
0
:
# In this case, execute_model(..) was called directly
if
num_steps
>
1
:
raise
ValueError
(
"execute_model(..) of draft_model_runner can be called "
"directly only with a single-step prefill"
)
else
:
# We can skip CPU samples for spec token generation.
# (We do allow CPU samples for num_steps == 1 to support the
# fallback case, where supports_gpu_multi_step(..) does not pass)
model_input
.
sampling_metadata
.
skip_sampler_cpu_output
=
(
not
is_fallback
)
# Attn attr defines if we use cuda graphs
use_cuda_graph
=
model_input
.
attn_metadata
.
use_cuda_graph
# Get model
if
use_cuda_graph
:
if
model_input
.
inputs_embeds
is
None
:
graph_batch_size
=
model_input
.
input_tokens
.
shape
[
0
]
model_executable
=
(
self
.
graph_runners
[
model_input
.
virtual_engine
][(
graph_batch_size
,
False
)])
else
:
graph_batch_size
=
model_input
.
inputs_embeds
.
shape
[
0
]
model_executable
=
(
self
.
graph_runners
[
model_input
.
virtual_engine
][(
graph_batch_size
,
True
)])
if
previous_hidden_states
is
not
None
:
hidden_states
=
torch
.
cat
([
previous_hidden_states
,
torch
.
empty
([
graph_batch_size
-
previous_hidden_states
.
shape
[
0
],
*
previous_hidden_states
.
shape
[
1
:]
],
dtype
=
previous_hidden_states
.
dtype
,
device
=
previous_hidden_states
.
device
)
])
else
:
hidden_states
=
None
else
:
model_executable
=
self
.
model
hidden_states
=
previous_hidden_states
outputs
:
List
[
SamplerOutput
]
=
[]
for
step
in
range
(
num_steps
):
multi_modal_kwargs
=
model_input
.
multi_modal_kwargs
or
{}
model_execute_kwargs
=
{
"previous_hidden_states"
:
hidden_states
}
\
if
previous_hidden_states
is
not
None
else
{}
compute_logits_kwargs
=
{}
# Run model
if
hasattr
(
self
.
model
.
config
,
"num_nextn_predict_layers"
):
# for DeepSeek MTP only to use the corresponding layer for
# each step
spec_step_idx
=
kwargs
.
get
(
"spec_step_idx"
,
step
)
model_execute_kwargs
[
"spec_step_idx"
]
=
spec_step_idx
compute_logits_kwargs
[
"spec_step_idx"
]
=
spec_step_idx
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
):
hidden_states
=
model_executable
(
input_ids
=
model_input
.
input_tokens
,
inputs_embeds
=
None
,
positions
=
model_input
.
input_positions
,
intermediate_tensors
=
intermediate_tensors
,
**
MultiModalKwargs
.
as_kwargs
(
multi_modal_kwargs
,
device
=
self
.
device
,
),
**
model_execute_kwargs
,
)
# Compute the logits.
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
model_input
.
sampling_metadata
,
**
compute_logits_kwargs
)
if
not
self
.
is_driver_worker
:
return
[]
# Sample the next token.
output
=
self
.
model_runner
.
sampler
(
logits
=
logits
,
sampling_metadata
=
model_input
.
sampling_metadata
,
)
outputs
.
append
(
output
)
if
self
.
return_hidden_states
and
is_fallback
:
if
use_cuda_graph
:
indices
=
model_input
.
sampling_metadata
\
.
selected_token_indices
output
.
hidden_states
=
hidden_states
[:
len
(
indices
)]
else
:
output
.
hidden_states
=
hidden_states
if
model_input
.
attn_metadata
.
num_prefills
==
0
\
and
self
.
indices_of_seq_with_bonus_tokens
is
not
None
:
assert
output
.
sampled_token_ids
is
not
None
# output.sampled_token_ids should be of shape (num_seqs, 1)
nums_seqs
,
num_tokens_per_seq
=
output
.
sampled_token_ids
.
shape
assert
num_tokens_per_seq
==
1
count
=
0
for
i
in
range
(
nums_seqs
):
bonus_seq_idx
=
self
.
indices_of_seq_with_bonus_tokens
[
count
]
if
i
!=
bonus_seq_idx
:
# The following might cause a cpu->gpu sync
# However, the performance impact is negligible as we
# benchmarked on H100.
output
.
sampled_token_ids
[
i
,
:]
=
model_input
.
input_tokens
[
bonus_seq_idx
]
else
:
count
+=
1
# Prepare inputs for the next step
if
step
!=
num_steps
-
1
:
model_input
=
self
.
_gpu_advance_step
(
model_input
,
outputs
[
-
1
])
return
outputs
vllm/spec_decode/interfaces.py
deleted
100644 → 0
View file @
9ffe905a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Set
,
Union
import
torch
from
vllm.sequence
import
ExecuteModelRequest
,
PromptLogprobs
from
vllm.worker.worker_base
import
WorkerBase
@
dataclass
class
SpeculativeProposals
:
"""Datastructure used to represent proposal tokens from some proposer. It
also tracks how many speculative tokens each sequence has.
"""
# Speculative proposal tokens.
proposal_token_ids
:
torch
.
Tensor
# Probabilities of the proposal tokens according to the proposer.
proposal_probs
:
torch
.
Tensor
# The valid length of each proposal; can be zero.
proposal_lens
:
torch
.
Tensor
# A flag to mark that there's no available proposals
no_proposals
:
bool
=
False
def
__repr__
(
self
):
return
(
f
"SpeculativeProposals("
f
"proposal_token_ids=
{
self
.
proposal_token_ids
}
, "
f
"proposal_probs=
{
self
.
proposal_probs
.
shape
}
, "
f
"proposal_lens=
{
self
.
proposal_lens
}
)"
)
@
dataclass
class
SpeculativeScores
:
"""Datastructure used to represent the scores of speculative tokens
according to the scoring model.
"""
# Probabilities of the speculative tokens according to the scoring model.
probs
:
torch
.
Tensor
# Log-probabilities of the speculative tokens according to the scoring
# model. These values can be used to generate Logprob objects that are
# returned to the user.
logprobs
:
torch
.
Tensor
# Token ids sampled from the scoring model. Used for speculative bonus
# tokens and also non-speculative normal decoding.
token_ids
:
torch
.
Tensor
# Optional last hidden states from the scoring model.
hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
# Scoring model may also return logprobs for prompt tokens
# for each request, when chunked prefill is enabled.
prompt_logprobs
:
Optional
[
List
[
PromptLogprobs
]]
=
None
def
__repr__
(
self
):
return
(
f
"SpeculativeScores("
f
"probs=
{
self
.
probs
.
shape
}
, "
f
"token_ids=
{
self
.
token_ids
.
shape
}
)"
)
class
SpeculativeProposer
(
ABC
):
@
abstractmethod
def
get_spec_proposals
(
self
,
execute_model_req
:
ExecuteModelRequest
,
# If set, this contains all sequence IDs that were assigned
# bonus tokens in their last forward pass.
seq_ids_with_bonus_token_in_last_step
:
Set
[
int
],
)
->
SpeculativeProposals
:
raise
NotImplementedError
class
SpeculativeScorer
(
ABC
):
def
__init__
(
self
,
scorer_worker
:
WorkerBase
,
device
:
Union
[
torch
.
device
,
str
],
vocab_size
:
int
):
self
.
_scorer_worker
=
scorer_worker
if
isinstance
(
device
,
torch
.
device
):
device
=
device
.
type
self
.
_device
=
device
self
.
_vocab_size
=
vocab_size
@
abstractmethod
def
score_proposals
(
self
,
execute_model_req
:
ExecuteModelRequest
,
proposals
:
SpeculativeProposals
,
)
->
SpeculativeScores
:
raise
NotImplementedError
vllm/spec_decode/medusa_worker.py
deleted
100644 → 0
View file @
9ffe905a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
weakref
from
typing
import
List
,
Optional
,
Set
,
Tuple
import
torch
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
SequenceGroupMetadata
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.proposer_worker_base
import
NonLLMProposerWorkerBase
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.worker.worker_base
import
DelegateWorkerBase
class
MedusaWorker
(
NonLLMProposerWorkerBase
,
DelegateWorkerBase
):
"""Worker for Medusa.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
DelegateWorkerBase
.
__init__
(
self
,
*
args
,
**
kwargs
)
# Lazy initialization list.
self
.
_proposer
:
Top1Proposer
def
init_device
(
self
):
self
.
worker
.
init_device
()
self
.
_proposer
=
Top1Proposer
(
weakref
.
proxy
(
self
),
# type: ignore[arg-type]
self
.
device
,
self
.
vocab_size
,
max_proposal_len
=
self
.
max_model_len
,
)
def
set_include_gpu_probs_tensor
(
self
):
pass
def
set_should_modify_greedy_probs_inplace
(
self
):
pass
@
torch
.
inference_mode
()
def
sampler_output
(
self
,
execute_model_req
:
ExecuteModelRequest
,
sample_len
:
int
,
# Unused parameter.
seq_ids_with_bonus_token_in_last_step
:
Set
[
int
],
)
->
Tuple
[
List
[
SamplerOutput
],
bool
]:
"""Run the model forward pass to generate sample_len future tokens.
Returns the list of sampler output, one per layer, along with indicator
of whether torch tensor in sampler output need to be transposed in
latter sampler_output_to_torch logic.
For medusa worker, this indicator shall be False.
"""
self
.
_raise_if_unsupported
(
execute_model_req
)
seq_group_metadata_list
=
execute_model_req
.
seq_group_metadata_list
seq_lens
,
query_lens
=
self
.
_prepare_input_tensors
(
seq_group_metadata_list
)
generators
=
self
.
model_runner
.
get_generators
(
execute_model_req
.
finished_requests_ids
)
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_lens
,
query_lens
,
self
.
device
,
self
.
model_runner
.
pin_memory
,
generators
)
model_outputs
=
self
.
model_runner
.
model
.
generate_proposals
(
previous_hidden_states
=
execute_model_req
.
previous_hidden_states
.
hidden_states
,
sampling_metadata
=
sampling_metadata
)
return
model_outputs
,
False
def
_prepare_input_tensors
(
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]],
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
if
not
seq_group_metadata_list
:
return
[],
[]
seq_lens
:
List
[
int
]
=
[]
query_lens
:
List
[
int
]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
is_prompt
=
seq_group_metadata
.
is_prompt
for
seq_data
in
seq_group_metadata
.
seq_data
.
values
():
seq_data_len
=
seq_data
.
get_len
()
if
is_prompt
:
context_len
=
seq_data
.
get_num_computed_tokens
()
seq_len
=
min
(
seq_data_len
,
context_len
+
seq_group_metadata
.
token_chunk_size
)
seq_lens
.
append
(
seq_len
)
query_lens
.
append
(
seq_len
-
context_len
)
else
:
seq_lens
.
append
(
seq_data_len
)
query_lens
.
append
(
1
)
return
seq_lens
,
query_lens
def
get_spec_proposals
(
self
,
execute_model_req
:
ExecuteModelRequest
,
seq_ids_with_bonus_token_in_last_step
:
Set
[
int
],
)
->
SpeculativeProposals
:
"""Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len.
"""
return
self
.
_proposer
.
get_spec_proposals
(
execute_model_req
,
seq_ids_with_bonus_token_in_last_step
)
def
_raise_if_unsupported
(
self
,
execute_model_req
:
ExecuteModelRequest
,
)
->
None
:
"""MedusaWorker does not yet implement support for cache swap
operations or beam search.
"""
if
any
([
execute_model_req
.
blocks_to_swap_in
,
execute_model_req
.
blocks_to_swap_out
,
execute_model_req
.
blocks_to_copy
]):
raise
NotImplementedError
(
"MedusaWorker does not support cache operations"
)
if
any
(
len
(
seq_group_metadata
.
seq_data
.
keys
())
!=
1
for
seq_group_metadata
in
execute_model_req
.
seq_group_metadata_list
):
raise
NotImplementedError
(
"MedusaWorker does not support beam search."
)
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment