Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c08e2b30
Unverified
Commit
c08e2b30
authored
Aug 11, 2024
by
William Lin
Committed by
GitHub
Aug 11, 2024
Browse files
[core] [2/N] refactor worker_base input preparation for multi-step (#7387)
parent
4fb7b52a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
63 additions
and
31 deletions
+63
-31
vllm/worker/worker.py
vllm/worker/worker.py
+2
-0
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+61
-31
No files found.
vllm/worker/worker.py
View file @
c08e2b30
...
...
@@ -264,6 +264,7 @@ class Worker(LocalOrDistributedWorkerBase):
def
prepare_worker_input
(
self
,
execute_model_req
:
ExecuteModelRequest
)
->
WorkerInput
:
virtual_engine
=
execute_model_req
.
virtual_engine
num_steps
=
execute_model_req
.
num_steps
num_seq_groups
=
len
(
execute_model_req
.
seq_group_metadata_list
)
# `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
# they contain parameters to launch cudamemcpyasync.
...
...
@@ -286,6 +287,7 @@ class Worker(LocalOrDistributedWorkerBase):
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
virtual_engine
=
virtual_engine
,
num_steps
=
num_steps
,
)
@
torch
.
inference_mode
()
...
...
vllm/worker/worker_base.py
View file @
c08e2b30
...
...
@@ -129,6 +129,7 @@ class WorkerInput:
blocks_to_swap_out
:
Optional
[
torch
.
Tensor
]
=
None
blocks_to_copy
:
Optional
[
torch
.
Tensor
]
=
None
virtual_engine
:
int
=
0
num_steps
:
int
=
1
@
classmethod
def
from_broadcasted_tensor_dict
(
...
...
@@ -145,6 +146,7 @@ class WorkerInput:
blocks_to_swap_out
=
tensor_dict
.
pop
(
"blocks_to_swap_out"
),
blocks_to_copy
=
tensor_dict
.
pop
(
"blocks_to_copy"
),
virtual_engine
=
tensor_dict
[
"virtual_engine"
],
num_steps
=
tensor_dict
.
pop
(
"num_steps"
),
)
def
as_broadcastable_tensor_dict
(
...
...
@@ -158,6 +160,7 @@ class WorkerInput:
"blocks_to_swap_out"
:
self
.
blocks_to_swap_out
,
"blocks_to_copy"
:
self
.
blocks_to_copy
,
"virtual_engine"
:
self
.
virtual_engine
,
"num_steps"
:
self
.
num_steps
,
}
return
tensor_dict
...
...
@@ -216,24 +219,28 @@ class LocalOrDistributedWorkerBase(WorkerBase):
"""
raise
NotImplementedError
def
execute_model
(
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
)
->
Optional
[
List
[
SamplerOutput
]]:
"""Executes at least one model step on the given sequences, unless no
sequences are provided."""
start_time
=
time
.
perf_counter
()
if
self
.
is_driver_worker
:
if
execute_model_req
is
None
:
if
self
.
do_metadata_broadcast
:
# This signals that there's no more requests to process for
# now. All workers are running infinite loop with
# broadcast_tensor_dict, and it stops the loop when the
# driver broadcasts an empty input. Send an empty input to
# notify all other workers to stop their execution loop.
broadcast_tensor_dict
({},
src
=
0
)
def
_get_worker_input_from_broadcast
(
self
)
->
Optional
[
Tuple
[
ModelRunnerInputBase
,
WorkerInput
]]:
""" Get the worker input from the broadcasted tensor dict. """
assert
self
.
do_metadata_broadcast
assert
not
self
.
is_driver_worker
broadcast_data
=
broadcast_tensor_dict
(
src
=
0
)
if
not
broadcast_data
:
return
None
worker_input
=
WorkerInput
.
from_broadcasted_tensor_dict
(
broadcast_data
)
model_input
=
(
self
.
model_runner
.
make_model_input_from_broadcasted_tensor_dict
(
broadcast_data
))
return
model_input
,
worker_input
def
_get_driver_input_and_broadcast
(
self
,
execute_model_req
:
ExecuteModelRequest
)
->
Tuple
[
ModelRunnerInputBase
,
WorkerInput
]:
""" Get the driver input and broadcast it to other workers. """
assert
self
.
is_driver_worker
worker_input
:
WorkerInput
=
self
.
prepare_worker_input
(
execute_model_req
=
execute_model_req
)
model_input
:
ModelRunnerInputBase
=
(
...
...
@@ -241,26 +248,49 @@ class LocalOrDistributedWorkerBase(WorkerBase):
execute_model_req
.
seq_group_metadata_list
,
execute_model_req
.
virtual_engine
,
execute_model_req
.
finished_requests_ids
))
num_steps
=
execute_model_req
.
num_steps
if
self
.
do_metadata_broadcast
:
broadcast_data
=
worker_input
.
as_broadcastable_tensor_dict
()
broadcast_data
.
update
(
model_input
.
as_broadcastable_tensor_dict
())
broadcast_data
[
"num_steps"
]
=
num_steps
broadcast_data
.
update
(
model_input
.
as_broadcastable_tensor_dict
())
broadcast_tensor_dict
(
broadcast_data
,
src
=
0
)
return
model_input
,
worker_input
def
prepare_input
(
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
)
->
Optional
[
Tuple
[
ModelRunnerInputBase
,
WorkerInput
]]:
"""
Prepare the inputs to ModelRunner and workers.
"""
if
self
.
is_driver_worker
:
if
execute_model_req
is
None
:
if
self
.
do_metadata_broadcast
:
# This signals that there's no more requests to process for
# now. All workers are running infinite loop with
# broadcast_tensor_dict, and it stops the loop when the
# driver broadcasts an empty input. Send an empty input to
# notify all other workers to stop their execution loop.
broadcast_tensor_dict
({},
src
=
0
)
return
None
return
self
.
_get_driver_input_and_broadcast
(
execute_model_req
)
else
:
assert
self
.
do_metadata_broadcast
broadcast_data
=
broadcast_tensor_dict
(
src
=
0
)
if
not
broadcast_data
:
return
self
.
_get_worker_input_from_broadcast
()
def
execute_model
(
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
)
->
Optional
[
List
[
SamplerOutput
]]:
"""Executes at least one model step on the given sequences, unless no
sequences are provided."""
start_time
=
time
.
perf_counter
()
inputs
=
self
.
prepare_input
(
execute_model_req
)
if
inputs
is
None
:
return
None
num_steps
=
broadcast_data
.
pop
(
"num_steps"
)
worker_input
=
WorkerInput
.
from_broadcasted_tensor_dict
(
broadcast_data
)
model_input
=
(
self
.
model_runner
.
make_model_input_from_broadcasted_tensor_dict
(
broadcast_data
))
model_input
,
worker_input
=
inputs
num_steps
=
worker_input
.
num_steps
self
.
execute_worker
(
worker_input
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment