Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1bd3ae33
Commit
1bd3ae33
authored
Oct 11, 2025
by
zhuwenwen
Browse files
skip silu_mul_fp8_quant_deep_gemm_cuda and remove zero_overhead
parent
9bf1b213
Changes
22
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
31 additions
and
3272 deletions
+31
-3272
vllm/config/model.py
vllm/config/model.py
+4
-0
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+2
-7
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+1
-2
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+0
-3
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+19
-19
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+4
-17
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+1
-2
vllm/zero_overhead/llm_engine.py
vllm/zero_overhead/llm_engine.py
+0
-655
vllm/zero_overhead/model_runner.py
vllm/zero_overhead/model_runner.py
+0
-171
vllm/zero_overhead/sampler.py
vllm/zero_overhead/sampler.py
+0
-500
vllm/zero_overhead/sequence.py
vllm/zero_overhead/sequence.py
+0
-64
vllm/zero_overhead/spec_decode/batch_expansion.py
vllm/zero_overhead/spec_decode/batch_expansion.py
+0
-141
vllm/zero_overhead/spec_decode/muti_step_worker.py
vllm/zero_overhead/spec_decode/muti_step_worker.py
+0
-137
vllm/zero_overhead/spec_decode/spec_decode_worker.py
vllm/zero_overhead/spec_decode/spec_decode_worker.py
+0
-565
vllm/zero_overhead/spec_decode/top1_proproser.py
vllm/zero_overhead/spec_decode/top1_proproser.py
+0
-84
vllm/zero_overhead/stop_check.py
vllm/zero_overhead/stop_check.py
+0
-77
vllm/zero_overhead/tokenizer.py
vllm/zero_overhead/tokenizer.py
+0
-84
vllm/zero_overhead/utils.py
vllm/zero_overhead/utils.py
+0
-71
vllm/zero_overhead/v1/core.py
vllm/zero_overhead/v1/core.py
+0
-357
vllm/zero_overhead/v1/eagle.py
vllm/zero_overhead/v1/eagle.py
+0
-316
No files found.
vllm/config/model.py
View file @
1bd3ae33
...
...
@@ -276,6 +276,9 @@ class ModelConfig:
override_pooler_config
:
Optional
[
Union
[
dict
,
PoolerConfig
]]
=
None
"""[DEPRECATED] Use `pooler_config` instead. This field will be removed in
v0.12.0 or v1.0.0, whichever is sooner."""
enable_chunked_prefill
:
Optional
[
bool
]
=
None
"""If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens."""
# Multimodal config and init vars
multimodal_config
:
Optional
[
MultiModalConfig
]
=
None
...
...
@@ -320,6 +323,7 @@ class ModelConfig:
factors
.
append
(
self
.
rope_scaling
)
factors
.
append
(
self
.
rope_theta
)
factors
.
append
(
self
.
video_pruning_rate
)
factors
.
append
(
self
.
enable_chunked_prefill
)
# hf_config can control how the model looks!
try
:
...
...
vllm/entrypoints/llm.py
View file @
1bd3ae33
...
...
@@ -56,7 +56,6 @@ from vllm.v1.engine.llm_engine import LLMEngine
from
vllm.v1.sample.logits_processor
import
LogitsProcessor
import
vllm.envs
as
envs
from
vllm.zero_overhead.llm_engine
import
ZeroOverheadEngine
if
TYPE_CHECKING
:
...
...
@@ -300,10 +299,6 @@ class LLM:
log_non_default_args
(
engine_args
)
# Create the Engine (autoselects V0 vs V1)
if
envs
.
VLLM_ZERO_OVERHEAD
:
self
.
llm_engine
=
ZeroOverheadEngine
.
from_engine_args
(
engine_args
=
engine_args
,
usage_context
=
UsageContext
.
LLM_CLASS
)
else
:
self
.
llm_engine
=
LLMEngine
.
from_engine_args
(
engine_args
=
engine_args
,
usage_context
=
UsageContext
.
LLM_CLASS
)
self
.
engine_class
=
type
(
self
.
llm_engine
)
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
1bd3ae33
...
...
@@ -1840,8 +1840,7 @@ class FusedMoE(CustomOp):
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
,
scoring_func
=
scoring_func
,
routed_scaling_factor
=
routed_scaling_factor
,
e_score_correction_bias
=
e_score_correction_bias
)
routed_scaling_factor
=
routed_scaling_factor
)
if
indices_type
is
not
None
:
topk_ids
=
topk_ids
.
to
(
dtype
=
indices_type
)
elif
e_score_correction_bias
is
not
None
:
...
...
vllm/v1/engine/core.py
View file @
1bd3ae33
...
...
@@ -16,7 +16,6 @@ from typing import Any, Callable, Optional, TypeVar, Union
import
msgspec
from
vllm
import
envs
from
vllm.zero_overhead.v1.core
import
engine_core_step
import
zmq
from
vllm.config
import
ParallelConfig
,
VllmConfig
...
...
@@ -277,8 +276,6 @@ class EngineCore:
Returns tuple of outputs and a flag indicating whether the model
was executed.
"""
if
envs
.
VLLM_ZERO_OVERHEAD
:
return
engine_core_step
(
self
)
# Check for any requests remaining in the scheduler - unfinished,
# or finished and not yet removed from the batch.
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
1bd3ae33
...
...
@@ -1458,16 +1458,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# [0, 1, 2, 5, 6, 9]
target_logits_indices
+=
arange
if
envs
.
VLLM_ZERO_OVERHEAD
:
cu_num_draft_tokens
=
torch
.
from_numpy
(
cu_num_draft_tokens
).
pin_memory
().
to
(
self
.
device
,
non_blocking
=
True
)
logits_indices
=
torch
.
from_numpy
(
logits_indices
).
pin_memory
().
to
(
self
.
device
,
non_blocking
=
True
)
target_logits_indices
=
torch
.
from_numpy
(
target_logits_indices
).
pin_memory
().
to
(
self
.
device
,
non_blocking
=
True
)
bonus_logits_indices
=
torch
.
from_numpy
(
bonus_logits_indices
).
pin_memory
().
to
(
self
.
device
,
non_blocking
=
True
)
else
:
#
if envs.VLLM_ZERO_OVERHEAD:
#
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).pin_memory().to(
#
self.device, non_blocking=True)
#
logits_indices = torch.from_numpy(logits_indices).pin_memory().to(self.device,
#
non_blocking=True)
#
target_logits_indices = torch.from_numpy(target_logits_indices).pin_memory().to(
#
self.device, non_blocking=True)
#
bonus_logits_indices = torch.from_numpy(bonus_logits_indices).pin_memory().to(
#
self.device, non_blocking=True)
#
else:
# TODO: Optimize the CPU -> GPU copy.
cu_num_draft_tokens
=
torch
.
from_numpy
(
cu_num_draft_tokens
).
to
(
self
.
device
,
non_blocking
=
True
)
...
...
vllm/v1/worker/gpu_worker.py
View file @
1bd3ae33
...
...
@@ -34,8 +34,6 @@ from vllm.v1.utils import report_usage_stats
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
from
vllm.v1.worker.utils
import
is_residual_scattered_for_sp
from
vllm.v1.worker.worker_base
import
WorkerBase
from
vllm.zero_overhead.utils
import
zero_overhead_stream
from
vllm.zero_overhead.v1.gpu_model_runner
import
V1ZeroModelRunner
logger
=
init_logger
(
__name__
)
...
...
@@ -200,11 +198,6 @@ class Worker(WorkerBase):
f
"Not support device type:
{
self
.
device_config
.
device
}
"
)
# Construct the model runner
if
envs
.
VLLM_ZERO_OVERHEAD
:
logger
.
info
(
'use zero overhead model_runner'
)
self
.
model_runner
:
GPUModelRunner
=
V1ZeroModelRunner
(
self
.
vllm_config
,
self
.
device
)
else
:
self
.
model_runner
:
GPUModelRunner
=
GPUModelRunner
(
self
.
vllm_config
,
self
.
device
)
...
...
@@ -451,12 +444,6 @@ class Worker(WorkerBase):
all_gather_group
=
get_tp_group
(),
all_gather_tensors
=
all_gather_tensors
))
if
envs
.
VLLM_ZERO_OVERHEAD
:
use_stream
=
zero_overhead_stream
(
self
.
device
)
with
torch
.
cuda
.
stream
(
use_stream
):
output
=
self
.
model_runner
.
execute_model
(
scheduler_output
,
intermediate_tensors
)
else
:
output
=
self
.
model_runner
.
execute_model
(
scheduler_output
,
intermediate_tensors
)
...
...
vllm/worker/worker_base.py
View file @
1bd3ae33
...
...
@@ -52,8 +52,7 @@ class WorkerBase:
different hardware. Also abstracts control plane communication, e.g., to
communicate request metadata to other workers.
"""
model_input
:
Optional
[
ModelRunnerInputBase
]
=
None
# TODO
tree_decoding
=
(
os
.
environ
.
get
(
'VLLM_TREE_DECODING'
)
==
'1'
)
def
__init__
(
...
...
vllm/zero_overhead/llm_engine.py
deleted
100644 → 0
View file @
9bf1b213
This diff is collapsed.
Click to expand it.
vllm/zero_overhead/model_runner.py
deleted
100644 → 0
View file @
9bf1b213
import
torch
import
itertools
from
typing
import
List
,
Optional
,
Set
from
vllm.lora.layers
import
LoRAMapping
from
vllm.multimodal.inputs
import
MultiModalKwargs
from
vllm.prompt_adapter.layers
import
PromptAdapterMapping
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sequence
import
SequenceGroupMetadata
from
vllm.utils
import
async_tensor_h2d
,
flatten_2d_lists
from
vllm.worker.model_runner
import
ModelInputForGPU
,
ModelInputForGPUBuilder
from
vllm.zero_overhead.sampler
import
get_last_sampler
from
vllm.zero_overhead.utils
import
SpecStepKind
,
get_accepted_token_ids
,
get_proposal_token_ids
,
get_spec_last_step
,
get_spec_step
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
_update_input_tokens
(
accepted_req_ids
,
accepted_req_ids_len
,
accepted_token_ids
,
accepted_token_len
,
chidren_req_ids
,
chidren_req_ids_len
,
input_tokens
,
input_tokens_len
,
input_positions
,
seq_lens
,
seq_lens_meta
,
seq_lens_tensor
,
slot_mapping
,
seq_start_loc
,
context_lens_tensor
,
):
chidren_req_ids_
=
tl
.
load
(
chidren_req_ids
+
tl
.
arange
(
0
,
chidren_req_ids_len
))
accepted_req_ids_
=
tl
.
load
(
accepted_req_ids
+
tl
.
arange
(
0
,
chidren_req_ids_len
))
for
seq_id_idx
in
range
(
chidren_req_ids_len
/
2
):
seq_id
=
chidren_req_ids_
[
2
*
seq_id_idx
]
for
i
in
range
(
accepted_req_ids_len
):
if
seq_id
==
accepted_req_ids_
[
i
]:
accepted_token_ids_
=
tl
.
load
(
accepted_token_ids
+
tl
.
arange
(
i
*
accepted_token_len
,
tl
.
arange
(
0
,
accepted_token_len
)))
accepted_token_counter
=
0
for
j
in
range
(
accepted_token_len
):
if
accepted_token_ids_
[
j
]
==
-
1
:
break
accepted_token_counter
+=
1
if
accepted_token_counter
==
accepted_token_len
:
tl
.
store
(
input_tokens
+
seq_id_idx
*
2
+
tl
.
arange
(
0
,
2
),
accepted_token_ids_
[
-
2
:])
else
:
tl
.
store
(
input_tokens
+
seq_id_idx
*
2
,
0
)
tl
.
store
(
input_tokens
+
seq_id_idx
*
2
+
1
,
accepted_token_ids_
[
accepted_token_counter
-
1
])
input_pos
=
tl
.
load
(
input_positions
+
seq_id_idx
*
2
+
tl
.
arange
(
0
,
2
))
input_pos
[
0
]
=
0
input_pos
[
1
]
=
input_pos
[
1
]
-
(
accepted_req_ids_len
-
accepted_token_counter
)
tl
.
store
(
input_positions
+
seq_id_idx
*
2
+
tl
.
arange
(
0
,
2
),
input_pos
)
tl
.
store
(
context_lens_tensor
+
seq_id_idx
*
2
+
tl
.
arange
(
0
,
2
),
input_pos
)
input_pos
[
0
]
=
-
1
tl
.
store
(
slot_mapping
+
seq_id_idx
*
2
+
tl
.
arange
(
0
,
2
),
input_pos
)
input_pos
[
0
]
=
1
input_pos
[
1
]
=
input_pos
[
1
]
+
1
tl
.
store
(
seq_lens
+
seq_id_idx
*
2
+
tl
.
arange
(
0
,
2
),
input_pos
)
tl
.
store
(
seq_lens_meta
+
seq_id_idx
*
2
+
tl
.
arange
(
0
,
2
),
input_pos
)
tl
.
store
(
seq_lens_tensor
+
seq_id_idx
*
2
+
tl
.
arange
(
0
,
2
),
input_pos
)
seq_lens_
=
tl
.
load
(
seq_lens
+
tl
.
arange
(
0
,
input_tokens_len
))
seq_start_loc_
=
tl
.
zero_like
(
seq_start_loc
)
for
i
in
range
(
input_tokens_len
):
seq_start_loc_
[
i
+
1
]
=
seq_start_loc_
[
i
]
+
seq_lens_
[
i
]
tl
.
store
(
seq_start_loc
+
tl
.
arange
(
0
,
input_tokens_len
+
1
),
seq_start_loc_
)
class
ZeroOverheadModelInputForGpuBuilder
(
ModelInputForGPUBuilder
):
def
__init__
(
self
,
runner
,
finished_requests_ids
=
None
):
super
().
__init__
(
runner
,
finished_requests_ids
)
self
.
req_ids
=
[]
def
prepare
(
self
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
None
:
self
.
req_ids
.
clear
()
return
super
().
prepare
(
finished_requests_ids
)
def
add_seq_group
(
self
,
seq_group_metadata
:
SequenceGroupMetadata
):
seq_ids
=
seq_group_metadata
.
seq_data
.
keys
()
n_seqs
=
len
(
seq_ids
)
seq_ids
=
list
(
seq_ids
)
for
seq_idx
in
range
(
n_seqs
):
self
.
req_ids
.
append
(
seq_ids
[
seq_idx
])
return
super
().
add_seq_group
(
seq_group_metadata
)
def
build
(
self
)
->
ModelInputForGPU
:
model_input
=
super
().
build
()
last_sampler
=
get_last_sampler
()
spec_step
=
get_spec_step
()
last_step
=
get_spec_last_step
()
if
last_sampler
is
not
None
:
if
spec_step
==
SpecStepKind
.
KIND_DEFAULT
:
update_indices
=
[]
select_indices
=
[]
query_idx
=
0
for
i
,
seq_id
in
enumerate
(
self
.
req_ids
):
for
j
,
seq_id_
in
enumerate
(
last_sampler
.
seq_ids
):
if
seq_id
==
seq_id_
:
select_indices
.
append
(
j
)
update_indices
.
append
(
query_idx
)
break
query_idx
+=
model_input
.
query_lens
[
i
]
if
len
(
select_indices
)
>
0
and
last_sampler
.
sampled_token_ids_tensor
is
not
None
:
select_indices
=
async_tensor_h2d
(
select_indices
,
torch
.
long
,
self
.
runner
.
device
,
self
.
runner
.
pin_memory
)
update_indices
=
async_tensor_h2d
(
update_indices
,
torch
.
long
,
self
.
runner
.
device
,
self
.
runner
.
pin_memory
)
model_input
.
input_tokens
[
update_indices
]
=
last_sampler
.
sampled_token_ids_tensor
[
select_indices
,
0
]
if
spec_step
==
SpecStepKind
.
OTHER_PROPOSAL
:
if
last_step
==
SpecStepKind
.
OTHER_PROPOSAL
:
# copy last sampled token ids to input tokens directly.
update_indices
=
[
i
for
i
in
range
(
len
(
self
.
req_ids
))]
update_indices
=
async_tensor_h2d
(
update_indices
,
torch
.
long
,
self
.
runner
.
device
,
self
.
runner
.
pin_memory
)
model_input
.
input_tokens
[
update_indices
]
=
last_sampler
.
sampled_token_ids_tensor
[
update_indices
,
0
]
if
last_step
==
SpecStepKind
.
FIRST_PROPOSAL
:
# TODO: ajust input tokens number to 1 per request.
update_indices
=
[
i
for
i
in
range
(
len
(
self
.
req_ids
))]
update_indices
=
async_tensor_h2d
(
update_indices
,
torch
.
long
,
self
.
runner
.
device
,
self
.
runner
.
pin_memory
)
model_input
.
input_tokens
[
update_indices
]
=
last_sampler
.
sampled_token_ids_tensor
[
update_indices
,
0
]
if
spec_step
==
SpecStepKind
.
SCORE_DECODE
:
proposal_token_ids
=
get_proposal_token_ids
()
shape
=
proposal_token_ids
.
shape
batch_size
=
shape
[
0
]
proposal_len
=
shape
[
1
]
update_indices
=
[]
for
i
in
range
(
batch_size
):
for
j
in
range
(
proposal_len
):
update_indices
.
append
(
i
*
(
proposal_len
+
1
)
+
j
+
1
)
update_indices
=
async_tensor_h2d
(
update_indices
,
torch
.
long
,
self
.
runner
.
device
,
self
.
runner
.
pin_memory
)
model_input
.
input_tokens
[
update_indices
]
=
proposal_token_ids
.
view
(
-
1
)
if
spec_step
==
SpecStepKind
.
FIRST_PROPOSAL
:
if
last_step
==
SpecStepKind
.
PREFILL
:
# TODO: when last step is prefill, just update the input ids for last seqence_id onely.
pass
if
last_step
==
SpecStepKind
.
SCORE_DECODE
:
# TODO: when last step is score decode, fix input ids、seq_lens、input_positions use accepte token ids
accept_token_ids
,
accept_seq_ids
=
get_accepted_token_ids
()
chidren_req_ids
=
async_tensor_h2d
(
self
.
req_ids
,
torch
.
long
,
self
.
runner
.
device
,
self
.
runner
.
pin_memory
)
grid
=
[
1
,
1
,
1
]
_update_input_tokens
[
grid
](
accept_seq_ids
,
accept_seq_ids
.
shape
[
0
],
accept_token_ids
,
accept_token_ids
.
shape
[
1
],
chidren_req_ids
,
chidren_req_ids
.
shape
[
0
],
model_input
.
input_tokens
,
model_input
.
input_tokens
.
shape
[
0
],
model_input
.
input_positions
,
model_input
.
seq_lens
,
model_input
.
attn_metadata
.
seq_lens_tensor
,
model_input
.
attn_metadata
.
seq_lens
,
model_input
.
attn_metadata
.
slot_mapping
,
model_input
.
attn_metadata
.
seq_start_loc
,
model_input
.
attn_metadata
.
context_lens_tensor
,
)
return
model_input
vllm/zero_overhead/sampler.py
deleted
100644 → 0
View file @
9bf1b213
This diff is collapsed.
Click to expand it.
vllm/zero_overhead/sequence.py
deleted
100644 → 0
View file @
9bf1b213
from
typing
import
Union
from
vllm.sequence
import
Sequence
from
typing
import
Sequence
as
GenericSequence
class
ZeroOverheadSequence
(
Sequence
):
def
__init__
(
self
,
seq_id
,
inputs
,
block_size
,
eos_token_id
=
None
,
lora_request
=
None
,
prompt_adapter_request
=
None
):
super
().
__init__
(
seq_id
,
inputs
,
block_size
,
eos_token_id
,
lora_request
,
prompt_adapter_request
)
self
.
effective_output_len
:
int
=
0
def
fix_last_token_id
(
self
,
token_id
:
int
)
->
None
:
effect_offset
=
self
.
effective_output_len
-
len
(
self
.
data
.
output_token_ids
)
if
effect_offset
<
0
:
self
.
data
.
_output_token_ids
[
effect_offset
]
=
token_id
if
len
(
self
.
data
.
_new_appended_tokens
)
>=
effect_offset
*
-
1
:
self
.
data
.
_new_appended_tokens
[
effect_offset
]
=
token_id
self
.
data
.
_cached_all_token_ids
[
effect_offset
]
=
token_id
self
.
effective_output_len
+=
1
def
remove_last_place_holder
(
self
,
count
):
self
.
data
.
_output_token_ids
=
self
.
data
.
_output_token_ids
[:
-
1
*
count
]
self
.
data
.
_new_appended_tokens
=
self
.
data
.
_new_appended_tokens
[:
-
1
*
count
]
self
.
data
.
_cached_all_token_ids
=
self
.
data
.
_cached_all_token_ids
[:
-
1
*
count
]
self
.
data
.
_num_computed_tokens
-=
count
def
zero_overhead_get_output_token_ids
(
self
)
->
tuple
[
int
,
...]:
return
self
.
data
.
output_token_ids
[:
self
.
effective_output_len
]
def
zero_overhead_get_output_len
(
self
)
->
int
:
return
self
.
effective_output_len
def
zero_overhead_get_last_token_id
(
self
)
->
int
:
if
self
.
effective_output_len
==
0
:
return
self
.
data
.
_prompt_token_ids
[
-
1
]
return
self
.
data
.
_output_token_ids
[
self
.
effective_output_len
-
1
]
def
zero_overhead_get_len
(
self
)
->
int
:
return
self
.
effective_output_len
+
len
(
self
.
data
.
_prompt_token_ids
)
def
get_output_token_ids_to_return
(
self
,
delta
:
bool
)
->
Union
[
GenericSequence
[
int
],
int
]:
"""If delta is True, only new tokens since the last call to
this method are returned"""
if
not
delta
:
return
self
.
zero_overhead_get_output_token_ids
()
output_len
=
self
.
zero_overhead_get_output_len
()
# Get the number of new tokens
num_new_tokens
=
output_len
-
self
.
_last_output_token_ids_offset
self
.
_last_output_token_ids_offset
=
output_len
# Return new tokens
if
num_new_tokens
==
1
:
# Optimization for single decode token case
# (which is what we have most of the time)
return
self
.
data
.
_cached_all_token_ids
[
self
.
effective_output_len
-
1
]
if
num_new_tokens
==
0
:
return
[]
effect_offset
=
self
.
effective_output_len
-
len
(
self
.
data
.
output_token_ids
)
return
self
.
data
.
_cached_all_token_ids
[
-
num_new_tokens
:
effect_offset
]
vllm/zero_overhead/spec_decode/batch_expansion.py
deleted
100644 → 0
View file @
9bf1b213
from
array
import
array
import
numpy
as
np
from
itertools
import
chain
,
count
from
typing
import
Iterator
,
List
,
Optional
,
Tuple
import
torch
from
vllm
import
SamplingParams
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
(
VLLM_INVALID_TOKEN_ID
,
VLLM_TOKEN_ID_ARRAY_TYPE
,
ExecuteModelRequest
,
SequenceData
,
SequenceGroupMetadata
,
get_all_seq_ids
)
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTop1Scorer
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeScorer
,
SpeculativeScores
)
from
vllm.spec_decode.util
import
nvtx_range
,
split_batch_by_proposal_len
from
vllm.utils
import
async_tensor_h2d
from
vllm.zero_overhead.utils
import
get_proposal_lens_list
,
record_proposal_token_ids
SeqId
=
int
TargetSeqId
=
int
TokenId
=
int
DEFAULT_SIMPLE_SAMPLING_PARAMS
=
SamplingParams
()
class
ZeroOverheadBatchExpansionTop1Scorer
(
BatchExpansionTop1Scorer
):
@
nvtx_range
(
"BatchExpansionTop1Scorer.score_proposals"
)
def
score_proposals
(
self
,
execute_model_req
:
ExecuteModelRequest
,
proposals
:
SpeculativeProposals
,
)
->
SpeculativeScores
:
"""Score the proposed tokens via the scorer model.
This converts each input sequence to a set of k+1 target sequences. The
target sequences have the unique continuations to be scored and a
unique sequence ID that is different from all input sequence ids.
If a speculative sequence length would exceed the max model length, then
no speculation is produced for that sequence.
Args:
execute_model_req: The execution request.
proposals: The speculative proposals to score.
Returns:
SpeculativeScores: The scores of each speculative token, along with
which sequences were ignored during scoring.
"""
proposal_lens_list
=
get_proposal_lens_list
()
record_proposal_token_ids
(
proposals
.
proposal_token_ids
)
proposal_token_ids_list
=
np
.
zeros
(
proposals
.
proposal_token_ids
.
shape
,
dtype
=
int
).
tolist
()
# place holder tokens
# Filter the list to ignore invalid proposals.
proposal_token_ids_list_without_skips
=
[
proposals
for
proposals
in
proposal_token_ids_list
if
VLLM_INVALID_TOKEN_ID
not
in
proposals
]
(
spec_indices
,
non_spec_indices
,
target_seq_group_metadata_list
,
num_scoring_tokens
)
=
self
.
_expand_batch
(
seq_group_metadata_list
=
execute_model_req
.
seq_group_metadata_list
,
proposal_token_ids_list
=
proposal_token_ids_list_without_skips
,
proposal_lens_list
=
proposal_lens_list
,
)
target_sampler_output
=
self
.
_scorer_worker
.
execute_model
(
execute_model_req
=
execute_model_req
.
clone
(
seq_group_metadata_list
=
target_seq_group_metadata_list
))
assert
len
(
target_sampler_output
)
==
1
,
"expected single-step output"
target_sampler_output
=
target_sampler_output
[
0
]
if
not
non_spec_indices
:
# All sequence groups in batch have spec decoding enabled
return
self
.
_contract_batch_all_spec
(
target_sampler_output
=
target_sampler_output
,
proposals
=
proposals
,
)
else
:
# Batch has a mix of spec decode enabled and disabled seq groups
return
self
.
_contract_batch
(
execute_model_req
.
seq_group_metadata_list
,
target_sampler_output
=
target_sampler_output
,
proposals
=
proposals
,
num_scoring_tokens
=
num_scoring_tokens
,
non_spec_indices
=
non_spec_indices
,
spec_indices
=
spec_indices
,
k
=
execute_model_req
.
num_lookahead_slots
,
)
def
_contract_non_speculative
(
self
,
scores
:
SpeculativeScores
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
non_spec_indices
:
List
[
int
],
non_spec_outputs
:
SpeculativeScores
,
has_prompt_log
:
bool
)
->
SpeculativeScores
:
"""
Augment input `scores` with non-speculative requests outputs.
This includes decode requests with speculation turned off, as well
as prefill requests when `enable_chunked_prefill` is set.
For the latter, prefills are further separated into terminal and
non-terminal chunks (from which no token is sampled).
"""
if
not
non_spec_indices
:
return
scores
if
has_prompt_log
:
# When prompt_logprobs is enabled, prefills yield output token
# (and respective prob) in the last entry (prompt|out):
# [.|.|.|prefill0_out|.|prefill1_out|decode0_out|..].
# With chunked prefill, non-terminal chunks have -1 on each
# position: they're still picked, but they're discarded later.
seq_meta
=
seq_group_metadata_list
nospec_sizes
=
torch
.
tensor
([
seq_meta
[
i
].
token_chunk_size
if
seq_meta
[
i
].
is_prompt
else
1
for
i
in
non_spec_indices
])
nospec_sampled_token_idxs
=
torch
.
cumsum
(
nospec_sizes
,
0
).
add_
(
-
1
)
else
:
# In this case only sampled tokens are returned, select all.
nospec_sampled_token_idxs
=
list
(
range
(
len
(
non_spec_outputs
.
token_ids
)))
nospec_sampled_token_idxs
=
async_tensor_h2d
(
nospec_sampled_token_idxs
,
torch
.
int32
,
self
.
_device
,
True
)
non_spec_indices
=
async_tensor_h2d
(
non_spec_indices
,
torch
.
int32
,
self
.
_device
,
True
)
scores
.
token_ids
[
non_spec_indices
,
:
1
]
=
\
non_spec_outputs
.
token_ids
[
nospec_sampled_token_idxs
].
unsqueeze
(
1
)
scores
.
probs
[
non_spec_indices
,
:
1
,
:]
=
\
non_spec_outputs
.
probs
[
nospec_sampled_token_idxs
].
unsqueeze
(
1
)
scores
.
logprobs
[
non_spec_indices
,
:
1
,
:]
=
\
non_spec_outputs
.
logprobs
[
nospec_sampled_token_idxs
].
unsqueeze
(
1
)
if
scores
.
hidden_states
is
not
None
:
assert
non_spec_outputs
.
hidden_states
is
not
None
scores
.
hidden_states
[
non_spec_indices
,
:
1
,
:]
=
\
non_spec_outputs
.
hidden_states
[
nospec_sampled_token_idxs
].
unsqueeze
(
1
)
return
scores
\ No newline at end of file
vllm/zero_overhead/spec_decode/muti_step_worker.py
deleted
100644 → 0
View file @
9bf1b213
import
copy
import
weakref
from
typing
import
Dict
,
List
,
Set
,
Tuple
import
torch
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
(
ExecuteModelRequest
,
HiddenStates
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.utils
import
async_tensor_h2d
from
vllm.zero_overhead.spec_decode.top1_proproser
import
ZeroOverheadTop1Proposer
from
vllm.zero_overhead.utils
import
SpecStepKind
,
set_spec_step
if
current_platform
.
is_cuda_alike
():
from
vllm.spec_decode.draft_model_runner
import
TP1DraftModelRunner
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeProposer
)
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
from
vllm.worker.worker_base
import
DelegateWorkerBase
class
ZeroOverheadMultiStepWorker
(
MultiStepWorker
):
def
init_device
(
self
)
->
None
:
self
.
worker
.
init_device
()
self
.
_proposer
=
ZeroOverheadTop1Proposer
(
weakref
.
proxy
(
self
),
# type: ignore[arg-type]
self
.
device
,
self
.
vocab_size
,
max_proposal_len
=
self
.
max_model_len
,
)
@
torch
.
inference_mode
()
def
sampler_output
(
self
,
execute_model_req
:
ExecuteModelRequest
,
sample_len
:
int
,
seq_ids_with_bonus_token_in_last_step
:
Set
[
int
],
)
->
Tuple
[
List
[
SamplerOutput
],
bool
]:
"""Run the model forward pass sample_len times. Returns the list of
sampler output, one per model forward pass, along with indicator of
whether torch tensor in sampler output need to be transposed in latter
sampler_output_to_torch logic.
For multi step worker, this indicator shall be True.
"""
self
.
_raise_if_unsupported
(
execute_model_req
)
# Expand the batch for sequences with a bonus token.
# Perform a forward pass on the expanded batch and filter the
# response to retain only the original sequences' responses.
expanded_request
,
indices_of_seq_with_bonus_tokens
=
\
self
.
_expand_execute_model_request
(
execute_model_req
,
seq_ids_with_bonus_token_in_last_step
)
# Run model sample_len times.
model_outputs
:
List
[
SamplerOutput
]
=
[]
if
current_platform
.
is_cuda_alike
()
and
isinstance
(
self
.
model_runner
,
TP1DraftModelRunner
)
and
self
.
model_runner
.
supports_gpu_multi_step
(
expanded_request
):
# Here we run the draft_model_runner with multi-step prepare
# on the GPU directly
expanded_request
.
num_steps
=
sample_len
self
.
model_runner
.
set_indices_of_seq_with_bonus_tokens
(
indices_of_seq_with_bonus_tokens
)
model_outputs
=
self
.
execute_model
(
execute_model_req
=
expanded_request
)
else
:
# Here we run multi-step directly, with every step prepared
# on the CPU.
# TODO: Remove this branch once DraftModelRunner supports TP>1
# and other restrictions that are part of DraftModelRunner's
# supports_gpu_multi_step(..)
set_spec_step
(
SpecStepKind
.
FIRST_PROPOSAL
)
for
_
in
range
(
sample_len
):
model_output
:
List
[
SamplerOutput
]
=
self
.
worker
.
execute_model
(
execute_model_req
=
expanded_request
)
assert
(
len
(
model_output
)
==
1
),
"composing multistep workers not supported"
model_output
=
model_output
[
0
]
set_spec_step
(
SpecStepKind
.
OTHER_PROPOSAL
)
self
.
_append_new_tokens
(
model_output
,
expanded_request
.
seq_group_metadata_list
,
indices_of_seq_with_bonus_tokens
)
model_outputs
.
append
(
model_output
)
set_spec_step
(
SpecStepKind
.
SCORE_DECODE
)
filtered_model_outputs
=
self
.
_filter_model_output_zero_overhead
(
model_outputs
,
indices_of_seq_with_bonus_tokens
)
return
filtered_model_outputs
,
True
def
_filter_model_output_zero_overhead
(
self
,
expanded_batch_outputs
:
List
[
SamplerOutput
],
output_indices_to_retain
:
List
[
int
])
->
List
[
SamplerOutput
]:
"""
Filters the model output to include only the specified sequence
outputs. This method contracts the expanded batch output from the
model to retain the outputs of only those sequences indicated by the
provided indices.
Args:
expanded_batch_output (List[SamplerOutput]): The expanded output
batch from the model.
output_indices_to_retain (torch.Tensor): Indices of the model
outputs to retain.
Returns:
List[SamplerOutput]: A list containing the filtered model
outputs for the specified indices.
"""
indices_of_seq_with_bonus_tokens
=
async_tensor_h2d
(
output_indices_to_retain
,
torch
.
int32
,
self
.
device
,
True
)
return
[
SamplerOutput
(
outputs
=
[
expanded_batch_output
.
outputs
[
i
]
for
i
in
output_indices_to_retain
]
if
len
(
expanded_batch_output
.
outputs
)
>
0
else
[],
sampled_token_probs
=
(
expanded_batch_output
.
sampled_token_probs
[
indices_of_seq_with_bonus_tokens
]
if
expanded_batch_output
.
sampled_token_probs
is
not
None
else
None
),
logprobs
=
(
expanded_batch_output
.
logprobs
[
indices_of_seq_with_bonus_tokens
]
if
expanded_batch_output
.
logprobs
is
not
None
else
None
),
sampled_token_ids
=
(
expanded_batch_output
.
sampled_token_ids
[
indices_of_seq_with_bonus_tokens
]
if
expanded_batch_output
.
sampled_token_ids
is
not
None
else
None
))
for
expanded_batch_output
in
expanded_batch_outputs
]
\ No newline at end of file
vllm/zero_overhead/spec_decode/spec_decode_worker.py
deleted
100644 → 0
View file @
9bf1b213
This diff is collapsed.
Click to expand it.
vllm/zero_overhead/spec_decode/top1_proproser.py
deleted
100644 → 0
View file @
9bf1b213
import
os
from
typing
import
List
,
Optional
,
Set
,
Tuple
import
torch
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
SequenceGroupMetadata
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeProposer
)
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.spec_decode.util
import
sampler_output_to_torch
from
vllm.utils
import
async_tensor_h2d
from
vllm.zero_overhead.utils
import
record_proposal_lens_list
class
ZeroOverheadTop1Proposer
(
Top1Proposer
):
def
_merge_outputs
(
self
,
batch_size
:
int
,
proposal_len
:
int
,
maybe_sampler_output
:
Optional
[
List
[
SamplerOutput
]],
proposal_lens
:
List
[
int
],
nonzero_proposal_len_indices
:
List
[
int
],
sampler_transposed
:
bool
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""After speculations are produced, merge the speculation results with
the skipped sequences.
"""
if
maybe_sampler_output
is
None
:
# If no speculative tokens, the sampler output will be None.
# In this case we return empty proposals.
proposal_tokens
=
torch
.
tensor
(
-
1
,
dtype
=
torch
.
long
,
device
=
self
.
_device
).
expand
(
batch_size
,
proposal_len
)
proposal_probs
=
torch
.
tensor
(
0
,
dtype
=
torch
.
float32
,
device
=
self
.
_device
).
expand
(
batch_size
,
proposal_len
,
self
.
_vocab_size
)
proposal_lens_tensor
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
self
.
_device
).
expand
(
len
(
proposal_lens
))
return
proposal_tokens
,
proposal_probs
,
proposal_lens_tensor
sampler_output
=
maybe_sampler_output
proposal_tokens
,
proposal_probs
,
*
_
=
sampler_output_to_torch
(
sampler_output
,
sampler_transposed
)
proposal_lens_list
=
[
0
for
i
in
range
(
batch_size
)]
for
indices
in
nonzero_proposal_len_indices
:
proposal_lens_list
[
indices
]
=
proposal_len
record_proposal_lens_list
(
proposal_lens_list
)
nonzero_proposal_len_indices
=
async_tensor_h2d
(
nonzero_proposal_len_indices
,
torch
.
int32
,
self
.
_device
,
True
)
# Now, reformat the output GPU tensors such that each sequence has
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
entire_proposal_tokens
=
proposal_tokens
.
new_full
(
size
=
(
batch_size
,
*
proposal_tokens
.
shape
[
1
:]),
fill_value
=-
1
,
)
entire_proposal_tokens
[
nonzero_proposal_len_indices
]
=
proposal_tokens
entire_proposal_probs
=
proposal_probs
.
new_zeros
(
batch_size
,
*
proposal_probs
.
shape
[
1
:],
)
entire_proposal_probs
[
nonzero_proposal_len_indices
]
=
proposal_probs
proposal_tokens
,
proposal_probs
=
(
entire_proposal_tokens
,
entire_proposal_probs
,
)
proposal_lens_tensor
=
async_tensor_h2d
(
proposal_lens_list
,
torch
.
long
,
self
.
_device
,
True
)
return
proposal_tokens
,
proposal_probs
,
proposal_lens_tensor
\ No newline at end of file
vllm/zero_overhead/stop_check.py
deleted
100644 → 0
View file @
9bf1b213
from
typing
import
Optional
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.lora.request
import
LoRARequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
SequenceStatus
from
vllm.zero_overhead.sequence
import
ZeroOverheadSequence
class
ZeroOverheadStopChecker
(
StopChecker
):
def
__init__
(
self
,
max_model_len
,
get_tokenizer_for_seq
):
super
().
__init__
(
max_model_len
,
get_tokenizer_for_seq
)
def
maybe_stop_sequence
(
self
,
seq
:
ZeroOverheadSequence
,
new_char_count
:
int
,
sampling_params
:
SamplingParams
,
lora_req
:
Optional
[
LoRARequest
]
=
None
,
)
->
None
:
"""Stop the finished sequences.
new_char_count is the number of chars added to the
sequence's output text for the newly generated token
"""
# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
if
seq
.
zero_overhead_get_output_len
()
<
sampling_params
.
min_tokens
:
return
# Check if the sequence has generated the EOS token.
if
((
not
sampling_params
.
ignore_eos
)
and
seq
.
zero_overhead_get_last_token_id
()
==
seq
.
eos_token_id
):
# Remove the last EOS token unless explicitly specified
# This prevents unintended exposure of the EOS token
if
new_char_count
and
(
not
sampling_params
.
include_stop_str_in_output
):
seq
.
output_text
=
seq
.
output_text
[:
-
new_char_count
]
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
return
# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id
=
seq
.
zero_overhead_get_last_token_id
()
if
last_token_id
in
(
sampling_params
.
stop_token_ids
or
()):
if
new_char_count
and
(
not
sampling_params
.
include_stop_str_in_output
):
# Remove last token
seq
.
output_text
=
seq
.
output_text
[:
-
new_char_count
]
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
stop_reason
=
last_token_id
return
# Check if any stop strings are matched.
stop
=
self
.
check_stop_strings
(
seq
.
output_text
,
new_char_count
,
sampling_params
.
stop
,
sampling_params
.
include_stop_str_in_output
)
if
stop
is
not
None
:
stop_str
,
truncate_to
=
stop
if
truncate_to
!=
-
1
:
seq
.
output_text
=
seq
.
output_text
[:
truncate_to
]
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
stop_reason
=
stop_str
return
# Check if the sequence has reached max_model_len.
if
seq
.
zero_overhead_get_len
()
>
self
.
_get_max_model_len
(
lora_req
):
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
return
# Check if the sequence has reached max_tokens.
if
seq
.
zero_overhead_get_output_len
()
==
sampling_params
.
max_tokens
:
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
return
\ No newline at end of file
vllm/zero_overhead/tokenizer.py
deleted
100644 → 0
View file @
9bf1b213
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
VLLM_INVALID_TOKEN_ID
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.detokenizer_utils
import
convert_prompt_ids_to_tokens
,
detokenize_incrementally
from
vllm.zero_overhead.sequence
import
ZeroOverheadSequence
class
ZeroOverheadDetokenizer
(
Detokenizer
):
def
__init__
(
self
,
tokenizer_group
):
super
().
__init__
(
tokenizer_group
)
def
decode_sequence_inplace
(
self
,
seq
:
ZeroOverheadSequence
,
prms
:
SamplingParams
)
->
int
:
"""Decodes the new token for a sequence. In-place operation.
Args:
seq: The sequence to decode.
prms: The sampling parameters used to generate the sequence.
Returns:
The number of characters added to the output text.
"""
eff_length
=
seq
.
get_prompt_len
()
+
seq
.
effective_output_len
all_input_ids
=
seq
.
get_token_ids
()[
:
eff_length
]
token_id_generated_this_iteration
=
all_input_ids
[
-
1
]
tokenizer
=
self
.
get_tokenizer_for_seq
(
seq
)
# Convert prompt token IDs to tokens if necessary.
# Do it here so that we don't have to repeat this
# computation for each logprob.
if
seq
.
tokens
is
None
:
(
seq
.
tokens
,
seq
.
prefix_offset
,
seq
.
read_offset
)
=
convert_prompt_ids_to_tokens
(
tokenizer
=
tokenizer
,
prompt_ids
=
all_input_ids
[:
-
1
],
skip_special_tokens
=
prms
.
skip_special_tokens
,
)
(
new_tokens
,
new_decoded_token_text
,
prefix_offset
,
read_offset
)
=
detokenize_incrementally
(
tokenizer
=
tokenizer
,
all_input_ids
=
all_input_ids
,
prev_tokens
=
seq
.
tokens
,
prefix_offset
=
seq
.
prefix_offset
,
read_offset
=
seq
.
read_offset
,
skip_special_tokens
=
prms
.
skip_special_tokens
,
spaces_between_special_tokens
=
prms
.
spaces_between_special_tokens
,
)
# Decode logprobs
logprobs
=
seq
.
output_logprobs
[
-
1
]
if
logprobs
:
previous_tokens
=
all_input_ids
[:
-
1
]
for
token_id
,
sample_logprob
in
logprobs
.
items
():
# If the token was generated this iteration,
# use the provided text.
if
token_id
==
token_id_generated_this_iteration
:
sample_logprob
.
decoded_token
=
new_decoded_token_text
continue
if
(
sample_logprob
.
decoded_token
is
None
and
token_id
!=
VLLM_INVALID_TOKEN_ID
):
all_input_ids_with_logprob
=
previous_tokens
+
[
token_id
]
(
_
,
new_text
,
_
,
_
)
=
detokenize_incrementally
(
tokenizer
=
tokenizer
,
all_input_ids
=
all_input_ids_with_logprob
,
prev_tokens
=
seq
.
tokens
,
prefix_offset
=
seq
.
prefix_offset
,
read_offset
=
seq
.
read_offset
,
skip_special_tokens
=
prms
.
skip_special_tokens
,
spaces_between_special_tokens
=
prms
.
spaces_between_special_tokens
,
)
sample_logprob
.
decoded_token
=
new_text
seq
.
tokens
.
extend
(
new_tokens
)
seq
.
prefix_offset
=
prefix_offset
seq
.
read_offset
=
read_offset
seq
.
output_text
+=
new_decoded_token_text
return
len
(
new_decoded_token_text
)
\ No newline at end of file
vllm/zero_overhead/utils.py
deleted
100644 → 0
View file @
9bf1b213
from
enum
import
Enum
import
os
import
torch
import
vllm.envs
as
envs
zero_no_thread
=
os
.
environ
.
get
(
'VLLM_ZERO_NO_THREAD'
)
==
'1'
def
is_zero_no_thread
():
return
zero_no_thread
and
envs
.
VLLM_ZERO_OVERHEAD
class
SpecStepKind
(
Enum
):
KIND_DEFAULT
=
0
PREFILL
=
1
FIRST_PROPOSAL
=
2
OTHER_PROPOSAL
=
3
SCORE_DECODE
=
4
class
ZeroOverheadSpecContext
():
def
__init__
(
self
):
self
.
step_kind
=
SpecStepKind
.
KIND_DEFAULT
self
.
last_step
=
SpecStepKind
.
KIND_DEFAULT
self
.
proposal_lens_list
=
None
self
.
proposal_token_ids
=
None
self
.
accepted_token_ids
=
None
self
.
accepted_seq_ids
=
None
spec_context
=
ZeroOverheadSpecContext
()
def
set_spec_step
(
_step
):
global
spec_context
spec_context
.
last_step
=
spec_context
.
step_kind
spec_context
.
step_kind
=
_step
def
get_spec_step
():
return
spec_context
.
step_kind
def
get_spec_last_step
():
return
spec_context
.
last_step
def
record_proposal_lens_list
(
list
):
global
spec_context
spec_context
.
proposal_lens_list
=
list
def
get_proposal_lens_list
():
return
spec_context
.
proposal_lens_list
def
record_proposal_token_ids
(
tensor
):
global
spec_context
spec_context
.
proposal_token_ids
=
tensor
def
get_proposal_token_ids
():
return
spec_context
.
proposal_token_ids
def
record_accepted_token_ids
(
tensor
,
seq_ids
):
global
spec_context
spec_context
.
accepted_token_ids
=
tensor
spec_context
.
accepted_seq_ids
=
seq_ids
def
get_accepted_token_ids
():
return
spec_context
.
accepted_token_ids
,
spec_context
.
accepted_seq_ids
# 零消耗调度不在默认流上推理,用以规避runtime引入的内存申请流同步问题。
alloc_stream
=
{}
def
zero_overhead_stream
(
target_device
):
"""Asynchronously create a tensor and copy it from host to device."""
if
target_device
not
in
alloc_stream
.
keys
():
alloc_stream
[
target_device
]
=
torch
.
cuda
.
Stream
(
device
=
target_device
)
return
alloc_stream
[
target_device
]
vllm/zero_overhead/v1/core.py
deleted
100644 → 0
View file @
9bf1b213
import
torch
from
collections
import
defaultdict
from
typing
import
Optional
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.scheduler
import
Scheduler
from
vllm.v1.engine
import
EngineCoreOutput
,
EngineCoreOutputs
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.spec_decode.metrics
import
SpecDecodingStats
from
vllm.zero_overhead.v1.outputs
import
ZeroV1ModelRunnerOutput
requsets_valid_token_len
=
{}
def
check_stop
(
request
:
Request
,
max_model_len
:
int
,
pooler_output
:
Optional
[
torch
.
Tensor
]
=
None
,
use_valid_token_len
:
bool
=
False
)
->
bool
:
if
use_valid_token_len
:
if
request
.
request_id
not
in
requsets_valid_token_len
:
requsets_valid_token_len
[
request
.
request_id
]
=
0
return
False
valid_output_len
=
requsets_valid_token_len
[
request
.
request_id
]
else
:
valid_output_len
=
request
.
num_output_tokens
valid_num_tokens
=
request
.
num_prompt_tokens
+
valid_output_len
if
(
valid_num_tokens
>=
max_model_len
or
valid_output_len
>=
request
.
max_tokens
):
request
.
status
=
RequestStatus
.
FINISHED_LENGTH_CAPPED
return
True
if
request
.
pooling_params
:
if
pooler_output
is
not
None
:
request
.
status
=
RequestStatus
.
FINISHED_STOPPED
return
True
return
False
sampling_params
=
request
.
sampling_params
assert
sampling_params
is
not
None
last_token_id
=
request
.
output_token_ids
[
valid_output_len
-
1
]
if
(
not
sampling_params
.
ignore_eos
and
last_token_id
==
request
.
eos_token_id
):
request
.
status
=
RequestStatus
.
FINISHED_STOPPED
return
True
if
last_token_id
in
(
sampling_params
.
stop_token_ids
or
()):
request
.
status
=
RequestStatus
.
FINISHED_STOPPED
request
.
stop_reason
=
last_token_id
return
True
return
False
def
zero_overhead_update_from_output
(
scheduler
:
Scheduler
,
scheduler_output
:
SchedulerOutput
,
model_runner_output
:
ZeroV1ModelRunnerOutput
):
global
requsets_valid_token_len
sampled_token_ids
=
model_runner_output
.
sampled_token_ids
spec_token_ids
=
model_runner_output
.
spec_token_ids
logprobs
=
model_runner_output
.
logprobs
prompt_logprobs_dict
=
model_runner_output
.
prompt_logprobs_dict
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
pooler_outputs
=
model_runner_output
.
pooler_output
num_nans_in_logits
=
model_runner_output
.
num_nans_in_logits
new_running
:
list
[
Request
]
=
[]
outputs
:
dict
[
int
,
list
[
EngineCoreOutput
]]
=
defaultdict
(
list
)
spec_decoding_stats
:
Optional
[
SpecDecodingStats
]
=
None
# fix last model out in zero overhead
if
model_runner_output
.
fix_req_ids
is
not
None
:
for
req_idx
,
req_id
in
enumerate
(
model_runner_output
.
fix_req_ids
):
if
req_id
not
in
scheduler
.
requests
:
continue
request
=
scheduler
.
requests
[
req_id
]
generated_token_ids
=
model_runner_output
.
fix_sampled_token_ids
[
req_idx
]
if
req_id
not
in
requsets_valid_token_len
:
requsets_valid_token_len
[
req_id
]
=
0
valid_output_len
=
requsets_valid_token_len
[
req_id
]
fix_offset
=
valid_output_len
-
request
.
num_output_tokens
if
isinstance
(
generated_token_ids
,
int
):
request
.
_output_token_ids
[
fix_offset
]
=
generated_token_ids
request
.
_all_token_ids
[
fix_offset
]
=
generated_token_ids
requsets_valid_token_len
[
req_id
]
+=
1
generated_token_ids
=
[
generated_token_ids
]
else
:
valid_output_end
=
valid_output_len
+
len
(
generated_token_ids
)
-
request
.
num_output_tokens
if
valid_output_end
==
0
:
request
.
_output_token_ids
[
fix_offset
:
]
=
generated_token_ids
request
.
_all_token_ids
[
fix_offset
:
]
=
generated_token_ids
else
:
request
.
_output_token_ids
[
fix_offset
:
valid_output_end
]
=
generated_token_ids
request
.
_all_token_ids
[
fix_offset
:
valid_output_end
]
=
generated_token_ids
requsets_valid_token_len
[
req_id
]
+=
len
(
generated_token_ids
)
stopped
=
False
new_logprobs
=
None
new_token_ids
=
generated_token_ids
kv_transfer_params
=
None
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
for
num_new
,
output_token_id
in
enumerate
(
new_token_ids
,
1
):
stopped
=
check_stop
(
request
,
scheduler
.
max_model_len
,
True
)
if
stopped
:
kv_transfer_params
=
scheduler
.
_free_request
(
request
)
del
new_token_ids
[
num_new
:]
# Trim new tokens if needed.
break
pooler_output
=
None
if
pooler_outputs
:
pooler_output
=
pooler_outputs
[
req_idx
]
stopped
=
check_stop
(
request
,
scheduler
.
max_model_len
,
pooler_output
,
True
)
if
stopped
:
kv_transfer_params
=
scheduler
.
_free_request
(
request
)
# Extract sample logprobs if needed.
if
request
.
sampling_params
is
not
None
\
and
request
.
sampling_params
.
logprobs
is
not
None
and
logprobs
:
# NOTE: once we support N tokens per step (spec decode),
# the outer lists can be of length > 1.
new_logprobs
=
logprobs
.
slice
(
req_idx
,
req_idx
+
1
)
if
new_token_ids
and
scheduler
.
structured_output_manager
.
should_advance
(
request
):
# NOTE: structured_output_request
# should not be None if use_structured_output, we have
# check above, so safe to ignore type warning
request
.
structured_output_request
.
grammar
.
accept_tokens
(
# type: ignore[union-attr]
req_id
,
new_token_ids
)
# spec_token_ids comes from the model runner output
if
num_nans_in_logits
is
not
None
and
req_id
in
num_nans_in_logits
:
request
.
num_nans_in_logits
=
num_nans_in_logits
[
req_id
]
# Get prompt logprobs for this request.
prompt_logprobs_tensors
=
prompt_logprobs_dict
.
get
(
req_id
)
if
new_token_ids
or
pooler_output
is
not
None
\
or
kv_transfer_params
:
# Add EngineCoreOutput for this Request.
outputs
[
request
.
client_index
].
append
(
EngineCoreOutput
(
request_id
=
req_id
,
new_token_ids
=
new_token_ids
,
finish_reason
=
request
.
get_finished_reason
(),
new_logprobs
=
new_logprobs
,
new_prompt_logprobs_tensors
=
prompt_logprobs_tensors
,
pooling_output
=
pooler_output
,
stop_reason
=
request
.
stop_reason
,
events
=
request
.
take_events
(),
kv_transfer_params
=
kv_transfer_params
,
num_cached_tokens
=
request
.
num_cached_tokens
,
))
else
:
assert
not
prompt_logprobs_tensors
# fix last model out in zero overhead
if
model_runner_output
.
fix_draft_req_ids
is
not
None
:
for
req_idx
,
req_id
in
enumerate
(
model_runner_output
.
fix_draft_req_ids
):
if
req_id
not
in
scheduler
.
requests
:
continue
request
=
scheduler
.
requests
[
req_id
]
# Add newly generated spec token ids to the request.
if
model_runner_output
.
fix_draft_tokens_ids
is
not
None
:
if
scheduler
.
structured_output_manager
.
should_advance
(
request
):
metadata
=
request
.
structured_output_request
# Needs to happen after new_token_ids are accepted.
request
.
spec_token_ids
=
metadata
.
grammar
.
validate_tokens
(
# type: ignore[union-attr]
model_runner_output
.
fix_draft_tokens_ids
[
req_idx
])
else
:
request
.
spec_token_ids
=
model_runner_output
.
fix_draft_tokens_ids
[
req_idx
]
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
# loop can be a performance bottleneck. We should do our best to avoid
# expensive operations inside the loop.
for
request
in
scheduler
.
running
:
req_id
=
request
.
request_id
if
request
.
is_finished
():
if
req_id
in
requsets_valid_token_len
:
requsets_valid_token_len
.
pop
(
req_id
)
continue
num_tokens_scheduled
=
num_scheduled_tokens
.
get
(
req_id
,
0
)
if
num_tokens_scheduled
==
0
:
# The request was not scheduled in this step.
new_running
.
append
(
request
)
continue
req_index
=
model_runner_output
.
req_id_to_index
[
req_id
]
generated_token_ids
=
sampled_token_ids
[
req_index
]
if
sampled_token_ids
else
[]
scheduled_spec_token_ids
=
(
scheduler_output
.
scheduled_spec_decode_tokens
.
get
(
req_id
))
if
scheduled_spec_token_ids
:
# num_computed_tokens represents the number of tokens
# processed in the current step, considering scheduled
# tokens and rejections. If some tokens are rejected,
# num_computed_tokens is decreased by the number of rejected
# tokens, where is given by:
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
num_tokens_rejected
=
(
len
(
scheduled_spec_token_ids
)
+
1
-
len
(
generated_token_ids
))
request
.
num_computed_tokens
-=
num_tokens_rejected
spec_decoding_stats
=
scheduler
.
make_spec_decoding_stats
(
spec_decoding_stats
,
num_draft_tokens
=
len
(
scheduled_spec_token_ids
),
num_accepted_tokens
=
len
(
generated_token_ids
)
-
1
)
# NOTE(woosuk): This has to be executed after updating
# `request.num_computed_tokens`.
if
request
.
has_encoder_inputs
:
scheduler
.
_free_encoder_inputs
(
request
)
stopped
=
False
new_logprobs
=
None
new_token_ids
=
generated_token_ids
kv_transfer_params
=
None
# Append generated tokens and check for stop. Note that if
# a request is still being prefilled, we expect the model runner
# to return empty token ids for the request.
for
num_new
,
output_token_id
in
enumerate
(
new_token_ids
,
1
):
request
.
append_output_token_ids
(
output_token_id
)
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
if
model_runner_output
.
is_output_valid
:
stopped
=
check_stop
(
request
,
scheduler
.
max_model_len
,
False
)
if
stopped
:
kv_transfer_params
=
scheduler
.
_free_request
(
request
)
del
new_token_ids
[
num_new
:]
# Trim new tokens if needed.
break
pooler_output
=
None
if
pooler_outputs
:
if
model_runner_output
.
is_output_valid
:
pooler_output
=
pooler_outputs
[
req_index
]
stopped
=
check_stop
(
request
,
scheduler
.
max_model_len
,
pooler_output
,
False
)
if
stopped
:
kv_transfer_params
=
scheduler
.
_free_request
(
request
)
# Extract sample logprobs if needed.
if
request
.
sampling_params
is
not
None
\
and
request
.
sampling_params
.
logprobs
is
not
None
and
logprobs
:
# NOTE: once we support N tokens per step (spec decode),
# the outer lists can be of length > 1.
new_logprobs
=
logprobs
.
slice
(
req_index
,
req_index
+
1
)
if
new_token_ids
and
scheduler
.
structured_output_manager
.
should_advance
(
request
):
# NOTE: structured_output_request
# should not be None if use_structured_output, we have
# check above, so safe to ignore type warning
request
.
structured_output_request
.
grammar
.
accept_tokens
(
# type: ignore[union-attr]
req_id
,
new_token_ids
)
# spec_token_ids comes from the model runner output
if
num_nans_in_logits
is
not
None
and
req_id
in
num_nans_in_logits
:
request
.
num_nans_in_logits
=
num_nans_in_logits
[
req_id
]
# Add newly generated spec token ids to the request.
if
spec_token_ids
is
not
None
:
if
scheduler
.
structured_output_manager
.
should_advance
(
request
):
metadata
=
request
.
structured_output_request
# Needs to happen after new_token_ids are accepted.
request
.
spec_token_ids
=
metadata
.
grammar
.
validate_tokens
(
# type: ignore[union-attr]
spec_token_ids
[
req_index
])
else
:
request
.
spec_token_ids
=
spec_token_ids
[
req_index
]
if
model_runner_output
.
is_output_valid
:
# # Get prompt logprobs for this request.
prompt_logprobs_tensors
=
prompt_logprobs_dict
.
get
(
req_id
)
if
new_token_ids
or
pooler_output
is
not
None
\
or
kv_transfer_params
:
# Add EngineCoreOutput for this Request.
outputs
[
request
.
client_index
].
append
(
EngineCoreOutput
(
request_id
=
req_id
,
new_token_ids
=
new_token_ids
,
finish_reason
=
request
.
get_finished_reason
(),
new_logprobs
=
new_logprobs
,
new_prompt_logprobs_tensors
=
prompt_logprobs_tensors
,
pooling_output
=
pooler_output
,
stop_reason
=
request
.
stop_reason
,
events
=
request
.
take_events
(),
kv_transfer_params
=
kv_transfer_params
,
num_cached_tokens
=
request
.
num_cached_tokens
,
))
if
stopped
:
if
req_id
in
requsets_valid_token_len
:
requsets_valid_token_len
.
pop
(
req_id
)
else
:
new_running
.
append
(
request
)
scheduler
.
running
=
new_running
# KV Connector: update state for finished KV Transfers.
scheduler
.
_update_from_kv_xfer_finished
(
model_runner_output
)
# Create EngineCoreOutputs for all clients that have requests with
# outputs in this step.
engine_core_outputs
=
{
client_index
:
EngineCoreOutputs
(
outputs
=
outs
)
for
client_index
,
outs
in
outputs
.
items
()
}
finished_req_ids
=
scheduler
.
finished_req_ids_dict
if
finished_req_ids
:
# Include ids of requests that finished since last outputs
# were sent.
for
client_index
,
finished_set
in
finished_req_ids
.
items
():
# Set finished request set in EngineCoreOutputs for this client.
if
(
eco
:
=
engine_core_outputs
.
get
(
client_index
))
is
not
None
:
eco
.
finished_requests
=
finished_set
else
:
engine_core_outputs
[
client_index
]
=
EngineCoreOutputs
(
finished_requests
=
finished_set
)
finished_req_ids
.
clear
()
if
engine_core_outputs
:
# Return stats to only one of the front-ends.
next
(
iter
(
engine_core_outputs
.
values
())).
scheduler_stats
=
(
scheduler
.
make_stats
(
spec_decoding_stats
))
return
engine_core_outputs
def
engine_core_step
(
core
)
->
tuple
[
dict
[
int
,
EngineCoreOutputs
],
bool
]:
"""Schedule, execute, and make output.
Returns tuple of outputs and a flag indicating whether the model
was executed.
"""
# Check for any requests remaining in the scheduler - unfinished,
# or finished and not yet removed from the batch.
if
not
core
.
scheduler
.
has_requests
():
return
{},
False
scheduler_output
=
core
.
scheduler
.
schedule
()
model_output
=
core
.
execute_model
(
scheduler_output
)
if
isinstance
(
model_output
,
ZeroV1ModelRunnerOutput
):
engine_core_outputs
=
zero_overhead_update_from_output
(
core
.
scheduler
,
scheduler_output
,
model_output
)
# type: ignore
else
:
engine_core_outputs
=
core
.
scheduler
.
update_from_output
(
scheduler_output
,
model_output
)
# type: ignore
return
(
engine_core_outputs
,
scheduler_output
.
total_num_scheduled_tokens
>
0
)
\ No newline at end of file
vllm/zero_overhead/v1/eagle.py
deleted
100644 → 0
View file @
9bf1b213
import
torch
from
vllm.forward_context
import
set_forward_context
from
vllm.model_executor.models.llama_eagle3
import
Eagle3LlamaForCausalLM
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.v1.attention.backends.mla.common
import
MLACommonMetadata
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.spec_decode.eagle
import
PADDING_SLOT_ID
,
EagleProposer
class
V1ZeroEagleProposer
(
EagleProposer
):
def
__init__
(
self
,
vllm_config
,
device
,
runner
=
None
):
super
().
__init__
(
vllm_config
,
device
,
runner
)
self
.
spec_scheduler_max_num_tokens
=
0
def
propose
(
self
,
# [num_tokens]
target_token_ids
:
torch
.
Tensor
,
# [num_tokens]
target_positions
:
torch
.
Tensor
,
# [num_tokens, hidden_size]
target_hidden_states
:
torch
.
Tensor
,
# [num_tokens]
target_slot_mapping
:
torch
.
Tensor
,
# [batch_size]
next_token_ids
:
torch
.
Tensor
,
# [batch_size + 1] starting with 0
cu_num_tokens
:
torch
.
Tensor
,
# [batch_size, max_num_blocks_per_req]
block_table
:
torch
.
Tensor
,
# [batch_size]
sampling_metadata
:
SamplingMetadata
,
decoding
:
bool
=
False
,
)
->
torch
.
Tensor
:
num_tokens
=
target_token_ids
.
shape
[
0
]
batch_size
=
next_token_ids
.
shape
[
0
]
last_token_indices
=
cu_num_tokens
[
1
:]
-
1
if
self
.
method
==
"eagle3"
:
assert
isinstance
(
self
.
model
,
Eagle3LlamaForCausalLM
)
target_hidden_states
=
self
.
model
.
combine_hidden_states
(
target_hidden_states
)
assert
target_hidden_states
.
shape
[
-
1
]
==
self
.
hidden_size
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
self
.
input_ids
[:
num_tokens
-
1
]
=
target_token_ids
[
1
:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
self
.
input_ids
[
last_token_indices
]
=
next_token_ids
# FA requires seq_len to have dtype int32.
seq_lens
=
(
target_positions
[
last_token_indices
]
+
1
).
int
()
if
self
.
method
in
[
"eagle"
,
"eagle3"
]:
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
max_seq_len
=
seq_lens
.
max
().
item
()
max_num_tokens
=
(
cu_num_tokens
[
1
:]
-
cu_num_tokens
[:
-
1
]).
max
().
item
()
attn_metadata
=
FlashAttentionMetadata
(
num_actual_tokens
=
num_tokens
,
max_query_len
=
max_num_tokens
,
query_start_loc
=
cu_num_tokens
,
max_seq_len
=
max_seq_len
,
seq_lens
=
seq_lens
,
block_table
=
block_table
,
slot_mapping
=
target_slot_mapping
,
# TODO(woosuk): Support cascade attention.
use_cascade
=
False
,
common_prefix_len
=
0
,
cu_prefix_query_lens
=
None
,
prefix_kv_lens
=
None
,
suffix_kv_lens
=
None
,
)
elif
self
.
method
==
"deepseek_mtp"
:
max_query_len
=
self
.
spec_scheduler_max_num_tokens
common_attn_metadata
=
CommonAttentionMetadata
(
query_start_loc
=
cu_num_tokens
,
seq_lens
=
seq_lens
,
num_reqs
=
batch_size
,
num_actual_tokens
=
num_tokens
,
max_query_len
=
max_query_len
,
slot_mapping
=
target_slot_mapping
,
spec_layer_decoding
=
decoding
)
assert
self
.
runner
is
not
None
# FIXME: need to consider multiple kv_cache_groups
attn_metadata
=
self
.
runner
.
attn_metadata_builders
[
0
].
build
(
common_prefix_len
=
0
,
common_attn_metadata
=
common_attn_metadata
)
else
:
raise
ValueError
(
f
"Unsupported method:
{
self
.
method
}
"
)
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
per_layer_attn_metadata
=
{}
for
layer_name
in
self
.
attn_layer_names
:
per_layer_attn_metadata
[
layer_name
]
=
attn_metadata
if
self
.
use_cuda_graph
and
\
num_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]:
num_input_tokens
=
self
.
vllm_config
.
pad_for_cudagraph
(
num_tokens
)
else
:
num_input_tokens
=
num_tokens
# copy inputs to buffer for cudagraph
self
.
positions
[:
num_tokens
]
=
target_positions
self
.
hidden_states
[:
num_tokens
]
=
target_hidden_states
if
(
decoding
and
self
.
use_full_cuda_graph
and
num_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
assert
self
.
attn_metadata_cudagraph
if
self
.
method
in
[
"eagle"
,
"eagle3"
]:
self
.
attn_metadata_cudagraph
.
seq_lens
[:
batch_size
]
=
(
attn_metadata
.
seq_lens
)
self
.
attn_metadata_cudagraph
.
slot_mapping
[:
num_tokens
]
=
(
attn_metadata
.
slot_mapping
)
self
.
attn_metadata_cudagraph
.
query_start_loc
[:
batch_size
+
1
]
=
(
attn_metadata
.
query_start_loc
)
self
.
attn_metadata_cudagraph
.
block_table
[:
batch_size
]
=
(
attn_metadata
.
block_table
)
elif
self
.
method
==
"deepseek_mtp"
:
self
.
attn_metadata_cudagraph
.
num_actual_tokens
=
(
attn_metadata
.
num_actual_tokens
)
self
.
attn_metadata_cudagraph
.
query_start_loc
[:
batch_size
+
1
]
=
(
attn_metadata
.
query_start_loc
)
self
.
attn_metadata_cudagraph
.
slot_mapping
[:
num_tokens
]
=
(
attn_metadata
.
slot_mapping
)
self
.
attn_metadata_cudagraph
.
num_decodes
=
(
attn_metadata
.
num_decodes
)
self
.
attn_metadata_cudagraph
.
num_decode_tokens
=
(
attn_metadata
.
num_decode_tokens
)
self
.
attn_metadata_cudagraph
.
num_prefills
=
(
attn_metadata
.
num_prefills
)
if
attn_metadata
.
decode
is
not
None
:
self
.
attn_metadata_cudagraph
.
decode
.
block_table
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
decode
.
block_table
)
self
.
attn_metadata_cudagraph
.
decode
.
seq_lens
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
decode
.
seq_lens
)
with
set_forward_context
(
per_layer_attn_metadata
,
self
.
vllm_config
,
num_tokens
=
num_input_tokens
,
skip_cuda_graphs
=
not
decoding
):
ret_hidden_states
=
self
.
model
(
self
.
input_ids
[:
num_input_tokens
],
self
.
positions
[:
num_input_tokens
],
self
.
hidden_states
[:
num_input_tokens
],
)
if
self
.
method
==
"deepseek_mtp"
:
last_hidden_states
=
ret_hidden_states
else
:
last_hidden_states
,
hidden_states
=
ret_hidden_states
sample_hidden_states
=
last_hidden_states
[
last_token_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
# Early exit if there is only one draft token to be generated.
if
self
.
num_speculative_tokens
==
1
:
# [batch_size, 1]
return
draft_token_ids
.
view
(
-
1
,
1
)
# TODO: Currently, MTP module released by deepseek only has
# one layer. Adapt this code to support multiple layers once
# there's a multi-layer MTP module.
# Generate the remaining draft tokens.
draft_token_ids_list
=
[
draft_token_ids
]
positions
=
target_positions
[
last_token_indices
]
if
self
.
method
==
"deepseek_mtp"
:
hidden_states
=
last_hidden_states
[
last_token_indices
]
else
:
hidden_states
=
hidden_states
[
last_token_indices
]
if
self
.
use_cuda_graph
and
\
batch_size
<=
self
.
cudagraph_batch_sizes
[
-
1
]:
input_batch_size
=
self
.
vllm_config
.
pad_for_cudagraph
(
batch_size
)
else
:
input_batch_size
=
batch_size
attn_metadata
.
num_actual_tokens
=
batch_size
attn_metadata
.
max_query_len
=
1
attn_metadata
.
query_start_loc
=
self
.
arange
[:
batch_size
+
1
]
if
isinstance
(
attn_metadata
,
MLACommonMetadata
):
attn_metadata
.
num_decodes
=
batch_size
attn_metadata
.
num_decode_tokens
=
batch_size
attn_metadata
.
num_prefills
=
0
block_table
=
self
.
runner
.
attn_metadata_builders
[
0
].
block_table
.
get_device_tensor
()[:
batch_size
,
...]
attn_metadata
.
decode
=
self
.
runner
.
attn_metadata_builders
[
0
].
_build_decode
(
block_table_tensor
=
block_table
,
seq_lens
=
seq_lens
,
)
for
i
in
range
(
self
.
num_speculative_tokens
-
1
):
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default.
input_ids
=
draft_token_ids_list
[
-
1
].
int
()
positions
+=
1
# NOTE(woosuk): We should handle the case where the draft model
# generates tokens beyond the max model length. Since it is complex
# to remove such requests from the batch, we keep them in the batch
# but adjust the position ids and slot mappings to avoid the
# out-of-range access during the model execution. The draft tokens
# generated with this adjustment should be ignored.
exceeds_max_model_len
=
positions
>=
self
.
max_model_len
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
clamped_positions
=
torch
.
where
(
exceeds_max_model_len
,
0
,
positions
)
if
isinstance
(
attn_metadata
,
MLACommonMetadata
):
attn_metadata
.
decode
.
seq_lens
+=
1
else
:
attn_metadata
.
seq_lens
+=
1
# Increment the sequence lengths.
attn_metadata
.
max_seq_len
+=
1
# Consider max model length.
attn_metadata
.
max_seq_len
=
min
(
attn_metadata
.
max_seq_len
,
self
.
max_model_len
)
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
attn_metadata
.
seq_lens
.
masked_fill_
(
exceeds_max_model_len
,
1
)
# Compute the slot mapping.
block_numbers
=
clamped_positions
//
self
.
block_size
block_ids
=
block_table
.
gather
(
dim
=
1
,
index
=
block_numbers
.
view
(
-
1
,
1
))
block_ids
=
block_ids
.
view
(
-
1
)
attn_metadata
.
slot_mapping
=
(
block_ids
*
self
.
block_size
+
clamped_positions
%
self
.
block_size
)
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
attn_metadata
.
slot_mapping
.
masked_fill_
(
exceeds_max_model_len
,
PADDING_SLOT_ID
)
# copy inputs to buffer for cudagraph
self
.
input_ids
[:
batch_size
]
=
input_ids
self
.
positions
[:
batch_size
]
=
clamped_positions
self
.
hidden_states
[:
batch_size
]
=
hidden_states
if
(
self
.
use_full_cuda_graph
and
batch_size
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
assert
self
.
attn_metadata_cudagraph
if
self
.
method
in
[
"eagle"
,
"eagle3"
]:
self
.
attn_metadata_cudagraph
.
seq_lens
[:
batch_size
]
=
(
attn_metadata
.
seq_lens
)
self
.
attn_metadata_cudagraph
.
slot_mapping
[:
batch_size
]
=
(
attn_metadata
.
slot_mapping
)
if
i
==
0
:
self
.
attn_metadata_cudagraph
.
query_start_loc
[:
batch_size
+
1
]
=
(
attn_metadata
.
query_start_loc
)
self
.
attn_metadata_cudagraph
.
block_table
[:
batch_size
]
=
(
attn_metadata
.
block_table
)
elif
self
.
method
==
"deepseek_mtp"
:
self
.
attn_metadata_cudagraph
.
num_actual_tokens
=
(
attn_metadata
.
num_actual_tokens
)
self
.
attn_metadata_cudagraph
.
slot_mapping
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
slot_mapping
)
self
.
attn_metadata_cudagraph
.
num_decodes
=
(
attn_metadata
.
num_decodes
)
self
.
attn_metadata_cudagraph
.
num_decode_tokens
=
(
attn_metadata
.
num_decode_tokens
)
self
.
attn_metadata_cudagraph
.
num_prefills
=
(
attn_metadata
.
num_prefills
)
self
.
attn_metadata_cudagraph
.
decode
.
seq_lens
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
decode
.
seq_lens
)
if
i
==
0
:
self
.
attn_metadata_cudagraph
.
query_start_loc
[:
batch_size
+
1
]
=
(
attn_metadata
.
query_start_loc
)
self
.
attn_metadata_cudagraph
.
decode
.
block_table
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
decode
.
block_table
)
# Run the model.
with
set_forward_context
(
per_layer_attn_metadata
,
self
.
vllm_config
,
num_tokens
=
input_batch_size
):
ret_hidden_states
=
self
.
model
(
self
.
input_ids
[:
input_batch_size
],
self
.
positions
[:
input_batch_size
],
self
.
hidden_states
[:
input_batch_size
],
)
if
self
.
method
==
"deepseek_mtp"
:
last_hidden_states
=
ret_hidden_states
hidden_states
=
last_hidden_states
[:
batch_size
]
else
:
last_hidden_states
,
hidden_states
=
ret_hidden_states
hidden_states
=
hidden_states
[:
batch_size
]
logits
=
self
.
model
.
compute_logits
(
last_hidden_states
[:
batch_size
],
None
)
# TODO(wenlong): get more than one token for tree attention
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
draft_token_ids_list
.
append
(
draft_token_ids
)
# [batch_size, num_speculative_tokens]
draft_token_ids
=
torch
.
stack
(
draft_token_ids_list
,
dim
=
1
)
return
draft_token_ids
\ No newline at end of file
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment