Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a657bfc4
Unverified
Commit
a657bfc4
authored
May 01, 2024
by
Nick Hill
Committed by
GitHub
May 01, 2024
Browse files
[Core] Add `multiproc_worker_utils` for multiprocessing-based workers (#4357)
parent
24750f4c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
440 additions
and
0 deletions
+440
-0
tests/engine/test_multiproc_workers.py
tests/engine/test_multiproc_workers.py
+176
-0
vllm/executor/multiproc_worker_utils.py
vllm/executor/multiproc_worker_utils.py
+264
-0
No files found.
tests/engine/test_multiproc_workers.py
0 → 100644
View file @
a657bfc4
import
asyncio
from
concurrent.futures
import
ThreadPoolExecutor
from
functools
import
partial
from
time
import
sleep
from
typing
import
Any
,
List
,
Tuple
import
pytest
from
vllm.executor.multiproc_worker_utils
import
(
ProcessWorkerWrapper
,
ResultHandler
,
WorkerMonitor
)
class
DummyWorker
:
"""Dummy version of vllm.worker.worker.Worker"""
def
__init__
(
self
,
rank
:
int
):
self
.
rank
=
rank
def
worker_method
(
self
,
worker_input
:
Any
)
->
Tuple
[
int
,
Any
]:
sleep
(
0.05
)
if
isinstance
(
worker_input
,
Exception
):
# simulate error case
raise
worker_input
return
self
.
rank
,
input
def
_start_workers
()
->
Tuple
[
List
[
ProcessWorkerWrapper
],
WorkerMonitor
]:
result_handler
=
ResultHandler
()
workers
=
[
ProcessWorkerWrapper
(
result_handler
,
partial
(
DummyWorker
,
rank
=
rank
))
for
rank
in
range
(
8
)
]
worker_monitor
=
WorkerMonitor
(
workers
,
result_handler
)
assert
not
worker_monitor
.
is_alive
()
result_handler
.
start
()
worker_monitor
.
start
()
assert
worker_monitor
.
is_alive
()
return
workers
,
worker_monitor
def
test_local_workers
()
->
None
:
"""Test workers with sync task submission"""
workers
,
worker_monitor
=
_start_workers
()
def
execute_workers
(
worker_input
:
str
)
->
None
:
worker_outputs
=
[
worker
.
execute_method
(
"worker_method"
,
worker_input
)
for
worker
in
workers
]
for
rank
,
output
in
enumerate
(
worker_outputs
):
assert
output
.
get
()
==
(
rank
,
input
)
executor
=
ThreadPoolExecutor
(
max_workers
=
4
)
# Test concurrent submission from different threads
futures
=
[
executor
.
submit
(
partial
(
execute_workers
,
f
"thread
{
thread_num
}
"
))
for
thread_num
in
range
(
4
)
]
for
future
in
futures
:
future
.
result
()
# Test error case
exception
=
ValueError
(
"fake error"
)
result
=
workers
[
0
].
execute_method
(
"worker_method"
,
exception
)
try
:
result
.
get
()
pytest
.
fail
(
"task should have failed"
)
except
Exception
as
e
:
assert
isinstance
(
e
,
ValueError
)
assert
str
(
e
)
==
"fake error"
# Test cleanup when a worker fails
assert
worker_monitor
.
is_alive
()
workers
[
3
].
process
.
kill
()
# Other workers should get shut down here
worker_monitor
.
join
(
2
)
# Ensure everything is stopped
assert
not
worker_monitor
.
is_alive
()
assert
all
(
not
worker
.
process
.
is_alive
()
for
worker
in
workers
)
# Further attempts to submit tasks should fail
try
:
_result
=
workers
[
0
].
execute_method
(
"worker_method"
,
"test"
)
pytest
.
fail
(
"task should fail once workers have been shut down"
)
except
Exception
as
e
:
assert
isinstance
(
e
,
ChildProcessError
)
def
test_local_workers_clean_shutdown
()
->
None
:
"""Test clean shutdown"""
workers
,
worker_monitor
=
_start_workers
()
assert
worker_monitor
.
is_alive
()
assert
all
(
worker
.
process
.
is_alive
()
for
worker
in
workers
)
# Clean shutdown
worker_monitor
.
close
()
worker_monitor
.
join
(
5
)
# Ensure everything is stopped
assert
not
worker_monitor
.
is_alive
()
assert
all
(
not
worker
.
process
.
is_alive
()
for
worker
in
workers
)
# Further attempts to submit tasks should fail
try
:
_result
=
workers
[
0
].
execute_method
(
"worker_method"
,
"test"
)
pytest
.
fail
(
"task should fail once workers have been shut down"
)
except
Exception
as
e
:
assert
isinstance
(
e
,
ChildProcessError
)
@
pytest
.
mark
.
asyncio
async
def
test_local_workers_async
()
->
None
:
"""Test local workers with async task submission"""
workers
,
worker_monitor
=
_start_workers
()
async
def
execute_workers
(
worker_input
:
str
)
->
None
:
worker_coros
=
[
worker
.
execute_method_async
(
"worker_method"
,
worker_input
)
for
worker
in
workers
]
results
=
await
asyncio
.
gather
(
*
worker_coros
)
for
rank
,
result
in
enumerate
(
results
):
assert
result
==
(
rank
,
input
)
tasks
=
[
asyncio
.
create_task
(
execute_workers
(
f
"task
{
task_num
}
"
))
for
task_num
in
range
(
4
)
]
for
task
in
tasks
:
await
task
# Test error case
exception
=
ValueError
(
"fake error"
)
try
:
_result
=
await
workers
[
0
].
execute_method_async
(
"worker_method"
,
exception
)
pytest
.
fail
(
"task should have failed"
)
except
Exception
as
e
:
assert
isinstance
(
e
,
ValueError
)
assert
str
(
e
)
==
"fake error"
# Test cleanup when a worker fails
assert
worker_monitor
.
is_alive
()
workers
[
3
].
process
.
kill
()
# Other workers should get shut down here
worker_monitor
.
join
(
2
)
# Ensure everything is stopped
assert
not
worker_monitor
.
is_alive
()
assert
all
(
not
worker
.
process
.
is_alive
()
for
worker
in
workers
)
# Further attempts to submit tasks should fail
try
:
_result
=
await
workers
[
0
].
execute_method_async
(
"worker_method"
,
"test"
)
pytest
.
fail
(
"task should fail once workers have been shut down"
)
except
Exception
as
e
:
assert
isinstance
(
e
,
ChildProcessError
)
vllm/executor/multiproc_worker_utils.py
0 → 100644
View file @
a657bfc4
import
asyncio
import
multiprocessing
import
os
import
sys
import
threading
import
traceback
import
uuid
from
dataclasses
import
dataclass
from
multiprocessing
import
Queue
from
multiprocessing.connection
import
wait
from
multiprocessing.process
import
BaseProcess
from
typing
import
(
Any
,
Callable
,
Dict
,
Generic
,
List
,
Optional
,
TextIO
,
TypeVar
,
Union
)
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
T
=
TypeVar
(
'T'
)
_TERMINATE
=
"TERMINATE"
# sentinel
# ANSI color codes
CYAN
=
'
\033
[1;36m'
RESET
=
'
\033
[0;0m'
JOIN_TIMEOUT_S
=
2
# Use dedicated multiprocess context for workers.
# Both spawn and fork work
mp_method
=
os
.
getenv
(
"VLLM_WORKER_MULTIPROC_METHOD"
,
"spawn"
)
mp
=
multiprocessing
.
get_context
(
mp_method
)
@
dataclass
class
Result
(
Generic
[
T
]):
"""Result of task dispatched to worker"""
task_id
:
uuid
.
UUID
value
:
Optional
[
T
]
=
None
exception
:
Optional
[
BaseException
]
=
None
class
ResultFuture
(
threading
.
Event
,
Generic
[
T
]):
"""Synchronous future for non-async case"""
def
__init__
(
self
):
super
().
__init__
()
self
.
result
:
Optional
[
Result
[
T
]]
=
None
def
set_result
(
self
,
result
:
Result
[
T
]):
self
.
result
=
result
self
.
set
()
def
get
(
self
)
->
T
:
self
.
wait
()
assert
self
.
result
is
not
None
if
self
.
result
.
exception
is
not
None
:
raise
self
.
result
.
exception
return
self
.
result
.
value
# type: ignore[return-value]
def
_set_future_result
(
future
:
Union
[
ResultFuture
,
asyncio
.
Future
],
result
:
Result
):
if
isinstance
(
future
,
ResultFuture
):
future
.
set_result
(
result
)
return
loop
=
future
.
get_loop
()
if
result
.
exception
is
not
None
:
loop
.
call_soon_threadsafe
(
future
.
set_exception
,
result
.
exception
)
else
:
loop
.
call_soon_threadsafe
(
future
.
set_result
,
result
.
value
)
class
ResultHandler
(
threading
.
Thread
):
"""Handle results from all workers (in background thread)"""
def
__init__
(
self
)
->
None
:
super
().
__init__
(
daemon
=
True
)
self
.
result_queue
=
mp
.
Queue
()
self
.
tasks
:
Dict
[
uuid
.
UUID
,
Union
[
ResultFuture
,
asyncio
.
Future
]]
=
{}
def
run
(
self
):
for
result
in
iter
(
self
.
result_queue
.
get
,
_TERMINATE
):
future
=
self
.
tasks
.
pop
(
result
.
task_id
)
_set_future_result
(
future
,
result
)
# Ensure that all waiters will receive an exception
for
task_id
,
future
in
self
.
tasks
.
items
():
_set_future_result
(
future
,
Result
(
task_id
=
task_id
,
exception
=
ChildProcessError
(
"worker died"
)))
def
close
(
self
):
self
.
result_queue
.
put
(
_TERMINATE
)
class
WorkerMonitor
(
threading
.
Thread
):
"""Monitor worker status (in background thread)"""
def
__init__
(
self
,
workers
:
List
[
'ProcessWorkerWrapper'
],
result_handler
:
ResultHandler
):
super
().
__init__
(
daemon
=
True
)
self
.
workers
=
workers
self
.
result_handler
=
result_handler
self
.
_close
=
False
def
run
(
self
)
->
None
:
# Blocks until any worker exits
dead_sentinels
=
wait
([
w
.
process
.
sentinel
for
w
in
self
.
workers
])
if
not
self
.
_close
:
self
.
_close
=
True
# Kill / cleanup all workers
for
worker
in
self
.
workers
:
process
=
worker
.
process
if
process
.
sentinel
in
dead_sentinels
:
process
.
join
(
JOIN_TIMEOUT_S
)
if
process
.
exitcode
is
not
None
and
process
.
exitcode
!=
0
:
logger
.
error
(
"Worker %s pid %s died, exit code: %s"
,
process
.
name
,
process
.
pid
,
process
.
exitcode
)
# Cleanup any remaining workers
logger
.
info
(
"Killing local vLLM worker processes"
)
for
worker
in
self
.
workers
:
worker
.
kill_worker
()
# Must be done after worker task queues are all closed
self
.
result_handler
.
close
()
for
worker
in
self
.
workers
:
worker
.
process
.
join
(
JOIN_TIMEOUT_S
)
def
close
(
self
):
if
self
.
_close
:
return
self
.
_close
=
True
logger
.
info
(
"Terminating local vLLM worker processes"
)
for
worker
in
self
.
workers
:
worker
.
terminate_worker
()
# Must be done after worker task queues are all closed
self
.
result_handler
.
close
()
class
ProcessWorkerWrapper
:
"""Local process wrapper for vllm.worker.Worker,
for handling single-node multi-GPU tensor parallel."""
def
__init__
(
self
,
result_handler
:
ResultHandler
,
worker_factory
:
Callable
[[],
Any
])
->
None
:
self
.
_task_queue
=
mp
.
Queue
()
self
.
result_queue
=
result_handler
.
result_queue
self
.
tasks
=
result_handler
.
tasks
self
.
process
:
BaseProcess
=
mp
.
Process
(
# type: ignore[attr-defined]
target
=
_run_worker_process
,
name
=
"VllmWorkerProcess"
,
kwargs
=
dict
(
worker_factory
=
worker_factory
,
task_queue
=
self
.
_task_queue
,
result_queue
=
self
.
result_queue
,
),
daemon
=
True
)
self
.
process
.
start
()
def
_enqueue_task
(
self
,
future
:
Union
[
ResultFuture
,
asyncio
.
Future
],
method
:
str
,
args
,
kwargs
):
task_id
=
uuid
.
uuid4
()
self
.
tasks
[
task_id
]
=
future
try
:
self
.
_task_queue
.
put
((
task_id
,
method
,
args
,
kwargs
))
except
BaseException
as
e
:
del
self
.
tasks
[
task_id
]
raise
ChildProcessError
(
"worker died"
)
from
e
def
execute_method
(
self
,
method
:
str
,
*
args
,
**
kwargs
):
future
:
ResultFuture
=
ResultFuture
()
self
.
_enqueue_task
(
future
,
method
,
args
,
kwargs
)
return
future
async
def
execute_method_async
(
self
,
method
:
str
,
*
args
,
**
kwargs
):
future
=
asyncio
.
get_running_loop
().
create_future
()
self
.
_enqueue_task
(
future
,
method
,
args
,
kwargs
)
return
await
future
def
terminate_worker
(
self
):
try
:
self
.
_task_queue
.
put
(
_TERMINATE
)
except
ValueError
:
self
.
process
.
kill
()
self
.
_task_queue
.
close
()
def
kill_worker
(
self
):
self
.
_task_queue
.
close
()
self
.
process
.
kill
()
def
_run_worker_process
(
worker_factory
:
Callable
[[],
Any
],
task_queue
:
Queue
,
result_queue
:
Queue
,
)
->
None
:
"""Worker process event loop"""
# Add process-specific prefix to stdout and stderr
process_name
=
mp
.
current_process
().
name
pid
=
os
.
getpid
()
_add_prefix
(
sys
.
stdout
,
process_name
,
pid
)
_add_prefix
(
sys
.
stderr
,
process_name
,
pid
)
# Initialize worker
worker
=
worker_factory
()
del
worker_factory
# Accept tasks from the engine in task_queue
# and return task output in result_queue
logger
.
info
(
"Worker ready; awaiting tasks"
)
try
:
for
items
in
iter
(
task_queue
.
get
,
_TERMINATE
):
output
=
None
exception
=
None
task_id
,
method
,
args
,
kwargs
=
items
try
:
executor
=
getattr
(
worker
,
method
)
output
=
executor
(
*
args
,
**
kwargs
)
except
BaseException
as
e
:
tb
=
traceback
.
format_exc
()
logger
.
error
(
"Exception in worker %s while processing method %s: %s, %s"
,
process_name
,
method
,
e
,
tb
)
exception
=
e
result_queue
.
put
(
Result
(
task_id
=
task_id
,
value
=
output
,
exception
=
exception
))
except
KeyboardInterrupt
:
pass
except
Exception
:
logger
.
exception
(
"Worker failed"
)
logger
.
info
(
"Worker exiting"
)
def
_add_prefix
(
file
:
TextIO
,
worker_name
:
str
,
pid
:
int
)
->
None
:
"""Prepend each output line with process-specific prefix"""
prefix
=
f
"
{
CYAN
}
(
{
worker_name
}
pid=
{
pid
}
)
{
RESET
}
"
file_write
=
file
.
write
def
write_with_prefix
(
s
:
str
):
if
not
s
:
return
if
file
.
start_new_line
:
# type: ignore[attr-defined]
file_write
(
prefix
)
idx
=
0
while
(
next_idx
:
=
s
.
find
(
'
\n
'
,
idx
))
!=
-
1
:
next_idx
+=
1
file_write
(
s
[
idx
:
next_idx
])
if
next_idx
==
len
(
s
):
file
.
start_new_line
=
True
# type: ignore[attr-defined]
return
file_write
(
prefix
)
idx
=
next_idx
file_write
(
s
[
idx
:])
file
.
start_new_line
=
False
# type: ignore[attr-defined]
file
.
start_new_line
=
True
# type: ignore[attr-defined]
file
.
write
=
write_with_prefix
# type: ignore[method-assign]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment