Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
22a95571
Commit
22a95571
authored
Aug 18, 2025
by
zhuwenwen
Browse files
add v1 engine + deepseek r1 mtp + zero-overhead scheduler
parent
ac4cc84e
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
727 additions
and
286 deletions
+727
-286
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+15
-4
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+20
-28
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+7
-2
vllm/zero_overhead/v1/core.py
vllm/zero_overhead/v1/core.py
+148
-111
vllm/zero_overhead/v1/gpu_model_runner.py
vllm/zero_overhead/v1/gpu_model_runner.py
+533
-140
vllm/zero_overhead/v1/outputs.py
vllm/zero_overhead/v1/outputs.py
+4
-1
No files found.
vllm/v1/attention/backends/mla/common.py
View file @
22a95571
...
...
@@ -764,10 +764,21 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
repeats
=
torch
.
from_numpy
(
query_lens
).
pin_memory
().
to
(
block_table_tensor
.
device
,
non_blocking
=
True
).
contiguous
()
decode_block_table_tensor
=
torch
.
repeat_interleave
(
block_table_tensor
[:
num_decodes
,
...],
repeats
,
dim
=
0
).
contiguous
()
decode_seq_lens
=
torch
.
repeat_interleave
(
seq_lens
[:
num_decodes
],
repeats
,
dim
=
0
).
contiguous
()
if
envs
.
VLLM_ZERO_OVERHEAD
:
decode_block_table_tensor
=
torch
.
empty
((
self
.
_num_decode_tokens
,
block_table_tensor
.
shape
[
1
]),
device
=
block_table_tensor
.
device
)
arange_np
=
np
.
arange
(
self
.
_num_decodes
)
indices_np
=
np
.
repeat
(
arange_np
,
query_lens
)
indices
=
torch
.
from_numpy
(
indices_np
).
pin_memory
().
to
(
block_table_tensor
.
device
,
non_blocking
=
True
)
decode_block_table_tensor
=
block_table_tensor
[
indices
].
contiguous
()
decode_seq_lens
=
seq_lens
[
indices
].
contiguous
()
else
:
decode_block_table_tensor
=
torch
.
repeat_interleave
(
block_table_tensor
[:
self
.
_num_decodes
,
...],
repeats
,
dim
=
0
).
contiguous
()
decode_seq_lens
=
torch
.
repeat_interleave
(
seq_lens
[:
self
.
_num_decodes
],
repeats
,
dim
=
0
).
contiguous
()
seq_lens_minus
=
torch
.
from_numpy
(
rarange
).
to
(
torch
.
int32
).
pin_memory
().
to
(
seq_lens
.
device
,
non_blocking
=
True
).
contiguous
()
decode_seq_lens
=
decode_seq_lens
-
seq_lens_minus
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
22a95571
...
...
@@ -69,7 +69,6 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
from
vllm.platforms
import
current_platform
from
vllm.two_batch_overlap.v1.model_input_split_v1
import
tbo_split_and_execute_model
from
vllm.zero_overhead.v1.gpu_model_runner
import
execute_model_sampled
,
zero_prepare_inputs
from
..sample.logits_processor
import
LogitsProcessorManager
from
.utils
import
(
bind_kv_cache
,
gather_mm_placeholders
,
...
...
@@ -1020,15 +1019,26 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# [0, 1, 2, 5, 6, 9]
target_logits_indices
+=
arange
# TODO: Optimize the CPU -> GPU copy.
cu_num_draft_tokens
=
torch
.
from_numpy
(
cu_num_draft_tokens
).
to
(
self
.
device
,
non_blocking
=
True
)
logits_indices
=
torch
.
from_numpy
(
logits_indices
).
to
(
self
.
device
,
non_blocking
=
True
)
target_logits_indices
=
torch
.
from_numpy
(
target_logits_indices
).
to
(
self
.
device
,
non_blocking
=
True
)
bonus_logits_indices
=
torch
.
from_numpy
(
bonus_logits_indices
).
to
(
self
.
device
,
non_blocking
=
True
)
if
envs
.
VLLM_ZERO_OVERHEAD
:
cu_num_draft_tokens
=
torch
.
from_numpy
(
cu_num_draft_tokens
).
pin_memory
().
to
(
self
.
device
,
non_blocking
=
True
)
logits_indices
=
torch
.
from_numpy
(
logits_indices
).
pin_memory
().
to
(
self
.
device
,
non_blocking
=
True
)
target_logits_indices
=
torch
.
from_numpy
(
target_logits_indices
).
pin_memory
().
to
(
self
.
device
,
non_blocking
=
True
)
bonus_logits_indices
=
torch
.
from_numpy
(
bonus_logits_indices
).
pin_memory
().
to
(
self
.
device
,
non_blocking
=
True
)
else
:
# TODO: Optimize the CPU -> GPU copy.
cu_num_draft_tokens
=
torch
.
from_numpy
(
cu_num_draft_tokens
).
to
(
self
.
device
,
non_blocking
=
True
)
logits_indices
=
torch
.
from_numpy
(
logits_indices
).
to
(
self
.
device
,
non_blocking
=
True
)
target_logits_indices
=
torch
.
from_numpy
(
target_logits_indices
).
to
(
self
.
device
,
non_blocking
=
True
)
bonus_logits_indices
=
torch
.
from_numpy
(
bonus_logits_indices
).
to
(
self
.
device
,
non_blocking
=
True
)
# Compute the draft token ids.
# draft_token_indices: [ 1, 2, 3, 105, 106, 208]
...
...
@@ -1440,9 +1450,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# If attention doesn't support CUDA Graphs for this batch, but we
# compiled with full CUDA graphs, we have to skip them entirely.
skip_cuda_graphs
=
self
.
full_cuda_graph
and
not
attention_cuda_graphs
if
envs
.
VLLM_ZERO_OVERHEAD
:
zero_prepare_inputs
(
self
,
scheduler_output
,
input_ids
)
if
envs
.
VLLM_ENABLE_TBO
and
not
self
.
use_cuda_graph
:
model_output
,
finished_sending
,
finished_recving
=
\
...
...
@@ -1591,21 +1598,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Get the valid generated tokens.
sampled_token_ids
=
sampler_output
.
sampled_token_ids
max_gen_len
=
sampled_token_ids
.
shape
[
-
1
]
if
envs
.
VLLM_ZERO_OVERHEAD
:
return
execute_model_sampled
(
self
,
max_gen_len
,
sampled_token_ids
,
discard_sampled_tokens_req_indices
,
scheduler_output
,
sampling_metadata
,
hidden_states
,
sample_hidden_states
,
aux_hidden_states
,
spec_decode_metadata
,
attn_metadata
,
logprobs_lists
,
prompt_logprobs_dict
,
finished_sending
,
finished_recving
,
num_nans_in_logits
)
if
max_gen_len
==
1
:
# No spec decode tokens.
...
...
vllm/v1/worker/gpu_worker.py
View file @
22a95571
...
...
@@ -33,6 +33,7 @@ from vllm.v1.utils import report_usage_stats
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
from
vllm.v1.worker.worker_base
import
WorkerBase
from
vllm.zero_overhead.utils
import
zero_overhead_stream
from
vllm.zero_overhead.v1.gpu_model_runner
import
V1ZeroModelRunner
logger
=
init_logger
(
__name__
)
...
...
@@ -187,8 +188,12 @@ class Worker(WorkerBase):
set_random_seed
(
self
.
model_config
.
seed
)
# Construct the model runner
self
.
model_runner
:
GPUModelRunner
=
GPUModelRunner
(
self
.
vllm_config
,
self
.
device
)
if
envs
.
VLLM_ZERO_OVERHEAD
:
self
.
model_runner
:
GPUModelRunner
=
V1ZeroModelRunner
(
self
.
vllm_config
,
self
.
device
)
else
:
self
.
model_runner
:
GPUModelRunner
=
GPUModelRunner
(
self
.
vllm_config
,
self
.
device
)
if
self
.
rank
==
0
:
# If usage stat is enabled, collect relevant info.
...
...
vllm/zero_overhead/v1/core.py
View file @
22a95571
...
...
@@ -12,11 +12,15 @@ requsets_valid_token_len = {}
def
check_stop
(
request
:
Request
,
max_model_len
:
int
,
pooler_output
:
Optional
[
torch
.
Tensor
]
=
None
)
->
bool
:
if
request
.
request_id
not
in
requsets_valid_token_len
:
requsets_valid_token_len
[
request
.
request_id
]
=
0
return
False
valid_output_len
=
requsets_valid_token_len
[
request
.
request_id
]
pooler_output
:
Optional
[
torch
.
Tensor
]
=
None
,
use_valid_token_len
:
bool
=
False
)
->
bool
:
if
use_valid_token_len
:
if
request
.
request_id
not
in
requsets_valid_token_len
:
requsets_valid_token_len
[
request
.
request_id
]
=
0
return
False
valid_output_len
=
requsets_valid_token_len
[
request
.
request_id
]
else
:
valid_output_len
=
request
.
num_output_tokens
valid_num_tokens
=
request
.
num_prompt_tokens
+
valid_output_len
if
(
valid_num_tokens
>=
max_model_len
or
valid_output_len
>=
request
.
max_tokens
):
...
...
@@ -60,110 +64,118 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
spec_decoding_stats
:
Optional
[
SpecDecodingStats
]
=
None
# fix last model out in zero overhead
for
req_idx
,
req_id
in
enumerate
(
model_runner_output
.
fix_req_ids
):
if
req_id
not
in
scheduler
.
requests
:
continue
request
=
scheduler
.
requests
[
req_id
]
generated_token_ids
=
model_runner_output
.
fix_sampled_token_ids
[
req_idx
]
if
req_id
not
in
requsets_valid_token_len
:
requsets_valid_token_len
[
req_id
]
=
0
valid_output_len
=
requsets_valid_token_len
[
req_id
]
fix_offset
=
valid_output_len
-
request
.
num_output_tokens
if
isinstance
(
generated_token_ids
,
int
):
request
.
_output_token_ids
[
fix_offset
]
=
generated_token_ids
request
.
_all_token_ids
[
fix_offset
]
=
generated_token_ids
requsets_valid_token_len
[
req_id
]
+=
1
else
:
valid_output_end
=
valid_output_len
+
len
(
generated_token_ids
)
-
request
.
num_output_tokens
if
valid_output_end
==
0
:
request
.
_output_token_ids
[
fix_offset
:
]
=
generated_token_ids
request
.
_all_token_ids
[
fix_offset
:
]
=
generated_token_ids
if
model_runner_output
.
fix_req_ids
is
not
None
:
for
req_idx
,
req_id
in
enumerate
(
model_runner_output
.
fix_req_ids
):
if
req_id
not
in
scheduler
.
requests
:
continue
request
=
scheduler
.
requests
[
req_id
]
generated_token_ids
=
model_runner_output
.
fix_sampled_token_ids
[
req_idx
]
if
req_id
not
in
requsets_valid_token_len
:
requsets_valid_token_len
[
req_id
]
=
0
valid_output_len
=
requsets_valid_token_len
[
req_id
]
fix_offset
=
valid_output_len
-
request
.
num_output_tokens
if
isinstance
(
generated_token_ids
,
int
):
request
.
_output_token_ids
[
fix_offset
]
=
generated_token_ids
request
.
_all_token_ids
[
fix_offset
]
=
generated_token_ids
requsets_valid_token_len
[
req_id
]
+=
1
else
:
request
.
_output_token_ids
[
fix_offset
:
valid_output_end
]
=
generated_token_ids
request
.
_all_token_ids
[
fix_offset
:
valid_output_end
]
=
generated_token_ids
requsets_valid_token_len
[
req_id
]
+=
len
(
generated_token_ids
)
valid_output_end
=
valid_output_len
+
len
(
generated_token_ids
)
-
request
.
num_output_tokens
if
valid_output_end
==
0
:
request
.
_output_token_ids
[
fix_offset
:
]
=
generated_token_ids
request
.
_all_token_ids
[
fix_offset
:
]
=
generated_token_ids
else
:
request
.
_output_token_ids
[
fix_offset
:
valid_output_end
]
=
generated_token_ids
request
.
_all_token_ids
[
fix_offset
:
valid_output_end
]
=
generated_token_ids
requsets_valid_token_len
[
req_id
]
+=
len
(
generated_token_ids
)
stopped
=
False
new_logprobs
=
None
new_token_ids
=
generated_token_ids
kv_transfer_params
=
None
stopped
=
False
new_logprobs
=
None
new_token_ids
=
generated_token_ids
kv_transfer_params
=
None
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
for
num_new
,
output_token_id
in
enumerate
(
new_token_ids
,
1
):
stopped
=
check_stop
(
request
,
scheduler
.
max_model_len
)
if
stopped
:
kv_transfer_params
=
scheduler
.
_free_request
(
request
)
del
new_token_ids
[
num_new
:]
# Trim new tokens if needed.
break
pooler_output
=
None
if
pooler_outputs
:
pooler_output
=
pooler_outputs
[
req_index
]
stopped
=
check_stop
(
request
,
scheduler
.
max_model_len
,
pooler_output
)
if
stopped
:
kv_transfer_params
=
scheduler
.
_free_request
(
request
)
# Extract sample logprobs if needed.
if
request
.
sampling_params
is
not
None
\
and
request
.
sampling_params
.
logprobs
is
not
None
and
logprobs
:
# NOTE: once we support N tokens per step (spec decode),
# the outer lists can be of length > 1.
new_logprobs
=
logprobs
.
slice
(
req_index
,
req_index
+
1
)
if
new_token_ids
and
scheduler
.
structured_output_manager
.
should_advance
(
request
):
# NOTE: structured_output_request
# should not be None if use_structured_output, we have
# check above, so safe to ignore type warning
request
.
structured_output_request
.
grammar
.
accept_tokens
(
# type: ignore[union-attr]
req_id
,
new_token_ids
)
# spec_token_ids comes from the model runner output
if
num_nans_in_logits
is
not
None
and
req_id
in
num_nans_in_logits
:
request
.
num_nans_in_logits
=
num_nans_in_logits
[
req_id
]
# Add newly generated spec token ids to the request.
if
spec_token_ids
is
not
None
:
if
scheduler
.
structured_output_manager
.
should_advance
(
request
):
metadata
=
request
.
structured_output_request
# Needs to happen after new_token_ids are accepted.
request
.
spec_token_ids
=
metadata
.
grammar
.
validate_tokens
(
# type: ignore[union-attr]
spec_token_ids
[
req_index
])
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
for
num_new
,
output_token_id
in
enumerate
(
new_token_ids
,
1
):
stopped
=
check_stop
(
request
,
scheduler
.
max_model_len
,
True
)
if
stopped
:
kv_transfer_params
=
scheduler
.
_free_request
(
request
)
del
new_token_ids
[
num_new
:]
# Trim new tokens if needed.
break
pooler_output
=
None
if
pooler_outputs
:
pooler_output
=
pooler_outputs
[
req_index
]
stopped
=
check_stop
(
request
,
scheduler
.
max_model_len
,
pooler_output
,
True
)
if
stopped
:
kv_transfer_params
=
scheduler
.
_free_request
(
request
)
# Extract sample logprobs if needed.
if
request
.
sampling_params
is
not
None
\
and
request
.
sampling_params
.
logprobs
is
not
None
and
logprobs
:
# NOTE: once we support N tokens per step (spec decode),
# the outer lists can be of length > 1.
new_logprobs
=
logprobs
.
slice
(
req_index
,
req_index
+
1
)
if
new_token_ids
and
scheduler
.
structured_output_manager
.
should_advance
(
request
):
# NOTE: structured_output_request
# should not be None if use_structured_output, we have
# check above, so safe to ignore type warning
request
.
structured_output_request
.
grammar
.
accept_tokens
(
# type: ignore[union-attr]
req_id
,
new_token_ids
)
# spec_token_ids comes from the model runner output
if
num_nans_in_logits
is
not
None
and
req_id
in
num_nans_in_logits
:
request
.
num_nans_in_logits
=
num_nans_in_logits
[
req_id
]
# Get prompt logprobs for this request.
prompt_logprobs_tensors
=
prompt_logprobs_dict
.
get
(
req_id
)
if
new_token_ids
or
pooler_output
is
not
None
\
or
kv_transfer_params
:
# Add EngineCoreOutput for this Request.
outputs
[
request
.
client_index
].
append
(
EngineCoreOutput
(
request_id
=
req_id
,
new_token_ids
=
new_token_ids
,
finish_reason
=
request
.
get_finished_reason
(),
new_logprobs
=
new_logprobs
,
new_prompt_logprobs_tensors
=
prompt_logprobs_tensors
,
pooling_output
=
pooler_output
,
stop_reason
=
request
.
stop_reason
,
events
=
request
.
take_events
(),
kv_transfer_params
=
kv_transfer_params
,
num_cached_tokens
=
request
.
num_cached_tokens
,
))
else
:
request
.
spec_token_ids
=
spec_token_ids
[
req_index
]
# Get prompt logprobs for this request.
prompt_logprobs_tensors
=
prompt_logprobs_dict
.
get
(
req_id
)
if
new_token_ids
or
pooler_output
is
not
None
\
or
kv_transfer_params
:
# Add EngineCoreOutput for this Request.
outputs
[
request
.
client_index
].
append
(
EngineCoreOutput
(
request_id
=
req_id
,
new_token_ids
=
new_token_ids
,
finish_reason
=
request
.
get_finished_reason
(),
new_logprobs
=
new_logprobs
,
new_prompt_logprobs_tensors
=
prompt_logprobs_tensors
,
pooling_output
=
pooler_output
,
stop_reason
=
request
.
stop_reason
,
events
=
request
.
take_events
(),
kv_transfer_params
=
kv_transfer_params
,
num_cached_tokens
=
request
.
num_cached_tokens
,
))
else
:
# Invariant: EngineCore returns no partial prefill outputs.
assert
not
prompt_logprobs_tensors
assert
not
prompt_logprobs_tensors
# fix last model out in zero overhead
if
model_runner_output
.
fix_draft_req_ids
is
not
None
:
for
req_idx
,
req_id
in
enumerate
(
model_runner_output
.
fix_draft_req_ids
):
if
req_id
not
in
scheduler
.
requests
:
continue
request
=
scheduler
.
requests
[
req_id
]
# Add newly generated spec token ids to the request.
if
model_runner_output
.
fix_draft_tokens_ids
is
not
None
:
if
scheduler
.
structured_output_manager
.
should_advance
(
request
):
metadata
=
request
.
structured_output_request
# Needs to happen after new_token_ids are accepted.
request
.
spec_token_ids
=
metadata
.
grammar
.
validate_tokens
(
# type: ignore[union-attr]
model_runner_output
.
fix_draft_tokens_ids
[
req_idx
])
else
:
request
.
spec_token_ids
=
model_runner_output
.
fix_draft_tokens_ids
[
req_idx
]
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
# loop can be a performance bottleneck. We should do our best to avoid
# expensive operations inside the loop.
for
request
in
scheduler
.
running
:
if
request
.
is_finished
():
continue
req_id
=
request
.
request_id
num_tokens_scheduled
=
num_scheduled_tokens
.
get
(
req_id
,
0
)
if
num_tokens_scheduled
==
0
:
...
...
@@ -197,7 +209,6 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
if
request
.
has_encoder_inputs
:
scheduler
.
_free_encoder_inputs
(
request
)
stopped
=
False
new_logprobs
=
None
new_token_ids
=
generated_token_ids
kv_transfer_params
=
None
...
...
@@ -210,19 +221,24 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
stopped
=
check_stop
(
request
,
scheduler
.
max_model_len
)
# if stopped:
# kv_transfer_params = scheduler._free_request(request)
# del new_token_ids[num_new:] # Trim new tokens if needed.
# break
if
model_runner_output
.
is_output_valid
:
stopped
=
check_stop
(
request
,
scheduler
.
max_model_len
,
False
)
if
stopped
:
kv_transfer_params
=
scheduler
.
_free_request
(
request
)
del
new_token_ids
[
num_new
:]
# Trim new tokens if needed.
break
pooler_output
=
None
if
pooler_outputs
:
pooler_output
=
pooler_outputs
[
req_index
]
stopped
=
check_stop
(
request
,
scheduler
.
max_model_len
,
pooler_output
)
# if stopped:
# kv_transfer_params = scheduler._free_request(request)
if
model_runner_output
.
is_output_valid
:
pooler_output
=
pooler_outputs
[
req_index
]
stopped
=
check_stop
(
request
,
scheduler
.
max_model_len
,
pooler_output
,
False
)
if
stopped
:
kv_transfer_params
=
scheduler
.
_free_request
(
request
)
# Extract sample logprobs if needed.
if
request
.
sampling_params
is
not
None
\
...
...
@@ -252,6 +268,27 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
spec_token_ids
[
req_index
])
else
:
request
.
spec_token_ids
=
spec_token_ids
[
req_index
]
if
model_runner_output
.
is_output_valid
:
# # Get prompt logprobs for this request.
prompt_logprobs_tensors
=
prompt_logprobs_dict
.
get
(
req_id
)
if
new_token_ids
or
pooler_output
is
not
None
\
or
kv_transfer_params
:
# Add EngineCoreOutput for this Request.
outputs
[
request
.
client_index
].
append
(
EngineCoreOutput
(
request_id
=
req_id
,
new_token_ids
=
new_token_ids
,
finish_reason
=
request
.
get_finished_reason
(),
new_logprobs
=
new_logprobs
,
new_prompt_logprobs_tensors
=
prompt_logprobs_tensors
,
pooling_output
=
pooler_output
,
stop_reason
=
request
.
stop_reason
,
events
=
request
.
take_events
(),
kv_transfer_params
=
kv_transfer_params
,
num_cached_tokens
=
request
.
num_cached_tokens
,
))
if
not
stopped
:
new_running
.
append
(
request
)
...
...
vllm/zero_overhead/v1/gpu_model_runner.py
View file @
22a95571
This diff is collapsed.
Click to expand it.
vllm/zero_overhead/v1/outputs.py
View file @
22a95571
...
...
@@ -6,4 +6,7 @@ from vllm.v1.outputs import ModelRunnerOutput
class
ZeroV1ModelRunnerOutput
(
ModelRunnerOutput
):
# [num_reqs]
fix_req_ids
:
list
[
str
]
=
None
fix_sampled_token_ids
:
list
[
list
[
int
]]
=
None
\ No newline at end of file
fix_sampled_token_ids
:
list
[
list
[
int
]]
=
None
fix_draft_req_ids
:
list
[
list
[
int
]]
=
None
fix_draft_tokens_ids
:
list
[
list
[
int
]]
=
None
is_output_valid
:
bool
=
True
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment