Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9bd32639
Commit
9bd32639
authored
Mar 17, 2025
by
lizhigong
Browse files
zero overhead engine update
parent
6b7651af
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
158 additions
and
11 deletions
+158
-11
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+35
-3
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+2
-0
vllm/model_executor/layers/ops/update_input.py
vllm/model_executor/layers/ops/update_input.py
+24
-0
vllm/profiler/prof.py
vllm/profiler/prof.py
+73
-0
vllm/sequence.py
vllm/sequence.py
+5
-1
vllm/version.py
vllm/version.py
+6
-4
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+8
-2
vllm/worker/model_runner_base.py
vllm/worker/model_runner_base.py
+1
-0
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+4
-1
No files found.
vllm/engine/llm_engine.py
View file @
9bd32639
...
@@ -61,6 +61,7 @@ from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
...
@@ -61,6 +61,7 @@ from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message
)
usage_message
)
from
vllm.utils
import
Counter
,
Device
,
deprecate_kwargs
,
weak_bind
from
vllm.utils
import
Counter
,
Device
,
deprecate_kwargs
,
weak_bind
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm.profiler.prof
import
profile
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_LOCAL_LOGGING_INTERVAL_SEC
=
5
_LOCAL_LOGGING_INTERVAL_SEC
=
5
...
@@ -407,6 +408,11 @@ class LLMEngine:
...
@@ -407,6 +408,11 @@ class LLMEngine:
self
.
tree_decoding
=
os
.
environ
.
get
(
'VLLM_TREE_DECODING'
)
==
'1'
self
.
tree_decoding
=
os
.
environ
.
get
(
'VLLM_TREE_DECODING'
)
==
'1'
self
.
seq_id_to_seq_group
:
Dict
[
str
,
SequenceGroupBase
]
=
{}
self
.
seq_id_to_seq_group
:
Dict
[
str
,
SequenceGroupBase
]
=
{}
self
.
zero_overhead
=
os
.
environ
.
get
(
'VLLM_ZERO_OVERHEAD'
)
==
'1'
self
.
step_switch
=
0
# 0 step A 1 step B
self
.
output_recorder
=
[
None
,
None
]
profile
.
StartTracer
()
def
_initialize_kv_caches
(
self
)
->
None
:
def
_initialize_kv_caches
(
self
)
->
None
:
"""Initialize the KV cache in the worker(s).
"""Initialize the KV cache in the worker(s).
...
@@ -1271,6 +1277,9 @@ class LLMEngine:
...
@@ -1271,6 +1277,9 @@ class LLMEngine:
else
:
else
:
seq
.
append_token_id
(
sample
.
output_token
,
sample
.
logprobs
)
seq
.
append_token_id
(
sample
.
output_token
,
sample
.
logprobs
)
def
trans_last_output_tensor
(
self
,
last_output
)
->
torch
.
Tensor
:
return
None
def
step
(
self
)
->
List
[
Union
[
RequestOutput
,
PoolingRequestOutput
]]:
def
step
(
self
)
->
List
[
Union
[
RequestOutput
,
PoolingRequestOutput
]]:
"""Performs one decoding iteration and returns newly generated results.
"""Performs one decoding iteration and returns newly generated results.
...
@@ -1346,6 +1355,7 @@ class LLMEngine:
...
@@ -1346,6 +1355,7 @@ class LLMEngine:
# Skip the scheduler if there are any remaining steps in the seq groups.
# Skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# This ensures that the scheduler is only called again when the current
# batch has completed.
# batch has completed.
profile
.
ProfRangeAutoPush
(
'has_remain'
)
if
not
self
.
_has_remaining_steps
(
seq_group_metadata_list
):
if
not
self
.
_has_remaining_steps
(
seq_group_metadata_list
):
# Schedule iteration
# Schedule iteration
(
seq_group_metadata_list
,
scheduler_outputs
,
(
seq_group_metadata_list
,
scheduler_outputs
,
...
@@ -1375,6 +1385,10 @@ class LLMEngine:
...
@@ -1375,6 +1385,10 @@ class LLMEngine:
assert
seq_group_metadata_list
is
not
None
assert
seq_group_metadata_list
is
not
None
assert
scheduler_outputs
is
not
None
assert
scheduler_outputs
is
not
None
profile
.
ProfRangeAutoPush
(
'execute_model'
)
last_outputs
=
None
if
self
.
zero_overhead
:
last_outputs
=
self
.
trans_last_output_tensor
(
self
.
output_recorder
[
self
.
step_switch
])
if
not
scheduler_outputs
.
is_empty
():
if
not
scheduler_outputs
.
is_empty
():
# Check if we have a cached last_output from the previous iteration.
# Check if we have a cached last_output from the previous iteration.
...
@@ -1384,6 +1398,14 @@ class LLMEngine:
...
@@ -1384,6 +1398,14 @@ class LLMEngine:
last_sampled_token_ids
=
\
last_sampled_token_ids
=
\
self
.
_get_last_sampled_token_ids
(
virtual_engine
)
self
.
_get_last_sampled_token_ids
(
virtual_engine
)
# print('seq_group_metadata_list', len(seq_group_metadata_list))
# print('scheduler_outputs.blocks_to_swap_in', len(scheduler_outputs.blocks_to_swap_in))
# print('scheduler_outputs.num_lookahead_slots', scheduler_outputs.num_lookahead_slots)
# print('scheduler_outputs.running_queue_size', scheduler_outputs.running_queue_size)
# print('finished_requests_ids', len(finished_requests_ids))
# print('last_sampled_token_ids', last_sampled_token_ids)
# print('self.model_executor', type(self.model_executor))
execute_model_req
=
ExecuteModelRequest
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
scheduler_outputs
.
blocks_to_swap_in
,
blocks_to_swap_in
=
scheduler_outputs
.
blocks_to_swap_in
,
...
@@ -1394,15 +1416,15 @@ class LLMEngine:
...
@@ -1394,15 +1416,15 @@ class LLMEngine:
finished_requests_ids
=
finished_requests_ids
,
finished_requests_ids
=
finished_requests_ids
,
# We use ExecuteModelRequest to pass the last sampled_token_ids
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids
=
last_sampled_token_ids
)
last_sampled_token_ids
=
last_sampled_token_ids
,
last_outputs
=
last_outputs
)
if
allow_async_output_proc
:
if
allow_async_output_proc
:
execute_model_req
.
async_callback
=
self
.
async_callbacks
[
execute_model_req
.
async_callback
=
self
.
async_callbacks
[
virtual_engine
]
virtual_engine
]
outputs
=
self
.
model_executor
.
execute_model
(
outputs
=
self
.
model_executor
.
execute_model
(
execute_model_req
=
execute_model_req
)
execute_model_req
=
execute_model_req
)
print
(
'###outputs'
,
outputs
)
# We need to do this here so that last step's sampled_token_ids can
# We need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
# be passed to the next iteration for PP.
if
self
.
scheduler_config
.
is_multi_step
:
if
self
.
scheduler_config
.
is_multi_step
:
...
@@ -1420,6 +1442,15 @@ class LLMEngine:
...
@@ -1420,6 +1442,15 @@ class LLMEngine:
for
seq_group
in
seq_group_metadata_list
:
for
seq_group
in
seq_group_metadata_list
:
seq_group
.
finish_step
()
seq_group
.
finish_step
()
if
self
.
zero_overhead
:
self
.
output_recorder
[
self
.
step_switch
]
=
outputs
self
.
step_switch
=
1
-
self
.
step_switch
outputs
=
self
.
output_recorder
[
self
.
step_switch
]
if
outputs
is
None
:
return
None
#同步上一次的output
if
not
self
.
_has_remaining_steps
(
seq_group_metadata_list
):
if
not
self
.
_has_remaining_steps
(
seq_group_metadata_list
):
# clear the cache if we have finished all the steps.
# clear the cache if we have finished all the steps.
if
self
.
scheduler_config
.
is_multi_step
:
if
self
.
scheduler_config
.
is_multi_step
:
...
@@ -1460,6 +1491,7 @@ class LLMEngine:
...
@@ -1460,6 +1491,7 @@ class LLMEngine:
# Multi-step case
# Multi-step case
return
ctx
.
request_outputs
return
ctx
.
request_outputs
profile
.
ProfRangeAutoPush
(
'has_unfinish'
)
if
not
self
.
has_unfinished_requests
():
if
not
self
.
has_unfinished_requests
():
# Drain async postprocessor (if exists)
# Drain async postprocessor (if exists)
if
len
(
ctx
.
output_queue
)
>
0
:
if
len
(
ctx
.
output_queue
)
>
0
:
...
...
vllm/entrypoints/llm.py
View file @
9bd32639
...
@@ -1388,6 +1388,8 @@ class LLM:
...
@@ -1388,6 +1388,8 @@ class LLM:
total_out_toks
=
0
total_out_toks
=
0
while
self
.
llm_engine
.
has_unfinished_requests
():
while
self
.
llm_engine
.
has_unfinished_requests
():
step_outputs
=
self
.
llm_engine
.
step
()
step_outputs
=
self
.
llm_engine
.
step
()
if
step_outputs
is
None
:
continue
for
output
in
step_outputs
:
for
output
in
step_outputs
:
if
output
.
finished
:
if
output
.
finished
:
outputs
.
append
(
output
)
outputs
.
append
(
output
)
...
...
vllm/model_executor/layers/ops/update_input.py
0 → 100644
View file @
9bd32639
import
torch
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
_update_input_tokens
(
sample_output
,
seq_ids
,
input_tokens
,
input_seq_ids
,
BATCH_SIZE1
,
BATCH_SIZE2
,
):
pid
=
tl
.
program_id
(
0
)
if
pid
>=
BATCH_SIZE2
:
return
output_token
=
tl
.
load
(
input_tokens
+
pid
)
_input_seq_id
=
tl
.
load
(
input_seq_ids
+
pid
)
for
i
in
range
(
BATCH_SIZE1
):
_seq_ids
=
tl
.
load
(
seq_ids
+
i
)
if
_seq_ids
==
_input_seq_id
:
output_token
=
tl
.
load
(
sample_output
+
i
)
tl
.
store
(
input_tokens
+
pid
,
output_token
)
\ No newline at end of file
vllm/profiler/prof.py
0 → 100644
View file @
9bd32639
from
ctypes
import
*
import
os
import
time
import
threading
class
Prof
:
def
__init__
(
self
):
self
.
use_nvtx
=
os
.
getenv
(
'VLLM_PROF_NVTX'
)
is
not
None
self
.
roc_tracer_flag
=
False
self
.
lib
=
None
if
self
.
use_nvtx
:
self
.
lib
=
cdll
.
LoadLibrary
(
"libnvToolsExt.so"
)
self
.
lib
.
nvtxRangePushA
.
argtypes
=
[
c_char_p
]
self
.
lib
.
nvtxRangePushA
.
restype
=
c_int
self
.
lib
.
nvtxRangePop
.
restype
=
c_int
self
.
use_roctx
=
os
.
getenv
(
'VLLM_PROF_ROCTX'
)
is
not
None
if
self
.
use_roctx
:
self
.
lib
=
cdll
.
LoadLibrary
(
"libroctracer64.so"
)
self
.
lib
.
roctxRangePushA
.
argtypes
=
[
c_char_p
]
self
.
lib
.
roctxRangePushA
.
restype
=
c_int
self
.
lib
.
roctxRangePop
.
restype
=
c_int
self
.
tm
=
time
.
perf_counter
()
self
.
push_depth
=
{}
def
StartTracer
(
self
):
if
self
.
use_roctx
:
if
self
.
lib
is
None
:
self
.
lib
=
cdll
.
LoadLibrary
(
"libroctracer64.so"
)
self
.
lib
.
roctracer_start
()
self
.
roc_tracer_flag
=
True
def
StopTracer
(
self
):
if
self
.
use_roctx
:
if
self
.
lib
is
None
:
self
.
lib
=
cdll
.
LoadLibrary
(
"libroctracer64.so"
)
self
.
lib
.
roctracer_stop
()
self
.
roc_tracer_flag
=
False
def
thread_depth_add
(
self
,
num
):
current_thread
=
threading
.
current_thread
()
thread_id
=
current_thread
.
ident
if
thread_id
not
in
self
.
push_depth
.
keys
():
self
.
push_depth
[
thread_id
]
=
0
if
num
<
0
and
self
.
push_depth
[
thread_id
]
==
0
:
return
False
self
.
push_depth
[
thread_id
]
+=
num
return
True
def
ProfRangePush
(
self
,
message
):
if
profile
.
use_nvtx
:
profile
.
lib
.
nvtxRangePushA
(
message
.
encode
(
'utf-8'
))
self
.
thread_depth_add
(
1
)
if
profile
.
use_roctx
and
self
.
roc_tracer_flag
:
profile
.
lib
.
roctxRangePushA
(
message
.
encode
(
'utf-8'
))
self
.
thread_depth_add
(
1
)
def
ProfRangePop
(
self
):
if
profile
.
use_nvtx
:
if
not
self
.
thread_depth_add
(
-
1
):
return
profile
.
lib
.
nvtxRangePop
()
if
profile
.
use_roctx
and
self
.
roc_tracer_flag
:
if
not
self
.
thread_depth_add
(
-
1
):
return
profile
.
lib
.
roctxRangePop
()
def
ProfRangeAutoPush
(
self
,
message
):
self
.
ProfRangePop
()
self
.
ProfRangePush
(
message
)
profile
=
Prof
()
vllm/sequence.py
View file @
9bd32639
...
@@ -1402,6 +1402,9 @@ class ExecuteModelRequest(
...
@@ -1402,6 +1402,9 @@ class ExecuteModelRequest(
# Optional slot mapping of kvcache that pending to be moved generated from draft model.
# Optional slot mapping of kvcache that pending to be moved generated from draft model.
kvcache_slot_to_be_moved
:
Optional
[
torch
.
Tensor
]
=
None
kvcache_slot_to_be_moved
:
Optional
[
torch
.
Tensor
]
=
None
# for zero-overhead scheduler
last_outputs
:
Optional
[
torch
.
Tensor
]
=
None
@
property
@
property
def
is_first_multi_step
(
self
)
->
bool
:
def
is_first_multi_step
(
self
)
->
bool
:
# TODO(will) make this be able to handle batches with variable number of
# TODO(will) make this be able to handle batches with variable number of
...
@@ -1451,7 +1454,8 @@ class ExecuteModelRequest(
...
@@ -1451,7 +1454,8 @@ class ExecuteModelRequest(
async_callback
=
self
.
async_callback
,
async_callback
=
self
.
async_callback
,
tree_attn_masks
=
self
.
tree_attn_masks
,
tree_attn_masks
=
self
.
tree_attn_masks
,
tree_position_ids
=
self
.
tree_position_ids
,
tree_position_ids
=
self
.
tree_position_ids
,
kvcache_slot_to_be_moved
=
self
.
kvcache_slot_to_be_moved
)
kvcache_slot_to_be_moved
=
self
.
kvcache_slot_to_be_moved
,
last_outputs
=
self
.
last_outputs
)
@
dataclass
@
dataclass
...
...
vllm/version.py
View file @
9bd32639
# SPDX-License-Identifier: Apache-2.0
try
:
try
:
from
._version
import
__version__
,
__version_tuple__
__version__
=
"0.7.2"
__version_tuple__
=
(
0
,
7
,
2
)
__hcu_version__
=
f
'0.7.2+das.opt1.cust1.6b7651a.dtk2504'
from
vllm.version
import
__version__
,
__version_tuple__
,
__hcu_version__
except
Exception
as
e
:
except
Exception
as
e
:
import
warnings
import
warnings
warnings
.
warn
(
f
"Failed to read commit hash:
\n
{
e
}
"
,
warnings
.
warn
(
f
"Failed to read commit hash:
\n
+ str(e)
"
,
RuntimeWarning
,
RuntimeWarning
,
stacklevel
=
2
)
stacklevel
=
2
)
__version__
=
"dev"
__version__
=
"dev"
__version_tuple__
=
(
0
,
0
,
__version__
)
__version_tuple__
=
(
0
,
0
,
__version__
)
vllm/worker/model_runner.py
View file @
9bd32639
...
@@ -59,6 +59,8 @@ from vllm.worker.model_runner_base import (
...
@@ -59,6 +59,8 @@ from vllm.worker.model_runner_base import (
_init_attn_metadata_from_tensor_dict
,
_init_attn_metadata_from_tensor_dict
,
_init_sampling_metadata_from_tensor_dict
)
_init_sampling_metadata_from_tensor_dict
)
from
vllm.profiler.prof
import
profile
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.backends.abstract
import
AttentionBackend
...
@@ -271,7 +273,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -271,7 +273,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self
.
computed_block_nums
=
computed_block_nums
self
.
computed_block_nums
=
computed_block_nums
self
.
n_seqs
=
n_seqs
self
.
n_seqs
=
n_seqs
self
.
encoder_seq_len
=
encoder_seq_len
self
.
encoder_seq_len
=
encoder_seq_len
if
reinit
:
if
reinit
:
if
len
(
self
.
seq_ids
)
==
1
and
reinit_use_defaults
:
if
len
(
self
.
seq_ids
)
==
1
and
reinit_use_defaults
:
self
.
simple_reinit
()
self
.
simple_reinit
()
...
@@ -900,6 +901,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -900,6 +901,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
input_tokens_tensor
=
async_tensor_h2d
(
input_tokens
,
torch
.
long
,
input_tokens_tensor
=
async_tensor_h2d
(
input_tokens
,
torch
.
long
,
self
.
runner
.
device
,
self
.
runner
.
device
,
self
.
runner
.
pin_memory
)
self
.
runner
.
pin_memory
)
token_types_tensor
=
async_tensor_h2d
(
token_types
,
torch
.
long
,
token_types_tensor
=
async_tensor_h2d
(
token_types
,
torch
.
long
,
self
.
runner
.
device
,
self
.
runner
.
device
,
...
@@ -1670,7 +1672,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
...
@@ -1670,7 +1672,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
self
.
set_active_prompt_adapters
(
self
.
set_active_prompt_adapters
(
model_input
.
prompt_adapter_requests
,
model_input
.
prompt_adapter_requests
,
model_input
.
prompt_adapter_mapping
)
model_input
.
prompt_adapter_mapping
)
profile
.
ProfRangeAutoPush
(
'begin_forward'
)
self
.
attn_state
.
begin_forward
(
model_input
)
self
.
attn_state
.
begin_forward
(
model_input
)
# Currently cuda graph is only supported by the decode phase.
# Currently cuda graph is only supported by the decode phase.
...
@@ -1772,6 +1774,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
...
@@ -1772,6 +1774,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
torch
.
tensor
(
model_forward_time
+
orig_model_forward_time
))
torch
.
tensor
(
model_forward_time
+
orig_model_forward_time
))
return
hidden_or_intermediate_states
return
hidden_or_intermediate_states
profile
.
ProfRangeAutoPush
(
'compute_logits'
)
logits
=
self
.
model
.
compute_logits
(
hidden_or_intermediate_states
,
logits
=
self
.
model
.
compute_logits
(
hidden_or_intermediate_states
,
model_input
.
sampling_metadata
)
model_input
.
sampling_metadata
)
...
@@ -1782,10 +1785,12 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
...
@@ -1782,10 +1785,12 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
model_input
.
async_callback
()
model_input
.
async_callback
()
# Sample the next token.
# Sample the next token.
profile
.
ProfRangeAutoPush
(
'sample'
)
output
:
SamplerOutput
=
self
.
model
.
sample
(
output
:
SamplerOutput
=
self
.
model
.
sample
(
logits
=
logits
,
logits
=
logits
,
sampling_metadata
=
model_input
.
sampling_metadata
,
sampling_metadata
=
model_input
.
sampling_metadata
,
)
)
profile
.
ProfRangeAutoPush
(
'sample_end'
)
if
(
self
.
observability_config
is
not
None
if
(
self
.
observability_config
is
not
None
and
self
.
observability_config
.
collect_model_forward_time
and
self
.
observability_config
.
collect_model_forward_time
and
output
is
not
None
):
and
output
is
not
None
):
...
@@ -1803,6 +1808,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
...
@@ -1803,6 +1808,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
output
.
model_forward_time
=
(
orig_model_forward_time
+
output
.
model_forward_time
=
(
orig_model_forward_time
+
model_forward_time
)
model_forward_time
)
profile
.
ProfRangeAutoPush
(
'output'
)
if
self
.
return_hidden_states
:
if
self
.
return_hidden_states
:
# we only need to pass hidden states of most recent token
# we only need to pass hidden states of most recent token
assert
model_input
.
sampling_metadata
is
not
None
assert
model_input
.
sampling_metadata
is
not
None
...
...
vllm/worker/model_runner_base.py
View file @
9bd32639
...
@@ -189,6 +189,7 @@ class ModelRunnerBase(ABC, Generic[T]):
...
@@ -189,6 +189,7 @@ class ModelRunnerBase(ABC, Generic[T]):
self
.
speculative_config
=
vllm_config
.
speculative_config
self
.
speculative_config
=
vllm_config
.
speculative_config
self
.
prompt_adapter_config
=
vllm_config
.
prompt_adapter_config
self
.
prompt_adapter_config
=
vllm_config
.
prompt_adapter_config
self
.
observability_config
=
vllm_config
.
observability_config
self
.
observability_config
=
vllm_config
.
observability_config
self
.
last_output
=
None
# Map of request_id -> generator used for seeded random sampling
# Map of request_id -> generator used for seeded random sampling
generators
:
Dict
[
str
,
torch
.
Generator
]
=
{}
generators
:
Dict
[
str
,
torch
.
Generator
]
=
{}
...
...
vllm/worker/worker_base.py
View file @
9bd32639
...
@@ -25,6 +25,7 @@ from vllm.worker.model_runner_base import (BroadcastableModelInput,
...
@@ -25,6 +25,7 @@ from vllm.worker.model_runner_base import (BroadcastableModelInput,
ModelRunnerBase
,
ModelRunnerBase
,
ModelRunnerInputBase
)
ModelRunnerInputBase
)
from
vllm.profiler.prof
import
profile
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -352,6 +353,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
...
@@ -352,6 +353,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
worker_input
:
WorkerInput
=
self
.
prepare_worker_input
(
worker_input
:
WorkerInput
=
self
.
prepare_worker_input
(
execute_model_req
=
execute_model_req
)
execute_model_req
=
execute_model_req
)
self
.
model_runner
.
last_output
=
execute_model_req
.
last_outputs
model_input
:
ModelRunnerInputBase
=
(
model_input
:
ModelRunnerInputBase
=
(
self
.
model_runner
.
prepare_model_input
(
self
.
model_runner
.
prepare_model_input
(
execute_model_req
.
seq_group_metadata_list
,
execute_model_req
.
seq_group_metadata_list
,
...
@@ -444,7 +446,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
...
@@ -444,7 +446,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
and
self
.
observability_config
.
collect_model_execute_time
):
and
self
.
observability_config
.
collect_model_execute_time
):
orig_model_execute_time
=
intermediate_tensors
.
tensors
.
get
(
orig_model_execute_time
=
intermediate_tensors
.
tensors
.
get
(
"model_execute_time"
,
torch
.
tensor
(
0
)).
item
()
"model_execute_time"
,
torch
.
tensor
(
0
)).
item
()
profile
.
ProfRangeAutoPush
(
'execute'
)
output
=
self
.
model_runner
.
execute_model
(
output
=
self
.
model_runner
.
execute_model
(
model_input
=
model_input
,
model_input
=
model_input
,
kv_caches
=
self
.
kv_cache
[
worker_input
.
virtual_engine
]
kv_caches
=
self
.
kv_cache
[
worker_input
.
virtual_engine
]
...
@@ -453,6 +455,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
...
@@ -453,6 +455,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
num_steps
=
num_steps
,
num_steps
=
num_steps
,
**
kwargs
,
**
kwargs
,
)
)
profile
.
ProfRangeAutoPush
(
'output'
)
model_execute_time
=
time
.
perf_counter
()
-
start_time
model_execute_time
=
time
.
perf_counter
()
-
start_time
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment