Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7ed82d19
Unverified
Commit
7ed82d19
authored
Sep 20, 2025
by
Woosuk Kwon
Committed by
GitHub
Sep 20, 2025
Browse files
[V0 Deprecation] Remove V0 MP executor (#25329)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
12dbd834
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
33 additions
and
530 deletions
+33
-530
vllm/executor/mp_distributed_executor.py
vllm/executor/mp_distributed_executor.py
+0
-244
vllm/executor/multiproc_worker_utils.py
vllm/executor/multiproc_worker_utils.py
+0
-279
vllm/v1/executor/multiproc_executor.py
vllm/v1/executor/multiproc_executor.py
+33
-7
No files found.
vllm/executor/mp_distributed_executor.py
deleted
100644 → 0
View file @
12dbd834
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
asyncio
import
os
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Union
import
cloudpickle
from
vllm.executor.executor_base
import
DistributedExecutorBase
from
vllm.executor.multiproc_worker_utils
import
(
ProcessWorkerWrapper
,
ResultHandler
,
WorkerMonitor
,
set_multiprocessing_worker_envs
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
(
_run_task_with_lock
,
cuda_device_count_stateless
,
get_distributed_init_method
,
get_ip
,
get_open_port
,
make_async
,
run_method
,
update_environment_variables
)
from
vllm.worker.worker_base
import
WorkerWrapperBase
logger
=
init_logger
(
__name__
)
class
MultiprocessingDistributedExecutor
(
DistributedExecutorBase
):
"""Python multiprocessing-based distributed executor"""
uses_ray
:
bool
=
False
def
_check_cuda
(
self
)
->
None
:
"""Check that the number of GPUs is sufficient for the parallel
configuration. Separate from _init_executor to reduce the number of
indented blocks.
"""
parallel_config
=
self
.
parallel_config
world_size
=
parallel_config
.
world_size
tensor_parallel_size
=
parallel_config
.
tensor_parallel_size
cuda_device_count
=
cuda_device_count_stateless
()
# Use confusing message for more common TP-only case.
if
tensor_parallel_size
>
cuda_device_count
:
raise
RuntimeError
(
f
"please set tensor_parallel_size (
{
tensor_parallel_size
}
) "
f
"to less than max local gpu count (
{
cuda_device_count
}
)"
)
if
world_size
>
cuda_device_count
:
raise
RuntimeError
(
f
"please ensure that world_size (
{
world_size
}
) "
f
"is less than than max local gpu count (
{
cuda_device_count
}
)"
)
# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
if
"CUDA_VISIBLE_DEVICES"
not
in
os
.
environ
:
update_environment_variables
({
"CUDA_VISIBLE_DEVICES"
:
(
","
.
join
(
map
(
str
,
range
(
world_size
))))
})
def
_init_executor
(
self
)
->
None
:
from
vllm.platforms
import
current_platform
if
current_platform
.
is_cuda_alike
():
self
.
_check_cuda
()
# Create the parallel GPU workers.
world_size
=
self
.
parallel_config
.
world_size
tensor_parallel_size
=
self
.
parallel_config
.
tensor_parallel_size
# Set multiprocessing envs that are common to V0 and V1
set_multiprocessing_worker_envs
(
self
.
parallel_config
)
# Multiprocessing-based executor does not support multi-node setting.
# Since it only works for single node, we can use the loopback address
# 127.0.0.1 for communication.
distributed_init_method
=
get_distributed_init_method
(
"127.0.0.1"
,
get_open_port
())
self
.
workers
:
List
[
ProcessWorkerWrapper
]
=
[]
# This is the list of workers that are rank 0 of each TP group EXCEPT
# global rank 0. These are the workers that will broadcast to the
# rest of the workers.
self
.
tp_driver_workers
:
List
[
ProcessWorkerWrapper
]
=
[]
# This is the list of workers that are not drivers and not the first
# worker in a TP group. These are the workers that will be
# broadcasted to.
self
.
non_driver_workers
:
List
[
ProcessWorkerWrapper
]
=
[]
if
world_size
==
1
:
self
.
worker_monitor
=
None
else
:
result_handler
=
ResultHandler
()
for
rank
in
range
(
1
,
world_size
):
worker
=
ProcessWorkerWrapper
(
result_handler
,
WorkerWrapperBase
,
self
.
vllm_config
,
rank
)
self
.
workers
.
append
(
worker
)
if
rank
%
tensor_parallel_size
==
0
:
self
.
tp_driver_workers
.
append
(
worker
)
else
:
self
.
non_driver_workers
.
append
(
worker
)
self
.
worker_monitor
=
WorkerMonitor
(
self
.
workers
,
result_handler
)
result_handler
.
start
()
self
.
worker_monitor
.
start
()
# Set up signal handlers to shut down the executor cleanly
# sometimes gc does not work well
self
.
driver_worker
=
WorkerWrapperBase
(
self
.
vllm_config
,
0
)
all_kwargs
=
[]
distributed_init_method
=
get_distributed_init_method
(
get_ip
(),
get_open_port
())
for
i
in
range
(
world_size
):
local_rank
=
i
rank
=
i
kwargs
=
dict
(
vllm_config
=
self
.
vllm_config
,
local_rank
=
local_rank
,
rank
=
rank
,
distributed_init_method
=
distributed_init_method
,
is_driver_worker
=
(
not
self
.
parallel_config
)
or
(
rank
%
self
.
parallel_config
.
tensor_parallel_size
==
0
),
)
all_kwargs
.
append
(
kwargs
)
self
.
_run_workers
(
"init_worker"
,
all_kwargs
)
self
.
_run_workers
(
"init_device"
)
self
.
_run_workers
(
"load_model"
,
max_concurrent_workers
=
self
.
parallel_config
.
max_parallel_loading_workers
)
self
.
driver_exec_model
=
make_async
(
self
.
driver_worker
.
execute_model
)
self
.
pp_locks
:
Optional
[
List
[
asyncio
.
Lock
]]
=
None
def
shutdown
(
self
):
if
(
worker_monitor
:
=
getattr
(
self
,
"worker_monitor"
,
None
))
is
not
None
:
worker_monitor
.
close
()
def
_driver_execute_model
(
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
)
->
Optional
[
List
[
SamplerOutput
]]:
"""Run execute_model in the driver worker.
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
return
self
.
driver_worker
.
execute_model
(
execute_model_req
)
def
_run_workers
(
self
,
method
:
Union
[
str
,
Callable
],
*
args
,
async_run_tensor_parallel_workers_only
:
bool
=
False
,
max_concurrent_workers
:
Optional
[
int
]
=
None
,
**
kwargs
,
)
->
List
[
Any
]:
"""Runs the given method on all workers.
Args:
async_run_tensor_parallel_workers_only: If True the method will be
run only in the remote TP workers, not the driver worker.
It will also be run asynchronously and return a list of futures
rather than blocking on the results.
"""
if
isinstance
(
method
,
str
):
sent_method
=
method
else
:
sent_method
=
cloudpickle
.
dumps
(
method
)
del
method
if
max_concurrent_workers
:
raise
NotImplementedError
(
"max_concurrent_workers is not supported yet."
)
if
async_run_tensor_parallel_workers_only
:
# Run only non-driver workers and just return futures.
return
[
worker
.
execute_method
(
sent_method
,
*
args
,
**
kwargs
)
for
worker
in
self
.
non_driver_workers
]
# Start all remote workers first.
worker_outputs
=
[
worker
.
execute_method
(
sent_method
,
*
args
,
**
kwargs
)
for
worker
in
self
.
workers
]
driver_worker_output
=
run_method
(
self
.
driver_worker
,
sent_method
,
args
,
kwargs
)
# Get the results of the workers.
return
[
driver_worker_output
]
+
[
output
.
get
()
for
output
in
worker_outputs
]
def
check_health
(
self
)
->
None
:
"""Raises an error if engine is unhealthy."""
if
self
.
worker_monitor
is
not
None
and
not
self
.
worker_monitor
.
is_alive
(
):
raise
RuntimeError
(
"Worker processes are not running"
)
def
_wait_for_tasks_completion
(
self
,
parallel_worker_tasks
:
Any
)
->
None
:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
for
result
in
parallel_worker_tasks
:
result
.
get
()
async
def
_driver_execute_model_async
(
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
)
->
List
[
SamplerOutput
]:
if
not
self
.
tp_driver_workers
:
return
await
self
.
driver_exec_model
(
execute_model_req
)
if
self
.
pp_locks
is
None
:
# This locks each pipeline parallel stage so multiple virtual
# engines can't execute on the same stage at the same time
# We create the locks here to avoid creating them in the constructor
# which uses a different asyncio loop.
self
.
pp_locks
=
[
asyncio
.
Lock
()
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
tasks
=
[
asyncio
.
create_task
(
_run_task_with_lock
(
self
.
driver_exec_model
,
self
.
pp_locks
[
0
],
execute_model_req
))
]
for
pp_rank
,
driver_worker
in
enumerate
(
self
.
tp_driver_workers
,
start
=
1
):
tasks
.
append
(
asyncio
.
create_task
(
_run_task_with_lock
(
driver_worker
.
execute_method_async
,
self
.
pp_locks
[
pp_rank
],
"execute_model"
,
execute_model_req
)))
results
=
await
asyncio
.
gather
(
*
tasks
)
# Only the last PP stage has the final results.
return
results
[
-
1
]
async
def
_start_worker_execution_loop
(
self
):
coros
=
[
worker
.
execute_method_async
(
"start_worker_execution_loop"
)
for
worker
in
self
.
non_driver_workers
]
return
await
asyncio
.
gather
(
*
coros
)
vllm/executor/multiproc_worker_utils.py
deleted
100644 → 0
View file @
12dbd834
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
asyncio
import
os
import
threading
import
uuid
from
dataclasses
import
dataclass
from
multiprocessing
import
Queue
from
multiprocessing.connection
import
wait
from
multiprocessing.process
import
BaseProcess
from
typing
import
Any
,
Callable
,
Dict
,
Generic
,
List
,
Optional
,
TypeVar
,
Union
import
torch
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.utils
import
(
_maybe_force_spawn
,
decorate_logs
,
get_mp_context
,
run_method
)
logger
=
init_logger
(
__name__
)
T
=
TypeVar
(
'T'
)
_TERMINATE
=
"TERMINATE"
# sentinel
JOIN_TIMEOUT_S
=
2
@
dataclass
class
Result
(
Generic
[
T
]):
"""Result of task dispatched to worker"""
task_id
:
uuid
.
UUID
value
:
Optional
[
T
]
=
None
exception
:
Optional
[
BaseException
]
=
None
class
ResultFuture
(
threading
.
Event
,
Generic
[
T
]):
"""Synchronous future for non-async case"""
def
__init__
(
self
):
super
().
__init__
()
self
.
result
:
Optional
[
Result
[
T
]]
=
None
def
set_result
(
self
,
result
:
Result
[
T
]):
self
.
result
=
result
self
.
set
()
def
get
(
self
)
->
T
:
self
.
wait
()
assert
self
.
result
is
not
None
if
self
.
result
.
exception
is
not
None
:
raise
self
.
result
.
exception
return
self
.
result
.
value
# type: ignore[return-value]
def
_set_future_result
(
future
:
Union
[
ResultFuture
,
asyncio
.
Future
],
result
:
Result
):
if
isinstance
(
future
,
ResultFuture
):
future
.
set_result
(
result
)
return
loop
=
future
.
get_loop
()
if
not
loop
.
is_closed
():
if
result
.
exception
is
not
None
:
loop
.
call_soon_threadsafe
(
future
.
set_exception
,
result
.
exception
)
else
:
loop
.
call_soon_threadsafe
(
future
.
set_result
,
result
.
value
)
class
ResultHandler
(
threading
.
Thread
):
"""Handle results from all workers (in background thread)"""
def
__init__
(
self
)
->
None
:
super
().
__init__
(
daemon
=
True
)
self
.
result_queue
=
get_mp_context
().
Queue
()
self
.
tasks
:
Dict
[
uuid
.
UUID
,
Union
[
ResultFuture
,
asyncio
.
Future
]]
=
{}
def
run
(
self
):
for
result
in
iter
(
self
.
result_queue
.
get
,
_TERMINATE
):
future
=
self
.
tasks
.
pop
(
result
.
task_id
)
_set_future_result
(
future
,
result
)
# Ensure that all waiters will receive an exception
for
task_id
,
future
in
self
.
tasks
.
items
():
_set_future_result
(
future
,
Result
(
task_id
=
task_id
,
exception
=
ChildProcessError
(
"worker died"
)))
def
close
(
self
):
self
.
result_queue
.
put
(
_TERMINATE
)
class
WorkerMonitor
(
threading
.
Thread
):
"""Monitor worker status (in background thread)"""
def
__init__
(
self
,
workers
:
List
[
'ProcessWorkerWrapper'
],
result_handler
:
ResultHandler
):
super
().
__init__
(
daemon
=
True
)
self
.
workers
=
workers
self
.
result_handler
=
result_handler
self
.
_close
=
False
def
run
(
self
)
->
None
:
# Blocks until any worker exits
dead_sentinels
=
wait
([
w
.
process
.
sentinel
for
w
in
self
.
workers
])
if
not
self
.
_close
:
self
.
_close
=
True
# Kill / cleanup all workers
for
worker
in
self
.
workers
:
process
=
worker
.
process
if
process
.
sentinel
in
dead_sentinels
:
process
.
join
(
JOIN_TIMEOUT_S
)
if
process
.
exitcode
is
not
None
and
process
.
exitcode
!=
0
:
logger
.
error
(
"Worker %s pid %s died, exit code: %s"
,
process
.
name
,
process
.
pid
,
process
.
exitcode
)
# Cleanup any remaining workers
if
logger
:
logger
.
info
(
"Killing local vLLM worker processes"
)
for
worker
in
self
.
workers
:
worker
.
kill_worker
()
# Must be done after worker task queues are all closed
self
.
result_handler
.
close
()
for
worker
in
self
.
workers
:
worker
.
process
.
join
(
JOIN_TIMEOUT_S
)
def
close
(
self
):
if
self
.
_close
:
return
self
.
_close
=
True
logger
.
info
(
"Terminating local vLLM worker processes"
)
for
worker
in
self
.
workers
:
worker
.
terminate_worker
()
# Must be done after worker task queues are all closed
self
.
result_handler
.
close
()
class
ProcessWorkerWrapper
:
"""Local process wrapper for vllm.worker.Worker,
for handling single-node multi-GPU tensor parallel."""
def
__init__
(
self
,
result_handler
:
ResultHandler
,
worker_factory
:
Callable
[[
VllmConfig
,
int
],
Any
],
vllm_config
:
VllmConfig
,
rank
:
int
)
->
None
:
self
.
mp
=
get_mp_context
()
self
.
_task_queue
=
self
.
mp
.
Queue
()
self
.
result_queue
=
result_handler
.
result_queue
self
.
tasks
=
result_handler
.
tasks
self
.
process
:
BaseProcess
=
self
.
mp
.
Process
(
# type: ignore[attr-defined]
target
=
_run_worker_process
,
name
=
"VllmWorkerProcess"
,
kwargs
=
dict
(
worker_factory
=
worker_factory
,
task_queue
=
self
.
_task_queue
,
result_queue
=
self
.
result_queue
,
vllm_config
=
vllm_config
,
rank
=
rank
,
),
daemon
=
True
)
self
.
process
.
start
()
def
_enqueue_task
(
self
,
future
:
Union
[
ResultFuture
,
asyncio
.
Future
],
method
:
Union
[
str
,
bytes
],
args
,
kwargs
):
task_id
=
uuid
.
uuid4
()
self
.
tasks
[
task_id
]
=
future
try
:
self
.
_task_queue
.
put
((
task_id
,
method
,
args
,
kwargs
))
except
SystemExit
:
raise
except
BaseException
as
e
:
del
self
.
tasks
[
task_id
]
raise
ChildProcessError
(
"worker died"
)
from
e
def
execute_method
(
self
,
method
:
Union
[
str
,
bytes
],
*
args
,
**
kwargs
):
future
:
ResultFuture
=
ResultFuture
()
self
.
_enqueue_task
(
future
,
method
,
args
,
kwargs
)
return
future
async
def
execute_method_async
(
self
,
method
:
Union
[
str
,
bytes
],
*
args
,
**
kwargs
):
future
=
asyncio
.
get_running_loop
().
create_future
()
self
.
_enqueue_task
(
future
,
method
,
args
,
kwargs
)
return
await
future
def
terminate_worker
(
self
):
try
:
self
.
_task_queue
.
put
(
_TERMINATE
)
except
ValueError
:
self
.
process
.
kill
()
self
.
_task_queue
.
close
()
def
kill_worker
(
self
):
self
.
_task_queue
.
close
()
self
.
process
.
kill
()
def
_run_worker_process
(
worker_factory
:
Callable
[[
VllmConfig
,
int
],
Any
],
task_queue
:
Queue
,
result_queue
:
Queue
,
vllm_config
:
VllmConfig
,
rank
:
int
,
)
->
None
:
"""Worker process event loop"""
# Add process-specific prefix to stdout and stderr
process_name
=
get_mp_context
().
current_process
().
name
decorate_logs
(
process_name
)
# Initialize worker
worker
=
worker_factory
(
vllm_config
,
rank
)
del
worker_factory
# Accept tasks from the engine in task_queue
# and return task output in result_queue
logger
.
info
(
"Worker ready; awaiting tasks"
)
try
:
for
items
in
iter
(
task_queue
.
get
,
_TERMINATE
):
output
=
None
exception
=
None
task_id
,
method
,
args
,
kwargs
=
items
try
:
output
=
run_method
(
worker
,
method
,
args
,
kwargs
)
except
SystemExit
:
raise
except
KeyboardInterrupt
:
break
except
BaseException
as
e
:
logger
.
exception
(
"Exception in worker %s while processing method %s."
,
process_name
,
method
)
exception
=
e
result_queue
.
put
(
Result
(
task_id
=
task_id
,
value
=
output
,
exception
=
exception
))
except
KeyboardInterrupt
:
pass
except
Exception
:
logger
.
exception
(
"Worker failed"
)
# Flush TunableOp results when TunableOp is enabled and
# online (in situ) tuning is enabled.
# Offline tuning API (record_untuned_is_enabled()) only
# available in PyTorch 2.6 or later.
if
torch
.
cuda
.
is_available
():
import
torch.cuda.tunable
as
tunable
if
(
tunable
.
is_enabled
()
and
tunable
.
tuning_is_enabled
()
and
not
tunable
.
record_untuned_is_enabled
()):
tunable
.
write_file
()
logger
.
info
(
"Worker exiting"
)
def
set_multiprocessing_worker_envs
(
parallel_config
):
""" Set up environment variables that should be used when there are workers
in a multiprocessing environment. This should be called by the parent
process before worker processes are created"""
_maybe_force_spawn
()
# Configure thread parallelism if OMP_NUM_THREADS isn't set
#
# Helps to avoid CPU contention. The default of spawning a thread per
# core combined with multiprocessing for each GPU can have a negative
# impact on performance. The contention is amplified when running in a
# container where CPU limits can cause throttling.
default_omp_num_threads
=
1
if
"OMP_NUM_THREADS"
not
in
os
.
environ
and
(
current_parallelism
:
=
torch
.
get_num_threads
())
>
default_omp_num_threads
:
logger
.
warning
(
"Reducing Torch parallelism from %d threads to %d to avoid "
"unnecessary CPU contention. Set OMP_NUM_THREADS in the "
"external environment to tune this value as needed."
,
current_parallelism
,
default_omp_num_threads
)
os
.
environ
[
"OMP_NUM_THREADS"
]
=
str
(
default_omp_num_threads
)
torch
.
set_num_threads
(
default_omp_num_threads
)
vllm/v1/executor/multiproc_executor.py
View file @
7ed82d19
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
multiprocessing
import
multiprocessing
import
os
import
pickle
import
pickle
import
queue
import
queue
import
signal
import
signal
...
@@ -19,6 +20,7 @@ from threading import Thread
...
@@ -19,6 +20,7 @@ from threading import Thread
from
typing
import
Any
,
Callable
,
Optional
,
Union
,
cast
from
typing
import
Any
,
Callable
,
Optional
,
Union
,
cast
import
cloudpickle
import
cloudpickle
import
torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
...
@@ -28,14 +30,12 @@ from vllm.distributed.device_communicators.shm_broadcast import (Handle,
...
@@ -28,14 +30,12 @@ from vllm.distributed.device_communicators.shm_broadcast import (Handle,
MessageQueue
)
MessageQueue
)
from
vllm.distributed.parallel_state
import
(
get_dp_group
,
get_ep_group
,
from
vllm.distributed.parallel_state
import
(
get_dp_group
,
get_ep_group
,
get_pp_group
,
get_tp_group
)
get_pp_group
,
get_tp_group
)
from
vllm.executor.multiproc_worker_utils
import
(
set_multiprocessing_worker_envs
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.cache
import
worker_receiver_cache_from_config
from
vllm.multimodal.cache
import
worker_receiver_cache_from_config
from
vllm.utils
import
(
decorate_logs
,
get_distributed_init_method
,
from
vllm.utils
import
(
_maybe_force_spawn
,
decorate_logs
,
get_
loopback_ip
,
get_mp_context
,
get_open_port
,
get_
distributed_init_method
,
get_loopback_ip
,
set_process_title
)
get_mp_context
,
get_open_port
,
set_process_title
)
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.executor.abstract
import
Executor
,
FailureCallback
from
vllm.v1.executor.abstract
import
Executor
,
FailureCallback
from
vllm.v1.executor.utils
import
get_and_update_mm_cache
from
vllm.v1.executor.utils
import
get_and_update_mm_cache
...
@@ -67,8 +67,8 @@ class MultiprocExecutor(Executor):
...
@@ -67,8 +67,8 @@ class MultiprocExecutor(Executor):
f
"tensor_parallel_size (
{
tensor_parallel_size
}
) x pipeline"
f
"tensor_parallel_size (
{
tensor_parallel_size
}
) x pipeline"
f
"_parallel_size (
{
pp_parallel_size
}
). "
)
f
"_parallel_size (
{
pp_parallel_size
}
). "
)
# Set multiprocessing envs
that are common to V0 and V1
# Set multiprocessing envs
set_multiprocessing_worker_envs
(
self
.
parallel_config
)
set_multiprocessing_worker_envs
()
# Multiprocessing-based executor does not support multi-node setting.
# Multiprocessing-based executor does not support multi-node setting.
# Since it only works for single node, we can use the loopback address
# Since it only works for single node, we can use the loopback address
...
@@ -698,3 +698,29 @@ class WorkerProc:
...
@@ -698,3 +698,29 @@ class WorkerProc:
process_name
+=
f
"_EP
{
ep_rank
}
"
process_name
+=
f
"_EP
{
ep_rank
}
"
set_process_title
(
name
=
process_name
)
set_process_title
(
name
=
process_name
)
decorate_logs
(
process_name
)
decorate_logs
(
process_name
)
def
set_multiprocessing_worker_envs
():
""" Set up environment variables that should be used when there are workers
in a multiprocessing environment. This should be called by the parent
process before worker processes are created"""
_maybe_force_spawn
()
# Configure thread parallelism if OMP_NUM_THREADS isn't set
#
# Helps to avoid CPU contention. The default of spawning a thread per
# core combined with multiprocessing for each GPU can have a negative
# impact on performance. The contention is amplified when running in a
# container where CPU limits can cause throttling.
default_omp_num_threads
=
1
if
"OMP_NUM_THREADS"
not
in
os
.
environ
and
(
current_parallelism
:
=
torch
.
get_num_threads
())
>
default_omp_num_threads
:
logger
.
warning
(
"Reducing Torch parallelism from %d threads to %d to avoid "
"unnecessary CPU contention. Set OMP_NUM_THREADS in the "
"external environment to tune this value as needed."
,
current_parallelism
,
default_omp_num_threads
)
os
.
environ
[
"OMP_NUM_THREADS"
]
=
str
(
default_omp_num_threads
)
torch
.
set_num_threads
(
default_omp_num_threads
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment