Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7d867671
Commit
7d867671
authored
Mar 18, 2025
by
lizhigong
Browse files
fix llm_engine to zero_overhead
parent
08c2298a
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
142 additions
and
56 deletions
+142
-56
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+52
-25
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+1
-0
vllm/model_executor/layers/ops/update_input.py
vllm/model_executor/layers/ops/update_input.py
+6
-1
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+38
-21
vllm/sequence.py
vllm/sequence.py
+11
-2
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+29
-4
vllm/worker/model_runner_base.py
vllm/worker/model_runner_base.py
+2
-1
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+3
-2
No files found.
vllm/engine/llm_engine.py
View file @
7d867671
...
@@ -1233,6 +1233,27 @@ class LLMEngine:
...
@@ -1233,6 +1233,27 @@ class LLMEngine:
return
None
return
None
def
_fix_last_step
(
self
,
output
:
List
[
SamplerOutput
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
scheduled_seq_groups
:
List
[
ScheduledSequenceGroup
])
->
None
:
sample_out_list
=
output
[
0
].
sampler_out_tenosr
.
cpu
().
tolist
()
for
seq_group_metadata
,
sequence_group_outputs
,
scheduled_seq_group
,
token_id
in
\
zip
(
seq_group_metadata_list
,
output
[
0
],
scheduled_seq_groups
,
sample_out_list
):
seq_group
=
scheduled_seq_group
.
seq_group
if
seq_group
.
is_finished
():
continue
if
seq_group_metadata
.
do_sample
:
sample
=
sequence_group_outputs
.
samples
[
0
]
assert
len
(
seq_group
.
seqs
)
==
1
seq
=
seq_group
.
seqs
[
0
]
sample
.
output_token
=
token_id
[
0
]
seq
.
fix_last_token_id
(
sample
.
output_token
)
def
_advance_to_next_step
(
def
_advance_to_next_step
(
self
,
output
:
List
[
SamplerOutput
],
self
,
output
:
List
[
SamplerOutput
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
...
@@ -1386,9 +1407,14 @@ class LLMEngine:
...
@@ -1386,9 +1407,14 @@ class LLMEngine:
assert
scheduler_outputs
is
not
None
assert
scheduler_outputs
is
not
None
profile
.
ProfRangeAutoPush
(
'execute_model'
)
profile
.
ProfRangeAutoPush
(
'execute_model'
)
last_outputs
=
None
last_outputs_ids
=
None
last_outputs_tensor
=
None
if
self
.
zero_overhead
:
if
self
.
zero_overhead
:
last_outputs
=
self
.
trans_last_output_tensor
(
self
.
output_recorder
[
self
.
step_switch
])
recode_output
=
self
.
output_recorder
[
self
.
step_switch
]
if
recode_output
is
not
None
:
last_output
=
recode_output
[
0
][
0
]
last_outputs_ids
,
last_outputs_tensor
=
last_output
.
sampler_out_ids
,
last_output
.
sampler_out_tenosr
self
.
output_recorder
[
self
.
step_switch
]
=
None
# only use for once
if
not
scheduler_outputs
.
is_empty
():
if
not
scheduler_outputs
.
is_empty
():
# Check if we have a cached last_output from the previous iteration.
# Check if we have a cached last_output from the previous iteration.
...
@@ -1398,14 +1424,6 @@ class LLMEngine:
...
@@ -1398,14 +1424,6 @@ class LLMEngine:
last_sampled_token_ids
=
\
last_sampled_token_ids
=
\
self
.
_get_last_sampled_token_ids
(
virtual_engine
)
self
.
_get_last_sampled_token_ids
(
virtual_engine
)
# print('seq_group_metadata_list', len(seq_group_metadata_list))
# print('scheduler_outputs.blocks_to_swap_in', len(scheduler_outputs.blocks_to_swap_in))
# print('scheduler_outputs.num_lookahead_slots', scheduler_outputs.num_lookahead_slots)
# print('scheduler_outputs.running_queue_size', scheduler_outputs.running_queue_size)
# print('finished_requests_ids', len(finished_requests_ids))
# print('last_sampled_token_ids', last_sampled_token_ids)
# print('self.model_executor', type(self.model_executor))
execute_model_req
=
ExecuteModelRequest
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
scheduler_outputs
.
blocks_to_swap_in
,
blocks_to_swap_in
=
scheduler_outputs
.
blocks_to_swap_in
,
...
@@ -1417,7 +1435,8 @@ class LLMEngine:
...
@@ -1417,7 +1435,8 @@ class LLMEngine:
# We use ExecuteModelRequest to pass the last sampled_token_ids
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids
=
last_sampled_token_ids
,
last_sampled_token_ids
=
last_sampled_token_ids
,
last_outputs
=
last_outputs
)
last_outputs_ids
=
last_outputs_ids
,
last_outputs_sample
=
last_outputs_tensor
)
if
allow_async_output_proc
:
if
allow_async_output_proc
:
execute_model_req
.
async_callback
=
self
.
async_callbacks
[
execute_model_req
.
async_callback
=
self
.
async_callbacks
[
virtual_engine
]
virtual_engine
]
...
@@ -1437,19 +1456,28 @@ class LLMEngine:
...
@@ -1437,19 +1456,28 @@ class LLMEngine:
# No outputs in this case
# No outputs in this case
outputs
=
[]
outputs
=
[]
# Finish the current step for all the sequence groups.
if
self
.
scheduler_config
.
is_multi_step
:
for
seq_group
in
seq_group_metadata_list
:
seq_group
.
finish_step
()
if
self
.
zero_overhead
:
if
self
.
zero_overhead
:
self
.
output_recorder
[
self
.
step_switch
]
=
outputs
self
.
output_recorder
[
self
.
step_switch
]
=
[
outputs
,
seq_group_metadata_list
,
scheduler_outputs
]
self
.
_advance_to_next_step
(
outputs
[
0
],
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
self
.
step_switch
=
1
-
self
.
step_switch
self
.
step_switch
=
1
-
self
.
step_switch
outputs
=
self
.
output_recorder
[
self
.
step_switch
]
if
outputs
is
None
:
recode_output
=
self
.
output_recorder
[
self
.
step_switch
]
if
recode_output
is
None
:
return
None
return
None
#同步上一次的output
outputs
,
seq_group_metadata_list
,
scheduler_outputs
=
self
.
output_recorder
[
self
.
step_switch
]
ctx
.
seq_group_metadata_list
=
seq_group_metadata_list
ctx
.
scheduler_outputs
=
scheduler_outputs
self
.
_fix_last_step
(
outputs
,
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
# Finish the current step for all the sequence groups.
if
self
.
scheduler_config
.
is_multi_step
:
for
seq_group
in
seq_group_metadata_list
:
seq_group
.
finish_step
()
if
not
self
.
_has_remaining_steps
(
seq_group_metadata_list
):
if
not
self
.
_has_remaining_steps
(
seq_group_metadata_list
):
# clear the cache if we have finished all the steps.
# clear the cache if we have finished all the steps.
...
@@ -1473,10 +1501,10 @@ class LLMEngine:
...
@@ -1473,10 +1501,10 @@ class LLMEngine:
if
outputs
and
allow_async_output_proc
:
if
outputs
and
allow_async_output_proc
:
assert
len
(
outputs
)
==
1
,
(
assert
len
(
outputs
)
==
1
,
(
"Async postprocessor expects only a single output set"
)
"Async postprocessor expects only a single output set"
)
if
not
self
.
zero_overhead
:
self
.
_advance_to_next_step
(
self
.
_advance_to_next_step
(
outputs
[
0
],
seq_group_metadata_list
,
outputs
[
0
],
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
scheduler_outputs
.
scheduled_seq_groups
)
# Check if need to run the usual non-async path
# Check if need to run the usual non-async path
if
not
allow_async_output_proc
:
if
not
allow_async_output_proc
:
...
@@ -1505,7 +1533,6 @@ class LLMEngine:
...
@@ -1505,7 +1533,6 @@ class LLMEngine:
# queued control plane messages, such as add/remove lora adapters.
# queued control plane messages, such as add/remove lora adapters.
logger
.
debug
(
"Stopping remote worker execution loop."
)
logger
.
debug
(
"Stopping remote worker execution loop."
)
self
.
model_executor
.
stop_remote_worker_execution_loop
()
self
.
model_executor
.
stop_remote_worker_execution_loop
()
return
ctx
.
request_outputs
return
ctx
.
request_outputs
def
_has_remaining_steps
(
def
_has_remaining_steps
(
...
...
vllm/entrypoints/llm.py
View file @
7d867671
...
@@ -1388,6 +1388,7 @@ class LLM:
...
@@ -1388,6 +1388,7 @@ class LLM:
total_out_toks
=
0
total_out_toks
=
0
while
self
.
llm_engine
.
has_unfinished_requests
():
while
self
.
llm_engine
.
has_unfinished_requests
():
step_outputs
=
self
.
llm_engine
.
step
()
step_outputs
=
self
.
llm_engine
.
step
()
print
(
'###step_outputs'
,
step_outputs
)
if
step_outputs
is
None
:
if
step_outputs
is
None
:
continue
continue
for
output
in
step_outputs
:
for
output
in
step_outputs
:
...
...
vllm/model_executor/layers/ops/update_input.py
View file @
7d867671
...
@@ -21,4 +21,9 @@ def _update_input_tokens(
...
@@ -21,4 +21,9 @@ def _update_input_tokens(
_seq_ids
=
tl
.
load
(
seq_ids
+
i
)
_seq_ids
=
tl
.
load
(
seq_ids
+
i
)
if
_seq_ids
==
_input_seq_id
:
if
_seq_ids
==
_input_seq_id
:
output_token
=
tl
.
load
(
sample_output
+
i
)
output_token
=
tl
.
load
(
sample_output
+
i
)
tl
.
store
(
input_tokens
+
pid
,
output_token
)
tl
.
store
(
input_tokens
+
pid
,
output_token
)
\ No newline at end of file
def
UpdateInputTokens
(
input_tokens
,
input_seq_ids
,
last_sample
,
last_ids
):
last_ids
=
last_ids
.
to
(
'cuda'
)
grid
=
[
input_seq_ids
.
shape
[
0
],
1
,
1
]
_update_input_tokens
[
grid
](
last_sample
,
last_ids
,
input_tokens
,
input_seq_ids
,
last_ids
.
shape
[
0
],
input_seq_ids
.
shape
[
0
])
\ No newline at end of file
vllm/model_executor/layers/sampler.py
View file @
7d867671
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
"""A layer that samples the next tokens from the model's outputs."""
"""A layer that samples the next tokens from the model's outputs."""
import
itertools
import
itertools
import
os
import
warnings
import
warnings
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
importlib.util
import
find_spec
from
importlib.util
import
find_spec
...
@@ -72,11 +73,12 @@ class SampleResultArgsType:
...
@@ -72,11 +73,12 @@ class SampleResultArgsType:
# Implemented by guanyu
# Implemented by guanyu
@
dataclass
@
dataclass
class
SampleDeviceToDevices
:
class
SampleDeviceToDevices
:
num_parent_seq
:
torch
.
Tensor
=
None
def
__init__
(
self
):
seq_id
:
torch
.
Tensor
=
None
self
.
seq_id
:
torch
.
Tensor
=
None
random_samples
:
torch
.
Tensor
=
None
self
.
random_samples
:
torch
.
Tensor
=
None
sample_idx
:
int
=
None
self
.
zero_overhead
:
bool
=
False
d2d_data
=
SampleDeviceToDevices
()
d2d_data
=
SampleDeviceToDevices
()
# Union of non-deferred (single-step scheduling)
# Union of non-deferred (single-step scheduling)
# vs deferred (multi-step scheduling)
# vs deferred (multi-step scheduling)
...
@@ -144,6 +146,9 @@ class SamplerOutput(
...
@@ -144,6 +146,9 @@ class SamplerOutput(
# tree-style cartesian candidates
# tree-style cartesian candidates
tree_attn_masks
:
Optional
[
torch
.
Tensor
]
=
None
tree_attn_masks
:
Optional
[
torch
.
Tensor
]
=
None
sampler_out_tenosr
:
Optional
[
torch
.
Tensor
]
=
None
sampler_out_ids
:
Optional
[
torch
.
Tensor
]
=
None
def
__getitem__
(
self
,
idx
:
int
)
->
CompletionSequenceGroupOutput
:
def
__getitem__
(
self
,
idx
:
int
)
->
CompletionSequenceGroupOutput
:
return
self
.
outputs
[
idx
]
return
self
.
outputs
[
idx
]
...
@@ -174,7 +179,10 @@ class SamplerOutput(
...
@@ -174,7 +179,10 @@ class SamplerOutput(
f
"sampled_token_ids=
{
sampled_token_ids_repr
}
, "
f
"sampled_token_ids=
{
sampled_token_ids_repr
}
, "
f
"spec_decode_worker_metrics=
{
self
.
spec_decode_worker_metrics
}
, "
f
"spec_decode_worker_metrics=
{
self
.
spec_decode_worker_metrics
}
, "
f
"logits=
{
self
.
logits
}
, "
f
"logits=
{
self
.
logits
}
, "
f
"tree_attn_masks=
{
self
.
tree_attn_masks
}
)"
)
f
"tree_attn_masks=
{
self
.
tree_attn_masks
}
, "
f
"sampler_out_tenosr=
{
self
.
sampler_out_tenosr
}
, "
f
"sampler_out_ids=
{
self
.
sampler_out_ids
}
, "
f
")"
)
class
Sampler
(
nn
.
Module
):
class
Sampler
(
nn
.
Module
):
...
@@ -206,6 +214,8 @@ class Sampler(nn.Module):
...
@@ -206,6 +214,8 @@ class Sampler(nn.Module):
# speculative decoding.
# speculative decoding.
self
.
include_gpu_probs_tensor
=
False
self
.
include_gpu_probs_tensor
=
False
self
.
should_modify_greedy_probs_inplace
=
False
self
.
should_modify_greedy_probs_inplace
=
False
self
.
zero_overhead
=
os
.
environ
.
get
(
'VLLM_ZERO_OVERHEAD'
)
==
'1'
d2d_data
.
zero_overhead
=
self
.
zero_overhead
def
_init_sampling_tensors
(
def
_init_sampling_tensors
(
self
,
self
,
...
@@ -503,8 +513,8 @@ def _random_sample(
...
@@ -503,8 +513,8 @@ def _random_sample(
seq_group has do_sample=False, tuple contains ([], [])
seq_group has do_sample=False, tuple contains ([], [])
"""
"""
# Find the maximum n value of the prompt phase requests.
# Find the maximum n value of the prompt phase requests.
#random_samples = random_samples.cpu()删除,取消gpu->cpu之间的同步
if
not
d2d_data
.
zero_overhead
:
random_samples
=
random_samples
.
cpu
()
random_samples
=
random_samples
.
cpu
()
sample_idx
=
0
sample_idx
=
0
results
:
SampleResultType
=
[]
results
:
SampleResultType
=
[]
for
seq_group
in
selected_seq_groups
:
for
seq_group
in
selected_seq_groups
:
...
@@ -516,20 +526,24 @@ def _random_sample(
...
@@ -516,20 +526,24 @@ def _random_sample(
sampling_params
=
seq_group
.
sampling_params
sampling_params
=
seq_group
.
sampling_params
is_prompt
=
seq_group
.
is_prompt
is_prompt
=
seq_group
.
is_prompt
num_parent_seqs
=
len
(
seq_ids
)
num_parent_seqs
=
len
(
seq_ids
)
d2d_data
.
num_parent_seq
=
num_parent_seqs
if
is_prompt
:
if
is_prompt
:
# Prompt phase.
# Prompt phase.
parent_ids
=
[
0
]
*
sampling_params
.
n
parent_ids
=
[
0
]
*
sampling_params
.
n
next_token_ids
=
random_samples
[
if
d2d_data
.
zero_overhead
:
sample_idx
,
:
sampling_params
.
n
].
tolist
()
next_token_ids
=
[
0
]
*
sampling_params
.
n
else
:
next_token_ids
=
random_samples
[
sample_idx
,
:
sampling_params
.
n
].
tolist
()
else
:
else
:
# Generation phase.
# Generation phase.
parent_ids
=
list
(
range
(
num_parent_seqs
))
parent_ids
=
list
(
range
(
num_parent_seqs
))
next_token_ids
=
random_samples
[
sample_idx
:
sample_idx
+
if
d2d_data
.
zero_overhead
:
num_parent_seqs
,
0
].
tolist
()
next_token_ids
=
[
0
]
*
num_parent_seqs
else
:
next_token_ids
=
random_samples
[
sample_idx
:
sample_idx
+
num_parent_seqs
,
0
].
tolist
()
results
.
append
((
next_token_ids
,
parent_ids
))
results
.
append
((
next_token_ids
,
parent_ids
))
sample_idx
+=
num_parent_seqs
sample_idx
+=
num_parent_seqs
d2d_data
.
sample_idx
=
sample_idx
return
results
return
results
...
@@ -707,7 +721,7 @@ def get_pythonized_sample_results(
...
@@ -707,7 +721,7 @@ def get_pythonized_sample_results(
if
sampling_type
==
SamplingType
.
GREEDY
:
if
sampling_type
==
SamplingType
.
GREEDY
:
sample_results
=
_greedy_sample
(
seq_groups
,
greedy_samples
)
sample_results
=
_greedy_sample
(
seq_groups
,
greedy_samples
)
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
d2d_data
.
random_samples
=
multinomial_samples
[
sampling_type
]
#记录random_samples的数据
d2d_data
.
random_samples
=
multinomial_samples
[
sampling_type
]
#记录random_samples的数据
sample_results
=
_random_sample
(
seq_groups
,
sample_results
=
_random_sample
(
seq_groups
,
multinomial_samples
[
sampling_type
])
multinomial_samples
[
sampling_type
])
elif
sampling_type
==
SamplingType
.
BEAM
:
elif
sampling_type
==
SamplingType
.
BEAM
:
...
@@ -744,13 +758,11 @@ def _sample_with_torch(
...
@@ -744,13 +758,11 @@ def _sample_with_torch(
categorized_seq_group_ids
:
Dict
[
SamplingType
,
List
[
int
]]
=
{
categorized_seq_group_ids
:
Dict
[
SamplingType
,
List
[
int
]]
=
{
t
:
[]
t
:
[]
for
t
in
SamplingType
for
t
in
SamplingType
}
#初始化各种结果存储容器然后按照类型分类
}
print
(
f
'sampling_metadata.seq_groups的长度:
{
len
(
sampling_metadata
.
seq_groups
)
}
'
)
d2d_data
.
seq_id
=
torch
.
zeros
(
len
(
sampling_metadata
.
seq_groups
))
# 初始化一个tensor张量用于保存seq_id,初始值为-1
d2d_data
.
seq_id
=
torch
.
zeros
(
len
(
sampling_metadata
.
seq_groups
),
1
)
-
1
categorized_sample_indices
=
sampling_metadata
.
categorized_sample_indices
categorized_sample_indices
=
sampling_metadata
.
categorized_sample_indices
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
d2d_data
.
seq_id
[
i
]
=
seq_group
.
seq_ids
[
0
]
#将 i对应的seq_id存储到d2d_data.seq_id中
d2d_data
.
seq_id
[
i
]
=
seq_group
.
seq_ids
[
0
]
sampling_params
=
seq_group
.
sampling_params
sampling_params
=
seq_group
.
sampling_params
sampling_type
=
sampling_params
.
sampling_type
sampling_type
=
sampling_params
.
sampling_type
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
...
@@ -1280,13 +1292,18 @@ def _build_sampler_output(
...
@@ -1280,13 +1292,18 @@ def _build_sampler_output(
sampled_token_probs
,
logprobs_tensor
,
sampled_token_ids
=
(
None
,
None
,
sampled_token_probs
,
logprobs_tensor
,
sampled_token_ids
=
(
None
,
None
,
None
)
None
)
if
d2d_data
.
zero_overhead
:
pass
return
SamplerOutput
(
return
SamplerOutput
(
outputs
=
sampler_output
,
outputs
=
sampler_output
,
sampled_token_probs
=
sampled_token_probs
,
sampled_token_probs
=
sampled_token_probs
,
sampled_token_ids
=
sampled_token_ids
,
sampled_token_ids
=
sampled_token_ids
,
logprobs
=
logprobs_tensor
,
logprobs
=
logprobs_tensor
,
deferred_sample_results_args
=
deferred_sample_results_args
,
deferred_sample_results_args
=
deferred_sample_results_args
,
logits
=
logits
)
logits
=
logits
,
sampler_out_tenosr
=
d2d_data
.
random_samples
,
sampler_out_ids
=
d2d_data
.
seq_id
)
def
_get_next_prompt_tokens
(
seq_group
:
SequenceGroupToSample
)
->
List
[
int
]:
def
_get_next_prompt_tokens
(
seq_group
:
SequenceGroupToSample
)
->
List
[
int
]:
...
...
vllm/sequence.py
View file @
7d867671
...
@@ -582,6 +582,11 @@ class Sequence:
...
@@ -582,6 +582,11 @@ class Sequence:
self
.
output_logprobs
.
append
(
logprobs
)
self
.
output_logprobs
.
append
(
logprobs
)
self
.
data
.
append_token_id
(
token_id
,
logprobs
[
token_id
].
logprob
)
self
.
data
.
append_token_id
(
token_id
,
logprobs
[
token_id
].
logprob
)
def
fix_last_token_id
(
self
,
token_id
:
int
)
->
None
:
self
.
data
.
_output_token_ids
[
-
2
]
=
token_id
self
.
data
.
_new_appended_tokens
[
-
2
]
=
token_id
self
.
data
.
_cached_all_token_ids
[
-
2
]
=
token_id
def
get_len
(
self
)
->
int
:
def
get_len
(
self
)
->
int
:
return
self
.
data
.
get_len
()
return
self
.
data
.
get_len
()
...
@@ -1403,7 +1408,10 @@ class ExecuteModelRequest(
...
@@ -1403,7 +1408,10 @@ class ExecuteModelRequest(
kvcache_slot_to_be_moved
:
Optional
[
torch
.
Tensor
]
=
None
kvcache_slot_to_be_moved
:
Optional
[
torch
.
Tensor
]
=
None
# for zero-overhead scheduler
# for zero-overhead scheduler
last_outputs
:
Optional
[
torch
.
Tensor
]
=
None
last_outputs_sample
:
Optional
[
torch
.
Tensor
]
=
None
# for zero-overhead scheduler
last_outputs_ids
:
Optional
[
torch
.
Tensor
]
=
None
@
property
@
property
def
is_first_multi_step
(
self
)
->
bool
:
def
is_first_multi_step
(
self
)
->
bool
:
...
@@ -1455,7 +1463,8 @@ class ExecuteModelRequest(
...
@@ -1455,7 +1463,8 @@ class ExecuteModelRequest(
tree_attn_masks
=
self
.
tree_attn_masks
,
tree_attn_masks
=
self
.
tree_attn_masks
,
tree_position_ids
=
self
.
tree_position_ids
,
tree_position_ids
=
self
.
tree_position_ids
,
kvcache_slot_to_be_moved
=
self
.
kvcache_slot_to_be_moved
,
kvcache_slot_to_be_moved
=
self
.
kvcache_slot_to_be_moved
,
last_outputs
=
self
.
last_outputs
)
last_outputs_sample
=
self
.
last_outputs_sample
,
last_outputs_ids
=
self
.
last_outputs_ids
)
@
dataclass
@
dataclass
...
...
vllm/worker/model_runner.py
View file @
7d867671
...
@@ -4,6 +4,7 @@ import dataclasses
...
@@ -4,6 +4,7 @@ import dataclasses
import
gc
import
gc
import
inspect
import
inspect
import
itertools
import
itertools
import
os
import
time
import
time
import
weakref
import
weakref
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
...
@@ -59,6 +60,8 @@ from vllm.worker.model_runner_base import (
...
@@ -59,6 +60,8 @@ from vllm.worker.model_runner_base import (
_init_attn_metadata_from_tensor_dict
,
_init_attn_metadata_from_tensor_dict
,
_init_sampling_metadata_from_tensor_dict
)
_init_sampling_metadata_from_tensor_dict
)
from
vllm.model_executor.layers.ops.update_input
import
UpdateInputTokens
from
vllm.profiler.prof
import
profile
from
vllm.profiler.prof
import
profile
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -476,6 +479,14 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -476,6 +479,14 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self
.
sliding_window_blocks
*
self
.
block_size
self
.
sliding_window_blocks
*
self
.
block_size
self
.
is_encoder_decoder_model
=
self
.
runner
.
model_config
.
is_encoder_decoder
self
.
is_encoder_decoder_model
=
self
.
runner
.
model_config
.
is_encoder_decoder
self
.
zero_overhead
=
os
.
environ
.
get
(
'VLLM_ZERO_OVERHEAD'
)
==
'1'
self
.
last_sample_tensor
=
None
self
.
last_sample_ids
=
None
self
.
req_ids
=
[]
def
SetLastSamperData
(
self
,
last_sample_ids
,
last_sample_tensor
):
self
.
last_sample_tensor
=
last_sample_tensor
self
.
last_sample_ids
=
last_sample_ids
def
prepare
(
self
,
def
prepare
(
self
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
None
:
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
None
:
...
@@ -491,6 +502,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -491,6 +502,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
ModelInputForGPUBuilder
.
InterDataForSeqGroup
]
=
[]
ModelInputForGPUBuilder
.
InterDataForSeqGroup
]
=
[]
self
.
attn_metadata_builder
.
prepare
()
self
.
attn_metadata_builder
.
prepare
()
self
.
req_ids
.
clear
()
def
_compute_lens
(
self
,
inter_data
:
InterDataForSeqGroup
,
seq_idx
:
int
,
def
_compute_lens
(
self
,
inter_data
:
InterDataForSeqGroup
,
seq_idx
:
int
,
seq_group_metadata
:
SequenceGroupMetadata
):
seq_group_metadata
:
SequenceGroupMetadata
):
...
@@ -756,8 +768,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -756,8 +768,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
encoder_seq_len
=
encoder_seq_len
)
encoder_seq_len
=
encoder_seq_len
)
self
.
inter_data_list
.
append
(
inter_data
)
self
.
inter_data_list
.
append
(
inter_data
)
seq_ids
=
list
(
seq_ids
)
for
seq_idx
in
range
(
n_seqs
):
for
seq_idx
in
range
(
n_seqs
):
self
.
req_ids
.
append
(
seq_ids
[
seq_idx
])
for
per_seq_fn
in
self
.
per_seq_compute_fns
:
for
per_seq_fn
in
self
.
per_seq_compute_fns
:
per_seq_fn
(
inter_data
,
seq_idx
,
seq_group_metadata
)
per_seq_fn
(
inter_data
,
seq_idx
,
seq_group_metadata
)
for
per_seq_group_fn
in
self
.
per_seq_group_compute_fns
:
for
per_seq_group_fn
in
self
.
per_seq_group_compute_fns
:
...
@@ -898,10 +911,18 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -898,10 +911,18 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
if
cuda_graph_pad_size
:
if
cuda_graph_pad_size
:
input_tokens
.
extend
(
itertools
.
repeat
(
0
,
cuda_graph_pad_size
))
input_tokens
.
extend
(
itertools
.
repeat
(
0
,
cuda_graph_pad_size
))
assert
self
.
runner
.
device
is
not
None
assert
self
.
runner
.
device
is
not
None
input_tokens_tensor
=
async_tensor_h2d
(
input_tokens
,
torch
.
long
,
input_tokens_tensor
=
async_tensor_h2d
(
input_tokens
,
torch
.
long
,
self
.
runner
.
device
,
self
.
runner
.
device
,
self
.
runner
.
pin_memory
)
self
.
runner
.
pin_memory
)
if
self
.
zero_overhead
and
self
.
last_sample_tensor
is
not
None
:
input_ids
=
torch
.
tensor
(
self
.
req_ids
,
device
=
'cuda'
)
UpdateInputTokens
(
input_tokens_tensor
,
input_ids
,
self
.
last_sample_tensor
,
self
.
last_sample_ids
)
print
(
'####input_tokens_tensor'
,
input_tokens_tensor
)
print
(
'####input_ids'
,
input_ids
)
print
(
'####self.last_sample_tensor'
,
self
.
last_sample_tensor
)
print
(
'####self.last_sample_ids'
,
self
.
last_sample_ids
)
token_types_tensor
=
async_tensor_h2d
(
token_types
,
torch
.
long
,
token_types_tensor
=
async_tensor_h2d
(
token_types
,
torch
.
long
,
self
.
runner
.
device
,
self
.
runner
.
device
,
...
@@ -1200,7 +1221,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1200,7 +1221,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
def
_prepare_model_input_tensors
(
def
_prepare_model_input_tensors
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
,
last_outputs_ids
:
torch
.
Tensor
=
None
,
last_output_sample
:
torch
.
Tensor
=
None
,
)
->
TModelInputForGPU
:
)
->
TModelInputForGPU
:
"""Helper method to prepare the model input based on a given sequence
"""Helper method to prepare the model input based on a given sequence
group. Prepares metadata needed for the base model forward pass but not
group. Prepares metadata needed for the base model forward pass but not
...
@@ -1221,7 +1244,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1221,7 +1244,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self
.
builder
.
add_seq_group
(
seq_group_metadata
)
self
.
builder
.
add_seq_group
(
seq_group_metadata
)
self
.
builder
.
reset_cached_inter_data
()
self
.
builder
.
reset_cached_inter_data
()
self
.
builder
.
SetLastSamperData
(
last_outputs_ids
,
last_output_sample
)
return
self
.
builder
.
build
()
# type: ignore
return
self
.
builder
.
build
()
# type: ignore
@
contextmanager
@
contextmanager
...
@@ -1616,6 +1639,8 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
...
@@ -1616,6 +1639,8 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
virtual_engine
:
int
=
0
,
virtual_engine
:
int
=
0
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
,
last_outputs_ids
:
torch
.
Tensor
=
None
,
last_output_sample
:
torch
.
Tensor
=
None
,
)
->
ModelInputForGPUWithSamplingMetadata
:
)
->
ModelInputForGPUWithSamplingMetadata
:
"""Prepare the model input based on a given sequence group, including
"""Prepare the model input based on a given sequence group, including
metadata for the sampling step.
metadata for the sampling step.
...
@@ -1631,7 +1656,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
...
@@ -1631,7 +1656,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
If cuda graph is required, this API automatically pads inputs.
If cuda graph is required, this API automatically pads inputs.
"""
"""
model_input
=
self
.
_prepare_model_input_tensors
(
model_input
=
self
.
_prepare_model_input_tensors
(
seq_group_metadata_list
,
finished_requests_ids
)
seq_group_metadata_list
,
finished_requests_ids
,
last_outputs_ids
,
last_output_sample
)
if
get_pp_group
().
is_last_rank
:
if
get_pp_group
().
is_last_rank
:
# Sampling metadata is only required for the final pp group
# Sampling metadata is only required for the final pp group
generators
=
self
.
get_generators
(
finished_requests_ids
)
generators
=
self
.
get_generators
(
finished_requests_ids
)
...
...
vllm/worker/model_runner_base.py
View file @
7d867671
...
@@ -189,7 +189,6 @@ class ModelRunnerBase(ABC, Generic[T]):
...
@@ -189,7 +189,6 @@ class ModelRunnerBase(ABC, Generic[T]):
self
.
speculative_config
=
vllm_config
.
speculative_config
self
.
speculative_config
=
vllm_config
.
speculative_config
self
.
prompt_adapter_config
=
vllm_config
.
prompt_adapter_config
self
.
prompt_adapter_config
=
vllm_config
.
prompt_adapter_config
self
.
observability_config
=
vllm_config
.
observability_config
self
.
observability_config
=
vllm_config
.
observability_config
self
.
last_output
=
None
# Map of request_id -> generator used for seeded random sampling
# Map of request_id -> generator used for seeded random sampling
generators
:
Dict
[
str
,
torch
.
Generator
]
=
{}
generators
:
Dict
[
str
,
torch
.
Generator
]
=
{}
...
@@ -211,6 +210,8 @@ class ModelRunnerBase(ABC, Generic[T]):
...
@@ -211,6 +210,8 @@ class ModelRunnerBase(ABC, Generic[T]):
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
virtual_engine
:
int
=
0
,
virtual_engine
:
int
=
0
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
,
last_outputs_ids
:
torch
.
Tensor
=
None
,
last_output_sample
:
torch
.
Tensor
=
None
,
)
->
T
:
)
->
T
:
"""
"""
Prepare the inputs to ModelRunnerBase.execute_model from an execution
Prepare the inputs to ModelRunnerBase.execute_model from an execution
...
...
vllm/worker/worker_base.py
View file @
7d867671
...
@@ -353,12 +353,13 @@ class LocalOrDistributedWorkerBase(WorkerBase):
...
@@ -353,12 +353,13 @@ class LocalOrDistributedWorkerBase(WorkerBase):
worker_input
:
WorkerInput
=
self
.
prepare_worker_input
(
worker_input
:
WorkerInput
=
self
.
prepare_worker_input
(
execute_model_req
=
execute_model_req
)
execute_model_req
=
execute_model_req
)
self
.
model_runner
.
last_output
=
execute_model_req
.
last_outputs
model_input
:
ModelRunnerInputBase
=
(
model_input
:
ModelRunnerInputBase
=
(
self
.
model_runner
.
prepare_model_input
(
self
.
model_runner
.
prepare_model_input
(
execute_model_req
.
seq_group_metadata_list
,
execute_model_req
.
seq_group_metadata_list
,
execute_model_req
.
virtual_engine
,
execute_model_req
.
virtual_engine
,
execute_model_req
.
finished_requests_ids
))
execute_model_req
.
finished_requests_ids
,
last_outputs_ids
=
execute_model_req
.
last_outputs_ids
,
last_output_sample
=
execute_model_req
.
last_outputs_sample
))
if
self
.
tree_decoding
and
execute_model_req
.
tree_position_ids
is
not
None
and
\
if
self
.
tree_decoding
and
execute_model_req
.
tree_position_ids
is
not
None
and
\
execute_model_req
.
tree_attn_masks
is
not
None
:
execute_model_req
.
tree_attn_masks
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment