Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
22a95571
Commit
22a95571
authored
Aug 18, 2025
by
zhuwenwen
Browse files
add v1 engine + deepseek r1 mtp + zero-overhead scheduler
parent
ac4cc84e
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
727 additions
and
286 deletions
+727
-286
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+15
-4
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+20
-28
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+7
-2
vllm/zero_overhead/v1/core.py
vllm/zero_overhead/v1/core.py
+148
-111
vllm/zero_overhead/v1/gpu_model_runner.py
vllm/zero_overhead/v1/gpu_model_runner.py
+533
-140
vllm/zero_overhead/v1/outputs.py
vllm/zero_overhead/v1/outputs.py
+4
-1
No files found.
vllm/v1/attention/backends/mla/common.py
View file @
22a95571
...
...
@@ -764,10 +764,21 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
repeats
=
torch
.
from_numpy
(
query_lens
).
pin_memory
().
to
(
block_table_tensor
.
device
,
non_blocking
=
True
).
contiguous
()
decode_block_table_tensor
=
torch
.
repeat_interleave
(
block_table_tensor
[:
num_decodes
,
...],
repeats
,
dim
=
0
).
contiguous
()
decode_seq_lens
=
torch
.
repeat_interleave
(
seq_lens
[:
num_decodes
],
repeats
,
dim
=
0
).
contiguous
()
if
envs
.
VLLM_ZERO_OVERHEAD
:
decode_block_table_tensor
=
torch
.
empty
((
self
.
_num_decode_tokens
,
block_table_tensor
.
shape
[
1
]),
device
=
block_table_tensor
.
device
)
arange_np
=
np
.
arange
(
self
.
_num_decodes
)
indices_np
=
np
.
repeat
(
arange_np
,
query_lens
)
indices
=
torch
.
from_numpy
(
indices_np
).
pin_memory
().
to
(
block_table_tensor
.
device
,
non_blocking
=
True
)
decode_block_table_tensor
=
block_table_tensor
[
indices
].
contiguous
()
decode_seq_lens
=
seq_lens
[
indices
].
contiguous
()
else
:
decode_block_table_tensor
=
torch
.
repeat_interleave
(
block_table_tensor
[:
self
.
_num_decodes
,
...],
repeats
,
dim
=
0
).
contiguous
()
decode_seq_lens
=
torch
.
repeat_interleave
(
seq_lens
[:
self
.
_num_decodes
],
repeats
,
dim
=
0
).
contiguous
()
seq_lens_minus
=
torch
.
from_numpy
(
rarange
).
to
(
torch
.
int32
).
pin_memory
().
to
(
seq_lens
.
device
,
non_blocking
=
True
).
contiguous
()
decode_seq_lens
=
decode_seq_lens
-
seq_lens_minus
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
22a95571
...
...
@@ -69,7 +69,6 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
from
vllm.platforms
import
current_platform
from
vllm.two_batch_overlap.v1.model_input_split_v1
import
tbo_split_and_execute_model
from
vllm.zero_overhead.v1.gpu_model_runner
import
execute_model_sampled
,
zero_prepare_inputs
from
..sample.logits_processor
import
LogitsProcessorManager
from
.utils
import
(
bind_kv_cache
,
gather_mm_placeholders
,
...
...
@@ -1020,15 +1019,26 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# [0, 1, 2, 5, 6, 9]
target_logits_indices
+=
arange
# TODO: Optimize the CPU -> GPU copy.
cu_num_draft_tokens
=
torch
.
from_numpy
(
cu_num_draft_tokens
).
to
(
self
.
device
,
non_blocking
=
True
)
logits_indices
=
torch
.
from_numpy
(
logits_indices
).
to
(
self
.
device
,
non_blocking
=
True
)
target_logits_indices
=
torch
.
from_numpy
(
target_logits_indices
).
to
(
self
.
device
,
non_blocking
=
True
)
bonus_logits_indices
=
torch
.
from_numpy
(
bonus_logits_indices
).
to
(
self
.
device
,
non_blocking
=
True
)
if
envs
.
VLLM_ZERO_OVERHEAD
:
cu_num_draft_tokens
=
torch
.
from_numpy
(
cu_num_draft_tokens
).
pin_memory
().
to
(
self
.
device
,
non_blocking
=
True
)
logits_indices
=
torch
.
from_numpy
(
logits_indices
).
pin_memory
().
to
(
self
.
device
,
non_blocking
=
True
)
target_logits_indices
=
torch
.
from_numpy
(
target_logits_indices
).
pin_memory
().
to
(
self
.
device
,
non_blocking
=
True
)
bonus_logits_indices
=
torch
.
from_numpy
(
bonus_logits_indices
).
pin_memory
().
to
(
self
.
device
,
non_blocking
=
True
)
else
:
# TODO: Optimize the CPU -> GPU copy.
cu_num_draft_tokens
=
torch
.
from_numpy
(
cu_num_draft_tokens
).
to
(
self
.
device
,
non_blocking
=
True
)
logits_indices
=
torch
.
from_numpy
(
logits_indices
).
to
(
self
.
device
,
non_blocking
=
True
)
target_logits_indices
=
torch
.
from_numpy
(
target_logits_indices
).
to
(
self
.
device
,
non_blocking
=
True
)
bonus_logits_indices
=
torch
.
from_numpy
(
bonus_logits_indices
).
to
(
self
.
device
,
non_blocking
=
True
)
# Compute the draft token ids.
# draft_token_indices: [ 1, 2, 3, 105, 106, 208]
...
...
@@ -1440,9 +1450,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# If attention doesn't support CUDA Graphs for this batch, but we
# compiled with full CUDA graphs, we have to skip them entirely.
skip_cuda_graphs
=
self
.
full_cuda_graph
and
not
attention_cuda_graphs
if
envs
.
VLLM_ZERO_OVERHEAD
:
zero_prepare_inputs
(
self
,
scheduler_output
,
input_ids
)
if
envs
.
VLLM_ENABLE_TBO
and
not
self
.
use_cuda_graph
:
model_output
,
finished_sending
,
finished_recving
=
\
...
...
@@ -1591,21 +1598,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Get the valid generated tokens.
sampled_token_ids
=
sampler_output
.
sampled_token_ids
max_gen_len
=
sampled_token_ids
.
shape
[
-
1
]
if
envs
.
VLLM_ZERO_OVERHEAD
:
return
execute_model_sampled
(
self
,
max_gen_len
,
sampled_token_ids
,
discard_sampled_tokens_req_indices
,
scheduler_output
,
sampling_metadata
,
hidden_states
,
sample_hidden_states
,
aux_hidden_states
,
spec_decode_metadata
,
attn_metadata
,
logprobs_lists
,
prompt_logprobs_dict
,
finished_sending
,
finished_recving
,
num_nans_in_logits
)
if
max_gen_len
==
1
:
# No spec decode tokens.
...
...
vllm/v1/worker/gpu_worker.py
View file @
22a95571
...
...
@@ -33,6 +33,7 @@ from vllm.v1.utils import report_usage_stats
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
from
vllm.v1.worker.worker_base
import
WorkerBase
from
vllm.zero_overhead.utils
import
zero_overhead_stream
from
vllm.zero_overhead.v1.gpu_model_runner
import
V1ZeroModelRunner
logger
=
init_logger
(
__name__
)
...
...
@@ -187,8 +188,12 @@ class Worker(WorkerBase):
set_random_seed
(
self
.
model_config
.
seed
)
# Construct the model runner
self
.
model_runner
:
GPUModelRunner
=
GPUModelRunner
(
self
.
vllm_config
,
self
.
device
)
if
envs
.
VLLM_ZERO_OVERHEAD
:
self
.
model_runner
:
GPUModelRunner
=
V1ZeroModelRunner
(
self
.
vllm_config
,
self
.
device
)
else
:
self
.
model_runner
:
GPUModelRunner
=
GPUModelRunner
(
self
.
vllm_config
,
self
.
device
)
if
self
.
rank
==
0
:
# If usage stat is enabled, collect relevant info.
...
...
vllm/zero_overhead/v1/core.py
View file @
22a95571
...
...
@@ -12,11 +12,15 @@ requsets_valid_token_len = {}
def
check_stop
(
request
:
Request
,
max_model_len
:
int
,
pooler_output
:
Optional
[
torch
.
Tensor
]
=
None
)
->
bool
:
if
request
.
request_id
not
in
requsets_valid_token_len
:
requsets_valid_token_len
[
request
.
request_id
]
=
0
return
False
valid_output_len
=
requsets_valid_token_len
[
request
.
request_id
]
pooler_output
:
Optional
[
torch
.
Tensor
]
=
None
,
use_valid_token_len
:
bool
=
False
)
->
bool
:
if
use_valid_token_len
:
if
request
.
request_id
not
in
requsets_valid_token_len
:
requsets_valid_token_len
[
request
.
request_id
]
=
0
return
False
valid_output_len
=
requsets_valid_token_len
[
request
.
request_id
]
else
:
valid_output_len
=
request
.
num_output_tokens
valid_num_tokens
=
request
.
num_prompt_tokens
+
valid_output_len
if
(
valid_num_tokens
>=
max_model_len
or
valid_output_len
>=
request
.
max_tokens
):
...
...
@@ -60,110 +64,118 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
spec_decoding_stats
:
Optional
[
SpecDecodingStats
]
=
None
# fix last model out in zero overhead
for
req_idx
,
req_id
in
enumerate
(
model_runner_output
.
fix_req_ids
):
if
req_id
not
in
scheduler
.
requests
:
continue
request
=
scheduler
.
requests
[
req_id
]
generated_token_ids
=
model_runner_output
.
fix_sampled_token_ids
[
req_idx
]
if
req_id
not
in
requsets_valid_token_len
:
requsets_valid_token_len
[
req_id
]
=
0
valid_output_len
=
requsets_valid_token_len
[
req_id
]
fix_offset
=
valid_output_len
-
request
.
num_output_tokens
if
isinstance
(
generated_token_ids
,
int
):
request
.
_output_token_ids
[
fix_offset
]
=
generated_token_ids
request
.
_all_token_ids
[
fix_offset
]
=
generated_token_ids
requsets_valid_token_len
[
req_id
]
+=
1
else
:
valid_output_end
=
valid_output_len
+
len
(
generated_token_ids
)
-
request
.
num_output_tokens
if
valid_output_end
==
0
:
request
.
_output_token_ids
[
fix_offset
:
]
=
generated_token_ids
request
.
_all_token_ids
[
fix_offset
:
]
=
generated_token_ids
if
model_runner_output
.
fix_req_ids
is
not
None
:
for
req_idx
,
req_id
in
enumerate
(
model_runner_output
.
fix_req_ids
):
if
req_id
not
in
scheduler
.
requests
:
continue
request
=
scheduler
.
requests
[
req_id
]
generated_token_ids
=
model_runner_output
.
fix_sampled_token_ids
[
req_idx
]
if
req_id
not
in
requsets_valid_token_len
:
requsets_valid_token_len
[
req_id
]
=
0
valid_output_len
=
requsets_valid_token_len
[
req_id
]
fix_offset
=
valid_output_len
-
request
.
num_output_tokens
if
isinstance
(
generated_token_ids
,
int
):
request
.
_output_token_ids
[
fix_offset
]
=
generated_token_ids
request
.
_all_token_ids
[
fix_offset
]
=
generated_token_ids
requsets_valid_token_len
[
req_id
]
+=
1
else
:
request
.
_output_token_ids
[
fix_offset
:
valid_output_end
]
=
generated_token_ids
request
.
_all_token_ids
[
fix_offset
:
valid_output_end
]
=
generated_token_ids
requsets_valid_token_len
[
req_id
]
+=
len
(
generated_token_ids
)
valid_output_end
=
valid_output_len
+
len
(
generated_token_ids
)
-
request
.
num_output_tokens
if
valid_output_end
==
0
:
request
.
_output_token_ids
[
fix_offset
:
]
=
generated_token_ids
request
.
_all_token_ids
[
fix_offset
:
]
=
generated_token_ids
else
:
request
.
_output_token_ids
[
fix_offset
:
valid_output_end
]
=
generated_token_ids
request
.
_all_token_ids
[
fix_offset
:
valid_output_end
]
=
generated_token_ids
requsets_valid_token_len
[
req_id
]
+=
len
(
generated_token_ids
)
stopped
=
False
new_logprobs
=
None
new_token_ids
=
generated_token_ids
kv_transfer_params
=
None
stopped
=
False
new_logprobs
=
None
new_token_ids
=
generated_token_ids
kv_transfer_params
=
None
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
for
num_new
,
output_token_id
in
enumerate
(
new_token_ids
,
1
):
stopped
=
check_stop
(
request
,
scheduler
.
max_model_len
)
if
stopped
:
kv_transfer_params
=
scheduler
.
_free_request
(
request
)
del
new_token_ids
[
num_new
:]
# Trim new tokens if needed.
break
pooler_output
=
None
if
pooler_outputs
:
pooler_output
=
pooler_outputs
[
req_index
]
stopped
=
check_stop
(
request
,
scheduler
.
max_model_len
,
pooler_output
)
if
stopped
:
kv_transfer_params
=
scheduler
.
_free_request
(
request
)
# Extract sample logprobs if needed.
if
request
.
sampling_params
is
not
None
\
and
request
.
sampling_params
.
logprobs
is
not
None
and
logprobs
:
# NOTE: once we support N tokens per step (spec decode),
# the outer lists can be of length > 1.
new_logprobs
=
logprobs
.
slice
(
req_index
,
req_index
+
1
)
if
new_token_ids
and
scheduler
.
structured_output_manager
.
should_advance
(
request
):
# NOTE: structured_output_request
# should not be None if use_structured_output, we have
# check above, so safe to ignore type warning
request
.
structured_output_request
.
grammar
.
accept_tokens
(
# type: ignore[union-attr]
req_id
,
new_token_ids
)
# spec_token_ids comes from the model runner output
if
num_nans_in_logits
is
not
None
and
req_id
in
num_nans_in_logits
:
request
.
num_nans_in_logits
=
num_nans_in_logits
[
req_id
]
# Add newly generated spec token ids to the request.
if
spec_token_ids
is
not
None
:
if
scheduler
.
structured_output_manager
.
should_advance
(
request
):
metadata
=
request
.
structured_output_request
# Needs to happen after new_token_ids are accepted.
request
.
spec_token_ids
=
metadata
.
grammar
.
validate_tokens
(
# type: ignore[union-attr]
spec_token_ids
[
req_index
])
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
for
num_new
,
output_token_id
in
enumerate
(
new_token_ids
,
1
):
stopped
=
check_stop
(
request
,
scheduler
.
max_model_len
,
True
)
if
stopped
:
kv_transfer_params
=
scheduler
.
_free_request
(
request
)
del
new_token_ids
[
num_new
:]
# Trim new tokens if needed.
break
pooler_output
=
None
if
pooler_outputs
:
pooler_output
=
pooler_outputs
[
req_index
]
stopped
=
check_stop
(
request
,
scheduler
.
max_model_len
,
pooler_output
,
True
)
if
stopped
:
kv_transfer_params
=
scheduler
.
_free_request
(
request
)
# Extract sample logprobs if needed.
if
request
.
sampling_params
is
not
None
\
and
request
.
sampling_params
.
logprobs
is
not
None
and
logprobs
:
# NOTE: once we support N tokens per step (spec decode),
# the outer lists can be of length > 1.
new_logprobs
=
logprobs
.
slice
(
req_index
,
req_index
+
1
)
if
new_token_ids
and
scheduler
.
structured_output_manager
.
should_advance
(
request
):
# NOTE: structured_output_request
# should not be None if use_structured_output, we have
# check above, so safe to ignore type warning
request
.
structured_output_request
.
grammar
.
accept_tokens
(
# type: ignore[union-attr]
req_id
,
new_token_ids
)
# spec_token_ids comes from the model runner output
if
num_nans_in_logits
is
not
None
and
req_id
in
num_nans_in_logits
:
request
.
num_nans_in_logits
=
num_nans_in_logits
[
req_id
]
# Get prompt logprobs for this request.
prompt_logprobs_tensors
=
prompt_logprobs_dict
.
get
(
req_id
)
if
new_token_ids
or
pooler_output
is
not
None
\
or
kv_transfer_params
:
# Add EngineCoreOutput for this Request.
outputs
[
request
.
client_index
].
append
(
EngineCoreOutput
(
request_id
=
req_id
,
new_token_ids
=
new_token_ids
,
finish_reason
=
request
.
get_finished_reason
(),
new_logprobs
=
new_logprobs
,
new_prompt_logprobs_tensors
=
prompt_logprobs_tensors
,
pooling_output
=
pooler_output
,
stop_reason
=
request
.
stop_reason
,
events
=
request
.
take_events
(),
kv_transfer_params
=
kv_transfer_params
,
num_cached_tokens
=
request
.
num_cached_tokens
,
))
else
:
request
.
spec_token_ids
=
spec_token_ids
[
req_index
]
# Get prompt logprobs for this request.
prompt_logprobs_tensors
=
prompt_logprobs_dict
.
get
(
req_id
)
if
new_token_ids
or
pooler_output
is
not
None
\
or
kv_transfer_params
:
# Add EngineCoreOutput for this Request.
outputs
[
request
.
client_index
].
append
(
EngineCoreOutput
(
request_id
=
req_id
,
new_token_ids
=
new_token_ids
,
finish_reason
=
request
.
get_finished_reason
(),
new_logprobs
=
new_logprobs
,
new_prompt_logprobs_tensors
=
prompt_logprobs_tensors
,
pooling_output
=
pooler_output
,
stop_reason
=
request
.
stop_reason
,
events
=
request
.
take_events
(),
kv_transfer_params
=
kv_transfer_params
,
num_cached_tokens
=
request
.
num_cached_tokens
,
))
else
:
# Invariant: EngineCore returns no partial prefill outputs.
assert
not
prompt_logprobs_tensors
assert
not
prompt_logprobs_tensors
# fix last model out in zero overhead
if
model_runner_output
.
fix_draft_req_ids
is
not
None
:
for
req_idx
,
req_id
in
enumerate
(
model_runner_output
.
fix_draft_req_ids
):
if
req_id
not
in
scheduler
.
requests
:
continue
request
=
scheduler
.
requests
[
req_id
]
# Add newly generated spec token ids to the request.
if
model_runner_output
.
fix_draft_tokens_ids
is
not
None
:
if
scheduler
.
structured_output_manager
.
should_advance
(
request
):
metadata
=
request
.
structured_output_request
# Needs to happen after new_token_ids are accepted.
request
.
spec_token_ids
=
metadata
.
grammar
.
validate_tokens
(
# type: ignore[union-attr]
model_runner_output
.
fix_draft_tokens_ids
[
req_idx
])
else
:
request
.
spec_token_ids
=
model_runner_output
.
fix_draft_tokens_ids
[
req_idx
]
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
# loop can be a performance bottleneck. We should do our best to avoid
# expensive operations inside the loop.
for
request
in
scheduler
.
running
:
if
request
.
is_finished
():
continue
req_id
=
request
.
request_id
num_tokens_scheduled
=
num_scheduled_tokens
.
get
(
req_id
,
0
)
if
num_tokens_scheduled
==
0
:
...
...
@@ -197,7 +209,6 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
if
request
.
has_encoder_inputs
:
scheduler
.
_free_encoder_inputs
(
request
)
stopped
=
False
new_logprobs
=
None
new_token_ids
=
generated_token_ids
kv_transfer_params
=
None
...
...
@@ -210,19 +221,24 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
stopped
=
check_stop
(
request
,
scheduler
.
max_model_len
)
# if stopped:
# kv_transfer_params = scheduler._free_request(request)
# del new_token_ids[num_new:] # Trim new tokens if needed.
# break
if
model_runner_output
.
is_output_valid
:
stopped
=
check_stop
(
request
,
scheduler
.
max_model_len
,
False
)
if
stopped
:
kv_transfer_params
=
scheduler
.
_free_request
(
request
)
del
new_token_ids
[
num_new
:]
# Trim new tokens if needed.
break
pooler_output
=
None
if
pooler_outputs
:
pooler_output
=
pooler_outputs
[
req_index
]
stopped
=
check_stop
(
request
,
scheduler
.
max_model_len
,
pooler_output
)
# if stopped:
# kv_transfer_params = scheduler._free_request(request)
if
model_runner_output
.
is_output_valid
:
pooler_output
=
pooler_outputs
[
req_index
]
stopped
=
check_stop
(
request
,
scheduler
.
max_model_len
,
pooler_output
,
False
)
if
stopped
:
kv_transfer_params
=
scheduler
.
_free_request
(
request
)
# Extract sample logprobs if needed.
if
request
.
sampling_params
is
not
None
\
...
...
@@ -252,6 +268,27 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
spec_token_ids
[
req_index
])
else
:
request
.
spec_token_ids
=
spec_token_ids
[
req_index
]
if
model_runner_output
.
is_output_valid
:
# # Get prompt logprobs for this request.
prompt_logprobs_tensors
=
prompt_logprobs_dict
.
get
(
req_id
)
if
new_token_ids
or
pooler_output
is
not
None
\
or
kv_transfer_params
:
# Add EngineCoreOutput for this Request.
outputs
[
request
.
client_index
].
append
(
EngineCoreOutput
(
request_id
=
req_id
,
new_token_ids
=
new_token_ids
,
finish_reason
=
request
.
get_finished_reason
(),
new_logprobs
=
new_logprobs
,
new_prompt_logprobs_tensors
=
prompt_logprobs_tensors
,
pooling_output
=
pooler_output
,
stop_reason
=
request
.
stop_reason
,
events
=
request
.
take_events
(),
kv_transfer_params
=
kv_transfer_params
,
num_cached_tokens
=
request
.
num_cached_tokens
,
))
if
not
stopped
:
new_running
.
append
(
request
)
...
...
vllm/zero_overhead/v1/gpu_model_runner.py
View file @
22a95571
from
typing
import
Any
,
Optional
,
Union
import
torch
import
numpy
as
np
from
vllm
import
envs
from
vllm.distributed.kv_transfer.kv_transfer_state
import
get_kv_transfer_group
,
has_kv_transfer_group
from
vllm.distributed.parallel_state
import
get_tp_group
from
vllm.utils
import
async_tensor_h2d
from
vllm.distributed.parallel_state
import
get_pp_group
,
get_tp_group
from
vllm.forward_context
import
set_forward_context
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
async_tensor_h2d
,
round_up
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.outputs
import
EMPTY_MODEL_RUNNER_OUTPUT
,
ModelRunnerOutput
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.spec_decode.eagle
import
EagleProposer
from
vllm.v1.spec_decode.medusa
import
MedusaProposer
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
from
vllm.zero_overhead.v1.outputs
import
ZeroV1ModelRunnerOutput
from
vllm.profiler.prof
import
profile
from
vllm.two_batch_overlap.v1.model_input_split_v1
import
tbo_split_and_execute_model
class
V1ZeroModelRunner
():
def
__init__
(
self
):
class
V1ZeroModelRunner
(
GPUModelRunner
):
def
__init__
(
self
,
vllm_config
,
device
):
super
().
__init__
(
vllm_config
,
device
)
self
.
last_sampled_token_ids
=
None
self
.
last_sampled_req_ids
=
[]
self
.
last_sampled_token_lens
=
[]
self
.
last_sampler_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
last_sampler_host_tokens
=
None
self
.
token_ids_cpu_fix_recode
=
[]
self
.
last_draft_token_ids
=
None
self
.
last_draft_host_tokens
=
None
self
.
last_draft_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
def
set_last_sampled_token_ids
(
self
,
sampled_token_ids
):
self
.
last_sampled_token_ids
=
sampled_token_ids
self
.
last_sampled_req_ids
=
[]
self
.
last_sampled_token_lens
=
[]
v1_zero_overhead
=
V1ZeroModelRunner
()
def
zero_prepare_inputs
(
runner
,
scheduler_output
,
input_ids
):
req_ids
=
runner
.
input_batch
.
req_ids
update_req_indices
=
[]
input_ids_indices
=
[]
token_idx
=
0
if
v1_zero_overhead
.
last_sampled_token_ids
is
None
:
return
sampled_tokens_num
=
v1_zero_overhead
.
last_sampled_token_ids
.
shape
[
1
]
for
req_id
in
req_ids
:
if
req_id
in
v1_zero_overhead
.
last_sampled_req_ids
:
req_idx
=
v1_zero_overhead
.
last_sampled_req_ids
.
index
(
req_id
)
*
sampled_tokens_num
update_req_indices
.
append
(
req_idx
)
input_ids_indices
.
append
(
token_idx
)
token_idx
+=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
if
len
(
update_req_indices
)
>
0
:
update_req_indices_tensor
=
async_tensor_h2d
(
update_req_indices
,
torch
.
int32
,
runner
.
device
,
True
)
input_ids_indices_tensor
=
async_tensor_h2d
(
input_ids_indices
,
torch
.
int32
,
runner
.
device
,
True
)
last_sampled_token_ids
=
v1_zero_overhead
.
last_sampled_token_ids
.
flatten
()
for
i
in
range
(
sampled_tokens_num
):
input_ids
[
input_ids_indices_tensor
+
i
]
=
last_sampled_token_ids
[
update_req_indices_tensor
+
i
]
def
execute_model_sampled
(
runner
,
max_gen_len
,
sampled_token_ids
,
discard_sampled_tokens_req_indices
,
scheduler_output
,
sampling_metadata
,
hidden_states
,
sample_hidden_states
,
aux_hidden_states
,
spec_decode_metadata
,
attn_metadata
,
logprobs_lists
,
prompt_logprobs_dict
,
finished_sending
,
finished_recving
,
num_nans_in_logits
):
fix_req_ids
=
None
fix_sampled_token_ids
=
None
if
max_gen_len
==
1
:
# No spec decode tokens.
if
v1_zero_overhead
.
last_sampler_host_tokens
!=
None
:
v1_zero_overhead
.
last_sampler_event
.
synchronize
()
fix_sampled_token_ids
=
v1_zero_overhead
.
last_sampler_host_tokens
.
tolist
()
for
req_idx
,
start_idx
,
end_idx
in
v1_zero_overhead
.
token_ids_cpu_fix_recode
:
runner
.
input_batch
.
token_ids_cpu
[
req_idx
,
start_idx
:
end_idx
]
=
fix_sampled_token_ids
[
req_idx
]
fix_req_ids
=
v1_zero_overhead
.
last_sampled_req_ids
for
req_idx
,
req_id
in
enumerate
(
fix_req_ids
):
if
req_id
in
runner
.
requests
:
req_state
=
runner
.
requests
[
req_id
]
token_idx
=
v1_zero_overhead
.
last_sampled_token_lens
[
req_idx
]
req_state
.
output_token_ids
[
token_idx
]
=
fix_sampled_token_ids
[
req_idx
][
0
]
v1_zero_overhead
.
last_sampler_host_tokens
=
sampled_token_ids
.
to
(
'cpu'
,
non_blocking
=
True
)
v1_zero_overhead
.
last_sampler_event
.
record
()
v1_zero_overhead
.
set_last_sampled_token_ids
(
sampled_token_ids
)
valid_sampled_token_ids
=
np
.
ones
(
sampled_token_ids
.
shape
,
dtype
=
int
).
tolist
()
else
:
# Includes spec decode tokens.
valid_sampled_token_ids
=
runner
.
rejection_sampler
.
parse_output
(
sampled_token_ids
,
runner
.
input_batch
.
vocab_size
,
)
# Mask out the sampled tokens that should not be sampled.
for
i
in
discard_sampled_tokens_req_indices
:
valid_sampled_token_ids
[
i
].
clear
()
# Cache the sampled tokens in the model runner, so that the scheduler
# doesn't need to send them back.
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
# the sampled tokens back, because there's no direct communication
# between the first-stage worker and the last-stage worker.
v1_zero_overhead
.
token_ids_cpu_fix_recode
.
clear
()
for
req_idx
,
sampled_ids
in
enumerate
(
valid_sampled_token_ids
):
if
not
sampled_ids
:
continue
start_idx
=
runner
.
input_batch
.
num_tokens_no_spec
[
req_idx
]
end_idx
=
start_idx
+
len
(
sampled_ids
)
assert
end_idx
<=
runner
.
max_model_len
,
(
"Sampled token IDs exceed the max model length. "
f
"Total number of tokens:
{
end_idx
}
> max_model_len: "
f
"
{
runner
.
max_model_len
}
"
)
runner
.
input_batch
.
token_ids_cpu
[
req_idx
,
start_idx
:
end_idx
]
=
sampled_ids
v1_zero_overhead
.
token_ids_cpu_fix_recode
.
append
([
req_idx
,
start_idx
,
end_idx
])
runner
.
input_batch
.
num_tokens_no_spec
[
req_idx
]
=
end_idx
runner
.
input_batch
.
num_tokens
[
req_idx
]
=
end_idx
req_id
=
runner
.
input_batch
.
req_ids
[
req_idx
]
if
req_id
in
runner
.
requests
:
req_state
=
runner
.
requests
[
req_id
]
v1_zero_overhead
.
last_sampled_req_ids
.
append
(
req_id
)
v1_zero_overhead
.
last_sampled_token_lens
.
append
(
len
(
req_state
.
output_token_ids
))
req_state
.
output_token_ids
.
extend
(
sampled_ids
)
if
not
runner
.
speculative_config
:
# Speculative decoding is not enabled.
spec_token_ids
=
None
def
zero_prepare_inputs
(
self
,
scheduler_output
,
input_ids
):
req_ids
=
self
.
input_batch
.
req_ids
update_req_indices
=
[]
input_ids_indices
=
[]
token_idx
=
0
if
self
.
last_draft_token_ids
is
not
None
:
draft_tokens_num
=
self
.
last_draft_token_ids
.
shape
[
1
]
for
req_id
in
req_ids
:
if
req_id
in
self
.
last_sampled_req_ids
:
req_idx
=
self
.
last_sampled_req_ids
.
index
(
req_id
)
*
draft_tokens_num
for
num_idx
in
range
(
draft_tokens_num
):
update_req_indices
.
append
(
req_idx
+
num_idx
)
input_ids_indices
.
append
(
token_idx
+
num_idx
+
1
)
token_idx
+=
draft_tokens_num
+
1
if
len
(
update_req_indices
)
>
0
:
update_req_indices_tensor
=
async_tensor_h2d
(
update_req_indices
,
torch
.
int32
,
self
.
device
,
True
)
input_ids_indices_tensor
=
async_tensor_h2d
(
input_ids_indices
,
torch
.
int32
,
self
.
device
,
True
)
last_draft_token_ids
=
self
.
last_draft_token_ids
.
flatten
().
to
(
torch
.
int
)
input_ids
[
input_ids_indices_tensor
]
=
last_draft_token_ids
[
update_req_indices_tensor
]
else
:
spec_token_ids
=
runner
.
propose_draft_token_ids
(
update_req_indices
=
[]
input_ids_indices
=
[]
token_idx
=
0
if
self
.
last_sampled_token_ids
is
not
None
:
sampled_tokens_num
=
self
.
last_sampled_token_ids
.
shape
[
1
]
for
req_id
in
req_ids
:
if
req_id
in
self
.
last_sampled_req_ids
:
req_idx
=
self
.
last_sampled_req_ids
.
index
(
req_id
)
*
sampled_tokens_num
update_req_indices
.
append
(
req_idx
)
input_ids_indices
.
append
(
token_idx
)
token_idx
+=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
if
len
(
update_req_indices
)
>
0
:
update_req_indices_tensor
=
async_tensor_h2d
(
update_req_indices
,
torch
.
int32
,
self
.
device
,
True
)
input_ids_indices_tensor
=
async_tensor_h2d
(
input_ids_indices
,
torch
.
int32
,
self
.
device
,
True
)
last_sampled_token_ids
=
self
.
last_sampled_token_ids
.
flatten
()
for
i
in
range
(
sampled_tokens_num
):
input_ids
[
input_ids_indices_tensor
+
i
]
=
last_sampled_token_ids
[
update_req_indices_tensor
+
i
]
def
propose_draft_token_ids
(
self
,
scheduler_output
:
"SchedulerOutput"
,
sampled_token_ids
:
list
[
list
[
int
]],
sampling_metadata
:
SamplingMetadata
,
hidden_states
:
torch
.
Tensor
,
sample_hidden_states
:
torch
.
Tensor
,
aux_hidden_states
:
Optional
[
torch
.
Tensor
],
spec_decode_metadata
:
Optional
[
SpecDecodeMetadata
],
attn_metadata
:
dict
[
str
,
Any
],
)
->
list
[
list
[
int
]]:
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
if
self
.
speculative_config
.
method
==
"ngram"
:
assert
isinstance
(
self
.
drafter
,
NgramProposer
)
spec_token_ids
=
self
.
propose_ngram_draft_token_ids
(
sampled_token_ids
)
elif
self
.
speculative_config
.
method
==
"medusa"
:
assert
isinstance
(
self
.
drafter
,
MedusaProposer
)
if
sample_hidden_states
.
shape
[
0
]
==
len
(
sampled_token_ids
):
# The input to the target model does not include draft tokens.
hidden_states
=
sample_hidden_states
else
:
indices
=
[]
offset
=
0
for
num_draft
,
tokens
in
zip
(
spec_decode_metadata
.
num_draft_tokens
,
sampled_token_ids
):
indices
.
append
(
offset
+
len
(
tokens
)
-
1
)
offset
+=
num_draft
+
1
indices
=
torch
.
tensor
(
indices
,
device
=
self
.
device
)
hidden_states
=
sample_hidden_states
[
indices
]
spec_token_ids
=
self
.
drafter
.
propose
(
target_hidden_states
=
hidden_states
,
sampling_metadata
=
sampling_metadata
,
)
elif
self
.
speculative_config
.
use_eagle
():
assert
isinstance
(
self
.
drafter
,
EagleProposer
)
# TODO(woosuk): Refactor the loop.
if
self
.
last_sampled_token_ids
is
not
None
:
next_token_ids
=
self
.
last_sampled_token_ids
.
flatten
()
else
:
next_token_ids
:
list
[
int
]
=
[]
for
i
,
token_ids
in
enumerate
(
sampled_token_ids
):
if
token_ids
:
# Common case.
next_token_id
=
token_ids
[
-
1
]
else
:
# Partial prefill (rare case).
# Get the next token id from the request state.
req_id
=
self
.
input_batch
.
req_ids
[
i
]
req_state
=
self
.
requests
[
req_id
]
seq_len
=
(
req_state
.
num_computed_tokens
+
scheduler_output
.
num_scheduled_tokens
[
req_id
])
next_token_id
=
req_state
.
get_token_id
(
seq_len
)
next_token_ids
.
append
(
next_token_id
)
next_token_ids
=
torch
.
tensor
(
next_token_ids
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
eagle_attn_metadata
=
attn_metadata
[
self
.
drafter
.
attn_layer_names
[
0
]]
# NOTE: deepseek_mtp uses MLA which does not have `block_table`
if
hasattr
(
eagle_attn_metadata
,
"block_table"
):
block_table
=
eagle_attn_metadata
.
block_table
else
:
block_table
=
None
num_rejected_tokens
=
None
if
spec_decode_metadata
is
None
:
# input_ids can be None for multimodal models.
target_token_ids
=
self
.
input_ids
[:
num_scheduled_tokens
]
# TODO(woosuk): Support M-RoPE.
target_positions
=
self
.
positions
[:
num_scheduled_tokens
]
if
self
.
use_aux_hidden_state_outputs
:
target_hidden_states
=
torch
.
cat
(
[
h
[:
num_scheduled_tokens
]
for
h
in
aux_hidden_states
],
dim
=-
1
)
else
:
target_hidden_states
=
hidden_states
[:
num_scheduled_tokens
]
target_slot_mapping
=
eagle_attn_metadata
.
slot_mapping
cu_num_tokens
=
eagle_attn_metadata
.
query_start_loc
else
:
# TODO(woosuk): Refactor this.
num_draft_tokens
=
spec_decode_metadata
.
num_draft_tokens
num_rejected_tokens
=
[
n
+
1
-
len
(
sampled_token_ids
[
i
])
if
n
>
0
else
0
for
i
,
n
in
enumerate
(
num_draft_tokens
)
]
num_rejected_tokens_tensor
=
async_tensor_h2d
(
num_rejected_tokens
,
dtype
=
torch
.
int32
,
target_device
=
self
.
device
,
pin_memory
=
True
)
num_tokens
=
num_scheduled_tokens
-
sum
(
num_rejected_tokens
)
cu_num_tokens
,
token_indices
=
self
.
drafter
.
prepare_inputs
(
eagle_attn_metadata
.
query_start_loc
,
num_rejected_tokens_tensor
,
num_tokens
,
)
target_token_ids
=
self
.
input_ids
[
token_indices
]
# TODO(woosuk): Support M-RoPE.
target_positions
=
self
.
positions
[
token_indices
]
if
self
.
use_aux_hidden_state_outputs
:
target_hidden_states
=
torch
.
cat
(
[
h
[
token_indices
]
for
h
in
aux_hidden_states
],
dim
=-
1
)
else
:
target_hidden_states
=
hidden_states
[
token_indices
]
target_slot_mapping
=
eagle_attn_metadata
.
slot_mapping
[
token_indices
]
draft_token_ids
=
self
.
drafter
.
propose
(
target_token_ids
=
target_token_ids
,
target_positions
=
target_positions
,
target_hidden_states
=
target_hidden_states
,
target_slot_mapping
=
target_slot_mapping
,
next_token_ids
=
next_token_ids
,
cu_num_tokens
=
cu_num_tokens
,
block_table
=
block_table
,
sampling_metadata
=
sampling_metadata
,
num_rejected_tokens
=
num_rejected_tokens
)
spec_token_ids
=
np
.
ones
(
draft_token_ids
.
shape
,
dtype
=
int
).
tolist
()
self
.
last_draft_token_ids
=
draft_token_ids
self
.
last_draft_host_tokens
=
draft_token_ids
.
to
(
'cpu'
,
non_blocking
=
True
)
self
.
last_draft_event
.
record
()
return
spec_token_ids
@
torch
.
inference_mode
()
def
execute_model
(
self
,
scheduler_output
:
"SchedulerOutput"
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
Union
[
ModelRunnerOutput
,
IntermediateTensors
]:
self
.
_update_states
(
scheduler_output
)
if
not
scheduler_output
.
total_num_scheduled_tokens
:
if
not
has_kv_transfer_group
():
# Return empty ModelRunnerOutput if there's no work to do.
return
EMPTY_MODEL_RUNNER_OUTPUT
return
self
.
kv_connector_no_forward
(
scheduler_output
)
# Prepare the decoder inputs.
(
attn_metadata
,
attention_cuda_graphs
,
logits_indices
,
spec_decode_metadata
,
num_scheduled_tokens_np
)
=
(
self
.
_prepare_inputs
(
scheduler_output
))
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
if
(
self
.
use_cuda_graph
and
num_scheduled_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_input_tokens
=
self
.
vllm_config
.
pad_for_cudagraph
(
num_scheduled_tokens
)
else
:
# Eager mode.
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
tp_size
=
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
if
self
.
compilation_config
.
pass_config
.
\
enable_sequence_parallelism
and
tp_size
>
1
:
num_input_tokens
=
round_up
(
num_scheduled_tokens
,
tp_size
)
else
:
num_input_tokens
=
num_scheduled_tokens
# Padding for DP
num_pad
,
num_tokens_across_dp
=
self
.
get_dp_padding
(
num_input_tokens
)
num_input_tokens
+=
num_pad
# _prepare_inputs may reorder the batch, so we must gather multi
# modal outputs after that to ensure the correct order
if
self
.
is_multimodal_model
:
# Run the multimodal encoder if any.
self
.
_execute_mm_encoder
(
scheduler_output
)
mm_embeds
=
self
.
_gather_mm_embeddings
(
scheduler_output
)
else
:
mm_embeds
=
[]
if
self
.
is_multimodal_model
and
get_pp_group
().
is_first_rank
:
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
input_ids
=
self
.
input_ids
[:
num_scheduled_tokens
]
if
mm_embeds
:
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
input_ids
,
mm_embeds
)
else
:
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
input_ids
)
# TODO(woosuk): Avoid the copy. Optimize.
self
.
inputs_embeds
[:
num_scheduled_tokens
].
copy_
(
inputs_embeds
)
inputs_embeds
=
self
.
inputs_embeds
[:
num_input_tokens
]
input_ids
=
None
else
:
# For text-only models, we use token ids as input.
# While it is possible to use embeddings as input just like the
# multimodal models, it is not desirable for performance since
# then the embedding layer is not included in the CUDA graph.
input_ids
=
self
.
input_ids
[:
num_input_tokens
]
inputs_embeds
=
None
if
self
.
uses_mrope
:
positions
=
self
.
mrope_positions
[:,
:
num_input_tokens
]
else
:
positions
=
self
.
positions
[:
num_input_tokens
]
if
get_pp_group
().
is_first_rank
:
intermediate_tensors
=
None
else
:
intermediate_tensors
=
self
.
sync_and_slice_intermediate_tensors
(
num_input_tokens
,
intermediate_tensors
,
True
)
# Some attention backends only support CUDA Graphs in pure decode.
# If attention doesn't support CUDA Graphs for this batch, but we
# compiled with full CUDA graphs, we have to skip them entirely.
skip_cuda_graphs
=
self
.
full_cuda_graph
and
not
attention_cuda_graphs
self
.
zero_prepare_inputs
(
scheduler_output
,
input_ids
)
if
envs
.
VLLM_ENABLE_TBO
and
not
self
.
use_cuda_graph
:
model_output
,
finished_sending
,
finished_recving
=
\
tbo_split_and_execute_model
(
self
,
attn_metadata
,
num_input_tokens
,
num_tokens_across_dp
,
input_ids
,
positions
,
inputs_embeds
,
scheduler_output
,
intermediate_tensors
)
else
:
# Run the model.
# Use persistent buffers for CUDA graphs.
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
num_tokens
=
num_input_tokens
,
num_tokens_across_dp
=
num_tokens_across_dp
,
skip_cuda_graphs
=
skip_cuda_graphs
,
):
self
.
maybe_setup_kv_connector
(
scheduler_output
)
model_output
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
)
self
.
maybe_wait_for_kv_save
()
finished_sending
,
finished_recving
=
(
self
.
get_finished_kv_transfers
(
scheduler_output
))
if
self
.
use_aux_hidden_state_outputs
:
hidden_states
,
aux_hidden_states
=
model_output
else
:
hidden_states
=
model_output
aux_hidden_states
=
None
# Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks
# TODO: Support overlapping mirco-batches
# https://github.com/vllm-project/vllm/issues/18019
broadcast_pp_output
=
\
self
.
parallel_config
.
distributed_executor_backend
\
==
"external_launcher"
and
len
(
get_pp_group
().
ranks
)
>
0
if
not
get_pp_group
().
is_last_rank
:
# For mid-pipeline stages, return the hidden states.
if
not
broadcast_pp_output
:
return
hidden_states
assert
isinstance
(
hidden_states
,
IntermediateTensors
)
get_pp_group
().
send_tensor_dict
(
hidden_states
.
tensors
,
all_gather_group
=
get_tp_group
())
logits
=
None
else
:
if
self
.
input_batch
.
pooling_params
:
return
self
.
_pool
(
hidden_states
,
num_scheduled_tokens
,
num_scheduled_tokens_np
,
finished_sending
,
finished_recving
)
sample_hidden_states
=
hidden_states
[
logits_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
if
broadcast_pp_output
:
model_output_broadcast_data
=
{
"logits"
:
logits
.
contiguous
(),
}
if
logits
is
not
None
else
{}
model_output_broadcast_data
=
get_pp_group
().
broadcast_tensor_dict
(
model_output_broadcast_data
,
src
=
len
(
get_pp_group
().
ranks
)
-
1
)
assert
model_output_broadcast_data
is
not
None
logits
=
model_output_broadcast_data
[
"logits"
]
# Apply structured output bitmasks if present
if
scheduler_output
.
grammar_bitmask
is
not
None
:
self
.
apply_grammar_bitmask
(
scheduler_output
,
logits
)
# Sample the next token and get logprobs if needed.
sampling_metadata
=
self
.
input_batch
.
sampling_metadata
if
spec_decode_metadata
is
None
:
sampler_output
=
self
.
sampler
(
logits
=
logits
,
sampling_metadata
=
sampling_metadata
,
)
else
:
# When indexing with a tensor (bonus_logits_indices), PyTorch
# creates a new tensor with separate storage from the original
# logits tensor. This means any in-place operations on bonus_logits
# won't affect the original logits tensor.
assert
logits
is
not
None
bonus_logits
=
logits
[
spec_decode_metadata
.
bonus_logits_indices
]
sampler_output
=
self
.
sampler
(
logits
=
bonus_logits
,
sampling_metadata
=
sampling_metadata
,
)
bonus_token_ids
=
sampler_output
.
sampled_token_ids
# Just like `bonus_logits`, `target_logits` is a new tensor with
# separate storage from the original `logits` tensor. Therefore,
# it is safe to update `target_logits` in place.
target_logits
=
logits
[
spec_decode_metadata
.
target_logits_indices
]
output_token_ids
=
self
.
rejection_sampler
(
spec_decode_metadata
,
None
,
# draft_probs
target_logits
,
bonus_token_ids
,
sampling_metadata
,
)
sampler_output
.
sampled_token_ids
=
output_token_ids
num_nans_in_logits
=
{}
if
envs
.
VLLM_COMPUTE_NANS_IN_LOGITS
:
num_nans_in_logits
=
self
.
_get_nans_in_logits
(
logits
)
# TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize.
discard_sampled_tokens_req_indices
=
[]
for
i
,
req_id
in
enumerate
(
self
.
input_batch
.
req_ids
):
req_state
=
self
.
requests
[
req_id
]
seq_len
=
(
req_state
.
num_computed_tokens
+
scheduler_output
.
num_scheduled_tokens
[
req_id
])
if
seq_len
<
req_state
.
num_tokens
:
# Ignore the sampled token for partial prefills.
# Rewind the generator state as if the token was not sampled.
# This relies on cuda-specific torch-internal impl details
generator
=
self
.
input_batch
.
generators
.
get
(
i
)
if
generator
is
not
None
:
generator
.
set_offset
(
generator
.
get_offset
()
-
4
)
# Record the index of the request that should not be sampled,
# so that we could clear the sampled tokens before returning.
discard_sampled_tokens_req_indices
.
append
(
i
)
# NOTE: GPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point.
logprobs_tensors
=
sampler_output
.
logprobs_tensors
logprobs_lists
=
logprobs_tensors
.
tolists
()
\
if
logprobs_tensors
is
not
None
else
None
# Compute prompt logprobs if needed.
prompt_logprobs_dict
=
self
.
_get_prompt_logprobs_dict
(
hidden_states
[:
num_scheduled_tokens
],
scheduler_output
,
valid_sampled_token_ids
,
sampling_metadata
,
hidden_states
,
sample_hidden_states
,
aux_hidden_states
,
spec_decode_metadata
,
attn_metadata
,
)
# Clear KVConnector state after all KVs are generated.
if
has_kv_transfer_group
():
get_kv_transfer_group
().
clear_connector_metadata
()
runner
.
eplb_step
()
model_output
=
ZeroV1ModelRunnerOutput
(
req_ids
=
runner
.
input_batch
.
req_ids
,
req_id_to_index
=
runner
.
input_batch
.
req_id_to_index
,
sampled_token_ids
=
valid_sampled_token_ids
,
spec_token_ids
=
spec_token_ids
,
logprobs
=
logprobs_lists
,
prompt_logprobs_dict
=
prompt_logprobs_dict
,
pooler_output
=
[],
finished_sending
=
finished_sending
,
finished_recving
=
finished_recving
,
num_nans_in_logits
=
num_nans_in_logits
,
fix_req_ids
=
fix_req_ids
,
fix_sampled_token_ids
=
fix_sampled_token_ids
)
return
model_output
\ No newline at end of file
# Get the valid generated tokens.
sampled_token_ids
=
sampler_output
.
sampled_token_ids
max_gen_len
=
sampled_token_ids
.
shape
[
-
1
]
fix_req_ids
=
None
fix_sampled_token_ids
=
None
fix_draft_token_ids
=
None
fix_draft_req_ids
=
self
.
last_sampled_req_ids
is_output_valid
=
False
if
self
.
speculative_config
:
if
max_gen_len
==
1
:
valid_sampled_token_ids
=
sampled_token_ids
.
tolist
()
else
:
# Includes spec decode tokens.
valid_sampled_token_ids
=
self
.
rejection_sampler
.
parse_output
(
sampled_token_ids
,
self
.
input_batch
.
vocab_size
,
)
self
.
last_sampler_host_tokens
=
None
self
.
last_sampled_token_ids
=
None
is_output_valid
=
True
else
:
# No spec decode tokens.
fix_req_ids
=
self
.
last_sampled_req_ids
if
self
.
last_sampler_host_tokens
!=
None
:
self
.
last_sampler_event
.
synchronize
()
fix_sampled_token_ids
=
self
.
last_sampler_host_tokens
.
tolist
()
for
req_idx
,
start_idx
,
end_idx
in
self
.
token_ids_cpu_fix_recode
:
self
.
input_batch
.
token_ids_cpu
[
req_idx
,
start_idx
:
end_idx
]
=
fix_sampled_token_ids
[
req_idx
]
for
req_idx
,
req_id
in
enumerate
(
fix_req_ids
):
if
req_id
in
self
.
requests
:
req_state
=
self
.
requests
[
req_id
]
token_idx
=
self
.
last_sampled_token_lens
[
req_idx
]
req_state
.
output_token_ids
[
token_idx
]
=
fix_sampled_token_ids
[
req_idx
][
0
]
self
.
last_sampler_host_tokens
=
sampled_token_ids
.
to
(
'cpu'
,
non_blocking
=
True
)
self
.
last_sampler_event
.
record
()
self
.
last_sampled_token_ids
=
sampled_token_ids
valid_sampled_token_ids
=
np
.
ones
(
sampled_token_ids
.
shape
,
dtype
=
int
).
tolist
()
# Mask out the sampled tokens that should not be sampled.
for
i
in
discard_sampled_tokens_req_indices
:
valid_sampled_token_ids
[
i
].
clear
()
# Cache the sampled tokens in the model runner, so that the scheduler
# doesn't need to send them back.
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
# the sampled tokens back, because there's no direct communication
# between the first-stage worker and the last-stage worker.
self
.
token_ids_cpu_fix_recode
.
clear
()
self
.
last_sampled_req_ids
=
[]
self
.
last_sampled_token_lens
=
[]
for
req_idx
,
sampled_ids
in
enumerate
(
valid_sampled_token_ids
):
if
not
sampled_ids
:
continue
start_idx
=
self
.
input_batch
.
num_tokens_no_spec
[
req_idx
]
end_idx
=
start_idx
+
len
(
sampled_ids
)
assert
end_idx
<=
self
.
max_model_len
,
(
"Sampled token IDs exceed the max model length. "
f
"Total number of tokens:
{
end_idx
}
> max_model_len: "
f
"
{
self
.
max_model_len
}
"
)
self
.
input_batch
.
token_ids_cpu
[
req_idx
,
start_idx
:
end_idx
]
=
sampled_ids
self
.
token_ids_cpu_fix_recode
.
append
([
req_idx
,
start_idx
,
end_idx
])
self
.
input_batch
.
num_tokens_no_spec
[
req_idx
]
=
end_idx
self
.
input_batch
.
num_tokens
[
req_idx
]
=
end_idx
req_id
=
self
.
input_batch
.
req_ids
[
req_idx
]
if
req_id
in
self
.
requests
:
req_state
=
self
.
requests
[
req_id
]
self
.
last_sampled_req_ids
.
append
(
req_id
)
self
.
last_sampled_token_lens
.
append
(
len
(
req_state
.
output_token_ids
))
req_state
.
output_token_ids
.
extend
(
sampled_ids
)
if
not
self
.
speculative_config
:
# Speculative decoding is not enabled.
spec_token_ids
=
None
fix_draft_req_ids
=
None
else
:
if
self
.
last_draft_host_tokens
is
not
None
:
self
.
last_draft_event
.
synchronize
()
fix_draft_token_ids
=
self
.
last_draft_host_tokens
.
tolist
()
spec_token_ids
=
self
.
propose_draft_token_ids
(
scheduler_output
,
valid_sampled_token_ids
,
sampling_metadata
,
hidden_states
,
sample_hidden_states
,
aux_hidden_states
,
spec_decode_metadata
,
attn_metadata
,
)
# Clear KVConnector state after all KVs are generated.
if
has_kv_transfer_group
():
get_kv_transfer_group
().
clear_connector_metadata
()
self
.
eplb_step
()
model_output
=
ZeroV1ModelRunnerOutput
(
req_ids
=
self
.
input_batch
.
req_ids
,
req_id_to_index
=
self
.
input_batch
.
req_id_to_index
,
sampled_token_ids
=
valid_sampled_token_ids
,
spec_token_ids
=
spec_token_ids
,
logprobs
=
logprobs_lists
,
prompt_logprobs_dict
=
prompt_logprobs_dict
,
pooler_output
=
[],
finished_sending
=
finished_sending
,
finished_recving
=
finished_recving
,
num_nans_in_logits
=
num_nans_in_logits
,
fix_req_ids
=
fix_req_ids
,
fix_sampled_token_ids
=
fix_sampled_token_ids
,
fix_draft_tokens_ids
=
fix_draft_token_ids
,
fix_draft_req_ids
=
fix_draft_req_ids
,
is_output_valid
=
is_output_valid
)
return
model_output
\ No newline at end of file
vllm/zero_overhead/v1/outputs.py
View file @
22a95571
...
...
@@ -6,4 +6,7 @@ from vllm.v1.outputs import ModelRunnerOutput
class
ZeroV1ModelRunnerOutput
(
ModelRunnerOutput
):
# [num_reqs]
fix_req_ids
:
list
[
str
]
=
None
fix_sampled_token_ids
:
list
[
list
[
int
]]
=
None
\ No newline at end of file
fix_sampled_token_ids
:
list
[
list
[
int
]]
=
None
fix_draft_req_ids
:
list
[
list
[
int
]]
=
None
fix_draft_tokens_ids
:
list
[
list
[
int
]]
=
None
is_output_valid
:
bool
=
True
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment