Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a183111e
Commit
a183111e
authored
Nov 24, 2025
by
lizhigong
Browse files
新增pp2零消耗调度分支
parent
83c1f04a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
1183 additions
and
0 deletions
+1183
-0
vllm/zero_overhead/v1/PP2mtp/gpu_model_runner.py
vllm/zero_overhead/v1/PP2mtp/gpu_model_runner.py
+749
-0
vllm/zero_overhead/v1/PP2mtp/gpu_worker.py
vllm/zero_overhead/v1/PP2mtp/gpu_worker.py
+420
-0
vllm/zero_overhead/v1/PP2mtp/outputs.py
vllm/zero_overhead/v1/PP2mtp/outputs.py
+14
-0
No files found.
vllm/zero_overhead/v1/PP2mtp/gpu_model_runner.py
0 → 100644
View file @
a183111e
from
typing
import
Any
,
Optional
,
Union
import
torch
import
numpy
as
np
from
vllm
import
envs
from
vllm.distributed.kv_transfer.kv_transfer_state
import
get_kv_transfer_group
,
has_kv_transfer_group
from
vllm.distributed.parallel_state
import
get_pp_group
,
get_tp_group
from
vllm.forward_context
import
set_forward_context
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
async_tensor_h2d
,
round_up
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.outputs
import
EMPTY_MODEL_RUNNER_OUTPUT
,
ModelRunnerOutput
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.spec_decode.eagle
import
EagleProposer
from
vllm.v1.spec_decode.medusa
import
MedusaProposer
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
from
vllm.zero_overhead.v1.eagle
import
V1ZeroEagleProposer
from
vllm.zero_overhead.v1.outputs
import
ZeroV1ModelRunnerOutput
from
vllm.profiler.prof
import
profile
from
vllm.two_batch_overlap.v1.model_input_split_v1
import
tbo_split_and_execute_model
class
V1ZeroModelRunner
(
GPUModelRunner
):
def
__init__
(
self
,
vllm_config
,
device
):
super
().
__init__
(
vllm_config
,
device
)
self
.
last_sampled_token_ids
=
None
self
.
last_sampled_req_ids
=
[]
self
.
last_sampled_token_lens
=
[]
self
.
last_sampler_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
last_sampler_host_tokens
=
None
self
.
token_ids_cpu_fix_record
=
[]
self
.
last_draft_token_ids
=
None
self
.
last_draft_host_tokens
=
None
self
.
last_draft_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
spec_sampler_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
spec_scheduler_max_num_tokens
=
0
if
hasattr
(
self
,
'drafter'
)
and
isinstance
(
self
.
drafter
,
EagleProposer
):
self
.
drafter
=
V1ZeroEagleProposer
(
self
.
vllm_config
,
self
.
device
,
self
)
def
_prepare_inputs
(
self
,
scheduler_output
:
"SchedulerOutput"
,
)
->
tuple
[
dict
[
str
,
Any
],
bool
,
torch
.
Tensor
,
Optional
[
SpecDecodeMetadata
],
np
.
ndarray
]:
"""
:return: tuple[
attn_metadata: layer-to-attention_metadata mapping,
attention_cuda_graphs: whether attention can run in cudagraph
logits_indices, spec_decode_metadata
]
"""
total_num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
assert
total_num_scheduled_tokens
>
0
num_reqs
=
self
.
input_batch
.
num_reqs
assert
num_reqs
>
0
# OPTIMIZATION: Start copying the block table first.
# This way, we can overlap the copy with the following CPU operations.
self
.
input_batch
.
block_table
.
commit
(
num_reqs
)
# Get the number of scheduled tokens for each request.
req_ids
=
self
.
input_batch
.
req_ids
tokens
=
[
scheduler_output
.
num_scheduled_tokens
[
i
]
for
i
in
req_ids
]
num_scheduled_tokens
=
np
.
array
(
tokens
,
dtype
=
np
.
int32
)
max_num_scheduled_tokens
=
max
(
tokens
)
self
.
spec_scheduler_max_num_tokens
=
max_num_scheduled_tokens
# Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
req_indices
=
np
.
repeat
(
self
.
arange_np
[:
num_reqs
],
num_scheduled_tokens
)
# cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
# arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
cu_num_tokens
,
arange
=
self
.
_get_cumsum_and_arange
(
num_scheduled_tokens
)
# Get positions.
positions_np
=
self
.
positions_np
[:
total_num_scheduled_tokens
]
np
.
add
(
self
.
input_batch
.
num_computed_tokens_cpu
[
req_indices
],
arange
,
out
=
positions_np
)
# Calculate M-RoPE positions.
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if
self
.
uses_mrope
:
self
.
_calc_mrope_positions
(
scheduler_output
)
# Get token indices.
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
# where M is the max_model_len.
token_indices
=
(
positions_np
+
req_indices
*
self
.
input_batch
.
token_ids_cpu
.
shape
[
1
])
# NOTE(woosuk): We use torch.index_select instead of np.take here
# because torch.index_select is much faster than np.take for large
# tensors.
torch
.
index_select
(
self
.
input_batch
.
token_ids_cpu_tensor
.
flatten
(),
0
,
torch
.
from_numpy
(
token_indices
),
out
=
self
.
input_ids_cpu
[:
total_num_scheduled_tokens
])
# Calculate the slot mapping for each KV cache group.
for
kv_cache_group_id
,
kv_cache_group_spec
in
enumerate
(
self
.
kv_cache_config
.
kv_cache_groups
):
block_size
=
kv_cache_group_spec
.
kv_cache_spec
.
block_size
block_table
:
BlockTable
=
self
.
input_batch
.
block_table
[
kv_cache_group_id
]
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
# where K is the max_num_blocks_per_req and the block size is 2.
# NOTE(woosuk): We can't simply use `token_indices // block_size`
# here because M (max_model_len) is not necessarily divisible by
# block_size.
block_table_indices
=
(
req_indices
*
block_table
.
max_num_blocks_per_req
+
positions_np
//
block_size
)
block_table_cpu
=
block_table
.
get_cpu_tensor
()
block_numbers
=
block_table_cpu
.
flatten
(
)[
block_table_indices
].
numpy
()
block_offsets
=
positions_np
%
block_size
np
.
add
(
block_numbers
*
block_size
,
block_offsets
,
out
=
block_table
.
slot_mapping_np
[:
total_num_scheduled_tokens
])
# Prepare the attention metadata.
self
.
query_start_loc_np
[
0
]
=
0
self
.
query_start_loc_np
[
1
:
num_reqs
+
1
]
=
cu_num_tokens
self
.
seq_lens_np
[:
num_reqs
]
=
(
self
.
input_batch
.
num_computed_tokens_cpu
[:
num_reqs
]
+
num_scheduled_tokens
)
# Copy the tensors to the GPU.
self
.
input_ids
[:
total_num_scheduled_tokens
].
copy_
(
self
.
input_ids_cpu
[:
total_num_scheduled_tokens
],
non_blocking
=
True
)
self
.
zero_prepare_inputs
(
scheduler_output
,
self
.
input_ids
)
if
self
.
uses_mrope
:
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
self
.
mrope_positions
[:,
:
total_num_scheduled_tokens
].
copy_
(
self
.
mrope_positions_cpu
[:,
:
total_num_scheduled_tokens
],
non_blocking
=
True
)
else
:
# Common case (1D positions)
self
.
positions
[:
total_num_scheduled_tokens
].
copy_
(
self
.
positions_cpu
[:
total_num_scheduled_tokens
],
non_blocking
=
True
)
self
.
query_start_loc
[:
num_reqs
+
1
].
copy_
(
self
.
query_start_loc_cpu
[:
num_reqs
+
1
],
non_blocking
=
True
)
self
.
seq_lens
[:
num_reqs
].
copy_
(
self
.
seq_lens_cpu
[:
num_reqs
],
non_blocking
=
True
)
# Fill unused with -1. Needed for reshape_and_cache
self
.
seq_lens
[
num_reqs
:].
fill_
(
0
)
# Note: pad query_start_loc to be non-decreasing, as kernels
# like FlashAttention requires that
self
.
query_start_loc
[
num_reqs
+
1
:].
fill_
(
self
.
query_start_loc_cpu
[
num_reqs
].
item
())
query_start_loc
=
self
.
query_start_loc
[:
num_reqs
+
1
]
seq_lens
=
self
.
seq_lens
[:
num_reqs
]
common_attn_metadata
=
CommonAttentionMetadata
(
query_start_loc
=
query_start_loc
,
seq_lens
=
seq_lens
,
num_reqs
=
num_reqs
,
num_actual_tokens
=
total_num_scheduled_tokens
,
max_query_len
=
max_num_scheduled_tokens
,
)
attn_metadata
:
dict
[
str
,
Any
]
=
{}
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
for
kv_cache_group_id
,
kv_cache_group_spec
in
enumerate
(
self
.
kv_cache_config
.
kv_cache_groups
):
# Prepare for cascade attention if enabled & beneficial.
common_prefix_len
=
0
builder
=
self
.
attn_metadata_builders
[
kv_cache_group_id
]
if
self
.
cascade_attn_enabled
:
common_prefix_len
=
self
.
_compute_cascade_attn_prefix_len
(
num_scheduled_tokens
,
scheduler_output
.
num_common_prefix_blocks
[
kv_cache_group_id
],
kv_cache_group_spec
.
kv_cache_spec
,
builder
,
)
attn_metadata_i
=
(
builder
.
build
(
common_prefix_len
=
common_prefix_len
,
common_attn_metadata
=
common_attn_metadata
,
))
for
layer_name
in
kv_cache_group_spec
.
layer_names
:
attn_metadata
[
layer_name
]
=
attn_metadata_i
attention_cuda_graphs
=
all
(
b
.
can_run_in_cudagraph
(
common_attn_metadata
)
for
b
in
self
.
attn_metadata_builders
)
use_spec_decode
=
len
(
scheduler_output
.
scheduled_spec_decode_tokens
)
>
0
if
not
use_spec_decode
:
# NOTE(woosuk): Due to chunked prefills, the batch may contain
# partial requests. While we should not sample any token
# from these partial requests, we do so for simplicity.
# We will ignore the sampled tokens from the partial requests.
# TODO: Support prompt logprobs.
logits_indices
=
query_start_loc
[
1
:]
-
1
spec_decode_metadata
=
None
else
:
# Get the number of draft tokens for each request.
# Iterate over the dictionary rather than all requests since not all
# requests have draft tokens.
num_draft_tokens
=
np
.
zeros
(
num_reqs
,
dtype
=
np
.
int32
)
for
req_id
,
draft_token_ids
in
(
scheduler_output
.
scheduled_spec_decode_tokens
.
items
()):
req_idx
=
self
.
input_batch
.
req_id_to_index
[
req_id
]
num_draft_tokens
[
req_idx
]
=
len
(
draft_token_ids
)
spec_decode_metadata
=
self
.
_calc_spec_decode_metadata
(
num_draft_tokens
,
cu_num_tokens
)
logits_indices
=
spec_decode_metadata
.
logits_indices
# Hot-Swap lora model
if
self
.
lora_config
:
self
.
set_active_loras
(
self
.
input_batch
,
num_scheduled_tokens
)
return
(
attn_metadata
,
attention_cuda_graphs
,
logits_indices
,
spec_decode_metadata
,
num_scheduled_tokens
)
def
zero_prepare_inputs
(
self
,
scheduler_output
,
input_ids
):
req_ids
=
self
.
input_batch
.
req_ids
update_req_indices
=
[]
input_ids_indices
=
[]
token_idx
=
0
if
self
.
last_draft_token_ids
is
not
None
:
draft_tokens_num
=
self
.
last_draft_token_ids
.
shape
[
1
]
for
req_id
in
req_ids
:
if
req_id
in
self
.
last_sampled_req_ids
:
req_idx
=
self
.
last_sampled_req_ids
.
index
(
req_id
)
*
draft_tokens_num
for
num_idx
in
range
(
draft_tokens_num
):
update_req_indices
.
append
(
req_idx
+
num_idx
)
input_ids_indices
.
append
(
token_idx
+
num_idx
+
1
)
token_idx
+=
draft_tokens_num
+
1
if
len
(
update_req_indices
)
>
0
:
update_req_indices_tensor
=
async_tensor_h2d
(
update_req_indices
,
torch
.
int32
,
self
.
device
,
True
)
input_ids_indices_tensor
=
async_tensor_h2d
(
input_ids_indices
,
torch
.
int32
,
self
.
device
,
True
)
last_draft_token_ids
=
self
.
last_draft_token_ids
.
flatten
().
to
(
torch
.
int
)
input_ids
[
input_ids_indices_tensor
]
=
last_draft_token_ids
[
update_req_indices_tensor
]
update_req_indices
=
[]
input_ids_indices
=
[]
token_idx
=
0
if
self
.
last_sampled_token_ids
is
not
None
:
sampled_tokens_num
=
self
.
last_sampled_token_ids
.
shape
[
1
]
for
req_id
in
req_ids
:
if
req_id
in
self
.
last_sampled_req_ids
:
req_idx
=
self
.
last_sampled_req_ids
.
index
(
req_id
)
*
sampled_tokens_num
update_req_indices
.
append
(
req_idx
)
input_ids_indices
.
append
(
token_idx
)
token_idx
+=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
if
len
(
update_req_indices
)
>
0
:
update_req_indices_tensor
=
async_tensor_h2d
(
update_req_indices
,
torch
.
int32
,
self
.
device
,
True
)
input_ids_indices_tensor
=
async_tensor_h2d
(
input_ids_indices
,
torch
.
int32
,
self
.
device
,
True
)
last_sampled_token_ids
=
self
.
last_sampled_token_ids
.
flatten
()
for
i
in
range
(
sampled_tokens_num
):
input_ids
[
input_ids_indices_tensor
+
i
]
=
last_sampled_token_ids
[
update_req_indices_tensor
+
i
]
def
propose_draft_token_ids
(
self
,
scheduler_output
:
"SchedulerOutput"
,
num_accepted_tokens_tensor
:
torch
.
Tensor
,
sampled_token_ids
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
hidden_states
:
torch
.
Tensor
,
sample_hidden_states
:
torch
.
Tensor
,
aux_hidden_states
:
Optional
[
torch
.
Tensor
],
spec_decode_metadata
:
Optional
[
SpecDecodeMetadata
],
attn_metadata
:
dict
[
str
,
Any
],
)
->
list
[
list
[
int
]]:
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
if
self
.
speculative_config
.
method
==
"ngram"
:
assert
isinstance
(
self
.
drafter
,
NgramProposer
)
spec_token_ids
=
self
.
propose_ngram_draft_token_ids
(
sampled_token_ids
)
elif
self
.
speculative_config
.
method
==
"medusa"
:
assert
isinstance
(
self
.
drafter
,
MedusaProposer
)
if
sample_hidden_states
.
shape
[
0
]
==
len
(
sampled_token_ids
):
# The input to the target model does not include draft tokens.
hidden_states
=
sample_hidden_states
else
:
indices
=
[]
offset
=
0
for
num_draft
,
tokens
in
zip
(
spec_decode_metadata
.
num_draft_tokens
,
sampled_token_ids
):
indices
.
append
(
offset
+
len
(
tokens
)
-
1
)
offset
+=
num_draft
+
1
indices
=
torch
.
tensor
(
indices
,
device
=
self
.
device
)
hidden_states
=
sample_hidden_states
[
indices
]
spec_token_ids
=
self
.
drafter
.
propose
(
target_hidden_states
=
hidden_states
,
sampling_metadata
=
sampling_metadata
,
)
elif
self
.
speculative_config
.
use_eagle
():
assert
isinstance
(
self
.
drafter
,
EagleProposer
)
# TODO(woosuk): Refactor the loop.
row_indices
=
torch
.
arange
(
sampled_token_ids
.
size
(
0
),
device
=
sampled_token_ids
.
device
)
next_token_ids
=
sampled_token_ids
[
row_indices
,
num_accepted_tokens_tensor
].
flatten
()
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
eagle_attn_metadata
=
attn_metadata
[
self
.
drafter
.
attn_layer_names
[
0
]]
# NOTE: deepseek_mtp uses MLA which does not have `block_table`
if
hasattr
(
eagle_attn_metadata
,
"block_table"
):
block_table
=
eagle_attn_metadata
.
block_table
else
:
block_table
=
None
spec_scheduler_max_num_tokens
=
self
.
spec_scheduler_max_num_tokens
if
spec_decode_metadata
is
None
:
# input_ids can be None for multimodal models.
target_token_ids
=
self
.
input_ids
[:
num_scheduled_tokens
]
# TODO(woosuk): Support M-RoPE.
target_positions
=
self
.
positions
[:
num_scheduled_tokens
]
if
self
.
use_aux_hidden_state_outputs
:
target_hidden_states
=
torch
.
cat
(
[
h
[:
num_scheduled_tokens
]
for
h
in
aux_hidden_states
],
dim
=-
1
)
else
:
target_hidden_states
=
hidden_states
[:
num_scheduled_tokens
]
target_slot_mapping
=
eagle_attn_metadata
.
slot_mapping
cu_num_tokens
=
eagle_attn_metadata
.
query_start_loc
else
:
# TODO(woosuk): Refactor this.
cu_num_tokens
,
token_indices
=
self
.
drafter
.
prepare_inputs
(
eagle_attn_metadata
.
query_start_loc
,
num_accepted_tokens_tensor
,
)
spec_scheduler_max_num_tokens
=
1
target_token_ids
=
self
.
input_ids
[
token_indices
]
# TODO(woosuk): Support M-RoPE.
target_positions
=
self
.
positions
[
token_indices
]
if
self
.
use_aux_hidden_state_outputs
:
target_hidden_states
=
torch
.
cat
(
[
h
[
token_indices
]
for
h
in
aux_hidden_states
],
dim
=-
1
)
else
:
target_hidden_states
=
hidden_states
[
token_indices
]
target_slot_mapping
=
eagle_attn_metadata
.
slot_mapping
[
token_indices
]
self
.
drafter
.
spec_scheduler_max_num_tokens
=
spec_scheduler_max_num_tokens
draft_token_ids
=
self
.
drafter
.
propose
(
target_token_ids
=
target_token_ids
,
target_positions
=
target_positions
,
target_hidden_states
=
target_hidden_states
,
target_slot_mapping
=
target_slot_mapping
,
next_token_ids
=
next_token_ids
,
cu_num_tokens
=
cu_num_tokens
,
block_table
=
block_table
,
sampling_metadata
=
sampling_metadata
,
decoding
=
spec_decode_metadata
is
not
None
,
)
spec_token_ids
=
np
.
ones
(
draft_token_ids
.
shape
,
dtype
=
int
).
tolist
()
self
.
last_draft_token_ids
=
draft_token_ids
self
.
last_draft_host_tokens
=
draft_token_ids
.
to
(
'cpu'
,
non_blocking
=
True
)
self
.
last_draft_event
.
record
()
return
spec_token_ids
@
torch
.
inference_mode
()
def
execute_model
(
self
,
scheduler_output
:
"SchedulerOutput"
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
Union
[
ModelRunnerOutput
,
IntermediateTensors
]:
self
.
_update_states
(
scheduler_output
)
if
not
scheduler_output
.
total_num_scheduled_tokens
:
if
not
has_kv_transfer_group
():
# Return empty ModelRunnerOutput if there's no work to do.
return
EMPTY_MODEL_RUNNER_OUTPUT
return
self
.
kv_connector_no_forward
(
scheduler_output
)
# Prepare the decoder inputs.
(
attn_metadata
,
attention_cuda_graphs
,
logits_indices
,
spec_decode_metadata
,
num_scheduled_tokens_np
)
=
(
self
.
_prepare_inputs
(
scheduler_output
))
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
if
(
self
.
use_cuda_graph
and
num_scheduled_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_input_tokens
=
self
.
vllm_config
.
pad_for_cudagraph
(
num_scheduled_tokens
)
else
:
# Eager mode.
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
tp_size
=
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
if
self
.
compilation_config
.
pass_config
.
\
enable_sequence_parallelism
and
tp_size
>
1
:
num_input_tokens
=
round_up
(
num_scheduled_tokens
,
tp_size
)
else
:
num_input_tokens
=
num_scheduled_tokens
# Padding for DP
num_pad
,
num_tokens_across_dp
=
self
.
get_dp_padding
(
num_input_tokens
)
num_input_tokens
+=
num_pad
# _prepare_inputs may reorder the batch, so we must gather multi
# modal outputs after that to ensure the correct order
if
self
.
is_multimodal_model
:
# Run the multimodal encoder if any.
self
.
_execute_mm_encoder
(
scheduler_output
)
mm_embeds
=
self
.
_gather_mm_embeddings
(
scheduler_output
)
else
:
mm_embeds
=
[]
if
self
.
is_multimodal_model
and
get_pp_group
().
is_first_rank
:
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
input_ids
=
self
.
input_ids
[:
num_scheduled_tokens
]
if
mm_embeds
:
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
input_ids
,
mm_embeds
)
else
:
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
input_ids
)
# TODO(woosuk): Avoid the copy. Optimize.
self
.
inputs_embeds
[:
num_scheduled_tokens
].
copy_
(
inputs_embeds
)
inputs_embeds
=
self
.
inputs_embeds
[:
num_input_tokens
]
input_ids
=
None
else
:
# For text-only models, we use token ids as input.
# While it is possible to use embeddings as input just like the
# multimodal models, it is not desirable for performance since
# then the embedding layer is not included in the CUDA graph.
input_ids
=
self
.
input_ids
[:
num_input_tokens
]
inputs_embeds
=
None
if
self
.
uses_mrope
:
positions
=
self
.
mrope_positions
[:,
:
num_input_tokens
]
else
:
positions
=
self
.
positions
[:
num_input_tokens
]
if
get_pp_group
().
is_first_rank
:
intermediate_tensors
=
None
else
:
intermediate_tensors
=
self
.
sync_and_slice_intermediate_tensors
(
num_input_tokens
,
intermediate_tensors
,
True
)
# Some attention backends only support CUDA Graphs in pure decode.
# If attention doesn't support CUDA Graphs for this batch, but we
# compiled with full CUDA graphs, we have to skip them entirely.
skip_cuda_graphs
=
self
.
full_cuda_graph
and
not
attention_cuda_graphs
if
envs
.
VLLM_ENABLE_TBO
and
(
not
self
.
use_cuda_graph
or
skip_cuda_graphs
):
model_output
,
finished_sending
,
finished_recving
=
\
tbo_split_and_execute_model
(
self
,
attn_metadata
,
num_input_tokens
,
num_tokens_across_dp
,
input_ids
,
positions
,
inputs_embeds
,
scheduler_output
,
intermediate_tensors
,
skip_cuda_graphs
)
else
:
# Run the model.
# Use persistent buffers for CUDA graphs.
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
num_tokens
=
num_input_tokens
,
num_tokens_across_dp
=
num_tokens_across_dp
,
skip_cuda_graphs
=
skip_cuda_graphs
,
):
self
.
maybe_setup_kv_connector
(
scheduler_output
)
model_output
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
)
self
.
maybe_wait_for_kv_save
()
finished_sending
,
finished_recving
=
(
self
.
get_finished_kv_transfers
(
scheduler_output
))
if
self
.
use_aux_hidden_state_outputs
:
hidden_states
,
aux_hidden_states
=
model_output
else
:
hidden_states
=
model_output
aux_hidden_states
=
None
# Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks
# TODO: Support overlapping mirco-batches
# https://github.com/vllm-project/vllm/issues/18019
broadcast_pp_output
=
\
self
.
parallel_config
.
distributed_executor_backend
\
==
"external_launcher"
and
len
(
get_pp_group
().
ranks
)
>
0
if
not
get_pp_group
().
is_last_rank
:
# For mid-pipeline stages, return the hidden states.
if
not
broadcast_pp_output
:
return
hidden_states
assert
isinstance
(
hidden_states
,
IntermediateTensors
)
get_pp_group
().
send_tensor_dict
(
hidden_states
.
tensors
,
all_gather_group
=
get_tp_group
())
logits
=
None
else
:
if
self
.
input_batch
.
pooling_params
:
return
self
.
_pool
(
hidden_states
,
num_scheduled_tokens
,
num_scheduled_tokens_np
,
finished_sending
,
finished_recving
)
sample_hidden_states
=
hidden_states
[
logits_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
if
broadcast_pp_output
:
model_output_broadcast_data
=
{
"logits"
:
logits
.
contiguous
(),
}
if
logits
is
not
None
else
{}
model_output_broadcast_data
=
get_pp_group
().
broadcast_tensor_dict
(
model_output_broadcast_data
,
src
=
len
(
get_pp_group
().
ranks
)
-
1
)
assert
model_output_broadcast_data
is
not
None
logits
=
model_output_broadcast_data
[
"logits"
]
# Apply structured output bitmasks if present
if
scheduler_output
.
grammar_bitmask
is
not
None
:
self
.
apply_grammar_bitmask
(
scheduler_output
,
logits
)
# Sample the next token and get logprobs if needed.
sampling_metadata
=
self
.
input_batch
.
sampling_metadata
if
spec_decode_metadata
is
None
:
sampler_output
=
self
.
sampler
(
logits
=
logits
,
sampling_metadata
=
sampling_metadata
,
)
else
:
# When indexing with a tensor (bonus_logits_indices), PyTorch
# creates a new tensor with separate storage from the original
# logits tensor. This means any in-place operations on bonus_logits
# won't affect the original logits tensor.
assert
logits
is
not
None
bonus_logits
=
logits
[
spec_decode_metadata
.
bonus_logits_indices
]
sampler_output
=
self
.
sampler
(
logits
=
bonus_logits
,
sampling_metadata
=
sampling_metadata
,
)
bonus_token_ids
=
sampler_output
.
sampled_token_ids
# Just like `bonus_logits`, `target_logits` is a new tensor with
# separate storage from the original `logits` tensor. Therefore,
# it is safe to update `target_logits` in place.
target_logits
=
logits
[
spec_decode_metadata
.
target_logits_indices
]
output_token_ids
=
self
.
rejection_sampler
(
spec_decode_metadata
,
None
,
# draft_probs
target_logits
,
bonus_token_ids
,
sampling_metadata
,
)
sampler_output
.
sampled_token_ids
=
output_token_ids
num_nans_in_logits
=
{}
if
envs
.
VLLM_COMPUTE_NANS_IN_LOGITS
:
num_nans_in_logits
=
self
.
_get_nans_in_logits
(
logits
)
# TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize.
discard_sampled_tokens_req_indices
=
[]
for
i
,
req_id
in
enumerate
(
self
.
input_batch
.
req_ids
):
req_state
=
self
.
requests
[
req_id
]
seq_len
=
(
req_state
.
num_computed_tokens
+
scheduler_output
.
num_scheduled_tokens
[
req_id
])
if
seq_len
<
req_state
.
num_tokens
:
# Ignore the sampled token for partial prefills.
# Rewind the generator state as if the token was not sampled.
# This relies on cuda-specific torch-internal impl details
generator
=
self
.
input_batch
.
generators
.
get
(
i
)
if
generator
is
not
None
:
generator
.
set_offset
(
generator
.
get_offset
()
-
4
)
# Record the index of the request that should not be sampled,
# so that we could clear the sampled tokens before returning.
discard_sampled_tokens_req_indices
.
append
(
i
)
# NOTE: GPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point.
logprobs_tensors
=
sampler_output
.
logprobs_tensors
logprobs_lists
=
logprobs_tensors
.
tolists
()
\
if
logprobs_tensors
is
not
None
else
None
# Compute prompt logprobs if needed.
prompt_logprobs_dict
=
self
.
_get_prompt_logprobs_dict
(
hidden_states
[:
num_scheduled_tokens
],
scheduler_output
,
)
fix_req_ids
=
None
fix_sampled_token_ids
=
None
fix_draft_token_ids
=
None
fix_draft_req_ids
=
self
.
last_sampled_req_ids
is_output_valid
=
False
# Get the valid generated tokens.
sampled_token_ids
=
sampler_output
.
sampled_token_ids
max_gen_len
=
sampled_token_ids
.
shape
[
-
1
]
if
not
self
.
speculative_config
:
# Speculative decoding is not enabled.
spec_token_ids
=
None
fix_draft_req_ids
=
None
else
:
sampled_token_ids_cpu
=
sampled_token_ids
.
to
(
'cpu'
,
non_blocking
=
True
)
self
.
spec_sampler_event
.
record
()
if
self
.
last_draft_host_tokens
is
not
None
:
self
.
last_draft_event
.
synchronize
()
fix_draft_token_ids
=
self
.
last_draft_host_tokens
.
tolist
()
mask
=
(
sampled_token_ids
==
-
1
)
mask_int
=
mask
.
int
()
first_neg_one_indices
=
torch
.
argmax
(
mask_int
,
dim
=
1
)
num_accepted_tokens_tensor
=
torch
.
where
(
torch
.
any
(
mask
,
dim
=
1
),
first_neg_one_indices
,
sampled_token_ids
.
size
(
1
))
-
1
spec_token_ids
=
self
.
propose_draft_token_ids
(
scheduler_output
,
num_accepted_tokens_tensor
,
sampled_token_ids
,
sampling_metadata
,
hidden_states
,
sample_hidden_states
,
aux_hidden_states
,
spec_decode_metadata
,
attn_metadata
,
)
if
self
.
speculative_config
:
self
.
spec_sampler_event
.
synchronize
()
if
max_gen_len
==
1
:
valid_sampled_token_ids
=
sampled_token_ids_cpu
.
tolist
()
else
:
# Includes spec decode tokens.
valid_sampled_token_ids
=
self
.
rejection_sampler
.
parse_output
(
sampled_token_ids_cpu
,
self
.
input_batch
.
vocab_size
,
)
self
.
last_sampler_host_tokens
=
None
self
.
last_sampled_token_ids
=
None
is_output_valid
=
True
else
:
# No spec decode tokens.
fix_req_ids
=
self
.
last_sampled_req_ids
if
self
.
last_sampler_host_tokens
!=
None
:
self
.
last_sampler_event
.
synchronize
()
fix_sampled_token_ids
=
self
.
last_sampler_host_tokens
.
tolist
()
for
req_idx
,
start_idx
,
end_idx
in
self
.
token_ids_cpu_fix_record
:
if
start_idx
==
-
1
:
continue
req_id
=
fix_req_ids
[
req_idx
]
if
req_id
in
self
.
input_batch
.
req_ids
:
new_req_idx
=
self
.
input_batch
.
req_ids
.
index
(
req_id
)
self
.
input_batch
.
token_ids_cpu
[
new_req_idx
,
start_idx
:
end_idx
]
=
fix_sampled_token_ids
[
req_idx
]
for
req_idx
,
req_id
in
enumerate
(
fix_req_ids
):
if
req_id
in
self
.
requests
:
req_state
=
self
.
requests
[
req_id
]
token_idx
=
self
.
last_sampled_token_lens
[
req_idx
]
if
token_idx
==
-
1
:
continue
fix_len
=
len
(
fix_sampled_token_ids
[
req_idx
])
req_state
.
output_token_ids
[
token_idx
:
token_idx
+
fix_len
]
=
fix_sampled_token_ids
[
req_idx
]
self
.
last_sampler_host_tokens
=
sampled_token_ids
.
to
(
'cpu'
,
non_blocking
=
True
)
self
.
last_sampler_event
.
record
()
self
.
last_sampled_token_ids
=
sampled_token_ids
valid_sampled_token_ids
=
np
.
ones
(
sampled_token_ids
.
shape
,
dtype
=
int
).
tolist
()
# Mask out the sampled tokens that should not be sampled.
for
i
in
discard_sampled_tokens_req_indices
:
valid_sampled_token_ids
[
i
].
clear
()
# Cache the sampled tokens in the model runner, so that the scheduler
# doesn't need to send them back.
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
# the sampled tokens back, because there's no direct communication
# between the first-stage worker and the last-stage worker.
self
.
token_ids_cpu_fix_record
.
clear
()
self
.
last_sampled_req_ids
=
[]
self
.
last_sampled_token_lens
=
[]
for
req_idx
,
sampled_ids
in
enumerate
(
valid_sampled_token_ids
):
req_id
=
self
.
input_batch
.
req_ids
[
req_idx
]
self
.
last_sampled_req_ids
.
append
(
req_id
)
cache_output_len
=
-
1
if
not
sampled_ids
:
self
.
last_sampled_token_lens
.
append
(
-
1
)
self
.
token_ids_cpu_fix_record
.
append
([
req_idx
,
-
1
,
-
1
])
continue
start_idx
=
self
.
input_batch
.
num_tokens_no_spec
[
req_idx
]
end_idx
=
start_idx
+
len
(
sampled_ids
)
assert
end_idx
<=
self
.
max_model_len
,
(
"Sampled token IDs exceed the max model length. "
f
"Total number of tokens:
{
end_idx
}
> max_model_len: "
f
"
{
self
.
max_model_len
}
"
)
self
.
input_batch
.
token_ids_cpu
[
req_idx
,
start_idx
:
end_idx
]
=
sampled_ids
self
.
token_ids_cpu_fix_record
.
append
([
req_idx
,
start_idx
,
end_idx
])
self
.
input_batch
.
num_tokens_no_spec
[
req_idx
]
=
end_idx
self
.
input_batch
.
num_tokens
[
req_idx
]
=
end_idx
if
req_id
in
self
.
requests
:
req_state
=
self
.
requests
[
req_id
]
cache_output_len
=
len
(
req_state
.
output_token_ids
)
req_state
.
output_token_ids
.
extend
(
sampled_ids
)
self
.
last_sampled_token_lens
.
append
(
cache_output_len
)
# Clear KVConnector state after all KVs are generated.
if
has_kv_transfer_group
():
get_kv_transfer_group
().
clear_connector_metadata
()
self
.
eplb_step
()
model_output
=
ZeroV1ModelRunnerOutput
(
req_ids
=
self
.
input_batch
.
req_ids
,
req_id_to_index
=
self
.
input_batch
.
req_id_to_index
,
sampled_token_ids
=
valid_sampled_token_ids
,
spec_token_ids
=
spec_token_ids
,
logprobs
=
logprobs_lists
,
prompt_logprobs_dict
=
prompt_logprobs_dict
,
pooler_output
=
[],
finished_sending
=
finished_sending
,
finished_recving
=
finished_recving
,
num_nans_in_logits
=
num_nans_in_logits
,
fix_req_ids
=
fix_req_ids
,
fix_sampled_token_ids
=
fix_sampled_token_ids
,
fix_draft_tokens_ids
=
fix_draft_token_ids
,
fix_draft_req_ids
=
fix_draft_req_ids
,
is_output_valid
=
is_output_valid
)
return
model_output
\ No newline at end of file
vllm/zero_overhead/v1/PP2mtp/gpu_worker.py
0 → 100644
View file @
a183111e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A GPU worker class."""
import
gc
import
os
from
typing
import
TYPE_CHECKING
,
Optional
import
torch
import
torch.distributed
import
torch.nn
as
nn
import
vllm.envs
as
envs
from
vllm.config
import
VllmConfig
from
vllm.device_allocator.cumem
import
CuMemAllocator
from
vllm.distributed
import
(
ensure_model_parallel_initialized
,
init_distributed_environment
,
set_custom_all_reduce
)
from
vllm.distributed.kv_transfer
import
ensure_kv_transfer_initialized
from
vllm.distributed.parallel_state
import
get_pp_group
,
get_tp_group
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor
import
set_random_seed
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
GiB_bytes
,
MemorySnapshot
,
memory_profiling
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
,
KVCacheSpec
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.utils
import
report_usage_stats
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
from
vllm.v1.worker.worker_base
import
WorkerBase
from
vllm.zero_overhead.utils
import
zero_overhead_stream
from
vllm.zero_overhead.v1.gpu_model_runner
import
V1ZeroModelRunner
logger
=
init_logger
(
__name__
)
if
TYPE_CHECKING
:
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
from
vllm.v1.core.sched.output
import
SchedulerOutput
class
Worker
(
WorkerBase
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
local_rank
:
int
,
rank
:
int
,
distributed_init_method
:
str
,
is_driver_worker
:
bool
=
False
,
):
super
().
__init__
(
vllm_config
=
vllm_config
,
local_rank
=
local_rank
,
rank
=
rank
,
distributed_init_method
=
distributed_init_method
,
is_driver_worker
=
is_driver_worker
)
if
self
.
model_config
.
trust_remote_code
:
# note: lazy import to avoid importing torch before initializing
from
vllm.utils
import
init_cached_hf_modules
init_cached_hf_modules
()
# Buffers saved before sleep
self
.
_sleep_saved_buffers
:
dict
[
str
,
torch
.
Tensor
]
=
{}
# Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
if
envs
.
VLLM_TORCH_PROFILER_DIR
:
torch_profiler_trace_dir
=
envs
.
VLLM_TORCH_PROFILER_DIR
logger
.
info
(
"Profiling enabled. Traces will be saved to: %s"
,
torch_profiler_trace_dir
)
self
.
profiler
=
torch
.
profiler
.
profile
(
activities
=
[
torch
.
profiler
.
ProfilerActivity
.
CPU
,
torch
.
profiler
.
ProfilerActivity
.
CUDA
,
],
with_stack
=
True
,
on_trace_ready
=
torch
.
profiler
.
tensorboard_trace_handler
(
torch_profiler_trace_dir
,
use_gzip
=
True
))
else
:
self
.
profiler
=
None
def
sleep
(
self
,
level
:
int
=
1
)
->
None
:
free_bytes_before_sleep
=
torch
.
cuda
.
mem_get_info
()[
0
]
# Save the buffers before level 2 sleep
if
level
==
2
:
model
=
self
.
model_runner
.
model
self
.
_sleep_saved_buffers
=
{
name
:
buffer
.
cpu
().
clone
()
for
name
,
buffer
in
model
.
named_buffers
()
}
allocator
=
CuMemAllocator
.
get_instance
()
allocator
.
sleep
(
offload_tags
=
(
"weights"
,
)
if
level
==
1
else
tuple
())
free_bytes_after_sleep
,
total
=
torch
.
cuda
.
mem_get_info
()
freed_bytes
=
free_bytes_after_sleep
-
free_bytes_before_sleep
used_bytes
=
total
-
free_bytes_after_sleep
assert
freed_bytes
>=
0
,
"Memory usage increased after sleeping."
logger
.
info
(
"Sleep mode freed %.2f GiB memory, "
"%.2f GiB memory is still in use."
,
freed_bytes
/
GiB_bytes
,
used_bytes
/
GiB_bytes
)
def
wake_up
(
self
,
tags
:
Optional
[
list
[
str
]]
=
None
)
->
None
:
allocator
=
CuMemAllocator
.
get_instance
()
allocator
.
wake_up
(
tags
)
# Restore the buffers after level 2 sleep
if
len
(
self
.
_sleep_saved_buffers
):
model
=
self
.
model_runner
.
model
for
name
,
buffer
in
model
.
named_buffers
():
if
name
in
self
.
_sleep_saved_buffers
:
buffer
.
data
.
copy_
(
self
.
_sleep_saved_buffers
[
name
].
data
)
self
.
_sleep_saved_buffers
=
{}
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
)
->
None
:
self
.
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
self
.
cache_config
.
num_cpu_blocks
=
num_cpu_blocks
def
init_device
(
self
):
if
self
.
device_config
.
device
.
type
==
"cuda"
:
# torch.distributed.all_reduce does not free the input tensor until
# the synchronization point. This causes the memory usage to grow
# as the number of all_reduce calls increases. This env var disables
# this behavior.
# Related issue:
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
os
.
environ
[
"TORCH_NCCL_AVOID_RECORD_STREAMS"
]
=
"1"
# This env var set by Ray causes exceptions with graph building.
os
.
environ
.
pop
(
"NCCL_ASYNC_ERROR_HANDLING"
,
None
)
self
.
device
=
torch
.
device
(
f
"cuda:
{
self
.
local_rank
}
"
)
torch
.
cuda
.
set_device
(
self
.
device
)
_check_if_gpu_supports_dtype
(
self
.
model_config
.
dtype
)
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
# take current memory snapshot
self
.
init_snapshot
=
MemorySnapshot
()
self
.
requested_memory
=
(
self
.
init_snapshot
.
total_memory
*
self
.
cache_config
.
gpu_memory_utilization
)
if
self
.
init_snapshot
.
free_memory
<
self
.
requested_memory
:
GiB
=
lambda
b
:
round
(
b
/
GiB_bytes
,
2
)
raise
ValueError
(
f
"Free memory on device "
f
"(
{
GiB
(
self
.
init_snapshot
.
free_memory
)
}
/"
f
"
{
GiB
(
self
.
init_snapshot
.
total_memory
)
}
GiB) on startup "
f
"is less than desired GPU memory utilization "
f
"(
{
self
.
cache_config
.
gpu_memory_utilization
}
, "
f
"
{
GiB
(
self
.
requested_memory
)
}
GiB). Decrease GPU memory "
f
"utilization or reduce GPU memory used by other processes."
)
else
:
raise
RuntimeError
(
f
"Not support device type:
{
self
.
device_config
.
device
}
"
)
# Initialize the distributed environment.
init_worker_distributed_environment
(
self
.
vllm_config
,
self
.
rank
,
self
.
distributed_init_method
,
self
.
local_rank
)
# Set random seed.
set_random_seed
(
self
.
model_config
.
seed
)
# Construct the model runner
if
envs
.
VLLM_ZERO_OVERHEAD
:
logger
.
info
(
'use zero overhead model_runner'
)
self
.
model_runner
:
GPUModelRunner
=
V1ZeroModelRunner
(
self
.
vllm_config
,
self
.
device
)
else
:
self
.
model_runner
:
GPUModelRunner
=
GPUModelRunner
(
self
.
vllm_config
,
self
.
device
)
if
self
.
rank
==
0
:
# If usage stat is enabled, collect relevant info.
report_usage_stats
(
self
.
vllm_config
)
# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
# to hijack tensor allocation.
def
load_model
(
self
)
->
None
:
if
self
.
vllm_config
.
model_config
.
enable_sleep_mode
:
allocator
=
CuMemAllocator
.
get_instance
()
assert
allocator
.
get_current_usage
()
==
0
,
(
"Sleep mode can only be "
"used for one instance per process."
)
context
=
allocator
.
use_memory_pool
(
tag
=
"weights"
)
else
:
from
contextlib
import
nullcontext
context
=
nullcontext
()
with
context
:
self
.
model_runner
.
load_model
()
@
torch
.
inference_mode
()
def
determine_available_memory
(
self
)
->
int
:
"""Profiles the peak memory usage of the model to determine how much
memory can be used for KV cache without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the free memory that can be used for KV cache in
bytes.
Tip:
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
reset_peak_memory_stats
()
GiB
=
lambda
b
:
b
/
GiB_bytes
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
with
memory_profiling
(
self
.
init_snapshot
,
weights_memory
=
int
(
self
.
model_runner
.
model_memory_usage
))
as
profile_result
:
self
.
model_runner
.
profile_run
()
free_gpu_memory
=
profile_result
.
after_profile
.
free_memory
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
assert
self
.
init_snapshot
.
free_memory
>
free_gpu_memory
,
(
"Error in memory profiling. "
f
"Initial free memory
{
GiB
(
self
.
init_snapshot
.
free_memory
)
}
GiB, "
f
"current free memory
{
GiB
(
free_gpu_memory
)
}
GiB. "
"This happens when other processes sharing the same container "
"release GPU memory while vLLM is profiling during initialization. "
"To fix this, ensure consistent GPU memory allocation or "
"isolate vLLM in its own container."
)
available_kv_cache_memory
=
self
.
requested_memory
\
-
profile_result
.
non_kv_cache_memory
logger
.
debug
(
"Initial free memory: %.2f GiB, free memory: %.2f GiB, "
"requested GPU memory: %.2f GiB"
,
GiB
(
self
.
init_snapshot
.
free_memory
),
GiB
(
free_gpu_memory
),
GiB
(
self
.
requested_memory
))
logger
.
debug
(
profile_result
)
logger
.
info
(
"Available KV cache memory: %.2f GiB"
,
GiB
(
available_kv_cache_memory
))
gc
.
collect
()
return
int
(
available_kv_cache_memory
)
def
get_kv_cache_spec
(
self
)
->
dict
[
str
,
KVCacheSpec
]:
return
self
.
model_runner
.
get_kv_cache_spec
()
def
initialize_from_config
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
"""Allocate GPU KV cache with the specified kv_cache_config."""
if
self
.
vllm_config
.
model_config
.
enable_sleep_mode
:
allocator
=
CuMemAllocator
.
get_instance
()
context
=
allocator
.
use_memory_pool
(
tag
=
"kv_cache"
)
else
:
from
contextlib
import
nullcontext
context
=
nullcontext
()
with
context
:
self
.
model_runner
.
initialize_kv_cache
(
kv_cache_config
)
def
compile_or_warm_up_model
(
self
)
->
None
:
# warm up sizes that are not in cudagraph capture sizes,
# but users still want to compile for better performance,
# e.g. for the max-num-batched token size in chunked prefill.
warmup_sizes
=
self
.
vllm_config
.
compilation_config
.
compile_sizes
.
copy
()
if
not
self
.
model_config
.
enforce_eager
:
warmup_sizes
=
[
x
for
x
in
warmup_sizes
if
x
not
in
self
.
vllm_config
.
compilation_config
.
cudagraph_capture_sizes
]
# We skip EPLB here since we don't want to record dummy metrics
for
size
in
sorted
(
warmup_sizes
,
reverse
=
True
):
logger
.
info
(
"Compile and warming up model for size %d"
,
size
)
self
.
model_runner
.
_dummy_run
(
size
,
skip_eplb
=
True
)
if
not
self
.
model_config
.
enforce_eager
:
self
.
model_runner
.
capture_model
()
# Warm up sampler and preallocate memory buffer for logits and other
# sampling related tensors of max possible shape to avoid memory
# fragmentation issue.
# NOTE: This is called after `capture_model` on purpose to prevent
# memory buffers from being cleared by `torch.cuda.empty_cache`.
if
get_pp_group
().
is_last_rank
:
max_num_reqs
=
min
(
self
.
scheduler_config
.
max_num_seqs
,
self
.
scheduler_config
.
max_num_batched_tokens
)
# We skip EPLB here since we don't want to record dummy metrics
hidden_states
,
last_hidden_states
=
\
self
.
model_runner
.
_dummy_run
(
num_tokens
=
max_num_reqs
,
skip_eplb
=
True
,
)
if
self
.
model_runner
.
is_pooling_model
:
self
.
model_runner
.
_dummy_pooler_run
(
hidden_states
)
else
:
self
.
model_runner
.
_dummy_sampler_run
(
hidden_states
=
last_hidden_states
)
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed
(
self
.
model_config
.
seed
)
def
get_model
(
self
)
->
nn
.
Module
:
return
self
.
model_runner
.
get_model
()
@
torch
.
inference_mode
()
def
execute_model
(
self
,
scheduler_output
:
"SchedulerOutput"
,
)
->
Optional
[
ModelRunnerOutput
]:
intermediate_tensors
=
None
if
not
get_pp_group
().
is_first_rank
:
intermediate_tensors
=
IntermediateTensors
(
get_pp_group
().
recv_tensor_dict
(
all_gather_group
=
get_tp_group
()))
if
envs
.
VLLM_ZERO_OVERHEAD
:
use_stream
=
zero_overhead_stream
(
self
.
device
)
with
torch
.
cuda
.
stream
(
use_stream
):
output
=
self
.
model_runner
.
execute_model
(
scheduler_output
,
intermediate_tensors
)
else
:
output
=
self
.
model_runner
.
execute_model
(
scheduler_output
,
intermediate_tensors
)
parallel_config
=
self
.
vllm_config
.
parallel_config
if
parallel_config
.
distributed_executor_backend
!=
"external_launcher"
\
and
not
get_pp_group
().
is_last_rank
:
assert
isinstance
(
output
,
IntermediateTensors
)
get_pp_group
().
send_tensor_dict
(
output
.
tensors
,
all_gather_group
=
get_tp_group
())
return
None
assert
isinstance
(
output
,
ModelRunnerOutput
)
return
output
if
self
.
is_driver_worker
else
None
def
profile
(
self
,
is_start
:
bool
=
True
):
if
self
.
profiler
is
None
:
raise
RuntimeError
(
"Profiler is not enabled."
)
if
is_start
:
self
.
profiler
.
start
()
else
:
self
.
profiler
.
stop
()
print
(
self
.
profiler
.
key_averages
().
table
(
sort_by
=
"self_cuda_time_total"
))
def
execute_dummy_batch
(
self
)
->
None
:
self
.
model_runner
.
_dummy_run
(
1
)
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
return
self
.
model_runner
.
add_lora
(
lora_request
)
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
model_runner
.
remove_lora
(
lora_id
)
def
list_loras
(
self
)
->
set
[
int
]:
return
self
.
model_runner
.
list_loras
()
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
model_runner
.
pin_lora
(
lora_id
)
def
check_health
(
self
)
->
None
:
# worker will always be healthy as long as it's running.
return
def
save_sharded_state
(
self
,
path
:
str
,
pattern
:
Optional
[
str
]
=
None
,
max_size
:
Optional
[
int
]
=
None
,
)
->
None
:
from
vllm.model_executor.model_loader
import
ShardedStateLoader
ShardedStateLoader
.
save_model
(
self
.
model_runner
.
model
,
path
,
pattern
=
pattern
,
max_size
=
max_size
,
)
def
save_tensorized_model
(
self
,
tensorizer_config
:
"TensorizerConfig"
,
)
->
None
:
self
.
model_runner
.
save_tensorized_model
(
tensorizer_config
=
tensorizer_config
,
)
def
init_worker_distributed_environment
(
vllm_config
:
VllmConfig
,
rank
:
int
,
distributed_init_method
:
Optional
[
str
]
=
None
,
local_rank
:
int
=
-
1
,
backend
:
str
=
"nccl"
,
)
->
None
:
"""Initialize the distributed environment."""
parallel_config
=
vllm_config
.
parallel_config
set_custom_all_reduce
(
not
parallel_config
.
disable_custom_all_reduce
)
init_distributed_environment
(
parallel_config
.
world_size
,
rank
,
distributed_init_method
,
local_rank
,
backend
)
ensure_model_parallel_initialized
(
parallel_config
.
tensor_parallel_size
,
parallel_config
.
pipeline_parallel_size
)
ensure_kv_transfer_initialized
(
vllm_config
)
def
_check_if_gpu_supports_dtype
(
torch_dtype
:
torch
.
dtype
):
# Check if the GPU supports the dtype.
if
torch_dtype
==
torch
.
bfloat16
:
# noqa: SIM102
if
not
current_platform
.
has_device_capability
(
80
):
capability
=
current_platform
.
get_device_capability
()
gpu_name
=
current_platform
.
get_device_name
()
if
capability
is
None
:
compute_str
=
"does not have a compute capability"
else
:
version_str
=
capability
.
as_version_str
()
compute_str
=
f
"has compute capability
{
version_str
}
"
raise
ValueError
(
"Bfloat16 is only supported on GPUs with compute capability "
f
"of at least 8.0. Your
{
gpu_name
}
GPU
{
compute_str
}
. "
"You can use float16 instead by explicitly setting the "
"`dtype` flag in CLI, for example: --dtype=half."
)
vllm/zero_overhead/v1/PP2mtp/outputs.py
0 → 100644
View file @
a183111e
from
dataclasses
import
dataclass
from
vllm.v1.outputs
import
ModelRunnerOutput
@
dataclass
class
ZeroV1ModelRunnerOutput
(
ModelRunnerOutput
):
# [num_reqs]
fix_req_ids
:
list
[
str
]
=
None
fix_sampled_token_ids
:
list
[
list
[
int
]]
=
None
fix_draft_req_ids
:
list
[
str
]
=
None
fix_draft_tokens_ids
:
list
[
list
[
int
]]
=
None
is_output_valid
:
bool
=
True
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment