Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
eb24dc4a
Unverified
Commit
eb24dc4a
authored
Feb 23, 2025
by
youkaichao
Committed by
GitHub
Feb 23, 2025
Browse files
[v1] torchrun compatibility (#13642)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
9bebc951
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
67 additions
and
24 deletions
+67
-24
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+1
-0
tests/distributed/test_torchrun_example.py
tests/distributed/test_torchrun_example.py
+6
-0
tests/v1/engine/test_engine_core.py
tests/v1/engine/test_engine_core.py
+4
-2
vllm/config.py
vllm/config.py
+5
-0
vllm/executor/ray_distributed_executor.py
vllm/executor/ray_distributed_executor.py
+1
-1
vllm/executor/ray_utils.py
vllm/executor/ray_utils.py
+3
-1
vllm/executor/uniproc_executor.py
vllm/executor/uniproc_executor.py
+4
-3
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+1
-1
vllm/v1/engine/llm_engine.py
vllm/v1/engine/llm_engine.py
+7
-2
vllm/v1/executor/abstract.py
vllm/v1/executor/abstract.py
+17
-3
vllm/v1/executor/multiproc_executor.py
vllm/v1/executor/multiproc_executor.py
+3
-2
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+3
-4
vllm/v1/worker/tpu_worker.py
vllm/v1/worker/tpu_worker.py
+3
-3
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+9
-2
No files found.
.buildkite/test-pipeline.yaml
View file @
eb24dc4a
...
@@ -503,6 +503,7 @@ steps:
...
@@ -503,6 +503,7 @@ steps:
-
entrypoints/llm/test_collective_rpc.py
-
entrypoints/llm/test_collective_rpc.py
commands
:
commands
:
-
pytest -v -s entrypoints/llm/test_collective_rpc.py
-
pytest -v -s entrypoints/llm/test_collective_rpc.py
-
VLLM_USE_V1=1 torchrun --nproc-per-node=2 distributed/test_torchrun_example.py
-
torchrun --nproc-per-node=2 distributed/test_torchrun_example.py
-
torchrun --nproc-per-node=2 distributed/test_torchrun_example.py
-
pytest -v -s ./compile/test_basic_correctness.py
-
pytest -v -s ./compile/test_basic_correctness.py
-
pytest -v -s ./compile/test_wrapper.py
-
pytest -v -s ./compile/test_wrapper.py
...
...
tests/distributed/test_torchrun_example.py
View file @
eb24dc4a
...
@@ -48,6 +48,12 @@ test_consistent_across_ranks(
...
@@ -48,6 +48,12 @@ test_consistent_across_ranks(
test_consistent_across_ranks
(
test_consistent_across_ranks
(
llm
.
llm_engine
.
vllm_config
.
cache_config
.
num_gpu_blocks
)
llm
.
llm_engine
.
vllm_config
.
cache_config
.
num_gpu_blocks
)
# make sure we can access the model parameters from the calling process
# of the `LLM` instance.
params
=
list
(
llm
.
llm_engine
.
model_executor
.
driver_worker
.
worker
.
model_runner
.
model
.
parameters
())
test_consistent_across_ranks
(
len
(
params
))
# all ranks should have the same outputs
# all ranks should have the same outputs
for
output
in
outputs
:
for
output
in
outputs
:
prompt
=
output
.
prompt
prompt
=
output
.
prompt
...
...
tests/v1/engine/test_engine_core.py
View file @
eb24dc4a
...
@@ -5,6 +5,7 @@ import threading
...
@@ -5,6 +5,7 @@ import threading
import
time
import
time
import
uuid
import
uuid
from
concurrent.futures
import
Future
from
concurrent.futures
import
Future
from
typing
import
List
import
pytest
import
pytest
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
...
@@ -211,8 +212,9 @@ def test_engine_core_concurrent_batches(monkeypatch):
...
@@ -211,8 +212,9 @@ def test_engine_core_concurrent_batches(monkeypatch):
class
DummyExecutor
(
UniProcExecutor
):
class
DummyExecutor
(
UniProcExecutor
):
def
initialize
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
def
initialize_from_config
(
super
().
initialize
(
kv_cache_config
)
self
,
kv_cache_configs
:
List
[
KVCacheConfig
])
->
None
:
super
().
initialize_from_config
(
kv_cache_configs
)
# This executor actually can only run 1 batch at a time
# This executor actually can only run 1 batch at a time
self
.
semaphore
=
threading
.
Semaphore
(
1
)
self
.
semaphore
=
threading
.
Semaphore
(
1
)
...
...
vllm/config.py
View file @
eb24dc4a
...
@@ -1407,6 +1407,11 @@ class ParallelConfig:
...
@@ -1407,6 +1407,11 @@ class ParallelConfig:
self
.
data_parallel_master_port
=
envs
.
VLLM_DP_MASTER_PORT
self
.
data_parallel_master_port
=
envs
.
VLLM_DP_MASTER_PORT
self
.
world_size_across_dp
=
self
.
world_size
*
self
.
data_parallel_size
self
.
world_size_across_dp
=
self
.
world_size
*
self
.
data_parallel_size
if
self
.
distributed_executor_backend
==
"external_launcher"
:
import
os
os
.
environ
[
"VLLM_ENABLE_V1_MULTIPROCESSING"
]
=
"0"
logger
.
info
(
"Disabling V1 multiprocessing for external launcher."
)
ray_only_devices
=
[
"tpu"
]
ray_only_devices
=
[
"tpu"
]
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
if
(
current_platform
.
device_type
in
ray_only_devices
if
(
current_platform
.
device_type
in
ray_only_devices
...
...
vllm/executor/ray_distributed_executor.py
View file @
eb24dc4a
...
@@ -541,7 +541,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
...
@@ -541,7 +541,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
# and the TP group executes in SPMD fashion.
# and the TP group executes in SPMD fashion.
if
self
.
use_v1
:
if
self
.
use_v1
:
outputs
=
[
outputs
=
[
worker
.
execute_model
.
worker
.
execute_model
_ray
.
bind
(
# type: ignore[attr-defined]
bind
(
# type: ignore[attr-defined]
outputs
[
i
])
for
i
,
worker
in
enumerate
(
tp_group
)
outputs
[
i
])
for
i
,
worker
in
enumerate
(
tp_group
)
]
]
...
...
vllm/executor/ray_utils.py
View file @
eb24dc4a
...
@@ -112,10 +112,12 @@ try:
...
@@ -112,10 +112,12 @@ try:
torch
.
cuda
.
set_device
(
self
.
worker
.
device
)
torch
.
cuda
.
set_device
(
self
.
worker
.
device
)
self
.
compiled_dag_cuda_device_set
=
True
self
.
compiled_dag_cuda_device_set
=
True
def
execute_model
(
def
execute_model
_ray
(
self
,
self
,
scheduler_output
:
"SchedulerOutput"
,
scheduler_output
:
"SchedulerOutput"
,
)
->
"ModelRunnerOutput"
:
)
->
"ModelRunnerOutput"
:
# this method is used to compile ray CG,
# and it needs a special logic of self.setup_device_if_necessary()
self
.
setup_device_if_necessary
()
self
.
setup_device_if_necessary
()
assert
self
.
worker
is
not
None
,
"Worker is not initialized"
assert
self
.
worker
is
not
None
,
"Worker is not initialized"
if
isinstance
(
scheduler_output
,
tuple
):
if
isinstance
(
scheduler_output
,
tuple
):
...
...
vllm/executor/uniproc_executor.py
View file @
eb24dc4a
...
@@ -93,9 +93,10 @@ class ExecutorWithExternalLauncher(UniProcExecutor):
...
@@ -93,9 +93,10 @@ class ExecutorWithExternalLauncher(UniProcExecutor):
(
"ExecutorWithExternalLauncher needs deterministic "
(
"ExecutorWithExternalLauncher needs deterministic "
"execution, so it"
"execution, so it"
"does not support delay_factor in scheduling"
)
"does not support delay_factor in scheduling"
)
assert
not
envs
.
VLLM_USE_V1
,
\
if
envs
.
VLLM_USE_V1
:
(
"V1 architecture cannot guarantee deterministic execution, "
assert
not
envs
.
VLLM_ENABLE_V1_MULTIPROCESSING
,
\
"so it is not supported in ExecutorWithExternalLauncher."
)
(
"To get deterministic execution in V1, "
"please set VLLM_ENABLE_V1_MULTIPROCESSING=0"
)
self
.
driver_worker
=
WorkerWrapperBase
(
vllm_config
=
self
.
vllm_config
,
self
.
driver_worker
=
WorkerWrapperBase
(
vllm_config
=
self
.
vllm_config
,
rpc_rank
=
0
)
rpc_rank
=
0
)
# engines are launched in torchrun-compatible launchers
# engines are launched in torchrun-compatible launchers
...
...
vllm/v1/engine/core.py
View file @
eb24dc4a
...
@@ -110,7 +110,7 @@ class EngineCore:
...
@@ -110,7 +110,7 @@ class EngineCore:
num_cpu_blocks
=
0
num_cpu_blocks
=
0
# Initialize kv cache and warmup the execution
# Initialize kv cache and warmup the execution
self
.
model_executor
.
initialize
(
kv_cache_configs
)
self
.
model_executor
.
initialize
_from_config
(
kv_cache_configs
)
elapsed
=
time
.
time
()
-
start
elapsed
=
time
.
time
()
-
start
logger
.
info
((
"init engine (profile, create kv cache, "
logger
.
info
((
"init engine (profile, create kv cache, "
...
...
vllm/v1/engine/llm_engine.py
View file @
eb24dc4a
...
@@ -4,10 +4,10 @@ from typing import Dict, List, Mapping, Optional, Type, Union
...
@@ -4,10 +4,10 @@ from typing import Dict, List, Mapping, Optional, Type, Union
from
typing_extensions
import
TypeVar
from
typing_extensions
import
TypeVar
import
vllm.envs
as
envs
from
vllm.config
import
ParallelConfig
,
VllmConfig
from
vllm.config
import
ParallelConfig
,
VllmConfig
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.metrics_types
import
StatLoggerBase
from
vllm.engine.metrics_types
import
StatLoggerBase
from
vllm.envs
import
VLLM_ENABLE_V1_MULTIPROCESSING
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
,
PromptType
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
,
PromptType
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
...
@@ -44,6 +44,7 @@ class LLMEngine:
...
@@ -44,6 +44,7 @@ class LLMEngine:
use_cached_outputs
:
bool
=
False
,
use_cached_outputs
:
bool
=
False
,
multiprocess_mode
:
bool
=
False
,
multiprocess_mode
:
bool
=
False
,
)
->
None
:
)
->
None
:
self
.
vllm_config
=
vllm_config
self
.
model_config
=
vllm_config
.
model_config
self
.
model_config
=
vllm_config
.
model_config
self
.
cache_config
=
vllm_config
.
cache_config
self
.
cache_config
=
vllm_config
.
cache_config
...
@@ -83,6 +84,10 @@ class LLMEngine:
...
@@ -83,6 +84,10 @@ class LLMEngine:
log_stats
=
False
,
# FIXME: implement
log_stats
=
False
,
# FIXME: implement
)
)
if
not
multiprocess_mode
:
# for v0 compatibility
self
.
model_executor
=
self
.
engine_core
.
engine_core
.
model_executor
# type: ignore
@
classmethod
@
classmethod
def
from_engine_args
(
def
from_engine_args
(
cls
,
cls
,
...
@@ -97,7 +102,7 @@ class LLMEngine:
...
@@ -97,7 +102,7 @@ class LLMEngine:
vllm_config
=
engine_args
.
create_engine_config
(
usage_context
)
vllm_config
=
engine_args
.
create_engine_config
(
usage_context
)
executor_class
=
Executor
.
get_class
(
vllm_config
)
executor_class
=
Executor
.
get_class
(
vllm_config
)
if
VLLM_ENABLE_V1_MULTIPROCESSING
:
if
envs
.
VLLM_ENABLE_V1_MULTIPROCESSING
:
logger
.
debug
(
"Enabling multiprocessing for LLMEngine."
)
logger
.
debug
(
"Enabling multiprocessing for LLMEngine."
)
enable_multiprocessing
=
True
enable_multiprocessing
=
True
...
...
vllm/v1/executor/abstract.py
View file @
eb24dc4a
...
@@ -3,6 +3,9 @@
...
@@ -3,6 +3,9 @@
from
concurrent.futures
import
Future
from
concurrent.futures
import
Future
from
typing
import
List
,
Type
,
Union
from
typing
import
List
,
Type
,
Union
import
torch
import
torch.distributed
as
dist
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.executor.uniproc_executor
import
(
# noqa
from
vllm.executor.uniproc_executor
import
(
# noqa
...
@@ -49,12 +52,14 @@ class Executor(ExecutorBase):
...
@@ -49,12 +52,14 @@ class Executor(ExecutorBase):
f
"
{
distributed_executor_backend
}
"
)
f
"
{
distributed_executor_backend
}
"
)
return
executor_class
return
executor_class
def
initialize
(
self
,
kv_cache_configs
:
List
[
KVCacheConfig
])
->
None
:
def
initialize_from_config
(
self
,
kv_cache_configs
:
List
[
KVCacheConfig
])
->
None
:
"""
"""
Initialize the KV caches and begin the model execution loop of the
Initialize the KV caches and begin the model execution loop of the
underlying workers.
underlying workers.
"""
"""
self
.
collective_rpc
(
"initialize_cache"
,
args
=
(
kv_cache_configs
,
))
self
.
collective_rpc
(
"initialize_from_config"
,
args
=
(
kv_cache_configs
,
))
self
.
collective_rpc
(
"compile_or_warm_up_model"
)
self
.
collective_rpc
(
"compile_or_warm_up_model"
)
def
determine_available_memory
(
self
)
->
int
:
# in bytes
def
determine_available_memory
(
self
)
->
int
:
# in bytes
...
@@ -89,4 +94,13 @@ class UniProcExecutor(UniProcExecutorV0, Executor):
...
@@ -89,4 +94,13 @@ class UniProcExecutor(UniProcExecutorV0, Executor):
class
ExecutorWithExternalLauncher
(
ExecutorWithExternalLauncherV0
,
Executor
):
class
ExecutorWithExternalLauncher
(
ExecutorWithExternalLauncherV0
,
Executor
):
pass
def
determine_available_memory
(
self
)
->
int
:
# in bytes
# same as determine_num_available_blocks in v0,
# we need to get the min across all ranks.
memory
=
super
().
determine_available_memory
()
from
vllm.distributed.parallel_state
import
get_world_group
cpu_group
=
get_world_group
().
cpu_group
memory_tensor
=
torch
.
tensor
([
memory
],
device
=
"cpu"
,
dtype
=
torch
.
int64
)
dist
.
all_reduce
(
memory_tensor
,
group
=
cpu_group
,
op
=
dist
.
ReduceOp
.
MIN
)
return
memory_tensor
.
item
()
vllm/v1/executor/multiproc_executor.py
View file @
eb24dc4a
...
@@ -216,9 +216,10 @@ class WorkerProc:
...
@@ -216,9 +216,10 @@ class WorkerProc:
"local_rank"
:
local_rank
,
"local_rank"
:
local_rank
,
"rank"
:
rank
,
"rank"
:
rank
,
"distributed_init_method"
:
distributed_init_method
,
"distributed_init_method"
:
distributed_init_method
,
"is_driver_worker"
:
rank
==
0
,
}
}
wrapper
.
init_worker
(
all_kwargs
)
wrapper
.
init_worker
(
all_kwargs
)
self
.
worker
=
wrapper
.
worker
self
.
worker
=
wrapper
pid
=
os
.
getpid
()
pid
=
os
.
getpid
()
_add_prefix
(
sys
.
stdout
,
f
"VllmWorker rank=
{
rank
}
"
,
pid
)
_add_prefix
(
sys
.
stdout
,
f
"VllmWorker rank=
{
rank
}
"
,
pid
)
...
@@ -239,7 +240,7 @@ class WorkerProc:
...
@@ -239,7 +240,7 @@ class WorkerProc:
ready_socket
.
send_string
(
WorkerProc
.
READY_STR
)
ready_socket
.
send_string
(
WorkerProc
.
READY_STR
)
ready_socket
.
send
(
payload
)
ready_socket
.
send
(
payload
)
wrapp
er
.
init_device
()
self
.
work
er
.
init_device
()
self
.
worker
.
load_model
()
self
.
worker
.
load_model
()
@
staticmethod
@
staticmethod
...
...
vllm/v1/worker/gpu_worker.py
View file @
eb24dc4a
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
"""A GPU worker class."""
"""A GPU worker class."""
import
gc
import
gc
import
os
import
os
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Optional
import
torch
import
torch
import
torch.distributed
import
torch.distributed
...
@@ -185,9 +185,8 @@ class Worker(WorkerBase):
...
@@ -185,9 +185,8 @@ class Worker(WorkerBase):
def
get_kv_cache_spec
(
self
)
->
KVCacheSpec
:
def
get_kv_cache_spec
(
self
)
->
KVCacheSpec
:
return
self
.
model_runner
.
get_kv_cache_spec
()
return
self
.
model_runner
.
get_kv_cache_spec
()
def
initialize_
cache
(
self
,
kv_cache_config
s
:
List
[
KVCacheConfig
]
)
->
None
:
def
initialize_
from_config
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
"""Allocate GPU KV cache with the specified kv_cache_config."""
"""Allocate GPU KV cache with the specified kv_cache_config."""
kv_cache_config
=
kv_cache_configs
[
self
.
rank
]
if
self
.
vllm_config
.
model_config
.
enable_sleep_mode
:
if
self
.
vllm_config
.
model_config
.
enable_sleep_mode
:
allocator
=
CuMemAllocator
.
get_instance
()
allocator
=
CuMemAllocator
.
get_instance
()
context
=
allocator
.
use_memory_pool
(
tag
=
"kv_cache"
)
context
=
allocator
.
use_memory_pool
(
tag
=
"kv_cache"
)
...
@@ -225,7 +224,7 @@ class Worker(WorkerBase):
...
@@ -225,7 +224,7 @@ class Worker(WorkerBase):
scheduler_output
:
"SchedulerOutput"
,
scheduler_output
:
"SchedulerOutput"
,
)
->
Optional
[
ModelRunnerOutput
]:
)
->
Optional
[
ModelRunnerOutput
]:
output
=
self
.
model_runner
.
execute_model
(
scheduler_output
)
output
=
self
.
model_runner
.
execute_model
(
scheduler_output
)
return
output
if
self
.
rank
==
0
else
None
return
output
if
self
.
is_driver_worker
else
None
def
profile
(
self
,
is_start
:
bool
=
True
):
def
profile
(
self
,
is_start
:
bool
=
True
):
if
self
.
profiler
is
None
:
if
self
.
profiler
is
None
:
...
...
vllm/v1/worker/tpu_worker.py
View file @
eb24dc4a
...
@@ -36,6 +36,7 @@ class TPUWorker:
...
@@ -36,6 +36,7 @@ class TPUWorker:
distributed_init_method
:
str
,
distributed_init_method
:
str
,
is_driver_worker
:
bool
=
False
,
is_driver_worker
:
bool
=
False
,
):
):
self
.
is_driver_worker
=
is_driver_worker
self
.
vllm_config
=
vllm_config
self
.
vllm_config
=
vllm_config
self
.
model_config
=
vllm_config
.
model_config
self
.
model_config
=
vllm_config
.
model_config
self
.
cache_config
=
vllm_config
.
cache_config
self
.
cache_config
=
vllm_config
.
cache_config
...
@@ -151,7 +152,7 @@ class TPUWorker:
...
@@ -151,7 +152,7 @@ class TPUWorker:
scheduler_output
:
"SchedulerOutput"
,
scheduler_output
:
"SchedulerOutput"
,
)
->
Optional
[
ModelRunnerOutput
]:
)
->
Optional
[
ModelRunnerOutput
]:
output
=
self
.
model_runner
.
execute_model
(
scheduler_output
)
output
=
self
.
model_runner
.
execute_model
(
scheduler_output
)
return
output
if
self
.
rank
==
0
else
None
return
output
if
self
.
is_driver_worker
else
None
def
load_model
(
self
)
->
None
:
def
load_model
(
self
)
->
None
:
self
.
model_runner
.
load_model
()
self
.
model_runner
.
load_model
()
...
@@ -170,9 +171,8 @@ class TPUWorker:
...
@@ -170,9 +171,8 @@ class TPUWorker:
def
get_kv_cache_spec
(
self
)
->
KVCacheSpec
:
def
get_kv_cache_spec
(
self
)
->
KVCacheSpec
:
return
self
.
model_runner
.
get_kv_cache_spec
()
return
self
.
model_runner
.
get_kv_cache_spec
()
def
initialize_
cache
(
self
,
kv_cache_config
s
:
List
[
KVCacheConfig
]
)
->
None
:
def
initialize_
from_config
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
"""Allocate GPU KV cache with the specified kv_cache_config."""
"""Allocate GPU KV cache with the specified kv_cache_config."""
kv_cache_config
=
kv_cache_configs
[
self
.
rank
]
self
.
model_runner
.
initialize_kv_cache
(
kv_cache_config
)
self
.
model_runner
.
initialize_kv_cache
(
kv_cache_config
)
def
check_health
(
self
)
->
None
:
def
check_health
(
self
)
->
None
:
...
...
vllm/worker/worker_base.py
View file @
eb24dc4a
...
@@ -567,6 +567,10 @@ class WorkerWrapperBase:
...
@@ -567,6 +567,10 @@ class WorkerWrapperBase:
self
.
worker
=
worker_class
(
**
kwargs
)
self
.
worker
=
worker_class
(
**
kwargs
)
assert
self
.
worker
is
not
None
assert
self
.
worker
is
not
None
def
initialize_from_config
(
self
,
kv_cache_configs
:
List
[
Any
])
->
None
:
kv_cache_config
=
kv_cache_configs
[
self
.
rpc_rank
]
self
.
worker
.
initialize_from_config
(
kv_cache_config
)
# type: ignore
def
init_device
(
self
):
def
init_device
(
self
):
with
set_current_vllm_config
(
self
.
vllm_config
):
with
set_current_vllm_config
(
self
.
vllm_config
):
# To make vLLM config available during device initialization
# To make vLLM config available during device initialization
...
@@ -574,8 +578,11 @@ class WorkerWrapperBase:
...
@@ -574,8 +578,11 @@ class WorkerWrapperBase:
def
execute_method
(
self
,
method
:
Union
[
str
,
bytes
],
*
args
,
**
kwargs
):
def
execute_method
(
self
,
method
:
Union
[
str
,
bytes
],
*
args
,
**
kwargs
):
try
:
try
:
target
=
self
if
self
.
worker
is
None
else
self
.
worker
# method resolution order:
return
run_method
(
target
,
method
,
args
,
kwargs
)
# if a method is defined in this class, it will be called directly.
# otherwise, since we define `__getattr__` and redirect attribute
# query to `self.worker`, the method will be called on the worker.
return
run_method
(
self
,
method
,
args
,
kwargs
)
except
Exception
as
e
:
except
Exception
as
e
:
# if the driver worker also execute methods,
# if the driver worker also execute methods,
# exceptions in the rest worker may cause deadlock in rpc like ray
# exceptions in the rest worker may cause deadlock in rpc like ray
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment