Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1bd3ae33
"vscode:/vscode.git/clone" did not exist on "73ff872db0d4e3f5e133d5d2a5307248619d93a6"
Commit
1bd3ae33
authored
Oct 11, 2025
by
zhuwenwen
Browse files
skip silu_mul_fp8_quant_deep_gemm_cuda and remove zero_overhead
parent
9bf1b213
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
31 additions
and
3272 deletions
+31
-3272
vllm/config/model.py
vllm/config/model.py
+4
-0
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+2
-7
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+1
-2
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+0
-3
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+19
-19
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+4
-17
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+1
-2
vllm/zero_overhead/llm_engine.py
vllm/zero_overhead/llm_engine.py
+0
-655
vllm/zero_overhead/model_runner.py
vllm/zero_overhead/model_runner.py
+0
-171
vllm/zero_overhead/sampler.py
vllm/zero_overhead/sampler.py
+0
-500
vllm/zero_overhead/sequence.py
vllm/zero_overhead/sequence.py
+0
-64
vllm/zero_overhead/spec_decode/batch_expansion.py
vllm/zero_overhead/spec_decode/batch_expansion.py
+0
-141
vllm/zero_overhead/spec_decode/muti_step_worker.py
vllm/zero_overhead/spec_decode/muti_step_worker.py
+0
-137
vllm/zero_overhead/spec_decode/spec_decode_worker.py
vllm/zero_overhead/spec_decode/spec_decode_worker.py
+0
-565
vllm/zero_overhead/spec_decode/top1_proproser.py
vllm/zero_overhead/spec_decode/top1_proproser.py
+0
-84
vllm/zero_overhead/stop_check.py
vllm/zero_overhead/stop_check.py
+0
-77
vllm/zero_overhead/tokenizer.py
vllm/zero_overhead/tokenizer.py
+0
-84
vllm/zero_overhead/utils.py
vllm/zero_overhead/utils.py
+0
-71
vllm/zero_overhead/v1/core.py
vllm/zero_overhead/v1/core.py
+0
-357
vllm/zero_overhead/v1/eagle.py
vllm/zero_overhead/v1/eagle.py
+0
-316
No files found.
vllm/config/model.py
View file @
1bd3ae33
...
@@ -276,6 +276,9 @@ class ModelConfig:
...
@@ -276,6 +276,9 @@ class ModelConfig:
override_pooler_config
:
Optional
[
Union
[
dict
,
PoolerConfig
]]
=
None
override_pooler_config
:
Optional
[
Union
[
dict
,
PoolerConfig
]]
=
None
"""[DEPRECATED] Use `pooler_config` instead. This field will be removed in
"""[DEPRECATED] Use `pooler_config` instead. This field will be removed in
v0.12.0 or v1.0.0, whichever is sooner."""
v0.12.0 or v1.0.0, whichever is sooner."""
enable_chunked_prefill
:
Optional
[
bool
]
=
None
"""If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens."""
# Multimodal config and init vars
# Multimodal config and init vars
multimodal_config
:
Optional
[
MultiModalConfig
]
=
None
multimodal_config
:
Optional
[
MultiModalConfig
]
=
None
...
@@ -320,6 +323,7 @@ class ModelConfig:
...
@@ -320,6 +323,7 @@ class ModelConfig:
factors
.
append
(
self
.
rope_scaling
)
factors
.
append
(
self
.
rope_scaling
)
factors
.
append
(
self
.
rope_theta
)
factors
.
append
(
self
.
rope_theta
)
factors
.
append
(
self
.
video_pruning_rate
)
factors
.
append
(
self
.
video_pruning_rate
)
factors
.
append
(
self
.
enable_chunked_prefill
)
# hf_config can control how the model looks!
# hf_config can control how the model looks!
try
:
try
:
...
...
vllm/entrypoints/llm.py
View file @
1bd3ae33
...
@@ -56,7 +56,6 @@ from vllm.v1.engine.llm_engine import LLMEngine
...
@@ -56,7 +56,6 @@ from vllm.v1.engine.llm_engine import LLMEngine
from
vllm.v1.sample.logits_processor
import
LogitsProcessor
from
vllm.v1.sample.logits_processor
import
LogitsProcessor
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.zero_overhead.llm_engine
import
ZeroOverheadEngine
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -300,12 +299,8 @@ class LLM:
...
@@ -300,12 +299,8 @@ class LLM:
log_non_default_args
(
engine_args
)
log_non_default_args
(
engine_args
)
# Create the Engine (autoselects V0 vs V1)
# Create the Engine (autoselects V0 vs V1)
if
envs
.
VLLM_ZERO_OVERHEAD
:
self
.
llm_engine
=
LLMEngine
.
from_engine_args
(
self
.
llm_engine
=
ZeroOverheadEngine
.
from_engine_args
(
engine_args
=
engine_args
,
usage_context
=
UsageContext
.
LLM_CLASS
)
engine_args
=
engine_args
,
usage_context
=
UsageContext
.
LLM_CLASS
)
else
:
self
.
llm_engine
=
LLMEngine
.
from_engine_args
(
engine_args
=
engine_args
,
usage_context
=
UsageContext
.
LLM_CLASS
)
self
.
engine_class
=
type
(
self
.
llm_engine
)
self
.
engine_class
=
type
(
self
.
llm_engine
)
self
.
request_counter
=
Counter
()
self
.
request_counter
=
Counter
()
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
1bd3ae33
...
@@ -1840,8 +1840,7 @@ class FusedMoE(CustomOp):
...
@@ -1840,8 +1840,7 @@ class FusedMoE(CustomOp):
num_expert_group
=
num_expert_group
,
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
,
topk_group
=
topk_group
,
scoring_func
=
scoring_func
,
scoring_func
=
scoring_func
,
routed_scaling_factor
=
routed_scaling_factor
,
routed_scaling_factor
=
routed_scaling_factor
)
e_score_correction_bias
=
e_score_correction_bias
)
if
indices_type
is
not
None
:
if
indices_type
is
not
None
:
topk_ids
=
topk_ids
.
to
(
dtype
=
indices_type
)
topk_ids
=
topk_ids
.
to
(
dtype
=
indices_type
)
elif
e_score_correction_bias
is
not
None
:
elif
e_score_correction_bias
is
not
None
:
...
...
vllm/v1/engine/core.py
View file @
1bd3ae33
...
@@ -16,7 +16,6 @@ from typing import Any, Callable, Optional, TypeVar, Union
...
@@ -16,7 +16,6 @@ from typing import Any, Callable, Optional, TypeVar, Union
import
msgspec
import
msgspec
from
vllm
import
envs
from
vllm
import
envs
from
vllm.zero_overhead.v1.core
import
engine_core_step
import
zmq
import
zmq
from
vllm.config
import
ParallelConfig
,
VllmConfig
from
vllm.config
import
ParallelConfig
,
VllmConfig
...
@@ -277,8 +276,6 @@ class EngineCore:
...
@@ -277,8 +276,6 @@ class EngineCore:
Returns tuple of outputs and a flag indicating whether the model
Returns tuple of outputs and a flag indicating whether the model
was executed.
was executed.
"""
"""
if
envs
.
VLLM_ZERO_OVERHEAD
:
return
engine_core_step
(
self
)
# Check for any requests remaining in the scheduler - unfinished,
# Check for any requests remaining in the scheduler - unfinished,
# or finished and not yet removed from the batch.
# or finished and not yet removed from the batch.
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
1bd3ae33
...
@@ -1458,25 +1458,25 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -1458,25 +1458,25 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# [0, 1, 2, 5, 6, 9]
# [0, 1, 2, 5, 6, 9]
target_logits_indices
+=
arange
target_logits_indices
+=
arange
if
envs
.
VLLM_ZERO_OVERHEAD
:
#
if envs.VLLM_ZERO_OVERHEAD:
cu_num_draft_tokens
=
torch
.
from_numpy
(
cu_num_draft_tokens
).
pin_memory
().
to
(
#
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).pin_memory().to(
self
.
device
,
non_blocking
=
True
)
#
self.device, non_blocking=True)
logits_indices
=
torch
.
from_numpy
(
logits_indices
).
pin_memory
().
to
(
self
.
device
,
#
logits_indices = torch.from_numpy(logits_indices).pin_memory().to(self.device,
non_blocking
=
True
)
#
non_blocking=True)
target_logits_indices
=
torch
.
from_numpy
(
target_logits_indices
).
pin_memory
().
to
(
#
target_logits_indices = torch.from_numpy(target_logits_indices).pin_memory().to(
self
.
device
,
non_blocking
=
True
)
#
self.device, non_blocking=True)
bonus_logits_indices
=
torch
.
from_numpy
(
bonus_logits_indices
).
pin_memory
().
to
(
#
bonus_logits_indices = torch.from_numpy(bonus_logits_indices).pin_memory().to(
self
.
device
,
non_blocking
=
True
)
#
self.device, non_blocking=True)
else
:
#
else:
# TODO: Optimize the CPU -> GPU copy.
# TODO: Optimize the CPU -> GPU copy.
cu_num_draft_tokens
=
torch
.
from_numpy
(
cu_num_draft_tokens
).
to
(
cu_num_draft_tokens
=
torch
.
from_numpy
(
cu_num_draft_tokens
).
to
(
self
.
device
,
non_blocking
=
True
)
self
.
device
,
non_blocking
=
True
)
logits_indices
=
torch
.
from_numpy
(
logits_indices
).
to
(
self
.
device
,
logits_indices
=
torch
.
from_numpy
(
logits_indices
).
to
(
self
.
device
,
non_blocking
=
True
)
non_blocking
=
True
)
target_logits_indices
=
torch
.
from_numpy
(
target_logits_indices
).
to
(
target_logits_indices
=
torch
.
from_numpy
(
target_logits_indices
).
to
(
self
.
device
,
non_blocking
=
True
)
self
.
device
,
non_blocking
=
True
)
bonus_logits_indices
=
torch
.
from_numpy
(
bonus_logits_indices
).
to
(
bonus_logits_indices
=
torch
.
from_numpy
(
bonus_logits_indices
).
to
(
self
.
device
,
non_blocking
=
True
)
self
.
device
,
non_blocking
=
True
)
# Compute the draft token ids.
# Compute the draft token ids.
...
...
vllm/v1/worker/gpu_worker.py
View file @
1bd3ae33
...
@@ -34,8 +34,6 @@ from vllm.v1.utils import report_usage_stats
...
@@ -34,8 +34,6 @@ from vllm.v1.utils import report_usage_stats
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
from
vllm.v1.worker.utils
import
is_residual_scattered_for_sp
from
vllm.v1.worker.utils
import
is_residual_scattered_for_sp
from
vllm.v1.worker.worker_base
import
WorkerBase
from
vllm.v1.worker.worker_base
import
WorkerBase
from
vllm.zero_overhead.utils
import
zero_overhead_stream
from
vllm.zero_overhead.v1.gpu_model_runner
import
V1ZeroModelRunner
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -200,13 +198,8 @@ class Worker(WorkerBase):
...
@@ -200,13 +198,8 @@ class Worker(WorkerBase):
f
"Not support device type:
{
self
.
device_config
.
device
}
"
)
f
"Not support device type:
{
self
.
device_config
.
device
}
"
)
# Construct the model runner
# Construct the model runner
if
envs
.
VLLM_ZERO_OVERHEAD
:
self
.
model_runner
:
GPUModelRunner
=
GPUModelRunner
(
logger
.
info
(
'use zero overhead model_runner'
)
self
.
vllm_config
,
self
.
device
)
self
.
model_runner
:
GPUModelRunner
=
V1ZeroModelRunner
(
self
.
vllm_config
,
self
.
device
)
else
:
self
.
model_runner
:
GPUModelRunner
=
GPUModelRunner
(
self
.
vllm_config
,
self
.
device
)
if
self
.
rank
==
0
:
if
self
.
rank
==
0
:
# If usage stat is enabled, collect relevant info.
# If usage stat is enabled, collect relevant info.
...
@@ -451,14 +444,8 @@ class Worker(WorkerBase):
...
@@ -451,14 +444,8 @@ class Worker(WorkerBase):
all_gather_group
=
get_tp_group
(),
all_gather_group
=
get_tp_group
(),
all_gather_tensors
=
all_gather_tensors
))
all_gather_tensors
=
all_gather_tensors
))
if
envs
.
VLLM_ZERO_OVERHEAD
:
output
=
self
.
model_runner
.
execute_model
(
scheduler_output
,
use_stream
=
zero_overhead_stream
(
self
.
device
)
intermediate_tensors
)
with
torch
.
cuda
.
stream
(
use_stream
):
output
=
self
.
model_runner
.
execute_model
(
scheduler_output
,
intermediate_tensors
)
else
:
output
=
self
.
model_runner
.
execute_model
(
scheduler_output
,
intermediate_tensors
)
if
isinstance
(
output
,
(
ModelRunnerOutput
,
AsyncModelRunnerOutput
)):
if
isinstance
(
output
,
(
ModelRunnerOutput
,
AsyncModelRunnerOutput
)):
return
output
return
output
...
...
vllm/worker/worker_base.py
View file @
1bd3ae33
...
@@ -52,8 +52,7 @@ class WorkerBase:
...
@@ -52,8 +52,7 @@ class WorkerBase:
different hardware. Also abstracts control plane communication, e.g., to
different hardware. Also abstracts control plane communication, e.g., to
communicate request metadata to other workers.
communicate request metadata to other workers.
"""
"""
# TODO
model_input
:
Optional
[
ModelRunnerInputBase
]
=
None
tree_decoding
=
(
os
.
environ
.
get
(
'VLLM_TREE_DECODING'
)
==
'1'
)
tree_decoding
=
(
os
.
environ
.
get
(
'VLLM_TREE_DECODING'
)
==
'1'
)
def
__init__
(
def
__init__
(
...
...
vllm/zero_overhead/llm_engine.py
deleted
100644 → 0
View file @
9bf1b213
from
functools
import
partial
import
os
import
queue
import
threading
import
traceback
from
typing
import
Callable
,
Dict
,
List
,
Mapping
,
Optional
,
Type
,
Union
from
zlib
import
ZLIB_VERSION
import
torch
from
vllm
import
envs
from
vllm.config
import
DecodingConfig
,
ObservabilityConfig
,
VllmConfig
from
vllm.core.scheduler
import
ScheduledSequenceGroup
from
vllm.engine.llm_engine
import
_LOCAL_LOGGING_INTERVAL_SEC
,
LLMEngine
,
SchedulerContext
,
SchedulerOutputState
from
vllm.engine.metrics_types
import
StatLoggerBase
from
vllm.engine.output_processor.interfaces
import
SequenceGroupOutputProcessor
from
vllm.logger
import
init_logger
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.inputs
import
INPUT_REGISTRY
from
vllm.inputs.data
import
ProcessorInputs
from
vllm.inputs.parse
import
split_enc_dec_inputs
from
vllm.inputs.preprocess
import
InputPreprocessor
from
vllm.inputs.registry
import
InputRegistry
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.registry
import
MultiModalRegistry
from
vllm.outputs
import
PoolingRequestOutput
,
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
ExecuteModelRequest
,
ParallelSampleSequenceGroup
,
SequenceGroup
,
SequenceGroupBase
,
SequenceGroupMetadata
from
vllm.tracing
import
init_tracer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm.usage.usage_lib
import
UsageContext
,
is_usage_stats_enabled
from
vllm.utils
import
resolve_obj_by_qualname
,
weak_bind
,
Counter
from
vllm.zero_overhead.sampler
import
SampleRecorder
,
get_last_sampler
from
vllm.zero_overhead.sequence
import
ZeroOverheadSequence
from
vllm.zero_overhead.stop_check
import
ZeroOverheadStopChecker
from
vllm.zero_overhead.tokenizer
import
ZeroOverheadDetokenizer
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
usage_message
)
from
vllm.profiler.prof
import
profile
from
vllm.zero_overhead.utils
import
SpecStepKind
,
get_accepted_token_ids
,
get_spec_step
,
is_zero_no_thread
,
set_spec_step
,
zero_overhead_stream
logger
=
init_logger
(
__name__
)
class
ZeroOverheadEngine
(
LLMEngine
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
executor_class
:
Type
[
ExecutorBase
],
log_stats
:
bool
,
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
stat_loggers
:
Optional
[
Dict
[
str
,
StatLoggerBase
]]
=
None
,
mm_registry
:
MultiModalRegistry
=
MULTIMODAL_REGISTRY
,
use_cached_outputs
:
bool
=
False
,
)
->
None
:
if
envs
.
VLLM_USE_V1
:
raise
ValueError
(
"Using V0 LLMEngine, but envs.VLLM_USE_V1=True. "
"This should not happen. As a workaround, try using "
"LLMEngine.from_vllm_config(...) or explicitly set "
"VLLM_USE_V1=0 or 1 and report this issue on Github."
)
self
.
vllm_config
=
vllm_config
self
.
model_config
=
vllm_config
.
model_config
self
.
cache_config
=
vllm_config
.
cache_config
self
.
lora_config
=
vllm_config
.
lora_config
self
.
parallel_config
=
vllm_config
.
parallel_config
self
.
scheduler_config
=
vllm_config
.
scheduler_config
self
.
device_config
=
vllm_config
.
device_config
self
.
speculative_config
=
vllm_config
.
speculative_config
# noqa
self
.
load_config
=
vllm_config
.
load_config
self
.
decoding_config
=
vllm_config
.
decoding_config
or
DecodingConfig
(
# noqa
)
self
.
prompt_adapter_config
=
vllm_config
.
prompt_adapter_config
# noqa
self
.
observability_config
=
vllm_config
.
observability_config
or
ObservabilityConfig
(
# noqa
)
logger
.
info
(
"Initializing a V0 LLM engine (v%s) with config: %s, "
"use_cached_outputs=%s, "
,
VLLM_VERSION
,
vllm_config
,
use_cached_outputs
,
)
self
.
log_stats
=
log_stats
self
.
use_cached_outputs
=
use_cached_outputs
self
.
thread_running
=
False
if
not
self
.
model_config
.
skip_tokenizer_init
:
self
.
tokenizer
=
self
.
_init_tokenizer
()
self
.
detokenizer
=
ZeroOverheadDetokenizer
(
self
.
tokenizer
)
tokenizer_group
=
self
.
get_tokenizer_group
()
else
:
self
.
tokenizer
=
None
self
.
detokenizer
=
None
tokenizer_group
=
None
# Ensure that the function doesn't contain a reference to self,
# to avoid engine GC issues
def
get_tokenizer_for_seq
(
sequence
:
ZeroOverheadSequence
)
->
AnyTokenizer
:
assert
tokenizer_group
,
(
"tokenizer_group cannot be None, "
"make sure skip_tokenizer_init is False"
)
return
tokenizer_group
.
get_lora_tokenizer
(
sequence
.
lora_request
)
self
.
seq_counter
=
Counter
()
self
.
generation_config_fields
=
(
self
.
model_config
.
try_get_generation_config
())
self
.
input_preprocessor
=
InputPreprocessor
(
self
.
model_config
,
self
.
tokenizer
,
mm_registry
)
self
.
model_executor
=
executor_class
(
vllm_config
=
vllm_config
,
)
if
self
.
model_config
.
runner_type
!=
"pooling"
:
self
.
_initialize_kv_caches
()
# If usage stat is enabled, collect relevant info.
if
is_usage_stats_enabled
():
from
vllm.model_executor.model_loader
import
(
get_architecture_class_name
)
usage_message
.
report_usage
(
get_architecture_class_name
(
self
.
model_config
),
usage_context
,
extra_kvs
=
{
# Common configuration
"dtype"
:
str
(
self
.
model_config
.
dtype
),
"tensor_parallel_size"
:
self
.
parallel_config
.
tensor_parallel_size
,
"block_size"
:
self
.
cache_config
.
block_size
,
"gpu_memory_utilization"
:
self
.
cache_config
.
gpu_memory_utilization
,
# Quantization
"quantization"
:
self
.
model_config
.
quantization
,
"kv_cache_dtype"
:
str
(
self
.
cache_config
.
cache_dtype
),
# Feature flags
"enable_lora"
:
bool
(
self
.
lora_config
),
"enable_prompt_adapter"
:
bool
(
self
.
prompt_adapter_config
),
"enable_prefix_caching"
:
self
.
cache_config
.
enable_prefix_caching
,
"enforce_eager"
:
self
.
model_config
.
enforce_eager
,
"disable_custom_all_reduce"
:
self
.
parallel_config
.
disable_custom_all_reduce
,
})
self
.
cached_scheduler_outputs
=
[
SchedulerOutputState
()
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
self
.
scheduler_contexts
=
[
SchedulerContext
(
multi_step_stream_outputs
=
self
.
scheduler_config
.
multi_step_stream_outputs
)
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
if
self
.
model_config
.
use_async_output_proc
:
process_model_outputs
=
weak_bind
(
self
.
_process_model_outputs
)
self
.
async_callbacks
=
[
partial
(
process_model_outputs
,
ctx
=
self
.
scheduler_contexts
[
v_id
])
for
v_id
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
else
:
self
.
async_callbacks
=
[]
# Currently used by AsyncLLMEngine to ensure quick append
# of request outputs to asyncio queues
self
.
process_request_outputs_callback
:
Optional
[
Callable
]
=
None
# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
if
isinstance
(
self
.
vllm_config
.
scheduler_config
.
scheduler_cls
,
str
):
Scheduler
=
resolve_obj_by_qualname
(
self
.
vllm_config
.
scheduler_config
.
scheduler_cls
)
else
:
Scheduler
=
self
.
vllm_config
.
scheduler_config
.
scheduler_cls
self
.
scheduler
=
[
Scheduler
(
self
.
scheduler_config
,
self
.
cache_config
,
self
.
lora_config
,
self
.
parallel_config
.
pipeline_parallel_size
,
self
.
async_callbacks
[
v_id
]
if
self
.
model_config
.
use_async_output_proc
else
None
)
for
v_id
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
# Metric Logging.
if
self
.
log_stats
:
if
stat_loggers
is
not
None
:
self
.
stat_loggers
=
stat_loggers
else
:
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from
vllm.engine.metrics
import
(
LoggingStatLogger
,
PrometheusStatLogger
)
self
.
stat_loggers
=
{
"logging"
:
LoggingStatLogger
(
local_interval
=
_LOCAL_LOGGING_INTERVAL_SEC
,
vllm_config
=
vllm_config
),
"prometheus"
:
PrometheusStatLogger
(
local_interval
=
_LOCAL_LOGGING_INTERVAL_SEC
,
labels
=
dict
(
model_name
=
self
.
model_config
.
served_model_name
),
vllm_config
=
vllm_config
),
}
self
.
stat_loggers
[
"prometheus"
].
info
(
"cache_config"
,
self
.
cache_config
)
self
.
tracer
=
None
if
self
.
observability_config
.
otlp_traces_endpoint
:
self
.
tracer
=
init_tracer
(
"vllm.llm_engine"
,
self
.
observability_config
.
otlp_traces_endpoint
)
# Create sequence output processor, e.g. for beam search or
# speculative decoding.
self
.
output_processor
=
(
SequenceGroupOutputProcessor
.
create_output_processor
(
self
.
scheduler_config
,
self
.
detokenizer
,
self
.
scheduler
,
self
.
seq_counter
,
get_tokenizer_for_seq
,
stop_checker
=
ZeroOverheadStopChecker
(
self
.
scheduler_config
.
max_model_len
,
get_tokenizer_for_seq
,
),
))
self
.
tree_decoding
=
os
.
environ
.
get
(
'VLLM_TREE_DECODING'
)
==
'1'
self
.
seq_id_to_seq_group
:
Dict
[
str
,
SequenceGroupBase
]
=
{}
# Flag to set when an input fails to process and the engine should run
# the next step without re-scheduling.
self
.
_skip_scheduling_next_step
=
False
self
.
async_d2h
=
None
self
.
last_record
=
None
self
.
async_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
q_recorder
=
queue
.
Queue
()
self
.
use_stream
=
zero_overhead_stream
(
self
.
model_executor
.
device_config
.
device
)
if
not
is_zero_no_thread
():
self
.
zero_thread
=
threading
.
Thread
(
target
=
self
.
thread_zero_overhead
)
self
.
thread_running
=
True
self
.
sem_m2s
=
threading
.
Semaphore
(
0
)
# main to scheduler thread
self
.
zero_thread
.
start
()
profile
.
StartTracer
()
def
__del__
(
self
):
self
.
finish_thread
()
return
super
().
__del__
()
def
finish_thread
(
self
):
if
self
.
thread_running
:
self
.
thread_running
=
False
self
.
sem_m2s
.
release
()
def
thread_zero_overhead
(
self
):
logger
.
info
(
'zero overhead thread start!'
)
last_sampler
=
get_last_sampler
()
last_sampler
.
seq_ids
.
clear
()
try
:
with
torch
.
cuda
.
stream
(
self
.
use_stream
):
while
True
:
self
.
sem_m2s
.
acquire
()
if
not
self
.
thread_running
:
logger
.
debug
(
"Stopping remote worker execution loop."
)
self
.
model_executor
.
stop_remote_worker_execution_loop
()
break
virtual_engine
=
0
# Clear outputs for each new scheduler iteration
# Schedule iteration
(
seq_group_metadata_list
,
scheduler_outputs
,
allow_async_output_proc
)
=
self
.
scheduler
[
virtual_engine
].
schedule
()
if
self
.
last_record
is
not
None
:
last_sampler
=
self
.
last_record
[
1
]
spec_step
=
get_spec_step
()
if
spec_step
==
SpecStepKind
.
KIND_DEFAULT
:
if
last_sampler
.
sampled_token_ids_tensor
is
not
None
:
self
.
async_d2h
=
last_sampler
.
sampled_token_ids_tensor
.
to
(
'cpu'
,
non_blocking
=
True
)
else
:
self
.
async_d2h
=
None
elif
spec_step
==
SpecStepKind
.
SCORE_DECODE
:
self
.
async_d2h
=
last_sampler
.
to
(
'cpu'
,
non_blocking
=
True
)
self
.
async_event
.
record
()
self
.
q_recorder
.
put
(
self
.
last_record
)
else
:
self
.
q_recorder
.
put
(
None
)
if
len
(
seq_group_metadata_list
)
==
0
:
self
.
last_record
=
None
continue
finished_requests_ids
=
self
.
scheduler
[
virtual_engine
].
get_and_reset_finished_requests_ids
()
assert
seq_group_metadata_list
is
not
None
assert
scheduler_outputs
is
not
None
last_sampled_token_ids
=
\
self
.
_get_last_sampled_token_ids
(
virtual_engine
)
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
scheduler_outputs
.
blocks_to_swap_in
,
blocks_to_swap_out
=
scheduler_outputs
.
blocks_to_swap_out
,
blocks_to_copy
=
scheduler_outputs
.
blocks_to_copy
,
num_lookahead_slots
=
scheduler_outputs
.
num_lookahead_slots
,
running_queue_size
=
scheduler_outputs
.
running_queue_size
,
finished_requests_ids
=
finished_requests_ids
,
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids
=
last_sampled_token_ids
)
outputs
=
self
.
model_executor
.
execute_model
(
execute_model_req
=
execute_model_req
)
for
output
in
outputs
:
self
.
_advance_to_next_step
(
output
,
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
scheduler_outputs
.
scheduled_seq_groups
=
[
item
for
item
in
scheduler_outputs
.
scheduled_seq_groups
]
#deep copy
last_sampler
=
None
spec_step
=
get_spec_step
()
if
spec_step
==
SpecStepKind
.
KIND_DEFAULT
:
last_sampler
=
get_last_sampler
()
elif
spec_step
==
SpecStepKind
.
SCORE_DECODE
:
last_sampler
,
_
=
get_accepted_token_ids
()
self
.
last_record
=
[
outputs
,
last_sampler
,
seq_group_metadata_list
,
scheduler_outputs
,
spec_step
]
except
Exception
as
e
:
print
(
f
"thread_zero_overhead error :
{
e
}
"
)
traceback
.
print_exc
()
def
zero_overhead_step
(
self
)
->
List
[
Union
[
RequestOutput
,
PoolingRequestOutput
]]:
if
not
self
.
thread_running
:
self
.
zero_thread
.
join
()
self
.
thread_running
=
True
self
.
zero_thread
=
threading
.
Thread
(
target
=
self
.
thread_zero_overhead
)
self
.
zero_thread
.
start
()
self
.
sem_m2s
.
release
()
recode_output
=
self
.
q_recorder
.
get
()
if
recode_output
is
None
:
# None is for the first step
return
None
virtual_engine
=
0
ctx
=
self
.
scheduler_contexts
[
virtual_engine
]
ctx
.
request_outputs
.
clear
()
outputs
,
last_sampler
,
seq_group_metadata_list
,
scheduler_outputs
,
spec_step
=
recode_output
ctx
.
seq_group_metadata_list
=
seq_group_metadata_list
ctx
.
scheduler_outputs
=
scheduler_outputs
if
spec_step
==
SpecStepKind
.
KIND_DEFAULT
:
self
.
async_event
.
synchronize
()
if
self
.
async_d2h
is
not
None
:
self
.
_fix_last_step
(
outputs
,
last_sampler
,
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
elif
spec_step
==
SpecStepKind
.
SCORE_DECODE
:
self
.
async_event
.
synchronize
()
self
.
_fix_spec_decode_steps
(
outputs
,
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
# is_first_step_output is True only when the num_steps of all
# the sequences are 1. When the num_steps > 1,
# multi_step_model_runner does the first-step output append.
is_first_step_output
:
bool
=
False
if
not
seq_group_metadata_list
\
else
seq_group_metadata_list
[
0
].
state
.
num_steps
==
1
# Add results to the output_queue
ctx
.
append_output
(
outputs
=
outputs
,
seq_group_metadata_list
=
seq_group_metadata_list
,
scheduler_outputs
=
scheduler_outputs
,
is_async
=
True
,
is_last_step
=
True
,
is_first_step_output
=
is_first_step_output
)
# Check if need to run the usual non-async path
#if not allow_async_output_proc:
self
.
_process_model_outputs
(
ctx
=
ctx
)
#profile.ProfRangeAutoPush('has_unfinish')
if
not
self
.
has_unfinished_requests
():
# Drain async postprocessor (if exists)
if
len
(
ctx
.
output_queue
)
>
0
:
self
.
_process_model_outputs
(
ctx
=
ctx
)
assert
len
(
ctx
.
output_queue
)
==
0
# Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in
# torch.distributed ops which may otherwise timeout, and unblocks
# the RPC thread in the workers so that they can process any other
# queued control plane messages, such as add/remove lora adapters.
# logger.debug("Stopping remote worker execution loop.")
# self.model_executor.stop_remote_worker_execution_loop()
self
.
finish_thread
()
return
ctx
.
request_outputs
def
_fix_last_step
(
self
,
output
:
List
[
SamplerOutput
],
last_sampler
:
SampleRecorder
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
scheduled_seq_groups
:
List
[
ScheduledSequenceGroup
])
->
None
:
#sample_out_list = output[0].sampler_out_tenosr.cpu().tolist()
sample_out_list
=
self
.
async_d2h
.
tolist
()
sample_out_ids
=
last_sampler
.
seq_ids
for
seq_group_metadata
,
sequence_group_outputs
,
scheduled_seq_group
in
\
zip
(
seq_group_metadata_list
,
output
[
0
],
scheduled_seq_groups
):
seq_group
=
scheduled_seq_group
.
seq_group
if
seq_group
.
is_finished
():
continue
if
seq_group_metadata
.
do_sample
:
sample
=
sequence_group_outputs
.
samples
[
0
]
assert
len
(
seq_group
.
seqs
)
==
1
seq
:
ZeroOverheadSequence
=
seq_group
.
seqs
[
0
]
for
token_id
,
seq_id
in
zip
(
sample_out_list
,
sample_out_ids
):
if
seq
.
seq_id
==
seq_id
:
if
type
(
token_id
)
is
list
:
sample
.
output_token
=
token_id
[
0
]
else
:
sample
.
output_token
=
token_id
seq
.
fix_last_token_id
(
sample
.
output_token
)
break
def
_fix_spec_decode_steps
(
self
,
output
:
List
[
SamplerOutput
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
scheduled_seq_groups
:
List
[
ScheduledSequenceGroup
]):
sample_out_list
=
self
.
async_d2h
.
tolist
()
group_idx
=
0
for
seq_group_metadata
,
accept_token_ids
,
scheduled_seq_group
in
\
zip
(
seq_group_metadata_list
,
sample_out_list
,
scheduled_seq_groups
):
seq_group
=
scheduled_seq_group
.
seq_group
if
seq_group
.
is_finished
():
group_idx
+=
1
continue
if
seq_group_metadata
.
do_sample
:
assert
len
(
seq_group
.
seqs
)
==
1
seq
:
ZeroOverheadSequence
=
seq_group
.
seqs
[
0
]
remove_count
=
0
for
token_id
in
accept_token_ids
:
if
token_id
==
-
1
:
remove_count
+=
1
else
:
seq
.
fix_last_token_id
(
token_id
)
seq
.
remove_last_place_holder
(
remove_count
)
group_idx
+=
1
def
no_thread_step
(
self
):
virtual_engine
=
0
# Clear outputs for each new scheduler iteration
# Schedule iteration
(
seq_group_metadata_list
,
scheduler_outputs
,
allow_async_output_proc
)
=
self
.
scheduler
[
virtual_engine
].
schedule
()
if
self
.
last_record
is
not
None
:
last_sampler
=
self
.
last_record
[
1
]
self
.
async_d2h
=
last_sampler
.
sampled_token_ids_tensor
.
to
(
'cpu'
,
non_blocking
=
True
)
self
.
async_event
.
record
()
self
.
q_recorder
.
put
(
self
.
last_record
)
else
:
self
.
q_recorder
.
put
(
None
)
if
len
(
seq_group_metadata_list
)
==
0
:
self
.
last_record
=
None
else
:
finished_requests_ids
=
self
.
scheduler
[
virtual_engine
].
get_and_reset_finished_requests_ids
()
assert
seq_group_metadata_list
is
not
None
assert
scheduler_outputs
is
not
None
last_sampled_token_ids
=
\
self
.
_get_last_sampled_token_ids
(
virtual_engine
)
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
scheduler_outputs
.
blocks_to_swap_in
,
blocks_to_swap_out
=
scheduler_outputs
.
blocks_to_swap_out
,
blocks_to_copy
=
scheduler_outputs
.
blocks_to_copy
,
num_lookahead_slots
=
scheduler_outputs
.
num_lookahead_slots
,
running_queue_size
=
scheduler_outputs
.
running_queue_size
,
finished_requests_ids
=
finished_requests_ids
,
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids
=
last_sampled_token_ids
)
outputs
=
self
.
model_executor
.
execute_model
(
execute_model_req
=
execute_model_req
)
if
len
(
outputs
)
==
1
:
self
.
_advance_to_next_step
(
outputs
[
0
],
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
scheduler_outputs
.
scheduled_seq_groups
=
[
item
for
item
in
scheduler_outputs
.
scheduled_seq_groups
]
#deep copy
last_sampler
=
get_last_sampler
()
self
.
last_record
=
[
outputs
,
last_sampler
,
seq_group_metadata_list
,
scheduler_outputs
]
recode_output
=
self
.
q_recorder
.
get
()
if
recode_output
is
None
:
# None is for the first step
return
None
virtual_engine
=
0
ctx
=
self
.
scheduler_contexts
[
virtual_engine
]
ctx
.
request_outputs
.
clear
()
outputs
,
last_sampler
,
seq_group_metadata_list
,
scheduler_outputs
=
recode_output
ctx
.
seq_group_metadata_list
=
seq_group_metadata_list
ctx
.
scheduler_outputs
=
scheduler_outputs
self
.
async_event
.
synchronize
()
self
.
_fix_last_step
(
outputs
,
last_sampler
,
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
# is_first_step_output is True only when the num_steps of all
# the sequences are 1. When the num_steps > 1,
# multi_step_model_runner does the first-step output append.
is_first_step_output
:
bool
=
False
if
not
seq_group_metadata_list
\
else
seq_group_metadata_list
[
0
].
state
.
num_steps
==
1
# Add results to the output_queue
ctx
.
append_output
(
outputs
=
outputs
,
seq_group_metadata_list
=
seq_group_metadata_list
,
scheduler_outputs
=
scheduler_outputs
,
is_async
=
True
,
is_last_step
=
True
,
is_first_step_output
=
is_first_step_output
)
# Check if need to run the usual non-async path
#if not allow_async_output_proc:
self
.
_process_model_outputs
(
ctx
=
ctx
)
#profile.ProfRangeAutoPush('has_unfinish')
if
not
self
.
has_unfinished_requests
():
# Drain async postprocessor (if exists)
if
len
(
ctx
.
output_queue
)
>
0
:
self
.
_process_model_outputs
(
ctx
=
ctx
)
assert
len
(
ctx
.
output_queue
)
==
0
# Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in
# torch.distributed ops which may otherwise timeout, and unblocks
# the RPC thread in the workers so that they can process any other
# queued control plane messages, such as add/remove lora adapters.
logger
.
debug
(
"Stopping remote worker execution loop."
)
self
.
model_executor
.
stop_remote_worker_execution_loop
()
return
ctx
.
request_outputs
def
step
(
self
)
->
List
[
Union
[
RequestOutput
,
PoolingRequestOutput
]]:
with
torch
.
cuda
.
stream
(
self
.
use_stream
):
if
is_zero_no_thread
():
out
=
self
.
no_thread_step
()
if
out
is
None
:
#the first step need launch twice
out
=
self
.
no_thread_step
()
else
:
out
=
self
.
zero_overhead_step
()
if
out
is
None
:
#the first step need launch twice
out
=
self
.
zero_overhead_step
()
return
out
def
_add_processed_request
(
self
,
request_id
:
str
,
processed_inputs
:
ProcessorInputs
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
float
,
lora_request
:
Optional
[
LoRARequest
],
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
priority
:
int
=
0
,
)
->
Optional
[
SequenceGroup
]:
"""Add a processed request to the engine's request pool.
return the created sequence group.
"""
if
isinstance
(
params
,
SamplingParams
)
and
params
.
n
>
1
:
ParallelSampleSequenceGroup
.
add_request
(
request_id
,
self
,
params
,
processed_inputs
=
processed_inputs
,
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
priority
=
priority
,
)
return
None
self
.
_validate_model_inputs
(
processed_inputs
,
lora_request
)
# Create the sequences.
block_size
=
self
.
cache_config
.
block_size
seq_id
=
next
(
self
.
seq_counter
)
eos_token_id
=
self
.
input_preprocessor
.
get_eos_token_id
(
lora_request
)
encoder_inputs
,
decoder_inputs
=
split_enc_dec_inputs
(
processed_inputs
)
seq
=
ZeroOverheadSequence
(
seq_id
,
decoder_inputs
,
block_size
,
eos_token_id
,
lora_request
)
encoder_seq
=
(
None
if
encoder_inputs
is
None
else
ZeroOverheadSequence
(
seq_id
,
encoder_inputs
,
block_size
,
eos_token_id
,
lora_request
))
# Create a SequenceGroup based on SamplingParams or PoolingParams
if
isinstance
(
params
,
SamplingParams
):
seq_group
=
self
.
_create_sequence_group_with_sampling
(
request_id
,
seq
,
params
,
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
encoder_seq
=
encoder_seq
,
priority
=
priority
)
elif
isinstance
(
params
,
PoolingParams
):
seq_group
=
self
.
_create_sequence_group_with_pooling
(
request_id
,
seq
,
params
,
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
encoder_seq
=
encoder_seq
,
priority
=
priority
)
else
:
raise
ValueError
(
"Either SamplingParams or PoolingParams must be provided."
)
# Add the sequence group to the scheduler with least unfinished seqs.
costs
=
[
scheduler
.
get_num_unfinished_seq_groups
()
for
scheduler
in
self
.
scheduler
]
min_cost_scheduler
=
self
.
scheduler
[
costs
.
index
(
min
(
costs
))]
min_cost_scheduler
.
add_seq_group
(
seq_group
)
return
seq_group
\ No newline at end of file
vllm/zero_overhead/model_runner.py
deleted
100644 → 0
View file @
9bf1b213
import
torch
import
itertools
from
typing
import
List
,
Optional
,
Set
from
vllm.lora.layers
import
LoRAMapping
from
vllm.multimodal.inputs
import
MultiModalKwargs
from
vllm.prompt_adapter.layers
import
PromptAdapterMapping
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sequence
import
SequenceGroupMetadata
from
vllm.utils
import
async_tensor_h2d
,
flatten_2d_lists
from
vllm.worker.model_runner
import
ModelInputForGPU
,
ModelInputForGPUBuilder
from
vllm.zero_overhead.sampler
import
get_last_sampler
from
vllm.zero_overhead.utils
import
SpecStepKind
,
get_accepted_token_ids
,
get_proposal_token_ids
,
get_spec_last_step
,
get_spec_step
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
_update_input_tokens
(
accepted_req_ids
,
accepted_req_ids_len
,
accepted_token_ids
,
accepted_token_len
,
chidren_req_ids
,
chidren_req_ids_len
,
input_tokens
,
input_tokens_len
,
input_positions
,
seq_lens
,
seq_lens_meta
,
seq_lens_tensor
,
slot_mapping
,
seq_start_loc
,
context_lens_tensor
,
):
chidren_req_ids_
=
tl
.
load
(
chidren_req_ids
+
tl
.
arange
(
0
,
chidren_req_ids_len
))
accepted_req_ids_
=
tl
.
load
(
accepted_req_ids
+
tl
.
arange
(
0
,
chidren_req_ids_len
))
for
seq_id_idx
in
range
(
chidren_req_ids_len
/
2
):
seq_id
=
chidren_req_ids_
[
2
*
seq_id_idx
]
for
i
in
range
(
accepted_req_ids_len
):
if
seq_id
==
accepted_req_ids_
[
i
]:
accepted_token_ids_
=
tl
.
load
(
accepted_token_ids
+
tl
.
arange
(
i
*
accepted_token_len
,
tl
.
arange
(
0
,
accepted_token_len
)))
accepted_token_counter
=
0
for
j
in
range
(
accepted_token_len
):
if
accepted_token_ids_
[
j
]
==
-
1
:
break
accepted_token_counter
+=
1
if
accepted_token_counter
==
accepted_token_len
:
tl
.
store
(
input_tokens
+
seq_id_idx
*
2
+
tl
.
arange
(
0
,
2
),
accepted_token_ids_
[
-
2
:])
else
:
tl
.
store
(
input_tokens
+
seq_id_idx
*
2
,
0
)
tl
.
store
(
input_tokens
+
seq_id_idx
*
2
+
1
,
accepted_token_ids_
[
accepted_token_counter
-
1
])
input_pos
=
tl
.
load
(
input_positions
+
seq_id_idx
*
2
+
tl
.
arange
(
0
,
2
))
input_pos
[
0
]
=
0
input_pos
[
1
]
=
input_pos
[
1
]
-
(
accepted_req_ids_len
-
accepted_token_counter
)
tl
.
store
(
input_positions
+
seq_id_idx
*
2
+
tl
.
arange
(
0
,
2
),
input_pos
)
tl
.
store
(
context_lens_tensor
+
seq_id_idx
*
2
+
tl
.
arange
(
0
,
2
),
input_pos
)
input_pos
[
0
]
=
-
1
tl
.
store
(
slot_mapping
+
seq_id_idx
*
2
+
tl
.
arange
(
0
,
2
),
input_pos
)
input_pos
[
0
]
=
1
input_pos
[
1
]
=
input_pos
[
1
]
+
1
tl
.
store
(
seq_lens
+
seq_id_idx
*
2
+
tl
.
arange
(
0
,
2
),
input_pos
)
tl
.
store
(
seq_lens_meta
+
seq_id_idx
*
2
+
tl
.
arange
(
0
,
2
),
input_pos
)
tl
.
store
(
seq_lens_tensor
+
seq_id_idx
*
2
+
tl
.
arange
(
0
,
2
),
input_pos
)
seq_lens_
=
tl
.
load
(
seq_lens
+
tl
.
arange
(
0
,
input_tokens_len
))
seq_start_loc_
=
tl
.
zero_like
(
seq_start_loc
)
for
i
in
range
(
input_tokens_len
):
seq_start_loc_
[
i
+
1
]
=
seq_start_loc_
[
i
]
+
seq_lens_
[
i
]
tl
.
store
(
seq_start_loc
+
tl
.
arange
(
0
,
input_tokens_len
+
1
),
seq_start_loc_
)
class
ZeroOverheadModelInputForGpuBuilder
(
ModelInputForGPUBuilder
):
def
__init__
(
self
,
runner
,
finished_requests_ids
=
None
):
super
().
__init__
(
runner
,
finished_requests_ids
)
self
.
req_ids
=
[]
def
prepare
(
self
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
None
:
self
.
req_ids
.
clear
()
return
super
().
prepare
(
finished_requests_ids
)
def
add_seq_group
(
self
,
seq_group_metadata
:
SequenceGroupMetadata
):
seq_ids
=
seq_group_metadata
.
seq_data
.
keys
()
n_seqs
=
len
(
seq_ids
)
seq_ids
=
list
(
seq_ids
)
for
seq_idx
in
range
(
n_seqs
):
self
.
req_ids
.
append
(
seq_ids
[
seq_idx
])
return
super
().
add_seq_group
(
seq_group_metadata
)
def
build
(
self
)
->
ModelInputForGPU
:
model_input
=
super
().
build
()
last_sampler
=
get_last_sampler
()
spec_step
=
get_spec_step
()
last_step
=
get_spec_last_step
()
if
last_sampler
is
not
None
:
if
spec_step
==
SpecStepKind
.
KIND_DEFAULT
:
update_indices
=
[]
select_indices
=
[]
query_idx
=
0
for
i
,
seq_id
in
enumerate
(
self
.
req_ids
):
for
j
,
seq_id_
in
enumerate
(
last_sampler
.
seq_ids
):
if
seq_id
==
seq_id_
:
select_indices
.
append
(
j
)
update_indices
.
append
(
query_idx
)
break
query_idx
+=
model_input
.
query_lens
[
i
]
if
len
(
select_indices
)
>
0
and
last_sampler
.
sampled_token_ids_tensor
is
not
None
:
select_indices
=
async_tensor_h2d
(
select_indices
,
torch
.
long
,
self
.
runner
.
device
,
self
.
runner
.
pin_memory
)
update_indices
=
async_tensor_h2d
(
update_indices
,
torch
.
long
,
self
.
runner
.
device
,
self
.
runner
.
pin_memory
)
model_input
.
input_tokens
[
update_indices
]
=
last_sampler
.
sampled_token_ids_tensor
[
select_indices
,
0
]
if
spec_step
==
SpecStepKind
.
OTHER_PROPOSAL
:
if
last_step
==
SpecStepKind
.
OTHER_PROPOSAL
:
# copy last sampled token ids to input tokens directly.
update_indices
=
[
i
for
i
in
range
(
len
(
self
.
req_ids
))]
update_indices
=
async_tensor_h2d
(
update_indices
,
torch
.
long
,
self
.
runner
.
device
,
self
.
runner
.
pin_memory
)
model_input
.
input_tokens
[
update_indices
]
=
last_sampler
.
sampled_token_ids_tensor
[
update_indices
,
0
]
if
last_step
==
SpecStepKind
.
FIRST_PROPOSAL
:
# TODO: ajust input tokens number to 1 per request.
update_indices
=
[
i
for
i
in
range
(
len
(
self
.
req_ids
))]
update_indices
=
async_tensor_h2d
(
update_indices
,
torch
.
long
,
self
.
runner
.
device
,
self
.
runner
.
pin_memory
)
model_input
.
input_tokens
[
update_indices
]
=
last_sampler
.
sampled_token_ids_tensor
[
update_indices
,
0
]
if
spec_step
==
SpecStepKind
.
SCORE_DECODE
:
proposal_token_ids
=
get_proposal_token_ids
()
shape
=
proposal_token_ids
.
shape
batch_size
=
shape
[
0
]
proposal_len
=
shape
[
1
]
update_indices
=
[]
for
i
in
range
(
batch_size
):
for
j
in
range
(
proposal_len
):
update_indices
.
append
(
i
*
(
proposal_len
+
1
)
+
j
+
1
)
update_indices
=
async_tensor_h2d
(
update_indices
,
torch
.
long
,
self
.
runner
.
device
,
self
.
runner
.
pin_memory
)
model_input
.
input_tokens
[
update_indices
]
=
proposal_token_ids
.
view
(
-
1
)
if
spec_step
==
SpecStepKind
.
FIRST_PROPOSAL
:
if
last_step
==
SpecStepKind
.
PREFILL
:
# TODO: when last step is prefill, just update the input ids for last seqence_id onely.
pass
if
last_step
==
SpecStepKind
.
SCORE_DECODE
:
# TODO: when last step is score decode, fix input ids、seq_lens、input_positions use accepte token ids
accept_token_ids
,
accept_seq_ids
=
get_accepted_token_ids
()
chidren_req_ids
=
async_tensor_h2d
(
self
.
req_ids
,
torch
.
long
,
self
.
runner
.
device
,
self
.
runner
.
pin_memory
)
grid
=
[
1
,
1
,
1
]
_update_input_tokens
[
grid
](
accept_seq_ids
,
accept_seq_ids
.
shape
[
0
],
accept_token_ids
,
accept_token_ids
.
shape
[
1
],
chidren_req_ids
,
chidren_req_ids
.
shape
[
0
],
model_input
.
input_tokens
,
model_input
.
input_tokens
.
shape
[
0
],
model_input
.
input_positions
,
model_input
.
seq_lens
,
model_input
.
attn_metadata
.
seq_lens_tensor
,
model_input
.
attn_metadata
.
seq_lens
,
model_input
.
attn_metadata
.
slot_mapping
,
model_input
.
attn_metadata
.
seq_start_loc
,
model_input
.
attn_metadata
.
context_lens_tensor
,
)
return
model_input
vllm/zero_overhead/sampler.py
deleted
100644 → 0
View file @
9bf1b213
from
importlib.util
import
find_spec
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
from
vllm
import
envs
from
vllm.distributed.parallel_state
import
get_tp_group
from
vllm.model_executor.layers.sampler
import
MaybeDeferredSampleResultType
,
MultinomialSamplesType
,
SampleMetadataType
,
\
SampleResultArgsType
,
SampleResultType
,
SampleResultsDictType
,
SampleReturnType
,
Sampler
,
\
SamplerOutput
,
_apply_min_p
,
_apply_min_tokens_penalty
,
_apply_top_k_top_p
,
\
_modify_greedy_probs_inplace
,
_top_k_top_p_multinomial_with_flashinfer
,
get_logprobs
,
_multinomial
from
vllm.model_executor.layers.utils
import
apply_penalties
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
,
SamplingTensors
,
SequenceGroupToSample
from
vllm.sampling_params
import
SamplingType
from
vllm.sequence
import
VLLM_INVALID_TOKEN_ID
,
CompletionSequenceGroupOutput
,
PromptLogprobs
,
SampleLogprobs
,
SequenceOutput
if
envs
.
VLLM_USE_FLASHINFER_SAMPLER
and
find_spec
(
"flashinfer"
):
import
flashinfer.sampling
# yapf: disable
from
flashinfer.sampling
import
(
top_k_top_p_sampling_from_probs
as
flashinfer_top_k_top_p_sampling
)
# yapf: enable
else
:
flashinfer_top_k_top_p_sampling
=
None
class
SampleRecorder
:
def
__init__
(
self
):
self
.
seq_ids
:
torch
.
Tensor
=
None
self
.
sampled_token_ids_tensor
:
torch
.
Tensor
=
None
last_sampler
=
None
def
get_last_sampler
():
return
last_sampler
class
ZeroOverheadSampler
(
Sampler
):
def
__init__
(
self
):
super
().
__init__
()
def
forward
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
"""
Single-step scheduling:
* Perform GPU-side sampling computation & compute
GPU-side logprobs tensor
* Pythonize sampling result & logprobs tensor
Multi-step scheduling:
* Perform GPU-side sampling computation & compute
GPU-side logprobs tensor
* Defer Pythonization of sampling result & logprobs
tensor
* Encapsulate arguments required for deferred Pythonization
in the :class:`SamplerOutput` structure
Args:
logits: (num_tokens, vocab_size).
sampling_metadata: Metadata for sampling.
"""
global
last_sampler
last_sampler
=
SampleRecorder
()
assert
logits
is
not
None
_
,
vocab_size
=
logits
.
shape
# Prepare sampling tensors with pinned memory to avoid blocking.
if
not
sampling_metadata
.
reuse_sampling_tensors
:
self
.
_init_sampling_tensors
(
logits
,
sampling_metadata
)
elif
self
.
_do_penalties
:
# In this case, the sampling tensors logic depends on
# "output_tokens" of a sequence. As a result, we cannot
# reuse sampling tensors, since "output_tokens" changes
# between decode runs.
self
.
_init_sampling_tensors
(
logits
,
sampling_metadata
)
assert
self
.
_sampling_tensors
is
not
None
sampling_tensors
=
self
.
_sampling_tensors
do_penalties
=
self
.
_do_penalties
do_top_p_top_k
=
self
.
_do_top_p_top_k
do_min_p
=
self
.
_do_min_p
logits
=
_apply_min_tokens_penalty
(
logits
,
sampling_metadata
)
# Apply presence and frequency penalties.
if
do_penalties
:
logits
=
apply_penalties
(
logits
,
sampling_tensors
.
prompt_tokens
,
sampling_tensors
.
output_tokens
,
sampling_tensors
.
presence_penalties
,
sampling_tensors
.
frequency_penalties
,
sampling_tensors
.
repetition_penalties
)
# Use float32 to apply temperature scaling.
# Use in-place division to avoid creating a new tensor.
logits
=
logits
.
to
(
torch
.
float
)
logits
.
div_
(
sampling_tensors
.
temperatures
.
unsqueeze
(
dim
=
1
))
if
do_top_p_top_k
and
flashinfer_top_k_top_p_sampling
is
None
:
logits
=
_apply_top_k_top_p
(
logits
,
sampling_tensors
.
top_ps
,
sampling_tensors
.
top_ks
)
if
do_min_p
:
logits
=
_apply_min_p
(
logits
,
sampling_tensors
.
min_ps
)
# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float
)
# Compute the log probabilities.
logprobs
=
torch
.
log_softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float
)
# Sample the next tokens.
maybe_deferred_sample_results
,
maybe_sampled_tokens_tensor
=
_sample
(
probs
,
logprobs
,
sampling_metadata
,
sampling_tensors
,
include_gpu_probs_tensor
=
self
.
include_gpu_probs_tensor
,
modify_greedy_probs
=
self
.
_should_modify_greedy_probs_inplace
,
)
if
self
.
include_gpu_probs_tensor
:
# Since we will defer sampler result Pythonization,
# preserve GPU-side tensors in support of later
# deferred pythonization of logprobs
assert
maybe_sampled_tokens_tensor
is
not
None
on_device_tensors
=
(
probs
,
logprobs
,
maybe_sampled_tokens_tensor
)
else
:
# Since Pythonization has already happened, don't preserve
# GPU-side tensors.
on_device_tensors
=
None
# Get the logprobs query results.
prompt_logprobs
=
None
sample_logprobs
=
None
if
not
sampling_metadata
.
skip_sampler_cpu_output
:
# Pythonize logprobs now (GPU -> CPU); do not defer.
assert
not
isinstance
(
maybe_deferred_sample_results
,
SampleResultArgsType
)
prompt_logprobs
,
sample_logprobs
=
get_logprobs
(
logprobs
,
sampling_metadata
,
maybe_deferred_sample_results
)
return
_build_sampler_output
(
maybe_deferred_sample_results
,
sampling_metadata
,
prompt_logprobs
,
sample_logprobs
,
on_device_tensors
=
on_device_tensors
,
skip_sampler_cpu_output
=
sampling_metadata
.
skip_sampler_cpu_output
,
logits
=
logits
)
def
_greedy_sample
(
selected_seq_groups
:
List
[
SequenceGroupToSample
],
samples
:
torch
.
Tensor
,
)
->
SampleResultType
:
"""Run greedy sampling on a given samples.
Args:
selected_seq_groups: A list of sequence groups batched.
samples: (num_selected_samples,) A tensor of samples. The length of
samples could be smaller than selected_seq_groups if
seq_group.do_sample is False.
Returns:
Tuple of (next_token_ids, parent_ids). The length of returned list is
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
sample_idx
=
0
results
:
SampleResultType
=
[]
for
seq_group
in
selected_seq_groups
:
if
not
seq_group
.
do_sample
:
results
.
append
(([],
[]))
continue
seq_ids
=
seq_group
.
seq_ids
num_parent_seqs
=
len
(
seq_ids
)
assert
num_parent_seqs
==
1
,
(
"Greedy sampling should have only one seq."
)
parent_ids
=
list
(
range
(
num_parent_seqs
))
assert
num_parent_seqs
==
1
# not support muti seqences in seqence group
next_token_ids
=
[
0
]
#place holder token id
results
.
append
((
next_token_ids
,
parent_ids
))
sample_idx
+=
num_parent_seqs
return
results
def
_random_sample
(
selected_seq_groups
:
List
[
SequenceGroupToSample
],
random_samples
:
torch
.
Tensor
,
)
->
SampleResultType
:
"""Run random sampling on a given samples.
Args:
selected_seq_groups: A list of sequence groups batched.
random_samples: (num_selected_samples,) A tensor of samples. The
length of samples could be smaller than selected_seq_groups if
seq_group.do_sample is False.
Returns:
Tuple of (next_token_ids, parent_ids). The length of returned list is
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
# Find the maximum n value of the prompt phase requests.
sample_idx
=
0
results
:
SampleResultType
=
[]
for
seq_group
in
selected_seq_groups
:
if
not
seq_group
.
do_sample
:
results
.
append
(([],
[]))
continue
seq_ids
=
seq_group
.
seq_ids
sampling_params
=
seq_group
.
sampling_params
is_prompt
=
seq_group
.
is_prompt
num_parent_seqs
=
len
(
seq_ids
)
if
is_prompt
:
# Prompt phase.
parent_ids
=
[
0
]
*
sampling_params
.
n
assert
num_parent_seqs
==
1
# not support muti seqences in seqence group
next_token_ids
=
[
0
]
*
sampling_params
.
n
#place holder token id
else
:
# Generation phase.
parent_ids
=
list
(
range
(
num_parent_seqs
))
assert
num_parent_seqs
==
1
# not support muti seqences in seqence group
next_token_ids
=
[
0
]
*
num_parent_seqs
#place holder token id
results
.
append
((
next_token_ids
,
parent_ids
))
sample_idx
+=
num_parent_seqs
return
results
def
_sample
(
probs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_tensors
:
SamplingTensors
,
include_gpu_probs_tensor
:
bool
,
modify_greedy_probs
:
bool
,
)
->
SampleReturnType
:
"""
Args:
probs: (num_query_tokens_in_batch, num_vocab)
logprobs: (num_query_tokens_in_batch, num_vocab)
sampling_metadata: The metadata for a batch for sampling.
sampling_tensors: Tensors that include sampling related metadata.
Returns:
(next_token_ids, parent_seq_ids) for each seq group in a batch.
If sampling is skipped, it returns ([], [])
sampled_token_ids_tensor: A tensor of sampled token ids.
"""
return
_sample_with_torch
(
probs
,
logprobs
,
sampling_metadata
,
sampling_tensors
,
include_gpu_probs_tensor
=
include_gpu_probs_tensor
,
modify_greedy_probs
=
modify_greedy_probs
,
)
def
_sample_with_torch
(
probs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_tensors
:
SamplingTensors
,
include_gpu_probs_tensor
:
bool
,
modify_greedy_probs
:
bool
,
)
->
SampleReturnType
:
'''Torch-oriented _sample() implementation.
Single-step scheduling:
* Perform GPU-side sampling computation
* Immediately Pythonize sampling result
Multi-step scheduling:
* Perform GPU-side sampling computation
* Defer Pythonization & preserve GPU-side
tensors required for Pythonization
'''
categorized_seq_group_ids
:
Dict
[
SamplingType
,
List
[
int
]]
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices
=
sampling_metadata
.
categorized_sample_indices
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
sampling_params
=
seq_group
.
sampling_params
sampling_type
=
sampling_params
.
sampling_type
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
sample_results_dict
:
SampleResultsDictType
=
{}
sample_metadata
:
SampleMetadataType
=
{}
multinomial_samples
:
MultinomialSamplesType
=
{}
greedy_samples
:
Optional
[
torch
.
Tensor
]
=
None
# Create output tensor for sampled token ids.
if
include_gpu_probs_tensor
:
sampled_token_ids_tensor
=
torch
.
full
((
logprobs
.
shape
[
0
],
1
),
VLLM_INVALID_TOKEN_ID
,
dtype
=
torch
.
long
,
device
=
logprobs
.
device
)
else
:
sampled_token_ids_tensor
=
None
# Counterintiutively, having two loops here is actually faster.
# The first loop can run without waiting on GPU<->CPU sync.
for
sampling_type
in
SamplingType
:
sample_indices
=
categorized_sample_indices
[
sampling_type
]
num_tokens
=
len
(
sample_indices
)
if
num_tokens
==
0
:
continue
seq_group_id
=
categorized_seq_group_ids
[
sampling_type
]
seq_groups
=
[
sampling_metadata
.
seq_groups
[
i
]
for
i
in
seq_group_id
]
sample_metadata
[
sampling_type
]
=
(
seq_group_id
,
seq_groups
)
long_sample_indices
=
sample_indices
.
long
()
if
sampling_type
==
SamplingType
.
GREEDY
:
greedy_samples
=
torch
.
argmax
(
logprobs
[
long_sample_indices
],
dim
=-
1
)
last_sampler
.
sampled_token_ids_tensor
=
greedy_samples
.
unsqueeze
(
-
1
)
if
sampled_token_ids_tensor
is
not
None
:
# Store sampled tokens in output tensor.
sampled_token_ids_tensor
[
long_sample_indices
]
=
greedy_samples
.
unsqueeze
(
-
1
)
if
modify_greedy_probs
:
# If required, modify the probabilities such that sampling from
# the modified distribution would always sample the argmax
# token id.
_modify_greedy_probs_inplace
(
logprobs
,
probs
,
long_sample_indices
,
greedy_samples
)
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
max_n_in_batch
=
1
for
seq_group
in
seq_groups
:
if
seq_group
.
is_prompt
:
sampling_params
=
seq_group
.
sampling_params
max_n_in_batch
=
max
(
max_n_in_batch
,
sampling_params
.
n
)
seq_groups_arg
=
(
None
if
sampling_type
==
SamplingType
.
RANDOM
else
seq_groups
)
if
flashinfer_top_k_top_p_sampling
is
not
None
:
multinomial_samples
[
sampling_type
]
=
_top_k_top_p_multinomial_with_flashinfer
(
probs
[
long_sample_indices
],
sampling_tensors
.
top_ks
[
long_sample_indices
],
sampling_tensors
.
top_ps
[
long_sample_indices
],
max_n_in_batch
,
seq_groups_arg
,
)
else
:
multinomial_samples
[
sampling_type
]
=
_multinomial
(
probs
[
long_sample_indices
],
max_n_in_batch
,
seq_groups
=
seq_groups_arg
)
last_sampler
.
sampled_token_ids_tensor
=
\
multinomial_samples
[
sampling_type
].
to
(
torch
.
long
)
if
sampled_token_ids_tensor
is
not
None
:
# Store sampled tokens in output tensor.
sampled_token_ids_tensor
[
long_sample_indices
]
=
\
multinomial_samples
[
sampling_type
].
to
(
torch
.
long
)
# Encapsulate arguments for computing Pythonized sampler
# results, whether deferred or otherwise.
maybe_deferred_args
=
SampleResultArgsType
(
sampling_metadata
=
sampling_metadata
,
sample_metadata
=
sample_metadata
,
multinomial_samples
=
multinomial_samples
,
greedy_samples
=
greedy_samples
,
sample_results_dict
=
sample_results_dict
)
if
not
sampling_metadata
.
skip_sampler_cpu_output
:
# GPU<->CPU sync happens here.
# This also converts the sampler output to a Python object.
# Return Pythonized sampler result & sampled token ids
return
get_pythonized_sample_results
(
maybe_deferred_args
),
sampled_token_ids_tensor
else
:
# Defer sampler result Pythonization; return deferred
# Pythonization args & sampled token ids
return
(
maybe_deferred_args
,
sampled_token_ids_tensor
,
)
def
get_pythonized_sample_results
(
sample_result_args
:
SampleResultArgsType
)
->
SampleResultType
:
'''This function consumes GPU-side sampler results and computes
Pythonized CPU-side sampler results (GPU -> CPU sync.)
Single-step scheduling: this function is invoked at sampling-time
for immediate Pythonization.
Multi-step scheduling: Pythonization is deferred until after multiple
GPU-side steps have been completed.
Args:
sample_result_args: GPU-side inputs to the Pythonization process
Returns:
Pythonized sampler results
'''
(
sample_metadata
,
sampling_metadata
,
greedy_samples
,
multinomial_samples
,
sample_results_dict
,
)
=
(
sample_result_args
.
sample_metadata
,
sample_result_args
.
sampling_metadata
,
sample_result_args
.
greedy_samples
,
sample_result_args
.
multinomial_samples
,
sample_result_args
.
sample_results_dict
,
)
for
sampling_type
in
SamplingType
:
if
sampling_type
not
in
sample_metadata
:
continue
(
seq_group_id
,
seq_groups
)
=
sample_metadata
[
sampling_type
]
if
sampling_type
==
SamplingType
.
GREEDY
:
sample_results
=
_greedy_sample
(
seq_groups
,
greedy_samples
)
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
sample_results
=
_random_sample
(
seq_groups
,
multinomial_samples
[
sampling_type
])
sample_results_dict
.
update
(
zip
(
seq_group_id
,
sample_results
))
return
[
sample_results_dict
.
get
(
i
,
([],
[]))
for
i
in
range
(
len
(
sampling_metadata
.
seq_groups
))
]
def
_build_sampler_output
(
maybe_deferred_sample_results
:
MaybeDeferredSampleResultType
,
sampling_metadata
:
SamplingMetadata
,
prompt_logprobs
:
Optional
[
List
[
Optional
[
PromptLogprobs
]]],
sample_logprobs
:
Optional
[
List
[
SampleLogprobs
]],
on_device_tensors
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]],
skip_sampler_cpu_output
:
bool
=
False
,
logits
:
Optional
[
torch
.
Tensor
]
=
None
)
->
SamplerOutput
:
"""Construct Python objects with the output of sampling.
Args:
on_device_tensors: Tuple containing on-device tensors with the
probabilities used in sampling and the sampled token ids. This
allows post-processing without copies to CPU/serialization, e.g. in
speculative decoding rejection sampling.
"""
sampler_output
:
List
[
CompletionSequenceGroupOutput
]
=
[]
last_sampler
.
seq_ids
=
[]
if
skip_sampler_cpu_output
:
assert
isinstance
(
maybe_deferred_sample_results
,
SampleResultArgsType
)
deferred_sample_results_args
=
maybe_deferred_sample_results
else
:
assert
prompt_logprobs
is
not
None
assert
sample_logprobs
is
not
None
assert
not
isinstance
(
maybe_deferred_sample_results
,
SampleResultArgsType
)
assert
len
(
sampling_metadata
.
seq_groups
)
\
==
len
(
maybe_deferred_sample_results
)
\
==
len
(
prompt_logprobs
)
\
==
len
(
sample_logprobs
)
deferred_sample_results_args
=
None
for
(
seq_group
,
sample_result
,
group_prompt_logprobs
,
group_sample_logprobs
)
in
zip
(
sampling_metadata
.
seq_groups
,
maybe_deferred_sample_results
,
prompt_logprobs
,
sample_logprobs
):
seq_ids
=
seq_group
.
seq_ids
next_token_ids
,
parent_ids
=
sample_result
seq_outputs
:
List
[
SequenceOutput
]
=
[]
for
parent_id
,
next_token_id
,
logprobs
in
zip
(
parent_ids
,
next_token_ids
,
group_sample_logprobs
):
seq_outputs
.
append
(
SequenceOutput
(
seq_ids
[
parent_id
],
next_token_id
,
logprobs
))
sampler_output
.
append
(
CompletionSequenceGroupOutput
(
seq_outputs
,
group_prompt_logprobs
))
if
len
(
seq_outputs
)
>
0
:
last_sampler
.
seq_ids
.
append
(
seq_outputs
[
0
].
parent_seq_id
)
# If not specified, store None values in SamplerOutput.
if
on_device_tensors
is
not
None
:
(
sampled_token_probs
,
logprobs_tensor
,
sampled_token_ids
)
=
on_device_tensors
else
:
sampled_token_probs
,
logprobs_tensor
,
sampled_token_ids
=
(
None
,
None
,
None
)
return
SamplerOutput
(
outputs
=
sampler_output
,
sampled_token_probs
=
sampled_token_probs
,
sampled_token_ids
=
sampled_token_ids
,
logprobs
=
logprobs_tensor
,
deferred_sample_results_args
=
deferred_sample_results_args
,
logits
=
logits
)
\ No newline at end of file
vllm/zero_overhead/sequence.py
deleted
100644 → 0
View file @
9bf1b213
from
typing
import
Union
from
vllm.sequence
import
Sequence
from
typing
import
Sequence
as
GenericSequence
class
ZeroOverheadSequence
(
Sequence
):
def
__init__
(
self
,
seq_id
,
inputs
,
block_size
,
eos_token_id
=
None
,
lora_request
=
None
,
prompt_adapter_request
=
None
):
super
().
__init__
(
seq_id
,
inputs
,
block_size
,
eos_token_id
,
lora_request
,
prompt_adapter_request
)
self
.
effective_output_len
:
int
=
0
def
fix_last_token_id
(
self
,
token_id
:
int
)
->
None
:
effect_offset
=
self
.
effective_output_len
-
len
(
self
.
data
.
output_token_ids
)
if
effect_offset
<
0
:
self
.
data
.
_output_token_ids
[
effect_offset
]
=
token_id
if
len
(
self
.
data
.
_new_appended_tokens
)
>=
effect_offset
*
-
1
:
self
.
data
.
_new_appended_tokens
[
effect_offset
]
=
token_id
self
.
data
.
_cached_all_token_ids
[
effect_offset
]
=
token_id
self
.
effective_output_len
+=
1
def
remove_last_place_holder
(
self
,
count
):
self
.
data
.
_output_token_ids
=
self
.
data
.
_output_token_ids
[:
-
1
*
count
]
self
.
data
.
_new_appended_tokens
=
self
.
data
.
_new_appended_tokens
[:
-
1
*
count
]
self
.
data
.
_cached_all_token_ids
=
self
.
data
.
_cached_all_token_ids
[:
-
1
*
count
]
self
.
data
.
_num_computed_tokens
-=
count
def
zero_overhead_get_output_token_ids
(
self
)
->
tuple
[
int
,
...]:
return
self
.
data
.
output_token_ids
[:
self
.
effective_output_len
]
def
zero_overhead_get_output_len
(
self
)
->
int
:
return
self
.
effective_output_len
def
zero_overhead_get_last_token_id
(
self
)
->
int
:
if
self
.
effective_output_len
==
0
:
return
self
.
data
.
_prompt_token_ids
[
-
1
]
return
self
.
data
.
_output_token_ids
[
self
.
effective_output_len
-
1
]
def
zero_overhead_get_len
(
self
)
->
int
:
return
self
.
effective_output_len
+
len
(
self
.
data
.
_prompt_token_ids
)
def
get_output_token_ids_to_return
(
self
,
delta
:
bool
)
->
Union
[
GenericSequence
[
int
],
int
]:
"""If delta is True, only new tokens since the last call to
this method are returned"""
if
not
delta
:
return
self
.
zero_overhead_get_output_token_ids
()
output_len
=
self
.
zero_overhead_get_output_len
()
# Get the number of new tokens
num_new_tokens
=
output_len
-
self
.
_last_output_token_ids_offset
self
.
_last_output_token_ids_offset
=
output_len
# Return new tokens
if
num_new_tokens
==
1
:
# Optimization for single decode token case
# (which is what we have most of the time)
return
self
.
data
.
_cached_all_token_ids
[
self
.
effective_output_len
-
1
]
if
num_new_tokens
==
0
:
return
[]
effect_offset
=
self
.
effective_output_len
-
len
(
self
.
data
.
output_token_ids
)
return
self
.
data
.
_cached_all_token_ids
[
-
num_new_tokens
:
effect_offset
]
vllm/zero_overhead/spec_decode/batch_expansion.py
deleted
100644 → 0
View file @
9bf1b213
from
array
import
array
import
numpy
as
np
from
itertools
import
chain
,
count
from
typing
import
Iterator
,
List
,
Optional
,
Tuple
import
torch
from
vllm
import
SamplingParams
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
(
VLLM_INVALID_TOKEN_ID
,
VLLM_TOKEN_ID_ARRAY_TYPE
,
ExecuteModelRequest
,
SequenceData
,
SequenceGroupMetadata
,
get_all_seq_ids
)
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTop1Scorer
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeScorer
,
SpeculativeScores
)
from
vllm.spec_decode.util
import
nvtx_range
,
split_batch_by_proposal_len
from
vllm.utils
import
async_tensor_h2d
from
vllm.zero_overhead.utils
import
get_proposal_lens_list
,
record_proposal_token_ids
SeqId
=
int
TargetSeqId
=
int
TokenId
=
int
DEFAULT_SIMPLE_SAMPLING_PARAMS
=
SamplingParams
()
class
ZeroOverheadBatchExpansionTop1Scorer
(
BatchExpansionTop1Scorer
):
@
nvtx_range
(
"BatchExpansionTop1Scorer.score_proposals"
)
def
score_proposals
(
self
,
execute_model_req
:
ExecuteModelRequest
,
proposals
:
SpeculativeProposals
,
)
->
SpeculativeScores
:
"""Score the proposed tokens via the scorer model.
This converts each input sequence to a set of k+1 target sequences. The
target sequences have the unique continuations to be scored and a
unique sequence ID that is different from all input sequence ids.
If a speculative sequence length would exceed the max model length, then
no speculation is produced for that sequence.
Args:
execute_model_req: The execution request.
proposals: The speculative proposals to score.
Returns:
SpeculativeScores: The scores of each speculative token, along with
which sequences were ignored during scoring.
"""
proposal_lens_list
=
get_proposal_lens_list
()
record_proposal_token_ids
(
proposals
.
proposal_token_ids
)
proposal_token_ids_list
=
np
.
zeros
(
proposals
.
proposal_token_ids
.
shape
,
dtype
=
int
).
tolist
()
# place holder tokens
# Filter the list to ignore invalid proposals.
proposal_token_ids_list_without_skips
=
[
proposals
for
proposals
in
proposal_token_ids_list
if
VLLM_INVALID_TOKEN_ID
not
in
proposals
]
(
spec_indices
,
non_spec_indices
,
target_seq_group_metadata_list
,
num_scoring_tokens
)
=
self
.
_expand_batch
(
seq_group_metadata_list
=
execute_model_req
.
seq_group_metadata_list
,
proposal_token_ids_list
=
proposal_token_ids_list_without_skips
,
proposal_lens_list
=
proposal_lens_list
,
)
target_sampler_output
=
self
.
_scorer_worker
.
execute_model
(
execute_model_req
=
execute_model_req
.
clone
(
seq_group_metadata_list
=
target_seq_group_metadata_list
))
assert
len
(
target_sampler_output
)
==
1
,
"expected single-step output"
target_sampler_output
=
target_sampler_output
[
0
]
if
not
non_spec_indices
:
# All sequence groups in batch have spec decoding enabled
return
self
.
_contract_batch_all_spec
(
target_sampler_output
=
target_sampler_output
,
proposals
=
proposals
,
)
else
:
# Batch has a mix of spec decode enabled and disabled seq groups
return
self
.
_contract_batch
(
execute_model_req
.
seq_group_metadata_list
,
target_sampler_output
=
target_sampler_output
,
proposals
=
proposals
,
num_scoring_tokens
=
num_scoring_tokens
,
non_spec_indices
=
non_spec_indices
,
spec_indices
=
spec_indices
,
k
=
execute_model_req
.
num_lookahead_slots
,
)
def
_contract_non_speculative
(
self
,
scores
:
SpeculativeScores
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
non_spec_indices
:
List
[
int
],
non_spec_outputs
:
SpeculativeScores
,
has_prompt_log
:
bool
)
->
SpeculativeScores
:
"""
Augment input `scores` with non-speculative requests outputs.
This includes decode requests with speculation turned off, as well
as prefill requests when `enable_chunked_prefill` is set.
For the latter, prefills are further separated into terminal and
non-terminal chunks (from which no token is sampled).
"""
if
not
non_spec_indices
:
return
scores
if
has_prompt_log
:
# When prompt_logprobs is enabled, prefills yield output token
# (and respective prob) in the last entry (prompt|out):
# [.|.|.|prefill0_out|.|prefill1_out|decode0_out|..].
# With chunked prefill, non-terminal chunks have -1 on each
# position: they're still picked, but they're discarded later.
seq_meta
=
seq_group_metadata_list
nospec_sizes
=
torch
.
tensor
([
seq_meta
[
i
].
token_chunk_size
if
seq_meta
[
i
].
is_prompt
else
1
for
i
in
non_spec_indices
])
nospec_sampled_token_idxs
=
torch
.
cumsum
(
nospec_sizes
,
0
).
add_
(
-
1
)
else
:
# In this case only sampled tokens are returned, select all.
nospec_sampled_token_idxs
=
list
(
range
(
len
(
non_spec_outputs
.
token_ids
)))
nospec_sampled_token_idxs
=
async_tensor_h2d
(
nospec_sampled_token_idxs
,
torch
.
int32
,
self
.
_device
,
True
)
non_spec_indices
=
async_tensor_h2d
(
non_spec_indices
,
torch
.
int32
,
self
.
_device
,
True
)
scores
.
token_ids
[
non_spec_indices
,
:
1
]
=
\
non_spec_outputs
.
token_ids
[
nospec_sampled_token_idxs
].
unsqueeze
(
1
)
scores
.
probs
[
non_spec_indices
,
:
1
,
:]
=
\
non_spec_outputs
.
probs
[
nospec_sampled_token_idxs
].
unsqueeze
(
1
)
scores
.
logprobs
[
non_spec_indices
,
:
1
,
:]
=
\
non_spec_outputs
.
logprobs
[
nospec_sampled_token_idxs
].
unsqueeze
(
1
)
if
scores
.
hidden_states
is
not
None
:
assert
non_spec_outputs
.
hidden_states
is
not
None
scores
.
hidden_states
[
non_spec_indices
,
:
1
,
:]
=
\
non_spec_outputs
.
hidden_states
[
nospec_sampled_token_idxs
].
unsqueeze
(
1
)
return
scores
\ No newline at end of file
vllm/zero_overhead/spec_decode/muti_step_worker.py
deleted
100644 → 0
View file @
9bf1b213
import
copy
import
weakref
from
typing
import
Dict
,
List
,
Set
,
Tuple
import
torch
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
(
ExecuteModelRequest
,
HiddenStates
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.utils
import
async_tensor_h2d
from
vllm.zero_overhead.spec_decode.top1_proproser
import
ZeroOverheadTop1Proposer
from
vllm.zero_overhead.utils
import
SpecStepKind
,
set_spec_step
if
current_platform
.
is_cuda_alike
():
from
vllm.spec_decode.draft_model_runner
import
TP1DraftModelRunner
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeProposer
)
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
from
vllm.worker.worker_base
import
DelegateWorkerBase
class
ZeroOverheadMultiStepWorker
(
MultiStepWorker
):
def
init_device
(
self
)
->
None
:
self
.
worker
.
init_device
()
self
.
_proposer
=
ZeroOverheadTop1Proposer
(
weakref
.
proxy
(
self
),
# type: ignore[arg-type]
self
.
device
,
self
.
vocab_size
,
max_proposal_len
=
self
.
max_model_len
,
)
@
torch
.
inference_mode
()
def
sampler_output
(
self
,
execute_model_req
:
ExecuteModelRequest
,
sample_len
:
int
,
seq_ids_with_bonus_token_in_last_step
:
Set
[
int
],
)
->
Tuple
[
List
[
SamplerOutput
],
bool
]:
"""Run the model forward pass sample_len times. Returns the list of
sampler output, one per model forward pass, along with indicator of
whether torch tensor in sampler output need to be transposed in latter
sampler_output_to_torch logic.
For multi step worker, this indicator shall be True.
"""
self
.
_raise_if_unsupported
(
execute_model_req
)
# Expand the batch for sequences with a bonus token.
# Perform a forward pass on the expanded batch and filter the
# response to retain only the original sequences' responses.
expanded_request
,
indices_of_seq_with_bonus_tokens
=
\
self
.
_expand_execute_model_request
(
execute_model_req
,
seq_ids_with_bonus_token_in_last_step
)
# Run model sample_len times.
model_outputs
:
List
[
SamplerOutput
]
=
[]
if
current_platform
.
is_cuda_alike
()
and
isinstance
(
self
.
model_runner
,
TP1DraftModelRunner
)
and
self
.
model_runner
.
supports_gpu_multi_step
(
expanded_request
):
# Here we run the draft_model_runner with multi-step prepare
# on the GPU directly
expanded_request
.
num_steps
=
sample_len
self
.
model_runner
.
set_indices_of_seq_with_bonus_tokens
(
indices_of_seq_with_bonus_tokens
)
model_outputs
=
self
.
execute_model
(
execute_model_req
=
expanded_request
)
else
:
# Here we run multi-step directly, with every step prepared
# on the CPU.
# TODO: Remove this branch once DraftModelRunner supports TP>1
# and other restrictions that are part of DraftModelRunner's
# supports_gpu_multi_step(..)
set_spec_step
(
SpecStepKind
.
FIRST_PROPOSAL
)
for
_
in
range
(
sample_len
):
model_output
:
List
[
SamplerOutput
]
=
self
.
worker
.
execute_model
(
execute_model_req
=
expanded_request
)
assert
(
len
(
model_output
)
==
1
),
"composing multistep workers not supported"
model_output
=
model_output
[
0
]
set_spec_step
(
SpecStepKind
.
OTHER_PROPOSAL
)
self
.
_append_new_tokens
(
model_output
,
expanded_request
.
seq_group_metadata_list
,
indices_of_seq_with_bonus_tokens
)
model_outputs
.
append
(
model_output
)
set_spec_step
(
SpecStepKind
.
SCORE_DECODE
)
filtered_model_outputs
=
self
.
_filter_model_output_zero_overhead
(
model_outputs
,
indices_of_seq_with_bonus_tokens
)
return
filtered_model_outputs
,
True
def
_filter_model_output_zero_overhead
(
self
,
expanded_batch_outputs
:
List
[
SamplerOutput
],
output_indices_to_retain
:
List
[
int
])
->
List
[
SamplerOutput
]:
"""
Filters the model output to include only the specified sequence
outputs. This method contracts the expanded batch output from the
model to retain the outputs of only those sequences indicated by the
provided indices.
Args:
expanded_batch_output (List[SamplerOutput]): The expanded output
batch from the model.
output_indices_to_retain (torch.Tensor): Indices of the model
outputs to retain.
Returns:
List[SamplerOutput]: A list containing the filtered model
outputs for the specified indices.
"""
indices_of_seq_with_bonus_tokens
=
async_tensor_h2d
(
output_indices_to_retain
,
torch
.
int32
,
self
.
device
,
True
)
return
[
SamplerOutput
(
outputs
=
[
expanded_batch_output
.
outputs
[
i
]
for
i
in
output_indices_to_retain
]
if
len
(
expanded_batch_output
.
outputs
)
>
0
else
[],
sampled_token_probs
=
(
expanded_batch_output
.
sampled_token_probs
[
indices_of_seq_with_bonus_tokens
]
if
expanded_batch_output
.
sampled_token_probs
is
not
None
else
None
),
logprobs
=
(
expanded_batch_output
.
logprobs
[
indices_of_seq_with_bonus_tokens
]
if
expanded_batch_output
.
logprobs
is
not
None
else
None
),
sampled_token_ids
=
(
expanded_batch_output
.
sampled_token_ids
[
indices_of_seq_with_bonus_tokens
]
if
expanded_batch_output
.
sampled_token_ids
is
not
None
else
None
))
for
expanded_batch_output
in
expanded_batch_outputs
]
\ No newline at end of file
vllm/zero_overhead/spec_decode/spec_decode_worker.py
deleted
100644 → 0
View file @
9bf1b213
import
os
import
copy
from
collections
import
defaultdict
from
functools
import
cached_property
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
import
torch
import
torch.nn
as
nn
from
vllm.config
import
ParallelConfig
,
SpeculativeConfig
,
VllmConfig
from
vllm.distributed.communication_op
import
(
broadcast_tensor_dict
,
get_tp_group
,
tensor_model_parallel_gather
)
from
vllm.distributed.parallel_state
import
model_parallel_is_initialized
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.spec_decode_base_sampler
import
(
SpecDecodeBaseSampler
,
SpecDecodeStochasticBaseSampler
)
from
vllm.model_executor.layers.typical_acceptance_sampler
import
(
TypicalAcceptanceSampler
)
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
(
VLLM_INVALID_TOKEN_ID
,
CompletionSequenceGroupOutput
,
ExecuteModelRequest
,
HiddenStates
,
SequenceGroupMetadata
,
get_all_seq_ids_and_request_ids
,
Logits
)
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTreeStyleScorer
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTop1Scorer
from
vllm.spec_decode.spec_decode_worker
import
SpecDecodeWorker
,
prepare_prefill_hidden_states
from
vllm.zero_overhead.spec_decode.batch_expansion
import
ZeroOverheadBatchExpansionTop1Scorer
from
vllm.zero_overhead.utils
import
SpecStepKind
,
record_accepted_token_ids
,
set_spec_step
if
current_platform
.
is_cuda_alike
():
from
vllm.spec_decode.draft_model_runner
import
TP1DraftModelRunner
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeScorer
,
SpeculativeScores
)
from
vllm.spec_decode.medusa_worker
import
MedusaWorker
from
vllm.spec_decode.metrics
import
AsyncMetricsCollector
from
vllm.spec_decode.mlp_speculator_worker
import
MLPSpeculatorWorker
from
vllm.spec_decode.mqa_scorer
import
MQAScorer
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.ngram_worker
import
NGramWorker
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
from
vllm.spec_decode.smaller_tp_proposer_worker
import
SmallerTpProposerWorker
from
vllm.spec_decode.target_model_runner
import
TargetModelRunner
from
vllm.spec_decode.util
import
(
Timer
,
create_logprobs_output
,
create_sequence_group_output
,
get_all_num_logprobs
,
get_sampled_token_logprobs
,
nvtx_range
,
split_batch_by_proposal_len
)
from
vllm.utils
import
async_tensor_h2d
,
resolve_obj_by_qualname
from
vllm.worker.worker_base
import
LoRANotSupportedWorkerBase
,
WorkerBase
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.spec_decode.proposer_worker_base
import
NonLLMProposerWorkerBase
logger
=
init_logger
(
__name__
)
class
ZeroOverheadSpecDecodeWorker
(
SpecDecodeWorker
):
def
init_device
(
self
)
->
None
:
"""Initialize both scorer and proposer models.
"""
# The scorer worker model is initialized first in case the proposer
# model has a smaller TP degree than the target worker.
self
.
scorer_worker
.
init_device
()
self
.
proposer_worker
.
init_device
()
# NOTE(cade): load_model is not part of the WorkerBase interface.
self
.
scorer_worker
.
load_model
()
self
.
proposer_worker
.
load_model
()
if
self
.
_enable_lm_head_weight_load
:
# NOTE(Shangming): gather lm_head weight when tp enabled
target_lm_head_weight
:
torch
.
Tensor
=
tensor_model_parallel_gather
(
self
.
scorer_worker
.
model_runner
.
model_runner
.
model
.
lm_head
.
\
weight
.
data
,
dim
=
0
,
)
self
.
proposer_worker
.
maybe_load_lm_head_weight
(
target_lm_head_weight
)
self
.
_metrics
.
init_tensors
(
self
.
rank
,
device_type
=
self
.
device
)
if
model_parallel_is_initialized
():
self
.
spec_decode_sampler
.
init_tensors
(
get_tp_group
().
local_rank
,
device_type
=
self
.
device
)
else
:
self
.
spec_decode_sampler
.
init_tensors
(
self
.
rank
,
device_type
=
self
.
device
)
scorer_cls
:
Type
[
SpeculativeScorer
]
if
self
.
disable_mqa_scorer
:
scorer_cls
=
ZeroOverheadBatchExpansionTop1Scorer
logger
.
info
(
"[Speculative Decoding] Use batch "
"expansion for scoring proposals."
)
else
:
scorer_cls
=
MQAScorer
logger
.
info
(
"[Speculative Decoding] Use MQA scorer for scoring proposals."
)
if
not
self
.
tree_decoding
:
self
.
scorer
=
scorer_cls
(
scorer_worker
=
self
.
scorer_worker
,
device
=
self
.
device
,
vocab_size
=
self
.
_vocab_size
)
else
:
self
.
scorer
=
BatchExpansionTreeStyleScorer
(
scorer_worker
=
self
.
scorer_worker
,
device
=
self
.
device
,
vocab_size
=
self
.
_vocab_size
)
self
.
_configure_model_sampler_for_spec_decode
()
@
nvtx_range
(
"spec_decode_worker._run_no_spec"
)
def
_run_no_spec
(
self
,
execute_model_req
:
ExecuteModelRequest
,
skip_proposer
:
bool
)
->
List
[
SamplerOutput
]:
"""Run a single generation step without any speculation. The input is
sent to the proposer and scorer model so that the KV cache is consistent
between the two. When skip_proposer is True, the proposer model is
not called, meaning that the kv-cache in proposer for requests is not
updated, so they cannot enable spec decode in the rest decoding.
"""
if
self
.
tree_decoding
and
self
.
kvcache_slot_to_be_moved
is
not
None
:
execute_model_req
.
kvcache_slot_to_be_moved
=
self
.
kvcache_slot_to_be_moved
self
.
kvcache_slot_to_be_moved
=
None
set_spec_step
(
SpecStepKind
.
PREFILL
)
sampler_output
=
self
.
scorer_worker
.
execute_model
(
execute_model_req
)
assert
len
(
sampler_output
)
==
1
sampler_output
=
sampler_output
[
0
]
# Store hidden states from target model execution, BxD.
hidden_states
=
sampler_output
.
hidden_states
if
hidden_states
is
not
None
:
# Only decodes and prefill terminal chunks need a hidden state.
seq_group_meta_with_hidden
=
[
sg
for
sg
in
execute_model_req
.
seq_group_metadata_list
if
sg
.
do_sample
]
if
any
(
seq
.
is_prompt
for
seq
in
seq_group_meta_with_hidden
):
# Drop hidden_states with no prediction (eg non-terminal chunks)
hidden_states
=
hidden_states
[
torch
.
where
(
sampler_output
.
sampled_token_ids
-
VLLM_INVALID_TOKEN_ID
)[
0
]]
# if not skip_proposer:
# if self.previous_hidden_states is None and len(
# seq_group_meta_with_hidden):
# self.previous_hidden_states = HiddenStates(
# hidden_states, seq_group_meta_with_hidden)
# elif self.previous_hidden_states and len(
# seq_group_meta_with_hidden):
# self.previous_hidden_states.update(hidden_states,
# seq_group_meta_with_hidden)
if
self
.
previous_hidden_states
is
None
and
len
(
seq_group_meta_with_hidden
):
self
.
previous_hidden_states
=
HiddenStates
(
hidden_states
,
seq_group_meta_with_hidden
)
elif
self
.
previous_hidden_states
and
len
(
seq_group_meta_with_hidden
):
self
.
previous_hidden_states
.
update
(
hidden_states
,
seq_group_meta_with_hidden
)
# Store logits from target model execution.
if
self
.
tree_decoding
:
logits
=
sampler_output
.
logits
if
logits
is
not
None
:
if
self
.
previous_logits
is
None
:
self
.
previous_logits
=
Logits
(
logits
,
execute_model_req
.
seq_group_metadata_list
)
else
:
self
.
previous_logits
.
update
(
logits
,
execute_model_req
.
seq_group_metadata_list
)
if
not
skip_proposer
:
# We prepare the prefill hidden states here so that there no
# additional complexity in worker for spec_decode vs non_spec_decode
# flow and execute_model doesn't need additional modifications.
execute_model_req
.
previous_hidden_states
=
\
prepare_prefill_hidden_states
(
sampler_output
.
prefill_hidden_states
)
for
i
in
range
(
self
.
_num_spec_prefill_steps
):
execute_model_req
.
spec_step_idx
=
i
self
.
proposer_worker
.
execute_model
(
execute_model_req
)
sampler_output_to_return
=
(
self
.
_serialize_sampler_output_no_logprobs
(
execute_model_req
=
execute_model_req
,
sampler_output
=
sampler_output
)
if
self
.
_disable_logprobs
else
[
sampler_output
])
# Clear device tensors from sampler output. This reduces communication
# overhead when the engine runs in a different process than the workers.
sampler_output
.
sampled_token_probs
=
None
sampler_output
.
sampled_token_ids
=
None
sampler_output
.
logprobs
=
None
return
sampler_output_to_return
@
nvtx_range
(
"spec_decode_worker._verify_tokens"
)
def
_verify_tokens
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
proposal_scores
:
SpeculativeScores
,
proposals
:
SpeculativeProposals
,
max_proposal_len
:
int
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
List
[
List
[
int
]],
List
[
int
]]:
"""Determine which speculative tokens are accepted using the
probabilities of each token according to the proposer and scorer models.
Returns a tuple of Tensors, one for the accepted token ids and one for
the logprobs according to the scoring model.
"""
proposal_lens_list
=
proposals
.
proposal_lens
# vLLM currently only supports proposal lens equal to zero or the batch
# proposal len. This adds some complexity (splitting the batch into spec
# and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens.
(
_
,
spec_indices
),
(
_
,
non_spec_indices
)
=
split_batch_by_proposal_len
(
seq_group_metadata_list
,
proposal_lens_list
)
original_indices
=
spec_indices
+
non_spec_indices
# Get probabilities of target model, including bonus tokens.
if
non_spec_indices
:
proposal_verifier_probs
=
proposal_scores
.
probs
[
spec_indices
]
else
:
proposal_verifier_probs
=
proposal_scores
.
probs
if
self
.
tree_decoding
:
retrieve_indices
=
proposals
.
retrieve_indices
proposal_verifier_probs
=
proposal_verifier_probs
[:,
retrieve_indices
]
# Get non-speculative sampled tokens from target model.
non_spec_token_ids
=
proposal_scores
.
token_ids
[
non_spec_indices
]
# Get bonus tokens from target model.
bonus_token_ids
=
proposal_scores
.
token_ids
[:,
-
1
:]
if
non_spec_indices
:
bonus_token_ids
=
bonus_token_ids
[
spec_indices
,
:]
# Get probabilities according to proposal method.
proposal_probs
=
proposals
.
proposal_probs
if
proposals
.
proposal_probs
is
not
None
else
None
if
proposal_probs
is
not
None
and
non_spec_indices
:
proposal_probs
=
proposal_probs
[
spec_indices
]
# Get proposed tokens.
proposal_token_ids
=
proposals
.
proposal_token_ids
if
non_spec_indices
:
proposal_token_ids
=
proposal_token_ids
[
spec_indices
]
# Get tree buffers.
cart_candidates
=
proposals
.
cart_candidates
if
proposals
.
cart_candidates
is
not
None
else
None
if
cart_candidates
is
not
None
and
non_spec_indices
:
cart_candidates
=
cart_candidates
[
spec_indices
]
# Sampler arguments
sampler_extra_kwargs
:
Dict
[
str
,
Any
]
=
{}
if
self
.
generators
and
isinstance
(
self
.
spec_decode_sampler
,
SpecDecodeStochasticBaseSampler
):
sampler_extra_kwargs
[
"seeded_seqs"
]
=
{
idx
:
self
.
generators
[
sgm
.
request_id
]
for
idx
,
sgm
in
enumerate
(
seq_group_metadata_list
)
if
sgm
.
sampling_params
.
seed
is
not
None
}
if
isinstance
(
self
.
spec_decode_sampler
,
TypicalAcceptanceSampler
):
sampler_extra_kwargs
[
"cart_candidates"
]
=
cart_candidates
sampler_extra_kwargs
[
"best_candidates"
]
=
[]
sampler_extra_kwargs
[
"accept_lengths"
]
=
[]
first_step_flags
=
[]
for
i
,
sgm
in
enumerate
(
seq_group_metadata_list
):
seq
=
next
(
iter
(
sgm
.
seq_data
.
values
()))
first_step_flags
.
append
(
True
if
seq
.
get_first_step_flag
()
else
False
)
sampler_extra_kwargs
[
"first_step_flags"
]
=
first_step_flags
accepted_token_ids
=
self
.
spec_decode_sampler
(
target_with_bonus_probs
=
proposal_verifier_probs
,
bonus_token_ids
=
bonus_token_ids
,
draft_probs
=
proposal_probs
,
draft_token_ids
=
proposal_token_ids
,
**
sampler_extra_kwargs
,
)
# Append output tokens from non-speculative sequences to
# the accepted token ids tensor.
if
not
self
.
tree_decoding
:
non_spec_token_ids
=
non_spec_token_ids
.
expand
(
-
1
,
max_proposal_len
+
1
).
clone
()
else
:
non_spec_token_ids
=
non_spec_token_ids
.
expand
(
-
1
,
max_proposal_len
).
clone
()
non_spec_token_ids
[:,
1
:]
=
-
1
accepted_token_ids
=
torch
.
cat
(
[
accepted_token_ids
,
non_spec_token_ids
])
logprobs
=
proposal_scores
.
logprobs
# Rearrange so that results are in the order of the original seq group
# metadata.
original_indices
=
async_tensor_h2d
(
original_indices
,
torch
.
int32
,
self
.
device
,
True
)
accepted_token_ids
[
original_indices
]
=
accepted_token_ids
.
clone
()
# B x K+1 x D
hidden_states
=
proposal_scores
.
hidden_states
select_indices
=
None
accept_lengths
=
None
select_indices_list
=
[]
if
cart_candidates
is
None
:
if
hidden_states
is
not
None
:
# Only get terminal hidden states for next step
terminal_metadata
=
[
sg
for
sg
in
seq_group_metadata_list
if
sg
.
do_sample
]
# Contract hidden states based on accepted tokens
hs_size
=
hidden_states
.
shape
[
-
1
]
accepted_index
=
accepted_token_ids
+
1
# Convert -1 to 0
accepted_index
=
accepted_index
.
count_nonzero
(
dim
=
1
).
add_
(
-
1
)
# b
# Drop non-terminal prefill chunks hidden states.
hidden_states
=
hidden_states
[
accepted_index
!=
VLLM_INVALID_TOKEN_ID
]
accepted_index
=
accepted_index
[
accepted_index
!=
VLLM_INVALID_TOKEN_ID
]
assert
len
(
accepted_index
)
==
hidden_states
.
shape
[
0
]
==
len
(
terminal_metadata
)
index
=
accepted_index
[:,
None
,
None
].
expand
(
-
1
,
1
,
hs_size
)
# b x 1 x d
second_last_token_hidden_states
=
hidden_states
[:,
-
2
]
# b x d
hidden_states
=
hidden_states
.
gather
(
1
,
index
).
squeeze
(
1
)
# b x d
# Store hidden states from target model for subsequent decode step
self
.
previous_hidden_states
=
HiddenStates
(
hidden_states
,
terminal_metadata
,
second_last_token_hidden_states
)
else
:
retrieve_indices
=
proposals
.
retrieve_indices
batch_size
=
len
(
seq_group_metadata_list
)
best_candidates
=
sampler_extra_kwargs
[
"best_candidates"
]
accept_lengths
=
sampler_extra_kwargs
[
"accept_lengths"
]
# Contract hidden states based on accepted tokens
hs_size
=
hidden_states
.
shape
[
-
1
]
hidden_states
=
hidden_states
.
view
(
batch_size
,
-
1
,
hs_size
)
# Store logits from target model for subsequent proposal
logits
=
proposal_scores
.
logits
logits
=
logits
.
view
(
batch_size
,
-
1
,
logits
.
shape
[
-
1
])
logits
=
logits
[:,
retrieve_indices
]
# [batch_size, retrieve_size, max_depth, vocab_size]
previous_logits_list
=
[]
previous_hidden_state_list
=
[]
retrieve_indices
=
retrieve_indices
.
cpu
()
for
i
in
range
(
batch_size
):
logit
=
logits
[
i
,
best_candidates
[
i
],
accept_lengths
[
i
]].
unsqueeze
(
0
)
previous_logits_list
.
append
(
logit
)
select_indices
=
retrieve_indices
[
best_candidates
[
i
],
:
accept_lengths
[
i
]
+
1
]
hidden_state
=
hidden_states
[
i
,
select_indices
[
-
1
]].
unsqueeze
(
0
)
select_indices_list
.
append
(
select_indices
)
previous_hidden_state_list
.
append
(
hidden_state
)
logits
=
torch
.
cat
(
previous_logits_list
,
dim
=
0
)
self
.
previous_logits
=
Logits
(
logits
,
seq_group_metadata_list
)
hidden_states
=
torch
.
cat
(
previous_hidden_state_list
,
dim
=
0
)
# [batch_size, 1, vocab_size]
self
.
previous_hidden_states
=
HiddenStates
(
hidden_states
,
seq_group_metadata_list
,)
return
accepted_token_ids
,
logprobs
,
select_indices_list
,
accept_lengths
def
_create_output_sampler_list
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
accepted_token_ids
:
torch
.
Tensor
,
# shape: [batch_size, k+1]
target_logprobs
:
torch
.
Tensor
,
# shape: [batch_size, k+1, vocab_size]
prompt_logprobs
:
Optional
[
torch
.
Tensor
],
# shape: [nprompt_tokens, vocab_size]
k
:
int
,
stage_times
:
Tuple
[
float
,
float
,
float
],
)
->
List
[
SamplerOutput
]:
"""Given the accepted token ids, create a list of SamplerOutput.
The output is padded with -1 tokens such that each sequence has
the same number of outputs.
"""
batch_size
,
num_steps
=
accepted_token_ids
.
shape
accepted_token_ids_by_step
=
accepted_token_ids
.
transpose
(
0
,
1
)
if
self
.
_disable_logprobs
:
# We are skipping the logprobs. Hence don't serialize the
# logprobs related tensors from the GPU. Instead create
# empty/dummy lists.
(
accepted_token_id_ranks_by_step
,
accepted_token_id_logprobs_by_step
,
topk_logprobs_by_step
,
topk_indices_by_step
)
=
\
self
.
_create_dummy_logprob_lists
(
batch_size
,
num_steps
,
self
.
scorer_worker
.
model_config
.
max_logprobs
)
else
:
# Organize input tensors by step instead of by sequence.
target_logprobs_by_step
=
target_logprobs
.
transpose
(
0
,
1
)
# Serialize all tensors into Python lists.
(
accepted_token_id_ranks_by_step
,
accepted_token_id_logprobs_by_step
,
topk_logprobs_by_step
,
topk_indices_by_step
)
=
\
self
.
_create_logprob_lists_from_tensors
(
target_logprobs_by_step
,
accepted_token_ids_by_step
,
self
.
scorer_worker
.
model_config
.
max_logprobs
)
# Get the sequence ids and num_logprobs (sampling parameter) in the
# batch.
seq_ids
,
request_ids_seq_ids_mapping
=
get_all_seq_ids_and_request_ids
(
seq_group_metadata_list
)
num_logprobs_per_seq
=
get_all_num_logprobs
(
seq_group_metadata_list
)
# Serialize tensor to CPU Python list.
#accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
record_accepted_token_ids
(
accepted_token_ids
,
seq_ids
)
# Construct the output on a per-step, per-sequence basis.
# Non-terminal prefill chunks will end up here as rows with just -1s
# i.e mixed-batch [[-1, 1576], [-1, 29884], [-1, -1], [-1, -1]] while
# terminal chunks will only have one generated token at time 0.
sampler_output_list
:
List
[
SamplerOutput
]
=
[]
# Prefills are not multi-step (return at most 1 token), in order to
# avoid padding or repetition to fit decodes, we separate them.
for
i
,
sg
in
enumerate
(
seq_group_metadata_list
):
if
not
sg
.
is_prompt
:
# Requests are ordered as prefills|decodes=>no more prefills.
break
num_logprobs
=
num_logprobs_per_seq
[
i
]
seq_kwargs
=
dict
(
token_id
=-
1
,
token_id_logprob_rank
=
0
,
token_id_logprob
=-
float
(
'inf'
),
topk_token_ids
=
[
-
1
]
*
num_logprobs
,
topk_logprobs
=
[
-
float
(
'inf'
)]
*
num_logprobs
,
seq_id
=
seq_ids
[
i
])
# Terminal chunk, has token.
if
sg
.
do_sample
:
seq_kwargs
.
update
(
dict
(
token_id
=
accepted_token_ids
[
i
][
0
].
item
(),
token_id_logprob_rank
=
accepted_token_id_ranks_by_step
[
0
][
i
],
token_id_logprob
=
accepted_token_id_logprobs_by_step
[
0
]
[
i
],
topk_token_ids
=
topk_indices_by_step
[
0
][
i
]
[:
num_logprobs
],
# output only so step is 0
topk_logprobs
=
topk_logprobs_by_step
[
0
][
i
]
[:
num_logprobs
],
))
needs_plogs
=
(
sg
.
sampling_params
.
prompt_logprobs
and
sg
.
sampling_params
.
prompt_logprobs
>
0
)
plogs
=
None
if
prompt_logprobs
is
not
None
:
# Even non-terminal prompt chunks can have logprobs here.
plogs
=
prompt_logprobs
[
i
]
elif
needs_plogs
:
# Prompt logprobs are requested but `_disable_logprobs` is set.
seq_data
=
next
(
iter
(
sg
.
seq_data
.
values
()))
# Get only the tokens in this chunk!
prompt_token_ids
=
seq_data
.
get_prompt_token_ids
()
prompt_token_ids
=
prompt_token_ids
[
seq_data
.
_num_computed_tokens
:
seq_data
.
_num_computed_tokens
+
sg
.
token_chunk_size
]
is_first_chunk
=
seq_data
.
_num_computed_tokens
==
0
# There's no prob generated for the first token in a sequence.
if
is_first_chunk
:
prompt_token_ids
=
prompt_token_ids
[
1
:]
plogs
=
[
create_logprobs_output
(
token_id
=
p_token_id
,
token_id_logprob_rank
=-
1
,
token_id_logprob
=
0.0
,
topk_token_ids
=
[],
topk_logprobs
=
[],
)
for
p_token_id
in
prompt_token_ids
]
seq_kwargs
.
update
(
dict
(
prompt_logprobs
=
plogs
))
sampler_output_list
.
append
(
SamplerOutput
(
outputs
=
[
create_sequence_group_output
(
**
seq_kwargs
)]))
# type: ignore
# Decodes, create one SamplerOutput per-step (at most K+1).
for
step_index
in
range
(
num_steps
):
# if all(token_id == -1 for sg, token_id in zip(
# seq_group_metadata_list,
# accepted_token_ids_by_step[step_index])
# if not sg.is_prompt):
# break
step_output_token_ids
:
List
[
CompletionSequenceGroupOutput
]
=
[]
for
sequence_index
in
range
(
batch_size
):
seq_meta
=
seq_group_metadata_list
[
sequence_index
]
# Prompts already processed above.
if
seq_meta
.
is_prompt
:
continue
# Each sequence may have a different num_logprobs; retrieve it.
num_logprobs
=
num_logprobs_per_seq
[
sequence_index
]
step_output_token_ids
.
append
(
create_sequence_group_output
(
token_id
=
0
,
token_id_logprob_rank
=
accepted_token_id_ranks_by_step
[
step_index
][
sequence_index
],
token_id_logprob
=
accepted_token_id_logprobs_by_step
[
step_index
][
sequence_index
],
seq_id
=
seq_ids
[
sequence_index
],
topk_token_ids
=
topk_indices_by_step
[
step_index
]
[
sequence_index
][:
num_logprobs
],
topk_logprobs
=
topk_logprobs_by_step
[
step_index
]
[
sequence_index
][:
num_logprobs
],
))
sampler_output_list
.
append
(
SamplerOutput
(
outputs
=
step_output_token_ids
))
# Populate the data structures needed to keep track of sequences with
# bonus tokens.
self
.
_track_sequences_with_bonus_tokens
(
seq_ids
,
request_ids_seq_ids_mapping
,
accepted_token_ids_by_step
)
maybe_rejsample_metrics
=
(
self
.
_metrics
.
maybe_collect_rejsample_metrics
(
k
))
if
maybe_rejsample_metrics
is
not
None
and
sampler_output_list
:
sampler_output_list
[
0
].
spec_decode_worker_metrics
=
maybe_rejsample_metrics
# Log time spent in each stage periodically.
# This is periodic because the rejection sampler emits metrics
# periodically.
self
.
_maybe_log_stage_times
(
*
stage_times
)
# First `n_prefills` entries will contain prefills SamplerOutput when
# chunked prefill is enabled, the rest is decodes in multi-step format.
return
sampler_output_list
def
_track_sequences_with_bonus_tokens
(
self
,
seq_ids
:
List
[
int
],
request_ids_seq_ids_mapping
:
Dict
[
str
,
Set
[
int
]],
accepted_token_ids_by_step
:
List
[
List
[
int
]]):
"""
Updates the internal data structures which keep track of sequences
which have been assigned bonus tokens in their last forward pass.
"""
for
seq_index
,
seq_id
in
enumerate
(
seq_ids
):
# last_token_id = accepted_token_ids_by_step[-1][seq_index]
# if last_token_id == -1:
# self._seq_with_bonus_token_in_last_step.discard(seq_id)
# else:
self
.
_seq_with_bonus_token_in_last_step
.
add
(
seq_id
)
for
request_id
,
sequences
in
request_ids_seq_ids_mapping
.
items
():
self
.
_request_id_seq_id_mapping
[
request_id
].
update
(
sequences
)
\ No newline at end of file
vllm/zero_overhead/spec_decode/top1_proproser.py
deleted
100644 → 0
View file @
9bf1b213
import
os
from
typing
import
List
,
Optional
,
Set
,
Tuple
import
torch
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
SequenceGroupMetadata
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeProposer
)
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.spec_decode.util
import
sampler_output_to_torch
from
vllm.utils
import
async_tensor_h2d
from
vllm.zero_overhead.utils
import
record_proposal_lens_list
class
ZeroOverheadTop1Proposer
(
Top1Proposer
):
def
_merge_outputs
(
self
,
batch_size
:
int
,
proposal_len
:
int
,
maybe_sampler_output
:
Optional
[
List
[
SamplerOutput
]],
proposal_lens
:
List
[
int
],
nonzero_proposal_len_indices
:
List
[
int
],
sampler_transposed
:
bool
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""After speculations are produced, merge the speculation results with
the skipped sequences.
"""
if
maybe_sampler_output
is
None
:
# If no speculative tokens, the sampler output will be None.
# In this case we return empty proposals.
proposal_tokens
=
torch
.
tensor
(
-
1
,
dtype
=
torch
.
long
,
device
=
self
.
_device
).
expand
(
batch_size
,
proposal_len
)
proposal_probs
=
torch
.
tensor
(
0
,
dtype
=
torch
.
float32
,
device
=
self
.
_device
).
expand
(
batch_size
,
proposal_len
,
self
.
_vocab_size
)
proposal_lens_tensor
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
self
.
_device
).
expand
(
len
(
proposal_lens
))
return
proposal_tokens
,
proposal_probs
,
proposal_lens_tensor
sampler_output
=
maybe_sampler_output
proposal_tokens
,
proposal_probs
,
*
_
=
sampler_output_to_torch
(
sampler_output
,
sampler_transposed
)
proposal_lens_list
=
[
0
for
i
in
range
(
batch_size
)]
for
indices
in
nonzero_proposal_len_indices
:
proposal_lens_list
[
indices
]
=
proposal_len
record_proposal_lens_list
(
proposal_lens_list
)
nonzero_proposal_len_indices
=
async_tensor_h2d
(
nonzero_proposal_len_indices
,
torch
.
int32
,
self
.
_device
,
True
)
# Now, reformat the output GPU tensors such that each sequence has
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
entire_proposal_tokens
=
proposal_tokens
.
new_full
(
size
=
(
batch_size
,
*
proposal_tokens
.
shape
[
1
:]),
fill_value
=-
1
,
)
entire_proposal_tokens
[
nonzero_proposal_len_indices
]
=
proposal_tokens
entire_proposal_probs
=
proposal_probs
.
new_zeros
(
batch_size
,
*
proposal_probs
.
shape
[
1
:],
)
entire_proposal_probs
[
nonzero_proposal_len_indices
]
=
proposal_probs
proposal_tokens
,
proposal_probs
=
(
entire_proposal_tokens
,
entire_proposal_probs
,
)
proposal_lens_tensor
=
async_tensor_h2d
(
proposal_lens_list
,
torch
.
long
,
self
.
_device
,
True
)
return
proposal_tokens
,
proposal_probs
,
proposal_lens_tensor
\ No newline at end of file
vllm/zero_overhead/stop_check.py
deleted
100644 → 0
View file @
9bf1b213
from
typing
import
Optional
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.lora.request
import
LoRARequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
SequenceStatus
from
vllm.zero_overhead.sequence
import
ZeroOverheadSequence
class
ZeroOverheadStopChecker
(
StopChecker
):
def
__init__
(
self
,
max_model_len
,
get_tokenizer_for_seq
):
super
().
__init__
(
max_model_len
,
get_tokenizer_for_seq
)
def
maybe_stop_sequence
(
self
,
seq
:
ZeroOverheadSequence
,
new_char_count
:
int
,
sampling_params
:
SamplingParams
,
lora_req
:
Optional
[
LoRARequest
]
=
None
,
)
->
None
:
"""Stop the finished sequences.
new_char_count is the number of chars added to the
sequence's output text for the newly generated token
"""
# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
if
seq
.
zero_overhead_get_output_len
()
<
sampling_params
.
min_tokens
:
return
# Check if the sequence has generated the EOS token.
if
((
not
sampling_params
.
ignore_eos
)
and
seq
.
zero_overhead_get_last_token_id
()
==
seq
.
eos_token_id
):
# Remove the last EOS token unless explicitly specified
# This prevents unintended exposure of the EOS token
if
new_char_count
and
(
not
sampling_params
.
include_stop_str_in_output
):
seq
.
output_text
=
seq
.
output_text
[:
-
new_char_count
]
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
return
# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id
=
seq
.
zero_overhead_get_last_token_id
()
if
last_token_id
in
(
sampling_params
.
stop_token_ids
or
()):
if
new_char_count
and
(
not
sampling_params
.
include_stop_str_in_output
):
# Remove last token
seq
.
output_text
=
seq
.
output_text
[:
-
new_char_count
]
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
stop_reason
=
last_token_id
return
# Check if any stop strings are matched.
stop
=
self
.
check_stop_strings
(
seq
.
output_text
,
new_char_count
,
sampling_params
.
stop
,
sampling_params
.
include_stop_str_in_output
)
if
stop
is
not
None
:
stop_str
,
truncate_to
=
stop
if
truncate_to
!=
-
1
:
seq
.
output_text
=
seq
.
output_text
[:
truncate_to
]
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
stop_reason
=
stop_str
return
# Check if the sequence has reached max_model_len.
if
seq
.
zero_overhead_get_len
()
>
self
.
_get_max_model_len
(
lora_req
):
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
return
# Check if the sequence has reached max_tokens.
if
seq
.
zero_overhead_get_output_len
()
==
sampling_params
.
max_tokens
:
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
return
\ No newline at end of file
vllm/zero_overhead/tokenizer.py
deleted
100644 → 0
View file @
9bf1b213
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
VLLM_INVALID_TOKEN_ID
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.detokenizer_utils
import
convert_prompt_ids_to_tokens
,
detokenize_incrementally
from
vllm.zero_overhead.sequence
import
ZeroOverheadSequence
class
ZeroOverheadDetokenizer
(
Detokenizer
):
def
__init__
(
self
,
tokenizer_group
):
super
().
__init__
(
tokenizer_group
)
def
decode_sequence_inplace
(
self
,
seq
:
ZeroOverheadSequence
,
prms
:
SamplingParams
)
->
int
:
"""Decodes the new token for a sequence. In-place operation.
Args:
seq: The sequence to decode.
prms: The sampling parameters used to generate the sequence.
Returns:
The number of characters added to the output text.
"""
eff_length
=
seq
.
get_prompt_len
()
+
seq
.
effective_output_len
all_input_ids
=
seq
.
get_token_ids
()[
:
eff_length
]
token_id_generated_this_iteration
=
all_input_ids
[
-
1
]
tokenizer
=
self
.
get_tokenizer_for_seq
(
seq
)
# Convert prompt token IDs to tokens if necessary.
# Do it here so that we don't have to repeat this
# computation for each logprob.
if
seq
.
tokens
is
None
:
(
seq
.
tokens
,
seq
.
prefix_offset
,
seq
.
read_offset
)
=
convert_prompt_ids_to_tokens
(
tokenizer
=
tokenizer
,
prompt_ids
=
all_input_ids
[:
-
1
],
skip_special_tokens
=
prms
.
skip_special_tokens
,
)
(
new_tokens
,
new_decoded_token_text
,
prefix_offset
,
read_offset
)
=
detokenize_incrementally
(
tokenizer
=
tokenizer
,
all_input_ids
=
all_input_ids
,
prev_tokens
=
seq
.
tokens
,
prefix_offset
=
seq
.
prefix_offset
,
read_offset
=
seq
.
read_offset
,
skip_special_tokens
=
prms
.
skip_special_tokens
,
spaces_between_special_tokens
=
prms
.
spaces_between_special_tokens
,
)
# Decode logprobs
logprobs
=
seq
.
output_logprobs
[
-
1
]
if
logprobs
:
previous_tokens
=
all_input_ids
[:
-
1
]
for
token_id
,
sample_logprob
in
logprobs
.
items
():
# If the token was generated this iteration,
# use the provided text.
if
token_id
==
token_id_generated_this_iteration
:
sample_logprob
.
decoded_token
=
new_decoded_token_text
continue
if
(
sample_logprob
.
decoded_token
is
None
and
token_id
!=
VLLM_INVALID_TOKEN_ID
):
all_input_ids_with_logprob
=
previous_tokens
+
[
token_id
]
(
_
,
new_text
,
_
,
_
)
=
detokenize_incrementally
(
tokenizer
=
tokenizer
,
all_input_ids
=
all_input_ids_with_logprob
,
prev_tokens
=
seq
.
tokens
,
prefix_offset
=
seq
.
prefix_offset
,
read_offset
=
seq
.
read_offset
,
skip_special_tokens
=
prms
.
skip_special_tokens
,
spaces_between_special_tokens
=
prms
.
spaces_between_special_tokens
,
)
sample_logprob
.
decoded_token
=
new_text
seq
.
tokens
.
extend
(
new_tokens
)
seq
.
prefix_offset
=
prefix_offset
seq
.
read_offset
=
read_offset
seq
.
output_text
+=
new_decoded_token_text
return
len
(
new_decoded_token_text
)
\ No newline at end of file
vllm/zero_overhead/utils.py
deleted
100644 → 0
View file @
9bf1b213
from
enum
import
Enum
import
os
import
torch
import
vllm.envs
as
envs
zero_no_thread
=
os
.
environ
.
get
(
'VLLM_ZERO_NO_THREAD'
)
==
'1'
def
is_zero_no_thread
():
return
zero_no_thread
and
envs
.
VLLM_ZERO_OVERHEAD
class
SpecStepKind
(
Enum
):
KIND_DEFAULT
=
0
PREFILL
=
1
FIRST_PROPOSAL
=
2
OTHER_PROPOSAL
=
3
SCORE_DECODE
=
4
class
ZeroOverheadSpecContext
():
def
__init__
(
self
):
self
.
step_kind
=
SpecStepKind
.
KIND_DEFAULT
self
.
last_step
=
SpecStepKind
.
KIND_DEFAULT
self
.
proposal_lens_list
=
None
self
.
proposal_token_ids
=
None
self
.
accepted_token_ids
=
None
self
.
accepted_seq_ids
=
None
spec_context
=
ZeroOverheadSpecContext
()
def
set_spec_step
(
_step
):
global
spec_context
spec_context
.
last_step
=
spec_context
.
step_kind
spec_context
.
step_kind
=
_step
def
get_spec_step
():
return
spec_context
.
step_kind
def
get_spec_last_step
():
return
spec_context
.
last_step
def
record_proposal_lens_list
(
list
):
global
spec_context
spec_context
.
proposal_lens_list
=
list
def
get_proposal_lens_list
():
return
spec_context
.
proposal_lens_list
def
record_proposal_token_ids
(
tensor
):
global
spec_context
spec_context
.
proposal_token_ids
=
tensor
def
get_proposal_token_ids
():
return
spec_context
.
proposal_token_ids
def
record_accepted_token_ids
(
tensor
,
seq_ids
):
global
spec_context
spec_context
.
accepted_token_ids
=
tensor
spec_context
.
accepted_seq_ids
=
seq_ids
def
get_accepted_token_ids
():
return
spec_context
.
accepted_token_ids
,
spec_context
.
accepted_seq_ids
# 零消耗调度不在默认流上推理,用以规避runtime引入的内存申请流同步问题。
alloc_stream
=
{}
def
zero_overhead_stream
(
target_device
):
"""Asynchronously create a tensor and copy it from host to device."""
if
target_device
not
in
alloc_stream
.
keys
():
alloc_stream
[
target_device
]
=
torch
.
cuda
.
Stream
(
device
=
target_device
)
return
alloc_stream
[
target_device
]
vllm/zero_overhead/v1/core.py
deleted
100644 → 0
View file @
9bf1b213
import
torch
from
collections
import
defaultdict
from
typing
import
Optional
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.scheduler
import
Scheduler
from
vllm.v1.engine
import
EngineCoreOutput
,
EngineCoreOutputs
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.spec_decode.metrics
import
SpecDecodingStats
from
vllm.zero_overhead.v1.outputs
import
ZeroV1ModelRunnerOutput
requsets_valid_token_len
=
{}
def
check_stop
(
request
:
Request
,
max_model_len
:
int
,
pooler_output
:
Optional
[
torch
.
Tensor
]
=
None
,
use_valid_token_len
:
bool
=
False
)
->
bool
:
if
use_valid_token_len
:
if
request
.
request_id
not
in
requsets_valid_token_len
:
requsets_valid_token_len
[
request
.
request_id
]
=
0
return
False
valid_output_len
=
requsets_valid_token_len
[
request
.
request_id
]
else
:
valid_output_len
=
request
.
num_output_tokens
valid_num_tokens
=
request
.
num_prompt_tokens
+
valid_output_len
if
(
valid_num_tokens
>=
max_model_len
or
valid_output_len
>=
request
.
max_tokens
):
request
.
status
=
RequestStatus
.
FINISHED_LENGTH_CAPPED
return
True
if
request
.
pooling_params
:
if
pooler_output
is
not
None
:
request
.
status
=
RequestStatus
.
FINISHED_STOPPED
return
True
return
False
sampling_params
=
request
.
sampling_params
assert
sampling_params
is
not
None
last_token_id
=
request
.
output_token_ids
[
valid_output_len
-
1
]
if
(
not
sampling_params
.
ignore_eos
and
last_token_id
==
request
.
eos_token_id
):
request
.
status
=
RequestStatus
.
FINISHED_STOPPED
return
True
if
last_token_id
in
(
sampling_params
.
stop_token_ids
or
()):
request
.
status
=
RequestStatus
.
FINISHED_STOPPED
request
.
stop_reason
=
last_token_id
return
True
return
False
def
zero_overhead_update_from_output
(
scheduler
:
Scheduler
,
scheduler_output
:
SchedulerOutput
,
model_runner_output
:
ZeroV1ModelRunnerOutput
):
global
requsets_valid_token_len
sampled_token_ids
=
model_runner_output
.
sampled_token_ids
spec_token_ids
=
model_runner_output
.
spec_token_ids
logprobs
=
model_runner_output
.
logprobs
prompt_logprobs_dict
=
model_runner_output
.
prompt_logprobs_dict
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
pooler_outputs
=
model_runner_output
.
pooler_output
num_nans_in_logits
=
model_runner_output
.
num_nans_in_logits
new_running
:
list
[
Request
]
=
[]
outputs
:
dict
[
int
,
list
[
EngineCoreOutput
]]
=
defaultdict
(
list
)
spec_decoding_stats
:
Optional
[
SpecDecodingStats
]
=
None
# fix last model out in zero overhead
if
model_runner_output
.
fix_req_ids
is
not
None
:
for
req_idx
,
req_id
in
enumerate
(
model_runner_output
.
fix_req_ids
):
if
req_id
not
in
scheduler
.
requests
:
continue
request
=
scheduler
.
requests
[
req_id
]
generated_token_ids
=
model_runner_output
.
fix_sampled_token_ids
[
req_idx
]
if
req_id
not
in
requsets_valid_token_len
:
requsets_valid_token_len
[
req_id
]
=
0
valid_output_len
=
requsets_valid_token_len
[
req_id
]
fix_offset
=
valid_output_len
-
request
.
num_output_tokens
if
isinstance
(
generated_token_ids
,
int
):
request
.
_output_token_ids
[
fix_offset
]
=
generated_token_ids
request
.
_all_token_ids
[
fix_offset
]
=
generated_token_ids
requsets_valid_token_len
[
req_id
]
+=
1
generated_token_ids
=
[
generated_token_ids
]
else
:
valid_output_end
=
valid_output_len
+
len
(
generated_token_ids
)
-
request
.
num_output_tokens
if
valid_output_end
==
0
:
request
.
_output_token_ids
[
fix_offset
:
]
=
generated_token_ids
request
.
_all_token_ids
[
fix_offset
:
]
=
generated_token_ids
else
:
request
.
_output_token_ids
[
fix_offset
:
valid_output_end
]
=
generated_token_ids
request
.
_all_token_ids
[
fix_offset
:
valid_output_end
]
=
generated_token_ids
requsets_valid_token_len
[
req_id
]
+=
len
(
generated_token_ids
)
stopped
=
False
new_logprobs
=
None
new_token_ids
=
generated_token_ids
kv_transfer_params
=
None
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
for
num_new
,
output_token_id
in
enumerate
(
new_token_ids
,
1
):
stopped
=
check_stop
(
request
,
scheduler
.
max_model_len
,
True
)
if
stopped
:
kv_transfer_params
=
scheduler
.
_free_request
(
request
)
del
new_token_ids
[
num_new
:]
# Trim new tokens if needed.
break
pooler_output
=
None
if
pooler_outputs
:
pooler_output
=
pooler_outputs
[
req_idx
]
stopped
=
check_stop
(
request
,
scheduler
.
max_model_len
,
pooler_output
,
True
)
if
stopped
:
kv_transfer_params
=
scheduler
.
_free_request
(
request
)
# Extract sample logprobs if needed.
if
request
.
sampling_params
is
not
None
\
and
request
.
sampling_params
.
logprobs
is
not
None
and
logprobs
:
# NOTE: once we support N tokens per step (spec decode),
# the outer lists can be of length > 1.
new_logprobs
=
logprobs
.
slice
(
req_idx
,
req_idx
+
1
)
if
new_token_ids
and
scheduler
.
structured_output_manager
.
should_advance
(
request
):
# NOTE: structured_output_request
# should not be None if use_structured_output, we have
# check above, so safe to ignore type warning
request
.
structured_output_request
.
grammar
.
accept_tokens
(
# type: ignore[union-attr]
req_id
,
new_token_ids
)
# spec_token_ids comes from the model runner output
if
num_nans_in_logits
is
not
None
and
req_id
in
num_nans_in_logits
:
request
.
num_nans_in_logits
=
num_nans_in_logits
[
req_id
]
# Get prompt logprobs for this request.
prompt_logprobs_tensors
=
prompt_logprobs_dict
.
get
(
req_id
)
if
new_token_ids
or
pooler_output
is
not
None
\
or
kv_transfer_params
:
# Add EngineCoreOutput for this Request.
outputs
[
request
.
client_index
].
append
(
EngineCoreOutput
(
request_id
=
req_id
,
new_token_ids
=
new_token_ids
,
finish_reason
=
request
.
get_finished_reason
(),
new_logprobs
=
new_logprobs
,
new_prompt_logprobs_tensors
=
prompt_logprobs_tensors
,
pooling_output
=
pooler_output
,
stop_reason
=
request
.
stop_reason
,
events
=
request
.
take_events
(),
kv_transfer_params
=
kv_transfer_params
,
num_cached_tokens
=
request
.
num_cached_tokens
,
))
else
:
assert
not
prompt_logprobs_tensors
# fix last model out in zero overhead
if
model_runner_output
.
fix_draft_req_ids
is
not
None
:
for
req_idx
,
req_id
in
enumerate
(
model_runner_output
.
fix_draft_req_ids
):
if
req_id
not
in
scheduler
.
requests
:
continue
request
=
scheduler
.
requests
[
req_id
]
# Add newly generated spec token ids to the request.
if
model_runner_output
.
fix_draft_tokens_ids
is
not
None
:
if
scheduler
.
structured_output_manager
.
should_advance
(
request
):
metadata
=
request
.
structured_output_request
# Needs to happen after new_token_ids are accepted.
request
.
spec_token_ids
=
metadata
.
grammar
.
validate_tokens
(
# type: ignore[union-attr]
model_runner_output
.
fix_draft_tokens_ids
[
req_idx
])
else
:
request
.
spec_token_ids
=
model_runner_output
.
fix_draft_tokens_ids
[
req_idx
]
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
# loop can be a performance bottleneck. We should do our best to avoid
# expensive operations inside the loop.
for
request
in
scheduler
.
running
:
req_id
=
request
.
request_id
if
request
.
is_finished
():
if
req_id
in
requsets_valid_token_len
:
requsets_valid_token_len
.
pop
(
req_id
)
continue
num_tokens_scheduled
=
num_scheduled_tokens
.
get
(
req_id
,
0
)
if
num_tokens_scheduled
==
0
:
# The request was not scheduled in this step.
new_running
.
append
(
request
)
continue
req_index
=
model_runner_output
.
req_id_to_index
[
req_id
]
generated_token_ids
=
sampled_token_ids
[
req_index
]
if
sampled_token_ids
else
[]
scheduled_spec_token_ids
=
(
scheduler_output
.
scheduled_spec_decode_tokens
.
get
(
req_id
))
if
scheduled_spec_token_ids
:
# num_computed_tokens represents the number of tokens
# processed in the current step, considering scheduled
# tokens and rejections. If some tokens are rejected,
# num_computed_tokens is decreased by the number of rejected
# tokens, where is given by:
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
num_tokens_rejected
=
(
len
(
scheduled_spec_token_ids
)
+
1
-
len
(
generated_token_ids
))
request
.
num_computed_tokens
-=
num_tokens_rejected
spec_decoding_stats
=
scheduler
.
make_spec_decoding_stats
(
spec_decoding_stats
,
num_draft_tokens
=
len
(
scheduled_spec_token_ids
),
num_accepted_tokens
=
len
(
generated_token_ids
)
-
1
)
# NOTE(woosuk): This has to be executed after updating
# `request.num_computed_tokens`.
if
request
.
has_encoder_inputs
:
scheduler
.
_free_encoder_inputs
(
request
)
stopped
=
False
new_logprobs
=
None
new_token_ids
=
generated_token_ids
kv_transfer_params
=
None
# Append generated tokens and check for stop. Note that if
# a request is still being prefilled, we expect the model runner
# to return empty token ids for the request.
for
num_new
,
output_token_id
in
enumerate
(
new_token_ids
,
1
):
request
.
append_output_token_ids
(
output_token_id
)
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
if
model_runner_output
.
is_output_valid
:
stopped
=
check_stop
(
request
,
scheduler
.
max_model_len
,
False
)
if
stopped
:
kv_transfer_params
=
scheduler
.
_free_request
(
request
)
del
new_token_ids
[
num_new
:]
# Trim new tokens if needed.
break
pooler_output
=
None
if
pooler_outputs
:
if
model_runner_output
.
is_output_valid
:
pooler_output
=
pooler_outputs
[
req_index
]
stopped
=
check_stop
(
request
,
scheduler
.
max_model_len
,
pooler_output
,
False
)
if
stopped
:
kv_transfer_params
=
scheduler
.
_free_request
(
request
)
# Extract sample logprobs if needed.
if
request
.
sampling_params
is
not
None
\
and
request
.
sampling_params
.
logprobs
is
not
None
and
logprobs
:
# NOTE: once we support N tokens per step (spec decode),
# the outer lists can be of length > 1.
new_logprobs
=
logprobs
.
slice
(
req_index
,
req_index
+
1
)
if
new_token_ids
and
scheduler
.
structured_output_manager
.
should_advance
(
request
):
# NOTE: structured_output_request
# should not be None if use_structured_output, we have
# check above, so safe to ignore type warning
request
.
structured_output_request
.
grammar
.
accept_tokens
(
# type: ignore[union-attr]
req_id
,
new_token_ids
)
# spec_token_ids comes from the model runner output
if
num_nans_in_logits
is
not
None
and
req_id
in
num_nans_in_logits
:
request
.
num_nans_in_logits
=
num_nans_in_logits
[
req_id
]
# Add newly generated spec token ids to the request.
if
spec_token_ids
is
not
None
:
if
scheduler
.
structured_output_manager
.
should_advance
(
request
):
metadata
=
request
.
structured_output_request
# Needs to happen after new_token_ids are accepted.
request
.
spec_token_ids
=
metadata
.
grammar
.
validate_tokens
(
# type: ignore[union-attr]
spec_token_ids
[
req_index
])
else
:
request
.
spec_token_ids
=
spec_token_ids
[
req_index
]
if
model_runner_output
.
is_output_valid
:
# # Get prompt logprobs for this request.
prompt_logprobs_tensors
=
prompt_logprobs_dict
.
get
(
req_id
)
if
new_token_ids
or
pooler_output
is
not
None
\
or
kv_transfer_params
:
# Add EngineCoreOutput for this Request.
outputs
[
request
.
client_index
].
append
(
EngineCoreOutput
(
request_id
=
req_id
,
new_token_ids
=
new_token_ids
,
finish_reason
=
request
.
get_finished_reason
(),
new_logprobs
=
new_logprobs
,
new_prompt_logprobs_tensors
=
prompt_logprobs_tensors
,
pooling_output
=
pooler_output
,
stop_reason
=
request
.
stop_reason
,
events
=
request
.
take_events
(),
kv_transfer_params
=
kv_transfer_params
,
num_cached_tokens
=
request
.
num_cached_tokens
,
))
if
stopped
:
if
req_id
in
requsets_valid_token_len
:
requsets_valid_token_len
.
pop
(
req_id
)
else
:
new_running
.
append
(
request
)
scheduler
.
running
=
new_running
# KV Connector: update state for finished KV Transfers.
scheduler
.
_update_from_kv_xfer_finished
(
model_runner_output
)
# Create EngineCoreOutputs for all clients that have requests with
# outputs in this step.
engine_core_outputs
=
{
client_index
:
EngineCoreOutputs
(
outputs
=
outs
)
for
client_index
,
outs
in
outputs
.
items
()
}
finished_req_ids
=
scheduler
.
finished_req_ids_dict
if
finished_req_ids
:
# Include ids of requests that finished since last outputs
# were sent.
for
client_index
,
finished_set
in
finished_req_ids
.
items
():
# Set finished request set in EngineCoreOutputs for this client.
if
(
eco
:
=
engine_core_outputs
.
get
(
client_index
))
is
not
None
:
eco
.
finished_requests
=
finished_set
else
:
engine_core_outputs
[
client_index
]
=
EngineCoreOutputs
(
finished_requests
=
finished_set
)
finished_req_ids
.
clear
()
if
engine_core_outputs
:
# Return stats to only one of the front-ends.
next
(
iter
(
engine_core_outputs
.
values
())).
scheduler_stats
=
(
scheduler
.
make_stats
(
spec_decoding_stats
))
return
engine_core_outputs
def
engine_core_step
(
core
)
->
tuple
[
dict
[
int
,
EngineCoreOutputs
],
bool
]:
"""Schedule, execute, and make output.
Returns tuple of outputs and a flag indicating whether the model
was executed.
"""
# Check for any requests remaining in the scheduler - unfinished,
# or finished and not yet removed from the batch.
if
not
core
.
scheduler
.
has_requests
():
return
{},
False
scheduler_output
=
core
.
scheduler
.
schedule
()
model_output
=
core
.
execute_model
(
scheduler_output
)
if
isinstance
(
model_output
,
ZeroV1ModelRunnerOutput
):
engine_core_outputs
=
zero_overhead_update_from_output
(
core
.
scheduler
,
scheduler_output
,
model_output
)
# type: ignore
else
:
engine_core_outputs
=
core
.
scheduler
.
update_from_output
(
scheduler_output
,
model_output
)
# type: ignore
return
(
engine_core_outputs
,
scheduler_output
.
total_num_scheduled_tokens
>
0
)
\ No newline at end of file
vllm/zero_overhead/v1/eagle.py
deleted
100644 → 0
View file @
9bf1b213
import
torch
from
vllm.forward_context
import
set_forward_context
from
vllm.model_executor.models.llama_eagle3
import
Eagle3LlamaForCausalLM
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.v1.attention.backends.mla.common
import
MLACommonMetadata
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.spec_decode.eagle
import
PADDING_SLOT_ID
,
EagleProposer
class
V1ZeroEagleProposer
(
EagleProposer
):
def
__init__
(
self
,
vllm_config
,
device
,
runner
=
None
):
super
().
__init__
(
vllm_config
,
device
,
runner
)
self
.
spec_scheduler_max_num_tokens
=
0
def
propose
(
self
,
# [num_tokens]
target_token_ids
:
torch
.
Tensor
,
# [num_tokens]
target_positions
:
torch
.
Tensor
,
# [num_tokens, hidden_size]
target_hidden_states
:
torch
.
Tensor
,
# [num_tokens]
target_slot_mapping
:
torch
.
Tensor
,
# [batch_size]
next_token_ids
:
torch
.
Tensor
,
# [batch_size + 1] starting with 0
cu_num_tokens
:
torch
.
Tensor
,
# [batch_size, max_num_blocks_per_req]
block_table
:
torch
.
Tensor
,
# [batch_size]
sampling_metadata
:
SamplingMetadata
,
decoding
:
bool
=
False
,
)
->
torch
.
Tensor
:
num_tokens
=
target_token_ids
.
shape
[
0
]
batch_size
=
next_token_ids
.
shape
[
0
]
last_token_indices
=
cu_num_tokens
[
1
:]
-
1
if
self
.
method
==
"eagle3"
:
assert
isinstance
(
self
.
model
,
Eagle3LlamaForCausalLM
)
target_hidden_states
=
self
.
model
.
combine_hidden_states
(
target_hidden_states
)
assert
target_hidden_states
.
shape
[
-
1
]
==
self
.
hidden_size
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
self
.
input_ids
[:
num_tokens
-
1
]
=
target_token_ids
[
1
:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
self
.
input_ids
[
last_token_indices
]
=
next_token_ids
# FA requires seq_len to have dtype int32.
seq_lens
=
(
target_positions
[
last_token_indices
]
+
1
).
int
()
if
self
.
method
in
[
"eagle"
,
"eagle3"
]:
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
max_seq_len
=
seq_lens
.
max
().
item
()
max_num_tokens
=
(
cu_num_tokens
[
1
:]
-
cu_num_tokens
[:
-
1
]).
max
().
item
()
attn_metadata
=
FlashAttentionMetadata
(
num_actual_tokens
=
num_tokens
,
max_query_len
=
max_num_tokens
,
query_start_loc
=
cu_num_tokens
,
max_seq_len
=
max_seq_len
,
seq_lens
=
seq_lens
,
block_table
=
block_table
,
slot_mapping
=
target_slot_mapping
,
# TODO(woosuk): Support cascade attention.
use_cascade
=
False
,
common_prefix_len
=
0
,
cu_prefix_query_lens
=
None
,
prefix_kv_lens
=
None
,
suffix_kv_lens
=
None
,
)
elif
self
.
method
==
"deepseek_mtp"
:
max_query_len
=
self
.
spec_scheduler_max_num_tokens
common_attn_metadata
=
CommonAttentionMetadata
(
query_start_loc
=
cu_num_tokens
,
seq_lens
=
seq_lens
,
num_reqs
=
batch_size
,
num_actual_tokens
=
num_tokens
,
max_query_len
=
max_query_len
,
slot_mapping
=
target_slot_mapping
,
spec_layer_decoding
=
decoding
)
assert
self
.
runner
is
not
None
# FIXME: need to consider multiple kv_cache_groups
attn_metadata
=
self
.
runner
.
attn_metadata_builders
[
0
].
build
(
common_prefix_len
=
0
,
common_attn_metadata
=
common_attn_metadata
)
else
:
raise
ValueError
(
f
"Unsupported method:
{
self
.
method
}
"
)
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
per_layer_attn_metadata
=
{}
for
layer_name
in
self
.
attn_layer_names
:
per_layer_attn_metadata
[
layer_name
]
=
attn_metadata
if
self
.
use_cuda_graph
and
\
num_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]:
num_input_tokens
=
self
.
vllm_config
.
pad_for_cudagraph
(
num_tokens
)
else
:
num_input_tokens
=
num_tokens
# copy inputs to buffer for cudagraph
self
.
positions
[:
num_tokens
]
=
target_positions
self
.
hidden_states
[:
num_tokens
]
=
target_hidden_states
if
(
decoding
and
self
.
use_full_cuda_graph
and
num_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
assert
self
.
attn_metadata_cudagraph
if
self
.
method
in
[
"eagle"
,
"eagle3"
]:
self
.
attn_metadata_cudagraph
.
seq_lens
[:
batch_size
]
=
(
attn_metadata
.
seq_lens
)
self
.
attn_metadata_cudagraph
.
slot_mapping
[:
num_tokens
]
=
(
attn_metadata
.
slot_mapping
)
self
.
attn_metadata_cudagraph
.
query_start_loc
[:
batch_size
+
1
]
=
(
attn_metadata
.
query_start_loc
)
self
.
attn_metadata_cudagraph
.
block_table
[:
batch_size
]
=
(
attn_metadata
.
block_table
)
elif
self
.
method
==
"deepseek_mtp"
:
self
.
attn_metadata_cudagraph
.
num_actual_tokens
=
(
attn_metadata
.
num_actual_tokens
)
self
.
attn_metadata_cudagraph
.
query_start_loc
[:
batch_size
+
1
]
=
(
attn_metadata
.
query_start_loc
)
self
.
attn_metadata_cudagraph
.
slot_mapping
[:
num_tokens
]
=
(
attn_metadata
.
slot_mapping
)
self
.
attn_metadata_cudagraph
.
num_decodes
=
(
attn_metadata
.
num_decodes
)
self
.
attn_metadata_cudagraph
.
num_decode_tokens
=
(
attn_metadata
.
num_decode_tokens
)
self
.
attn_metadata_cudagraph
.
num_prefills
=
(
attn_metadata
.
num_prefills
)
if
attn_metadata
.
decode
is
not
None
:
self
.
attn_metadata_cudagraph
.
decode
.
block_table
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
decode
.
block_table
)
self
.
attn_metadata_cudagraph
.
decode
.
seq_lens
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
decode
.
seq_lens
)
with
set_forward_context
(
per_layer_attn_metadata
,
self
.
vllm_config
,
num_tokens
=
num_input_tokens
,
skip_cuda_graphs
=
not
decoding
):
ret_hidden_states
=
self
.
model
(
self
.
input_ids
[:
num_input_tokens
],
self
.
positions
[:
num_input_tokens
],
self
.
hidden_states
[:
num_input_tokens
],
)
if
self
.
method
==
"deepseek_mtp"
:
last_hidden_states
=
ret_hidden_states
else
:
last_hidden_states
,
hidden_states
=
ret_hidden_states
sample_hidden_states
=
last_hidden_states
[
last_token_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
# Early exit if there is only one draft token to be generated.
if
self
.
num_speculative_tokens
==
1
:
# [batch_size, 1]
return
draft_token_ids
.
view
(
-
1
,
1
)
# TODO: Currently, MTP module released by deepseek only has
# one layer. Adapt this code to support multiple layers once
# there's a multi-layer MTP module.
# Generate the remaining draft tokens.
draft_token_ids_list
=
[
draft_token_ids
]
positions
=
target_positions
[
last_token_indices
]
if
self
.
method
==
"deepseek_mtp"
:
hidden_states
=
last_hidden_states
[
last_token_indices
]
else
:
hidden_states
=
hidden_states
[
last_token_indices
]
if
self
.
use_cuda_graph
and
\
batch_size
<=
self
.
cudagraph_batch_sizes
[
-
1
]:
input_batch_size
=
self
.
vllm_config
.
pad_for_cudagraph
(
batch_size
)
else
:
input_batch_size
=
batch_size
attn_metadata
.
num_actual_tokens
=
batch_size
attn_metadata
.
max_query_len
=
1
attn_metadata
.
query_start_loc
=
self
.
arange
[:
batch_size
+
1
]
if
isinstance
(
attn_metadata
,
MLACommonMetadata
):
attn_metadata
.
num_decodes
=
batch_size
attn_metadata
.
num_decode_tokens
=
batch_size
attn_metadata
.
num_prefills
=
0
block_table
=
self
.
runner
.
attn_metadata_builders
[
0
].
block_table
.
get_device_tensor
()[:
batch_size
,
...]
attn_metadata
.
decode
=
self
.
runner
.
attn_metadata_builders
[
0
].
_build_decode
(
block_table_tensor
=
block_table
,
seq_lens
=
seq_lens
,
)
for
i
in
range
(
self
.
num_speculative_tokens
-
1
):
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default.
input_ids
=
draft_token_ids_list
[
-
1
].
int
()
positions
+=
1
# NOTE(woosuk): We should handle the case where the draft model
# generates tokens beyond the max model length. Since it is complex
# to remove such requests from the batch, we keep them in the batch
# but adjust the position ids and slot mappings to avoid the
# out-of-range access during the model execution. The draft tokens
# generated with this adjustment should be ignored.
exceeds_max_model_len
=
positions
>=
self
.
max_model_len
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
clamped_positions
=
torch
.
where
(
exceeds_max_model_len
,
0
,
positions
)
if
isinstance
(
attn_metadata
,
MLACommonMetadata
):
attn_metadata
.
decode
.
seq_lens
+=
1
else
:
attn_metadata
.
seq_lens
+=
1
# Increment the sequence lengths.
attn_metadata
.
max_seq_len
+=
1
# Consider max model length.
attn_metadata
.
max_seq_len
=
min
(
attn_metadata
.
max_seq_len
,
self
.
max_model_len
)
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
attn_metadata
.
seq_lens
.
masked_fill_
(
exceeds_max_model_len
,
1
)
# Compute the slot mapping.
block_numbers
=
clamped_positions
//
self
.
block_size
block_ids
=
block_table
.
gather
(
dim
=
1
,
index
=
block_numbers
.
view
(
-
1
,
1
))
block_ids
=
block_ids
.
view
(
-
1
)
attn_metadata
.
slot_mapping
=
(
block_ids
*
self
.
block_size
+
clamped_positions
%
self
.
block_size
)
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
attn_metadata
.
slot_mapping
.
masked_fill_
(
exceeds_max_model_len
,
PADDING_SLOT_ID
)
# copy inputs to buffer for cudagraph
self
.
input_ids
[:
batch_size
]
=
input_ids
self
.
positions
[:
batch_size
]
=
clamped_positions
self
.
hidden_states
[:
batch_size
]
=
hidden_states
if
(
self
.
use_full_cuda_graph
and
batch_size
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
assert
self
.
attn_metadata_cudagraph
if
self
.
method
in
[
"eagle"
,
"eagle3"
]:
self
.
attn_metadata_cudagraph
.
seq_lens
[:
batch_size
]
=
(
attn_metadata
.
seq_lens
)
self
.
attn_metadata_cudagraph
.
slot_mapping
[:
batch_size
]
=
(
attn_metadata
.
slot_mapping
)
if
i
==
0
:
self
.
attn_metadata_cudagraph
.
query_start_loc
[:
batch_size
+
1
]
=
(
attn_metadata
.
query_start_loc
)
self
.
attn_metadata_cudagraph
.
block_table
[:
batch_size
]
=
(
attn_metadata
.
block_table
)
elif
self
.
method
==
"deepseek_mtp"
:
self
.
attn_metadata_cudagraph
.
num_actual_tokens
=
(
attn_metadata
.
num_actual_tokens
)
self
.
attn_metadata_cudagraph
.
slot_mapping
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
slot_mapping
)
self
.
attn_metadata_cudagraph
.
num_decodes
=
(
attn_metadata
.
num_decodes
)
self
.
attn_metadata_cudagraph
.
num_decode_tokens
=
(
attn_metadata
.
num_decode_tokens
)
self
.
attn_metadata_cudagraph
.
num_prefills
=
(
attn_metadata
.
num_prefills
)
self
.
attn_metadata_cudagraph
.
decode
.
seq_lens
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
decode
.
seq_lens
)
if
i
==
0
:
self
.
attn_metadata_cudagraph
.
query_start_loc
[:
batch_size
+
1
]
=
(
attn_metadata
.
query_start_loc
)
self
.
attn_metadata_cudagraph
.
decode
.
block_table
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
decode
.
block_table
)
# Run the model.
with
set_forward_context
(
per_layer_attn_metadata
,
self
.
vllm_config
,
num_tokens
=
input_batch_size
):
ret_hidden_states
=
self
.
model
(
self
.
input_ids
[:
input_batch_size
],
self
.
positions
[:
input_batch_size
],
self
.
hidden_states
[:
input_batch_size
],
)
if
self
.
method
==
"deepseek_mtp"
:
last_hidden_states
=
ret_hidden_states
hidden_states
=
last_hidden_states
[:
batch_size
]
else
:
last_hidden_states
,
hidden_states
=
ret_hidden_states
hidden_states
=
hidden_states
[:
batch_size
]
logits
=
self
.
model
.
compute_logits
(
last_hidden_states
[:
batch_size
],
None
)
# TODO(wenlong): get more than one token for tree attention
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
draft_token_ids_list
.
append
(
draft_token_ids
)
# [batch_size, num_speculative_tokens]
draft_token_ids
=
torch
.
stack
(
draft_token_ids_list
,
dim
=
1
)
return
draft_token_ids
\ No newline at end of file
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment