Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cfabf125
"vllm/vscode:/vscode.git/clone" did not exist on "4a30d7e3ccae6e977d728e2157aaa11ac0fed549"
Commit
cfabf125
authored
Aug 27, 2025
by
王敏
Browse files
Merge remote-tracking branch 'origin/v0.9.2-dev' into v0.9.2-dev
parents
dbd0bda6
645fcfd9
Changes
28
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
438 additions
and
1632 deletions
+438
-1632
vllm/compilation/decorators.py
vllm/compilation/decorators.py
+6
-1
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+0
-1544
vllm/two_batch_overlap/two_batch_overlap.py
vllm/two_batch_overlap/two_batch_overlap.py
+2
-1
vllm/two_batch_overlap/v1/model_input_split_v1.py
vllm/two_batch_overlap/v1/model_input_split_v1.py
+38
-15
vllm/two_batch_overlap/v1/two_batch_overlap_v1.py
vllm/two_batch_overlap/v1/two_batch_overlap_v1.py
+4
-8
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+1
-1
vllm/zero_overhead/v1/eagle.py
vllm/zero_overhead/v1/eagle.py
+317
-0
vllm/zero_overhead/v1/gpu_model_runner.py
vllm/zero_overhead/v1/gpu_model_runner.py
+70
-62
No files found.
vllm/compilation/decorators.py
View file @
cfabf125
...
@@ -9,9 +9,10 @@ import torch
...
@@ -9,9 +9,10 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch._dynamo.symbolic_convert
import
InliningInstructionTranslator
from
torch._dynamo.symbolic_convert
import
InliningInstructionTranslator
from
vllm
import
envs
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.wrapper
import
TorchCompileWrapperWithCustomDispatcher
from
vllm.compilation.wrapper
import
TorchCompileWrapperWithCustomDispatcher
from
vllm.forward_context
import
get_profilling
from
vllm.forward_context
import
get_forward_context
,
get_profilling
from
vllm.config
import
CompilationLevel
,
VllmConfig
from
vllm.config
import
CompilationLevel
,
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
...
@@ -170,6 +171,10 @@ def _support_torch_compile(
...
@@ -170,6 +171,10 @@ def _support_torch_compile(
# torch.compiler.is_compiling() means we are inside the compilation
# torch.compiler.is_compiling() means we are inside the compilation
# e.g. TPU has the compilation logic in model runner, so we don't
# e.g. TPU has the compilation logic in model runner, so we don't
# need to compile the model inside.
# need to compile the model inside.
skip_cuda_graphs
=
get_forward_context
().
skip_cuda_graphs
if
envs
.
VLLM_ENABLE_TBO
and
skip_cuda_graphs
:
return
self
.
forward
(
*
args
,
**
kwargs
)
if
self
.
do_not_compile
or
torch
.
compiler
.
is_compiling
()
or
get_profilling
():
if
self
.
do_not_compile
or
torch
.
compiler
.
is_compiling
()
or
get_profilling
():
return
self
.
forward
(
*
args
,
**
kwargs
)
return
self
.
forward
(
*
args
,
**
kwargs
)
...
...
vllm/model_executor/model_loader/loader.py
deleted
100644 → 0
View file @
dbd0bda6
This diff is collapsed.
Click to expand it.
vllm/two_batch_overlap/two_batch_overlap.py
View file @
cfabf125
...
@@ -58,7 +58,8 @@ class TwoBatchOverlap():
...
@@ -58,7 +58,8 @@ class TwoBatchOverlap():
self
.
left_thread
.
start
()
self
.
left_thread
.
start
()
self
.
right_thread
=
threading
.
Thread
(
target
=
self
.
thread_two_batch_overlap
,
args
=
(
self
.
model_input_right_queue
,))
self
.
right_thread
=
threading
.
Thread
(
target
=
self
.
thread_two_batch_overlap
,
args
=
(
self
.
model_input_right_queue
,))
self
.
right_thread
.
start
()
self
.
right_thread
.
start
()
logger
.
info
(
'tbo:two batch overlap start'
)
if
get_tp_group
().
rank
==
0
:
logger
.
info
(
'tbo:two batch overlap start'
)
def
finish_thread
(
self
):
def
finish_thread
(
self
):
self
.
left_thread
.
join
()
self
.
left_thread
.
join
()
...
...
vllm/two_batch_overlap/v1/model_input_split_v1.py
View file @
cfabf125
...
@@ -9,6 +9,7 @@ from vllm.forward_context import set_forward_context
...
@@ -9,6 +9,7 @@ from vllm.forward_context import set_forward_context
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.two_batch_overlap.v1.two_batch_overlap_v1
import
tbo_model_executable_v1
from
vllm.two_batch_overlap.v1.two_batch_overlap_v1
import
tbo_model_executable_v1
from
vllm.utils
import
async_tensor_h2d
from
vllm.utils
import
async_tensor_h2d
from
vllm.v1.attention.backends.mla.common
import
MLACommonMetadataBuilder
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.core.sched.output
import
CachedRequestData
,
SchedulerOutput
from
vllm.v1.core.sched.output
import
CachedRequestData
,
SchedulerOutput
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.outputs
import
ModelRunnerOutput
...
@@ -224,28 +225,45 @@ def prepare_tbo_atten_metadata(
...
@@ -224,28 +225,45 @@ def prepare_tbo_atten_metadata(
# Prepare for cascade attention if enabled & beneficial.
# Prepare for cascade attention if enabled & beneficial.
common_prefix_len
=
0
common_prefix_len
=
0
metadata_builder
=
runner
.
attn_metadata_builders
[
kv_cache_group_id
]
if
runner
.
cascade_attn_enabled
:
if
runner
.
cascade_attn_enabled
:
common_prefix_len
=
runner
.
_compute_cascade_attn_prefix_len
(
common_prefix_len
=
runner
.
_compute_cascade_attn_prefix_len
(
num_scheduled_tokens
,
num_scheduled_tokens
,
scheduler_output
.
scheduler_output
.
num_common_prefix_blocks
[
kv_cache_group_id
],
num_common_prefix_blocks
[
kv_cache_group_id
],
kv_cache_group_spec
.
kv_cache_spec
,
kv_cache_group_spec
.
kv_cache_spec
,
runner
.
attn_
metadata_builder
s
[
kv_cache_group_id
]
,
metadata_builder
,
)
)
if
req_offset
>
0
:
if
req_offset
>
0
:
origin_block_table
=
runner
.
attn_
metadata_builder
s
[
kv_cache_group_id
]
.
block_table
.
block_table
origin_block_table
=
metadata_builder
.
block_table
.
block_table
runner
.
attn_
metadata_builder
s
[
kv_cache_group_id
]
.
block_table
.
block_table
=
origin_block_table
[
req_offset
:,
:]
metadata_builder
.
block_table
.
block_table
=
origin_block_table
[
req_offset
:,
:]
origin_slot_mapping
=
runner
.
attn_
metadata_builder
s
[
kv_cache_group_id
]
.
block_table
.
slot_mapping
origin_slot_mapping
=
metadata_builder
.
block_table
.
slot_mapping
runner
.
attn_
metadata_builder
s
[
kv_cache_group_id
]
.
block_table
.
slot_mapping
=
\
metadata_builder
.
block_table
.
slot_mapping
=
\
origin_slot_mapping
[
input_split
.
scheduler_output_left
.
total_num_scheduled_tokens
:]
origin_slot_mapping
[
input_split
.
scheduler_output_left
.
total_num_scheduled_tokens
:]
if
isinstance
(
metadata_builder
,
MLACommonMetadataBuilder
):
# now support prefill only
_num_decodes_record
=
metadata_builder
.
_num_decodes
_num_prefills_record
=
metadata_builder
.
_num_prefills
_num_decode_tokens_record
=
metadata_builder
.
_num_decode_tokens
_num_prefill_tokens_record
=
metadata_builder
.
_num_prefill_tokens
metadata_builder
.
_num_decodes
=
0
metadata_builder
.
_num_prefills
=
num_reqs
metadata_builder
.
_num_decode_tokens
=
0
metadata_builder
.
_num_prefill_tokens
=
total_num_scheduled_tokens
attn_metadata_i
=
(
attn_metadata_i
=
(
runner
.
attn_
metadata_builder
s
[
kv_cache_group_id
]
.
build
(
metadata_builder
.
build
(
common_prefix_len
=
common_prefix_len
,
common_prefix_len
=
common_prefix_len
,
common_attn_metadata
=
common_attn_metadata
))
# maybe FlashAttentionMetadata
common_attn_metadata
=
common_attn_metadata
))
# maybe FlashAttentionMetadata
if
req_offset
>
0
:
if
req_offset
>
0
:
runner
.
attn_metadata_builders
[
kv_cache_group_id
].
block_table
.
block_table
=
origin_block_table
metadata_builder
.
block_table
.
block_table
=
origin_block_table
runner
.
attn_metadata_builders
[
kv_cache_group_id
].
block_table
.
slot_mapping
=
origin_slot_mapping
metadata_builder
.
block_table
.
slot_mapping
=
origin_slot_mapping
if
isinstance
(
metadata_builder
,
MLACommonMetadataBuilder
):
# now support prefill only
metadata_builder
.
_num_decodes
=
_num_decodes_record
metadata_builder
.
_num_prefills
=
_num_prefills_record
metadata_builder
.
_num_decode_tokens
=
_num_decode_tokens_record
metadata_builder
.
_num_prefill_tokens
=
_num_prefill_tokens_record
for
layer_name
in
kv_cache_group_spec
.
layer_names
:
for
layer_name
in
kv_cache_group_spec
.
layer_names
:
attn_metadata
[
layer_name
]
=
attn_metadata_i
attn_metadata
[
layer_name
]
=
attn_metadata_i
...
@@ -288,12 +306,16 @@ def tbo_split_and_execute_model(
...
@@ -288,12 +306,16 @@ def tbo_split_and_execute_model(
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
Union
[
ModelRunnerOutput
,
IntermediateTensors
]:
)
->
Union
[
ModelRunnerOutput
,
IntermediateTensors
]:
use_tbo
=
False
use_tbo
=
False
if
len
(
scheduler_output
.
num_scheduled_tokens
)
>
1
:
split_scheduler_output
(
runner
,
scheduler_output
)
if
input_split
.
scheduler_output_left
.
total_num_scheduled_tokens
>=
envs
.
VLLM_TBO_MIN_TOKENS
and
\
input_split
.
scheduler_output_right
.
total_num_scheduled_tokens
>=
envs
.
VLLM_TBO_MIN_TOKENS
:
use_tbo
=
True
if
isinstance
(
runner
.
attn_metadata_builders
[
0
],
MLACommonMetadataBuilder
)
and
\
runner
.
attn_metadata_builders
[
0
].
_num_decodes
>
0
:
#is mla decode
use_tbo
=
False
else
:
if
len
(
scheduler_output
.
num_scheduled_tokens
)
>
1
:
split_scheduler_output
(
runner
,
scheduler_output
)
if
input_split
.
scheduler_output_left
.
total_num_scheduled_tokens
>=
envs
.
VLLM_TBO_MIN_TOKENS
and
\
input_split
.
scheduler_output_right
.
total_num_scheduled_tokens
>=
envs
.
VLLM_TBO_MIN_TOKENS
:
use_tbo
=
True
if
use_tbo
:
if
use_tbo
:
num_input_tokens_left
=
input_split
.
scheduler_output_left
.
total_num_scheduled_tokens
num_input_tokens_left
=
input_split
.
scheduler_output_left
.
total_num_scheduled_tokens
num_input_tokens_right
=
num_input_tokens
-
num_input_tokens_left
num_input_tokens_right
=
num_input_tokens
-
num_input_tokens_left
...
@@ -319,7 +341,8 @@ def tbo_split_and_execute_model(
...
@@ -319,7 +341,8 @@ def tbo_split_and_execute_model(
with
set_forward_context
(
attn_metadata
,
with
set_forward_context
(
attn_metadata
,
runner
.
vllm_config
,
runner
.
vllm_config
,
num_tokens
=
num_input_tokens
,
num_tokens
=
num_input_tokens
,
num_tokens_across_dp
=
num_tokens_across_dp
):
num_tokens_across_dp
=
num_tokens_across_dp
,
skip_cuda_graphs
=
True
):
runner
.
maybe_setup_kv_connector
(
scheduler_output
)
runner
.
maybe_setup_kv_connector
(
scheduler_output
)
model_output
=
runner
.
model
(
model_output
=
runner
.
model
(
...
...
vllm/two_batch_overlap/v1/two_batch_overlap_v1.py
View file @
cfabf125
...
@@ -50,7 +50,8 @@ class TwoBatchOverlap():
...
@@ -50,7 +50,8 @@ class TwoBatchOverlap():
self
.
left_thread
.
start
()
self
.
left_thread
.
start
()
self
.
right_thread
=
threading
.
Thread
(
target
=
self
.
thread_two_batch_overlap
,
args
=
(
self
.
model_input_right_queue
,))
self
.
right_thread
=
threading
.
Thread
(
target
=
self
.
thread_two_batch_overlap
,
args
=
(
self
.
model_input_right_queue
,))
self
.
right_thread
.
start
()
self
.
right_thread
.
start
()
logger
.
info
(
'tbo:two batch overlap start'
)
if
get_tp_group
().
rank
==
0
:
logger
.
info
(
'tbo:two batch overlap start'
)
def
finish_thread
(
self
):
def
finish_thread
(
self
):
self
.
left_thread
.
join
()
self
.
left_thread
.
join
()
...
@@ -71,7 +72,6 @@ class TwoBatchOverlap():
...
@@ -71,7 +72,6 @@ class TwoBatchOverlap():
init_tbo_forward_context
(
False
,
self
.
right_tid
)
init_tbo_forward_context
(
False
,
self
.
right_tid
)
with
torch
.
cuda
.
stream
(
tbo_step_stream
):
with
torch
.
cuda
.
stream
(
tbo_step_stream
):
queue
.
get
()
queue
.
get
()
profile
.
ProfRangePush
(
'start'
)
self
.
tbo_thread_synchronize
(
tid
)
self
.
tbo_thread_synchronize
(
tid
)
if
is_left_thread
:
if
is_left_thread
:
attn_metadata
=
self
.
attn_metadata_left
attn_metadata
=
self
.
attn_metadata_left
...
@@ -90,7 +90,8 @@ class TwoBatchOverlap():
...
@@ -90,7 +90,8 @@ class TwoBatchOverlap():
with
set_forward_context
(
attn_metadata
,
with
set_forward_context
(
attn_metadata
,
self
.
model_runner
.
vllm_config
,
self
.
model_runner
.
vllm_config
,
num_tokens
=
num_input_tokens
,
num_tokens
=
num_input_tokens
,
num_tokens_across_dp
=
self
.
num_tokens_across_dp
):
num_tokens_across_dp
=
self
.
num_tokens_across_dp
,
skip_cuda_graphs
=
True
):
model_output
=
self
.
model_runner
.
model
(
model_output
=
self
.
model_runner
.
model
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
positions
=
positions
,
positions
=
positions
,
...
@@ -102,22 +103,17 @@ class TwoBatchOverlap():
...
@@ -102,22 +103,17 @@ class TwoBatchOverlap():
self
.
states_left_queue
.
put
(
model_output
)
self
.
states_left_queue
.
put
(
model_output
)
else
:
else
:
self
.
states_right_queue
.
put
(
model_output
)
self
.
states_right_queue
.
put
(
model_output
)
profile
.
ProfRangePop
()
def
tbo_thread_synchronize
(
self
,
tid
):
def
tbo_thread_synchronize
(
self
,
tid
):
if
tid
==
self
.
left_tid
:
if
tid
==
self
.
left_tid
:
if
not
self
.
left_first
:
if
not
self
.
left_first
:
self
.
sem_right
.
release
()
self
.
sem_right
.
release
()
self
.
left_first
=
False
self
.
left_first
=
False
profile
.
ProfRangePop
()
self
.
sem_left
.
acquire
()
self
.
sem_left
.
acquire
()
profile
.
ProfRangePush
(
'left'
)
return
self
.
event_left_c2t
,
self
.
event_left_t2c
return
self
.
event_left_c2t
,
self
.
event_left_t2c
else
:
else
:
self
.
sem_left
.
release
()
self
.
sem_left
.
release
()
profile
.
ProfRangePop
()
self
.
sem_right
.
acquire
()
self
.
sem_right
.
acquire
()
profile
.
ProfRangePush
(
'right'
)
return
self
.
event_right_c2t
,
self
.
event_right_t2c
return
self
.
event_right_c2t
,
self
.
event_right_t2c
def
set_model_input
(
self
,
def
set_model_input
(
self
,
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
cfabf125
...
@@ -1373,7 +1373,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1373,7 +1373,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# compiled with full CUDA graphs, we have to skip them entirely.
# compiled with full CUDA graphs, we have to skip them entirely.
skip_cuda_graphs
=
self
.
full_cuda_graph
and
not
attention_cuda_graphs
skip_cuda_graphs
=
self
.
full_cuda_graph
and
not
attention_cuda_graphs
if
envs
.
VLLM_ENABLE_TBO
and
not
self
.
use_cuda_graph
:
if
envs
.
VLLM_ENABLE_TBO
and
(
not
self
.
use_cuda_graph
or
skip_cuda_graphs
)
:
model_output
,
finished_sending
,
finished_recving
=
\
model_output
,
finished_sending
,
finished_recving
=
\
tbo_split_and_execute_model
(
self
,
attn_metadata
,
num_input_tokens
,
tbo_split_and_execute_model
(
self
,
attn_metadata
,
num_input_tokens
,
num_tokens_across_dp
,
input_ids
,
positions
,
num_tokens_across_dp
,
input_ids
,
positions
,
...
...
vllm/zero_overhead/v1/eagle.py
0 → 100644
View file @
cfabf125
import
torch
from
vllm.forward_context
import
set_forward_context
from
vllm.model_executor.models.llama_eagle3
import
Eagle3LlamaForCausalLM
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.v1.attention.backends.mla.common
import
MLACommonMetadata
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.spec_decode.eagle
import
PADDING_SLOT_ID
,
EagleProposer
class
V1ZeroEagleProposer
(
EagleProposer
):
def
__init__
(
self
,
vllm_config
,
device
,
runner
=
None
):
super
().
__init__
(
vllm_config
,
device
,
runner
)
self
.
spec_scheduler_max_num_tokens
=
0
def
propose
(
self
,
# [num_tokens]
target_token_ids
:
torch
.
Tensor
,
# [num_tokens]
target_positions
:
torch
.
Tensor
,
# [num_tokens, hidden_size]
target_hidden_states
:
torch
.
Tensor
,
# [num_tokens]
target_slot_mapping
:
torch
.
Tensor
,
# [batch_size]
next_token_ids
:
torch
.
Tensor
,
# [batch_size + 1] starting with 0
cu_num_tokens
:
torch
.
Tensor
,
# [batch_size, max_num_blocks_per_req]
block_table
:
torch
.
Tensor
,
# [batch_size]
sampling_metadata
:
SamplingMetadata
,
decoding
:
bool
=
False
,
)
->
torch
.
Tensor
:
num_tokens
=
target_token_ids
.
shape
[
0
]
batch_size
=
next_token_ids
.
shape
[
0
]
last_token_indices
=
cu_num_tokens
[
1
:]
-
1
if
self
.
method
==
"eagle3"
:
assert
isinstance
(
self
.
model
,
Eagle3LlamaForCausalLM
)
target_hidden_states
=
self
.
model
.
combine_hidden_states
(
target_hidden_states
)
assert
target_hidden_states
.
shape
[
-
1
]
==
self
.
hidden_size
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
self
.
input_ids
[:
num_tokens
-
1
]
=
target_token_ids
[
1
:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
self
.
input_ids
[
last_token_indices
]
=
next_token_ids
# FA requires seq_len to have dtype int32.
seq_lens
=
(
target_positions
[
last_token_indices
]
+
1
).
int
()
if
self
.
method
in
[
"eagle"
,
"eagle3"
]:
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
max_seq_len
=
seq_lens
.
max
().
item
()
max_num_tokens
=
(
cu_num_tokens
[
1
:]
-
cu_num_tokens
[:
-
1
]).
max
().
item
()
attn_metadata
=
FlashAttentionMetadata
(
num_actual_tokens
=
num_tokens
,
max_query_len
=
max_num_tokens
,
query_start_loc
=
cu_num_tokens
,
max_seq_len
=
max_seq_len
,
seq_lens
=
seq_lens
,
block_table
=
block_table
,
slot_mapping
=
target_slot_mapping
,
# TODO(woosuk): Support cascade attention.
use_cascade
=
False
,
common_prefix_len
=
0
,
cu_prefix_query_lens
=
None
,
prefix_kv_lens
=
None
,
suffix_kv_lens
=
None
,
)
elif
self
.
method
==
"deepseek_mtp"
:
max_query_len
=
self
.
spec_scheduler_max_num_tokens
common_attn_metadata
=
CommonAttentionMetadata
(
query_start_loc
=
cu_num_tokens
,
seq_lens
=
seq_lens
,
num_reqs
=
batch_size
,
num_actual_tokens
=
num_tokens
,
max_query_len
=
max_query_len
,
slot_mapping
=
target_slot_mapping
,
spec_layer_decoding
=
decoding
)
assert
self
.
runner
is
not
None
# FIXME: need to consider multiple kv_cache_groups
attn_metadata
=
self
.
runner
.
attn_metadata_builders
[
0
].
build
(
common_prefix_len
=
0
,
common_attn_metadata
=
common_attn_metadata
)
else
:
raise
ValueError
(
f
"Unsupported method:
{
self
.
method
}
"
)
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
per_layer_attn_metadata
=
{}
for
layer_name
in
self
.
attn_layer_names
:
per_layer_attn_metadata
[
layer_name
]
=
attn_metadata
if
self
.
use_cuda_graph
and
\
num_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]:
num_input_tokens
=
self
.
vllm_config
.
pad_for_cudagraph
(
num_tokens
)
else
:
num_input_tokens
=
num_tokens
# copy inputs to buffer for cudagraph
self
.
positions
[:
num_tokens
]
=
target_positions
self
.
hidden_states
[:
num_tokens
]
=
target_hidden_states
if
(
decoding
and
self
.
use_full_cuda_graph
and
num_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
assert
self
.
attn_metadata_cudagraph
if
self
.
method
in
[
"eagle"
,
"eagle3"
]:
self
.
attn_metadata_cudagraph
.
seq_lens
[:
batch_size
]
=
(
attn_metadata
.
seq_lens
)
self
.
attn_metadata_cudagraph
.
slot_mapping
[:
num_tokens
]
=
(
attn_metadata
.
slot_mapping
)
self
.
attn_metadata_cudagraph
.
query_start_loc
[:
batch_size
+
1
]
=
(
attn_metadata
.
query_start_loc
)
self
.
attn_metadata_cudagraph
.
block_table
[:
batch_size
]
=
(
attn_metadata
.
block_table
)
elif
self
.
method
==
"deepseek_mtp"
:
self
.
attn_metadata_cudagraph
.
num_actual_tokens
=
(
attn_metadata
.
num_actual_tokens
)
self
.
attn_metadata_cudagraph
.
query_start_loc
[:
batch_size
+
1
]
=
(
attn_metadata
.
query_start_loc
)
self
.
attn_metadata_cudagraph
.
slot_mapping
[:
num_tokens
]
=
(
attn_metadata
.
slot_mapping
)
self
.
attn_metadata_cudagraph
.
num_decodes
=
(
attn_metadata
.
num_decodes
)
self
.
attn_metadata_cudagraph
.
num_decode_tokens
=
(
attn_metadata
.
num_decode_tokens
)
self
.
attn_metadata_cudagraph
.
num_prefills
=
(
attn_metadata
.
num_prefills
)
if
attn_metadata
.
decode
is
not
None
:
self
.
attn_metadata_cudagraph
.
decode
.
block_table
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
decode
.
block_table
)
self
.
attn_metadata_cudagraph
.
decode
.
seq_lens
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
decode
.
seq_lens
)
with
set_forward_context
(
per_layer_attn_metadata
,
self
.
vllm_config
,
num_tokens
=
num_input_tokens
,
skip_cuda_graphs
=
not
decoding
):
ret_hidden_states
=
self
.
model
(
self
.
input_ids
[:
num_input_tokens
],
self
.
positions
[:
num_input_tokens
],
self
.
hidden_states
[:
num_input_tokens
],
)
if
self
.
method
==
"deepseek_mtp"
:
last_hidden_states
=
ret_hidden_states
else
:
last_hidden_states
,
hidden_states
=
ret_hidden_states
sample_hidden_states
=
last_hidden_states
[
last_token_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
# Early exit if there is only one draft token to be generated.
if
self
.
num_speculative_tokens
==
1
:
# [batch_size, 1]
return
draft_token_ids
.
view
(
-
1
,
1
)
# TODO: Currently, MTP module released by deepseek only has
# one layer. Adapt this code to support multiple layers once
# there's a multi-layer MTP module.
# Generate the remaining draft tokens.
draft_token_ids_list
=
[
draft_token_ids
]
positions
=
target_positions
[
last_token_indices
]
if
self
.
method
==
"deepseek_mtp"
:
hidden_states
=
last_hidden_states
[
last_token_indices
]
else
:
hidden_states
=
hidden_states
[
last_token_indices
]
if
self
.
use_cuda_graph
and
\
batch_size
<=
self
.
cudagraph_batch_sizes
[
-
1
]:
input_batch_size
=
self
.
vllm_config
.
pad_for_cudagraph
(
batch_size
)
else
:
input_batch_size
=
batch_size
attn_metadata
.
num_actual_tokens
=
batch_size
attn_metadata
.
max_query_len
=
1
attn_metadata
.
query_start_loc
=
self
.
arange
[:
batch_size
+
1
]
if
isinstance
(
attn_metadata
,
MLACommonMetadata
):
attn_metadata
.
num_decodes
=
batch_size
attn_metadata
.
num_decode_tokens
=
batch_size
attn_metadata
.
num_prefills
=
0
block_table
=
self
.
runner
.
attn_metadata_builders
[
0
].
block_table
.
get_device_tensor
()[:
batch_size
,
...]
attn_metadata
.
decode
=
self
.
runner
.
attn_metadata_builders
[
0
].
_build_decode
(
block_table_tensor
=
block_table
,
seq_lens
=
seq_lens
,
)
for
i
in
range
(
self
.
num_speculative_tokens
-
1
):
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default.
input_ids
=
draft_token_ids_list
[
-
1
].
int
()
positions
+=
1
# NOTE(woosuk): We should handle the case where the draft model
# generates tokens beyond the max model length. Since it is complex
# to remove such requests from the batch, we keep them in the batch
# but adjust the position ids and slot mappings to avoid the
# out-of-range access during the model execution. The draft tokens
# generated with this adjustment should be ignored.
exceeds_max_model_len
=
positions
>=
self
.
max_model_len
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
clamped_positions
=
torch
.
where
(
exceeds_max_model_len
,
0
,
positions
)
if
isinstance
(
attn_metadata
,
MLACommonMetadata
):
attn_metadata
.
decode
.
seq_lens
+=
1
else
:
attn_metadata
.
seq_lens
+=
1
# Increment the sequence lengths.
attn_metadata
.
max_seq_len
+=
1
# Consider max model length.
attn_metadata
.
max_seq_len
=
min
(
attn_metadata
.
max_seq_len
,
self
.
max_model_len
)
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
attn_metadata
.
seq_lens
.
masked_fill_
(
exceeds_max_model_len
,
1
)
# Compute the slot mapping.
block_numbers
=
clamped_positions
//
self
.
block_size
block_ids
=
block_table
.
gather
(
dim
=
1
,
index
=
block_numbers
.
view
(
-
1
,
1
))
block_ids
=
block_ids
.
view
(
-
1
)
attn_metadata
.
slot_mapping
=
(
block_ids
*
self
.
block_size
+
clamped_positions
%
self
.
block_size
)
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
attn_metadata
.
slot_mapping
.
masked_fill_
(
exceeds_max_model_len
,
PADDING_SLOT_ID
)
# copy inputs to buffer for cudagraph
self
.
input_ids
[:
batch_size
]
=
input_ids
self
.
positions
[:
batch_size
]
=
clamped_positions
self
.
hidden_states
[:
batch_size
]
=
hidden_states
if
(
self
.
use_full_cuda_graph
and
batch_size
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
assert
self
.
attn_metadata_cudagraph
if
self
.
method
in
[
"eagle"
,
"eagle3"
]:
self
.
attn_metadata_cudagraph
.
seq_lens
[:
batch_size
]
=
(
attn_metadata
.
seq_lens
)
self
.
attn_metadata_cudagraph
.
slot_mapping
[:
batch_size
]
=
(
attn_metadata
.
slot_mapping
)
if
i
==
0
:
self
.
attn_metadata_cudagraph
.
query_start_loc
[:
batch_size
+
1
]
=
(
attn_metadata
.
query_start_loc
)
self
.
attn_metadata_cudagraph
.
block_table
[:
batch_size
]
=
(
attn_metadata
.
block_table
)
elif
self
.
method
==
"deepseek_mtp"
:
self
.
attn_metadata_cudagraph
.
num_actual_tokens
=
(
attn_metadata
.
num_actual_tokens
)
self
.
attn_metadata_cudagraph
.
slot_mapping
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
slot_mapping
)
self
.
attn_metadata_cudagraph
.
num_decodes
=
(
attn_metadata
.
num_decodes
)
self
.
attn_metadata_cudagraph
.
num_decode_tokens
=
(
attn_metadata
.
num_decode_tokens
)
self
.
attn_metadata_cudagraph
.
num_prefills
=
(
attn_metadata
.
num_prefills
)
self
.
attn_metadata_cudagraph
.
decode
.
seq_lens
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
decode
.
seq_lens
)
if
i
==
0
:
self
.
attn_metadata_cudagraph
.
query_start_loc
[:
batch_size
+
1
]
=
(
attn_metadata
.
query_start_loc
)
self
.
attn_metadata_cudagraph
.
decode
.
block_table
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
decode
.
block_table
)
# Run the model.
with
set_forward_context
(
per_layer_attn_metadata
,
self
.
vllm_config
,
num_tokens
=
input_batch_size
):
ret_hidden_states
=
self
.
model
(
self
.
input_ids
[:
input_batch_size
],
self
.
positions
[:
input_batch_size
],
self
.
hidden_states
[:
input_batch_size
],
)
if
self
.
method
==
"deepseek_mtp"
:
last_hidden_states
=
ret_hidden_states
hidden_states
=
last_hidden_states
[:
batch_size
]
else
:
last_hidden_states
,
hidden_states
=
ret_hidden_states
hidden_states
=
hidden_states
[:
batch_size
]
logits
=
self
.
model
.
compute_logits
(
last_hidden_states
[:
batch_size
],
None
)
# TODO(wenlong): get more than one token for tree attention
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
draft_token_ids_list
.
append
(
draft_token_ids
)
# [batch_size, num_speculative_tokens]
draft_token_ids
=
torch
.
stack
(
draft_token_ids_list
,
dim
=
1
)
return
draft_token_ids
vllm/zero_overhead/v1/gpu_model_runner.py
View file @
cfabf125
...
@@ -18,6 +18,7 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
...
@@ -18,6 +18,7 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
from
vllm.zero_overhead.v1.eagle
import
V1ZeroEagleProposer
from
vllm.zero_overhead.v1.outputs
import
ZeroV1ModelRunnerOutput
from
vllm.zero_overhead.v1.outputs
import
ZeroV1ModelRunnerOutput
from
vllm.profiler.prof
import
profile
from
vllm.profiler.prof
import
profile
from
vllm.two_batch_overlap.v1.model_input_split_v1
import
tbo_split_and_execute_model
from
vllm.two_batch_overlap.v1.model_input_split_v1
import
tbo_split_and_execute_model
...
@@ -31,10 +32,15 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -31,10 +32,15 @@ class V1ZeroModelRunner(GPUModelRunner):
self
.
last_sampled_token_lens
=
[]
self
.
last_sampled_token_lens
=
[]
self
.
last_sampler_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
last_sampler_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
last_sampler_host_tokens
=
None
self
.
last_sampler_host_tokens
=
None
self
.
token_ids_cpu_fix_recod
e
=
[]
self
.
token_ids_cpu_fix_reco
r
d
=
[]
self
.
last_draft_token_ids
=
None
self
.
last_draft_token_ids
=
None
self
.
last_draft_host_tokens
=
None
self
.
last_draft_host_tokens
=
None
self
.
last_draft_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
last_draft_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
spec_sampler_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
spec_scheduler_max_num_tokens
=
0
if
hasattr
(
self
,
'drafter'
)
and
isinstance
(
self
.
drafter
,
EagleProposer
):
self
.
drafter
=
V1ZeroEagleProposer
(
self
.
vllm_config
,
self
.
device
,
self
)
def
_prepare_inputs
(
def
_prepare_inputs
(
self
,
self
,
...
@@ -62,6 +68,7 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -62,6 +68,7 @@ class V1ZeroModelRunner(GPUModelRunner):
tokens
=
[
scheduler_output
.
num_scheduled_tokens
[
i
]
for
i
in
req_ids
]
tokens
=
[
scheduler_output
.
num_scheduled_tokens
[
i
]
for
i
in
req_ids
]
num_scheduled_tokens
=
np
.
array
(
tokens
,
dtype
=
np
.
int32
)
num_scheduled_tokens
=
np
.
array
(
tokens
,
dtype
=
np
.
int32
)
max_num_scheduled_tokens
=
max
(
tokens
)
max_num_scheduled_tokens
=
max
(
tokens
)
self
.
spec_scheduler_max_num_tokens
=
max_num_scheduled_tokens
# Get request indices.
# Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
...
@@ -281,7 +288,8 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -281,7 +288,8 @@ class V1ZeroModelRunner(GPUModelRunner):
def
propose_draft_token_ids
(
def
propose_draft_token_ids
(
self
,
self
,
scheduler_output
:
"SchedulerOutput"
,
scheduler_output
:
"SchedulerOutput"
,
sampled_token_ids
:
list
[
list
[
int
]],
num_accepted_tokens_tensor
:
torch
.
Tensor
,
sampled_token_ids
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
sample_hidden_states
:
torch
.
Tensor
,
sample_hidden_states
:
torch
.
Tensor
,
...
@@ -317,26 +325,8 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -317,26 +325,8 @@ class V1ZeroModelRunner(GPUModelRunner):
elif
self
.
speculative_config
.
use_eagle
():
elif
self
.
speculative_config
.
use_eagle
():
assert
isinstance
(
self
.
drafter
,
EagleProposer
)
assert
isinstance
(
self
.
drafter
,
EagleProposer
)
# TODO(woosuk): Refactor the loop.
# TODO(woosuk): Refactor the loop.
if
self
.
last_sampled_token_ids
is
not
None
:
row_indices
=
torch
.
arange
(
sampled_token_ids
.
size
(
0
),
device
=
sampled_token_ids
.
device
)
next_token_ids
=
self
.
last_sampled_token_ids
.
flatten
()
next_token_ids
=
sampled_token_ids
[
row_indices
,
num_accepted_tokens_tensor
].
flatten
()
else
:
next_token_ids
:
list
[
int
]
=
[]
for
i
,
token_ids
in
enumerate
(
sampled_token_ids
):
if
token_ids
:
# Common case.
next_token_id
=
token_ids
[
-
1
]
else
:
# Partial prefill (rare case).
# Get the next token id from the request state.
req_id
=
self
.
input_batch
.
req_ids
[
i
]
req_state
=
self
.
requests
[
req_id
]
seq_len
=
(
req_state
.
num_computed_tokens
+
scheduler_output
.
num_scheduled_tokens
[
req_id
])
next_token_id
=
req_state
.
get_token_id
(
seq_len
)
next_token_ids
.
append
(
next_token_id
)
next_token_ids
=
torch
.
tensor
(
next_token_ids
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
# At this moment, we assume all eagle layers belong to the same KV
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
# cache group, thus using the same attention metadata.
eagle_attn_metadata
=
attn_metadata
[
eagle_attn_metadata
=
attn_metadata
[
...
@@ -348,6 +338,7 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -348,6 +338,7 @@ class V1ZeroModelRunner(GPUModelRunner):
else
:
else
:
block_table
=
None
block_table
=
None
spec_scheduler_max_num_tokens
=
self
.
spec_scheduler_max_num_tokens
if
spec_decode_metadata
is
None
:
if
spec_decode_metadata
is
None
:
# input_ids can be None for multimodal models.
# input_ids can be None for multimodal models.
target_token_ids
=
self
.
input_ids
[:
num_scheduled_tokens
]
target_token_ids
=
self
.
input_ids
[:
num_scheduled_tokens
]
...
@@ -363,16 +354,11 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -363,16 +354,11 @@ class V1ZeroModelRunner(GPUModelRunner):
cu_num_tokens
=
eagle_attn_metadata
.
query_start_loc
cu_num_tokens
=
eagle_attn_metadata
.
query_start_loc
else
:
else
:
# TODO(woosuk): Refactor this.
# TODO(woosuk): Refactor this.
num_accepted_tokens
=
[
len
(
s
)
-
1
for
s
in
sampled_token_ids
]
num_accepted_tokens_tensor
=
async_tensor_h2d
(
num_accepted_tokens
,
dtype
=
torch
.
int32
,
target_device
=
self
.
device
,
pin_memory
=
True
)
cu_num_tokens
,
token_indices
=
self
.
drafter
.
prepare_inputs
(
cu_num_tokens
,
token_indices
=
self
.
drafter
.
prepare_inputs
(
eagle_attn_metadata
.
query_start_loc
,
eagle_attn_metadata
.
query_start_loc
,
num_accepted_tokens_tensor
,
num_accepted_tokens_tensor
,
)
)
spec_scheduler_max_num_tokens
=
1
target_token_ids
=
self
.
input_ids
[
token_indices
]
target_token_ids
=
self
.
input_ids
[
token_indices
]
# TODO(woosuk): Support M-RoPE.
# TODO(woosuk): Support M-RoPE.
target_positions
=
self
.
positions
[
token_indices
]
target_positions
=
self
.
positions
[
token_indices
]
...
@@ -383,6 +369,7 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -383,6 +369,7 @@ class V1ZeroModelRunner(GPUModelRunner):
target_hidden_states
=
hidden_states
[
token_indices
]
target_hidden_states
=
hidden_states
[
token_indices
]
target_slot_mapping
=
eagle_attn_metadata
.
slot_mapping
[
target_slot_mapping
=
eagle_attn_metadata
.
slot_mapping
[
token_indices
]
token_indices
]
self
.
drafter
.
spec_scheduler_max_num_tokens
=
spec_scheduler_max_num_tokens
draft_token_ids
=
self
.
drafter
.
propose
(
draft_token_ids
=
self
.
drafter
.
propose
(
target_token_ids
=
target_token_ids
,
target_token_ids
=
target_token_ids
,
target_positions
=
target_positions
,
target_positions
=
target_positions
,
...
@@ -392,7 +379,7 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -392,7 +379,7 @@ class V1ZeroModelRunner(GPUModelRunner):
cu_num_tokens
=
cu_num_tokens
,
cu_num_tokens
=
cu_num_tokens
,
block_table
=
block_table
,
block_table
=
block_table
,
sampling_metadata
=
sampling_metadata
,
sampling_metadata
=
sampling_metadata
,
decoding
=
spec_decode_metadata
is
not
None
decoding
=
spec_decode_metadata
is
not
None
,
)
)
spec_token_ids
=
np
.
ones
(
draft_token_ids
.
shape
,
dtype
=
int
).
tolist
()
spec_token_ids
=
np
.
ones
(
draft_token_ids
.
shape
,
dtype
=
int
).
tolist
()
self
.
last_draft_token_ids
=
draft_token_ids
self
.
last_draft_token_ids
=
draft_token_ids
...
@@ -486,7 +473,7 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -486,7 +473,7 @@ class V1ZeroModelRunner(GPUModelRunner):
# compiled with full CUDA graphs, we have to skip them entirely.
# compiled with full CUDA graphs, we have to skip them entirely.
skip_cuda_graphs
=
self
.
full_cuda_graph
and
not
attention_cuda_graphs
skip_cuda_graphs
=
self
.
full_cuda_graph
and
not
attention_cuda_graphs
if
envs
.
VLLM_ENABLE_TBO
and
not
self
.
use_cuda_graph
:
if
envs
.
VLLM_ENABLE_TBO
and
(
not
self
.
use_cuda_graph
or
skip_cuda_graphs
)
:
model_output
,
finished_sending
,
finished_recving
=
\
model_output
,
finished_sending
,
finished_recving
=
\
tbo_split_and_execute_model
(
self
,
attn_metadata
,
num_input_tokens
,
tbo_split_and_execute_model
(
self
,
attn_metadata
,
num_input_tokens
,
num_tokens_across_dp
,
input_ids
,
positions
,
num_tokens_across_dp
,
input_ids
,
positions
,
...
@@ -622,22 +609,49 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -622,22 +609,49 @@ class V1ZeroModelRunner(GPUModelRunner):
scheduler_output
,
scheduler_output
,
)
)
# Get the valid generated tokens.
sampled_token_ids
=
sampler_output
.
sampled_token_ids
max_gen_len
=
sampled_token_ids
.
shape
[
-
1
]
fix_req_ids
=
None
fix_req_ids
=
None
fix_sampled_token_ids
=
None
fix_sampled_token_ids
=
None
fix_draft_token_ids
=
None
fix_draft_token_ids
=
None
fix_draft_req_ids
=
self
.
last_sampled_req_ids
fix_draft_req_ids
=
self
.
last_sampled_req_ids
is_output_valid
=
False
is_output_valid
=
False
# Get the valid generated tokens.
sampled_token_ids
=
sampler_output
.
sampled_token_ids
max_gen_len
=
sampled_token_ids
.
shape
[
-
1
]
if
not
self
.
speculative_config
:
# Speculative decoding is not enabled.
spec_token_ids
=
None
fix_draft_req_ids
=
None
else
:
sampled_token_ids_cpu
=
sampled_token_ids
.
to
(
'cpu'
,
non_blocking
=
True
)
self
.
spec_sampler_event
.
record
()
if
self
.
last_draft_host_tokens
is
not
None
:
self
.
last_draft_event
.
synchronize
()
fix_draft_token_ids
=
self
.
last_draft_host_tokens
.
tolist
()
mask
=
(
sampled_token_ids
==
-
1
)
mask_int
=
mask
.
int
()
first_neg_one_indices
=
torch
.
argmax
(
mask_int
,
dim
=
1
)
num_accepted_tokens_tensor
=
torch
.
where
(
torch
.
any
(
mask
,
dim
=
1
),
first_neg_one_indices
,
sampled_token_ids
.
size
(
1
))
-
1
spec_token_ids
=
self
.
propose_draft_token_ids
(
scheduler_output
,
num_accepted_tokens_tensor
,
sampled_token_ids
,
sampling_metadata
,
hidden_states
,
sample_hidden_states
,
aux_hidden_states
,
spec_decode_metadata
,
attn_metadata
,
)
if
self
.
speculative_config
:
if
self
.
speculative_config
:
self
.
spec_sampler_event
.
synchronize
()
if
max_gen_len
==
1
:
if
max_gen_len
==
1
:
valid_sampled_token_ids
=
sampled_token_ids
.
tolist
()
valid_sampled_token_ids
=
sampled_token_ids
_cpu
.
tolist
()
else
:
else
:
# Includes spec decode tokens.
# Includes spec decode tokens.
valid_sampled_token_ids
=
self
.
rejection_sampler
.
parse_output
(
valid_sampled_token_ids
=
self
.
rejection_sampler
.
parse_output
(
sampled_token_ids
,
sampled_token_ids
_cpu
,
self
.
input_batch
.
vocab_size
,
self
.
input_batch
.
vocab_size
,
)
)
self
.
last_sampler_host_tokens
=
None
self
.
last_sampler_host_tokens
=
None
...
@@ -649,13 +663,21 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -649,13 +663,21 @@ class V1ZeroModelRunner(GPUModelRunner):
if
self
.
last_sampler_host_tokens
!=
None
:
if
self
.
last_sampler_host_tokens
!=
None
:
self
.
last_sampler_event
.
synchronize
()
self
.
last_sampler_event
.
synchronize
()
fix_sampled_token_ids
=
self
.
last_sampler_host_tokens
.
tolist
()
fix_sampled_token_ids
=
self
.
last_sampler_host_tokens
.
tolist
()
for
req_idx
,
start_idx
,
end_idx
in
self
.
token_ids_cpu_fix_recode
:
for
req_idx
,
start_idx
,
end_idx
in
self
.
token_ids_cpu_fix_record
:
self
.
input_batch
.
token_ids_cpu
[
req_idx
,
start_idx
:
end_idx
]
=
fix_sampled_token_ids
[
req_idx
]
if
start_idx
==
-
1
:
continue
req_id
=
fix_req_ids
[
req_idx
]
if
req_id
in
self
.
input_batch
.
req_ids
:
new_req_idx
=
self
.
input_batch
.
req_ids
.
index
(
req_id
)
self
.
input_batch
.
token_ids_cpu
[
new_req_idx
,
start_idx
:
end_idx
]
=
fix_sampled_token_ids
[
req_idx
]
for
req_idx
,
req_id
in
enumerate
(
fix_req_ids
):
for
req_idx
,
req_id
in
enumerate
(
fix_req_ids
):
if
req_id
in
self
.
requests
:
if
req_id
in
self
.
requests
:
req_state
=
self
.
requests
[
req_id
]
req_state
=
self
.
requests
[
req_id
]
token_idx
=
self
.
last_sampled_token_lens
[
req_idx
]
token_idx
=
self
.
last_sampled_token_lens
[
req_idx
]
req_state
.
output_token_ids
[
token_idx
]
=
fix_sampled_token_ids
[
req_idx
][
0
]
if
token_idx
==
-
1
:
continue
fix_len
=
len
(
fix_sampled_token_ids
[
req_idx
])
req_state
.
output_token_ids
[
token_idx
:
token_idx
+
fix_len
]
=
fix_sampled_token_ids
[
req_idx
]
self
.
last_sampler_host_tokens
=
sampled_token_ids
.
to
(
'cpu'
,
non_blocking
=
True
)
self
.
last_sampler_host_tokens
=
sampled_token_ids
.
to
(
'cpu'
,
non_blocking
=
True
)
self
.
last_sampler_event
.
record
()
self
.
last_sampler_event
.
record
()
self
.
last_sampled_token_ids
=
sampled_token_ids
self
.
last_sampled_token_ids
=
sampled_token_ids
...
@@ -670,11 +692,16 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -670,11 +692,16 @@ class V1ZeroModelRunner(GPUModelRunner):
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
# the sampled tokens back, because there's no direct communication
# the sampled tokens back, because there's no direct communication
# between the first-stage worker and the last-stage worker.
# between the first-stage worker and the last-stage worker.
self
.
token_ids_cpu_fix_recod
e
.
clear
()
self
.
token_ids_cpu_fix_reco
r
d
.
clear
()
self
.
last_sampled_req_ids
=
[]
self
.
last_sampled_req_ids
=
[]
self
.
last_sampled_token_lens
=
[]
self
.
last_sampled_token_lens
=
[]
for
req_idx
,
sampled_ids
in
enumerate
(
valid_sampled_token_ids
):
for
req_idx
,
sampled_ids
in
enumerate
(
valid_sampled_token_ids
):
req_id
=
self
.
input_batch
.
req_ids
[
req_idx
]
self
.
last_sampled_req_ids
.
append
(
req_id
)
cache_output_len
=
-
1
if
not
sampled_ids
:
if
not
sampled_ids
:
self
.
last_sampled_token_lens
.
append
(
-
1
)
self
.
token_ids_cpu_fix_record
.
append
([
req_idx
,
-
1
,
-
1
])
continue
continue
start_idx
=
self
.
input_batch
.
num_tokens_no_spec
[
req_idx
]
start_idx
=
self
.
input_batch
.
num_tokens_no_spec
[
req_idx
]
...
@@ -686,34 +713,15 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -686,34 +713,15 @@ class V1ZeroModelRunner(GPUModelRunner):
self
.
input_batch
.
token_ids_cpu
[
req_idx
,
self
.
input_batch
.
token_ids_cpu
[
req_idx
,
start_idx
:
end_idx
]
=
sampled_ids
start_idx
:
end_idx
]
=
sampled_ids
self
.
token_ids_cpu_fix_recod
e
.
append
([
req_idx
,
start_idx
,
end_idx
])
self
.
token_ids_cpu_fix_reco
r
d
.
append
([
req_idx
,
start_idx
,
end_idx
])
self
.
input_batch
.
num_tokens_no_spec
[
req_idx
]
=
end_idx
self
.
input_batch
.
num_tokens_no_spec
[
req_idx
]
=
end_idx
self
.
input_batch
.
num_tokens
[
req_idx
]
=
end_idx
self
.
input_batch
.
num_tokens
[
req_idx
]
=
end_idx
req_id
=
self
.
input_batch
.
req_ids
[
req_idx
]
if
req_id
in
self
.
requests
:
if
req_id
in
self
.
requests
:
req_state
=
self
.
requests
[
req_id
]
req_state
=
self
.
requests
[
req_id
]
self
.
last_sampled_req_ids
.
append
(
req_id
)
cache_output_len
=
len
(
req_state
.
output_token_ids
)
self
.
last_sampled_token_lens
.
append
(
len
(
req_state
.
output_token_ids
))
req_state
.
output_token_ids
.
extend
(
sampled_ids
)
req_state
.
output_token_ids
.
extend
(
sampled_ids
)
self
.
last_sampled_token_lens
.
append
(
cache_output_len
)
if
not
self
.
speculative_config
:
# Speculative decoding is not enabled.
spec_token_ids
=
None
fix_draft_req_ids
=
None
else
:
if
self
.
last_draft_host_tokens
is
not
None
:
self
.
last_draft_event
.
synchronize
()
fix_draft_token_ids
=
self
.
last_draft_host_tokens
.
tolist
()
spec_token_ids
=
self
.
propose_draft_token_ids
(
scheduler_output
,
valid_sampled_token_ids
,
sampling_metadata
,
hidden_states
,
sample_hidden_states
,
aux_hidden_states
,
spec_decode_metadata
,
attn_metadata
,
)
# Clear KVConnector state after all KVs are generated.
# Clear KVConnector state after all KVs are generated.
if
has_kv_transfer_group
():
if
has_kv_transfer_group
():
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment