Commit 55c719cb authored by 王敏's avatar 王敏
Browse files

[feat]支持ray分布式异步调度,VLLM_ENABLE_RAY_ASYNC_SCHEDULING环境变量控制

parent 8001970c
...@@ -559,11 +559,19 @@ class VllmConfig: ...@@ -559,11 +559,19 @@ class VllmConfig:
) )
executor_backend = self.parallel_config.distributed_executor_backend executor_backend = self.parallel_config.distributed_executor_backend
executor_supports_async_sched = executor_backend in ( if envs.VLLM_ENABLE_RAY_ASYNC_SCHEDULING:
"mp", executor_supports_async_sched = executor_backend in (
"uni", "mp",
"external_launcher", "uni",
) "external_launcher",
"ray"
)
else:
executor_supports_async_sched = executor_backend in (
"mp",
"uni",
"external_launcher"
)
if self.scheduler_config.async_scheduling: if self.scheduler_config.async_scheduling:
# Async scheduling explicitly enabled, hard fail any incompatibilities. # Async scheduling explicitly enabled, hard fail any incompatibilities.
......
...@@ -310,6 +310,7 @@ if TYPE_CHECKING: ...@@ -310,6 +310,7 @@ if TYPE_CHECKING:
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: bool = False VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: bool = False
VLLM_USE_CUDA_GRAPH_SIZES: bool = False VLLM_USE_CUDA_GRAPH_SIZES: bool = False
VLLM_USE_LIGHTOP_FUSED_TOPP_TOPK: bool = False VLLM_USE_LIGHTOP_FUSED_TOPP_TOPK: bool = False
VLLM_ENABLE_RAY_ASYNC_SCHEDULING: bool = False
def get_default_cache_root(): def get_default_cache_root():
...@@ -1940,6 +1941,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1940,6 +1941,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_LIGHTOP_FUSED_TOPP_TOPK": "VLLM_USE_LIGHTOP_FUSED_TOPP_TOPK":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_FUSED_TOPP_TOPK", "False").lower() in lambda: (os.environ.get("VLLM_USE_LIGHTOP_FUSED_TOPP_TOPK", "False").lower() in
("true", "1")), ("true", "1")),
#If set to 1/True, enable async scheduling in ray distribute mode
"VLLM_ENABLE_RAY_ASYNC_SCHEDULING":
lambda: (os.environ.get("VLLM_ENABLE_RAY_ASYNC_SCHEDULING", "False").lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -2,11 +2,12 @@ ...@@ -2,11 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os import os
from collections import defaultdict from collections import defaultdict, deque
from collections.abc import Callable from collections.abc import Callable, Sequence
from concurrent.futures import Future from concurrent.futures import Future
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from functools import partial
import cloudpickle import cloudpickle
...@@ -24,6 +25,7 @@ from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType ...@@ -24,6 +25,7 @@ from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.executor.ray_utils import ( from vllm.v1.executor.ray_utils import (
FutureWrapper, FutureWrapper,
NonBlockFutureWrapper,
RayWorkerWrapper, RayWorkerWrapper,
initialize_ray_cluster, initialize_ray_cluster,
ray, ray,
...@@ -33,6 +35,8 @@ from vllm.v1.outputs import ModelRunnerOutput ...@@ -33,6 +35,8 @@ from vllm.v1.outputs import ModelRunnerOutput
if ray is not None: if ray is not None:
from ray.actor import ActorHandle from ray.actor import ActorHandle
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from ray.util.queue import Queue as RayQueue
from ray.util.queue import Empty as EmptyError
else: else:
ActorHandle = None ActorHandle = None
...@@ -84,6 +88,9 @@ class RayDistributedExecutor(Executor): ...@@ -84,6 +88,9 @@ class RayDistributedExecutor(Executor):
if current_platform.is_tpu() or current_platform.is_xpu(): if current_platform.is_tpu() or current_platform.is_xpu():
os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm" os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm"
# KV connector setup
self.has_connector = self.vllm_config.kv_transfer_config is not None
assert self.uses_ray assert self.uses_ray
initialize_ray_cluster(self.parallel_config) initialize_ray_cluster(self.parallel_config)
placement_group = self.parallel_config.placement_group placement_group = self.parallel_config.placement_group
...@@ -96,9 +103,6 @@ class RayDistributedExecutor(Executor): ...@@ -96,9 +103,6 @@ class RayDistributedExecutor(Executor):
# Create the parallel GPU workers. # Create the parallel GPU workers.
self._init_workers_ray(placement_group) self._init_workers_ray(placement_group)
# KV connector setup
self.has_connector = self.vllm_config.kv_transfer_config is not None
self.uses_sampler = self.vllm_config.model_config.runner_type != "pooling" and ( self.uses_sampler = self.vllm_config.model_config.runner_type != "pooling" and (
self.vllm_config.ec_transfer_config is None self.vllm_config.ec_transfer_config is None
or not self.vllm_config.ec_transfer_config.is_ec_producer or not self.vllm_config.ec_transfer_config.is_ec_producer
...@@ -164,6 +168,8 @@ class RayDistributedExecutor(Executor): ...@@ -164,6 +168,8 @@ class RayDistributedExecutor(Executor):
# the TP group of workers for a PP rank. # the TP group of workers for a PP rank.
self.pp_tp_workers: list[list[RayWorkerWrapper]] = [] self.pp_tp_workers: list[list[RayWorkerWrapper]] = []
self.output_rank = self._get_output_rank()
if self.parallel_config.ray_workers_use_nsight: if self.parallel_config.ray_workers_use_nsight:
ray_remote_kwargs = self._configure_ray_workers_use_nsight( ray_remote_kwargs = self._configure_ray_workers_use_nsight(
ray_remote_kwargs ray_remote_kwargs
...@@ -193,6 +199,9 @@ class RayDistributedExecutor(Executor): ...@@ -193,6 +199,9 @@ class RayDistributedExecutor(Executor):
worker_metadata: list[RayWorkerMetaData] = [] worker_metadata: list[RayWorkerMetaData] = []
driver_ip = get_ip() driver_ip = get_ip()
self.response_mqs = [None] * len(bundle_indices)
response_mqs_tmp = [None] * len(bundle_indices)
for rank, bundle_id in enumerate(bundle_indices): for rank, bundle_id in enumerate(bundle_indices):
scheduling_strategy = PlacementGroupSchedulingStrategy( scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group, placement_group=placement_group,
...@@ -200,6 +209,12 @@ class RayDistributedExecutor(Executor): ...@@ -200,6 +209,12 @@ class RayDistributedExecutor(Executor):
placement_group_bundle_index=bundle_id, placement_group_bundle_index=bundle_id,
) )
# use queue to implement actor worker response output in async scheduling mode
response_mq = None
if self.scheduler_config.async_scheduling:
response_mq = RayQueue(maxsize=256)
response_mqs_tmp[rank] = response_mq
if current_platform.ray_device_key == "GPU": if current_platform.ray_device_key == "GPU":
# NV+AMD GPUs, and Intel XPUs # NV+AMD GPUs, and Intel XPUs
worker = ray.remote( worker = ray.remote(
...@@ -207,7 +222,8 @@ class RayDistributedExecutor(Executor): ...@@ -207,7 +222,8 @@ class RayDistributedExecutor(Executor):
num_gpus=num_gpus, num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy, scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs, **ray_remote_kwargs,
)(RayWorkerWrapper).remote(rpc_rank=rank) )(RayWorkerWrapper).remote(use_async_scheduling=self.scheduler_config.async_scheduling,
response_mq=response_mq, rpc_rank=rank)
else: else:
worker = ray.remote( worker = ray.remote(
num_cpus=0, num_cpus=0,
...@@ -215,7 +231,8 @@ class RayDistributedExecutor(Executor): ...@@ -215,7 +231,8 @@ class RayDistributedExecutor(Executor):
resources={current_platform.ray_device_key: num_gpus}, resources={current_platform.ray_device_key: num_gpus},
scheduling_strategy=scheduling_strategy, scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs, **ray_remote_kwargs,
)(RayWorkerWrapper).remote(rpc_rank=rank) )(RayWorkerWrapper).remote(use_async_scheduling=self.scheduler_config.async_scheduling,
response_mq=response_mq, rpc_rank=rank)
worker_metadata.append(RayWorkerMetaData(worker=worker, created_rank=rank)) worker_metadata.append(RayWorkerMetaData(worker=worker, created_rank=rank))
...@@ -261,7 +278,10 @@ class RayDistributedExecutor(Executor): ...@@ -261,7 +278,10 @@ class RayDistributedExecutor(Executor):
rerank_mapping = { rerank_mapping = {
item.created_rank: item.adjusted_rank for item in sorted_worker_metadata item.created_rank: item.adjusted_rank for item in sorted_worker_metadata
} }
self.collective_rpc("adjust_rank", args=(rerank_mapping,)) self.collective_rpc("adjust_rank", args=(rerank_mapping, -1 if self.has_connector else self.output_rank))
for created_rank, adjusted_rank in rerank_mapping.items():
self.response_mqs[adjusted_rank] = response_mqs_tmp[created_rank]
# Get the set of GPU IDs used on each node. # Get the set of GPU IDs used on each node.
worker_node_and_gpu_ids = [] worker_node_and_gpu_ids = []
...@@ -376,6 +396,10 @@ class RayDistributedExecutor(Executor): ...@@ -376,6 +396,10 @@ class RayDistributedExecutor(Executor):
assert pp_rank < len(self.pp_tp_workers) assert pp_rank < len(self.pp_tp_workers)
self.pp_tp_workers[pp_rank].append(self.workers[rank]) self.pp_tp_workers[pp_rank].append(self.workers[rank])
if self.scheduler_config.async_scheduling:
self.futures_queue = deque[tuple[NonBlockFutureWrapper, Callable]]()
def reinitialize_distributed( def reinitialize_distributed(
self, reconfig_request: ReconfigureDistributedRequest self, reconfig_request: ReconfigureDistributedRequest
) -> None: ) -> None:
...@@ -442,24 +466,57 @@ class RayDistributedExecutor(Executor): ...@@ -442,24 +466,57 @@ class RayDistributedExecutor(Executor):
refs = self.forward_dag.execute((scheduler_output, grammar_output)) # type: ignore refs = self.forward_dag.execute((scheduler_output, grammar_output)) # type: ignore
if not self.has_connector: if not self.scheduler_config.async_scheduling:
# Get output only from a single worker (output_rank) if not self.has_connector:
# When PP is not used, we block here until the result is available. # Get output only from a single worker (output_rank)
if not non_block: # When PP is not used, we block here until the result is available.
return refs[0].get() if not non_block:
return refs[0].get()
# When PP is used, we return a FutureWrapper immediately so that # When PP is used, we return a FutureWrapper immediately so that
# the scheduler can yield to the next batch. # the scheduler can yield to the next batch.
return FutureWrapper(refs[0]) return FutureWrapper(refs[0])
# Get output from all workers when connector is present # Get output from all workers when connector is present
assert self.kv_output_aggregator is not None assert self.kv_output_aggregator is not None
if not non_block: if not non_block:
# Block and get results from all workers # Block and get results from all workers
return self.kv_output_aggregator.aggregate(ray.get(refs)) return self.kv_output_aggregator.aggregate(ray.get(refs))
# Return a future that will aggregate outputs from all workers # Return a future that will aggregate outputs from all workers
return FutureWrapper(refs, self.kv_output_aggregator) return FutureWrapper(refs, self.kv_output_aggregator)
else:
if self.has_connector:
aggregate: Callable[[Any], Any] = partial(
self.kv_output_aggregator.aggregate, output_rank= self.output_rank
)
else:
aggregate = lambda x: x
output_rank = self.output_rank if not self.has_connector else None
response_mqs: Sequence[RayQueue] = self.response_mqs
if not self.has_connector:
response_mqs = (response_mqs[self.output_rank],)
def get_response():
responses = []
for mq in response_mqs:
try:
status, result = mq.get(timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS)
except EmptyError as e:
raise TimeoutError(f"ray exec timed out.") from e
if status != RayWorkerWrapper.ResponseStatus.SUCCESS:
raise RuntimeError(
f"Worker failed with error '{result}', please check the"
" stack trace above for the root cause"
)
responses.append(result)
return responses[0] if output_rank is not None else responses
future = NonBlockFutureWrapper(self.futures_queue, aggregate=aggregate)
self.futures_queue.appendleft((future, get_response))
return future
def collective_rpc( # type: ignore[override] def collective_rpc( # type: ignore[override]
self, self,
...@@ -620,3 +677,19 @@ class RayDistributedExecutor(Executor): ...@@ -620,3 +677,19 @@ class RayDistributedExecutor(Executor):
# Assume that the Ray workers are healthy. # Assume that the Ray workers are healthy.
# TODO: check the health of the Ray workers # TODO: check the health of the Ray workers
return return
def _get_output_rank(self) -> int:
# Only returns ModelRunnerOutput from TP rank=0 and PP rank=-1
# (the first TP worker of the last PP stage).
# Example:
# Assuming TP=8, PP=4, then the world_size=32
# 0-7, PP rank 0
# 8-15, PP rank 1
# 16-23, PP rank 2
# 24-31, PP rank 3
# so world_size - tp_size = 32 - 8 = 24 should be PP rank = -1 (i.e. 3)
return (
self.parallel_config.world_size
- self.parallel_config.tensor_parallel_size
* self.parallel_config.prefill_context_parallel_size
)
...@@ -3,9 +3,14 @@ ...@@ -3,9 +3,14 @@
import os import os
import time import time
from collections import defaultdict import queue
from concurrent.futures import Future from collections import defaultdict, deque
from typing import TYPE_CHECKING, Union from collections.abc import Callable, Sequence
from concurrent.futures import Future, InvalidStateError
from typing import TYPE_CHECKING, Union, Any
from threading import Thread
from enum import Enum, auto
from contextlib import suppress
import vllm.platforms import vllm.platforms
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
...@@ -28,6 +33,7 @@ PG_WAIT_TIMEOUT = 1800 ...@@ -28,6 +33,7 @@ PG_WAIT_TIMEOUT = 1800
try: try:
import ray import ray
from ray.util import placement_group_table from ray.util import placement_group_table
from ray.util.queue import Queue as RayQueue
from ray.util.placement_group import PlacementGroup from ray.util.placement_group import PlacementGroup
try: try:
...@@ -42,7 +48,7 @@ try: ...@@ -42,7 +48,7 @@ try:
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be """Ray wrapper for vllm.worker.Worker, allowing Worker to be
lazily initialized after Ray sets CUDA_VISIBLE_DEVICES.""" lazily initialized after Ray sets CUDA_VISIBLE_DEVICES."""
def __init__(self, *args, **kwargs) -> None: def __init__(self, use_async_scheduling: bool, response_mq: RayQueue, *args, **kwargs) -> None: # type: ignore
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
# Since the compiled DAG runs a main execution # Since the compiled DAG runs a main execution
# in a different thread that calls cuda.set_device. # in a different thread that calls cuda.set_device.
...@@ -50,6 +56,22 @@ try: ...@@ -50,6 +56,22 @@ try:
# that thread. # that thread.
self.compiled_dag_cuda_device_set = False self.compiled_dag_cuda_device_set = False
# async scheduling
self.use_async_scheduling = use_async_scheduling
self.worker_response_mq = response_mq
if self.use_async_scheduling:
self.async_output_queue: queue.Queue = queue.Queue()
self.async_output_copy_thread = Thread(
target=self.async_output_busy_loop,
daemon=True,
name="WorkerAsyncOutputCopy",
)
self.async_output_copy_thread.start()
class ResponseStatus(Enum):
SUCCESS = auto()
FAILURE = auto()
def get_node_ip(self) -> str: def get_node_ip(self) -> str:
return get_ip() return get_ip()
...@@ -83,9 +105,11 @@ try: ...@@ -83,9 +105,11 @@ try:
def execute_model_ray( def execute_model_ray(
self, self,
execute_model_input: tuple["SchedulerOutput", "GrammarOutput"] execute_model_input: tuple["SchedulerOutput", "GrammarOutput"]
| tuple["SchedulerOutput", "GrammarOutput", "IntermediateTensors"]
| tuple["SchedulerOutput", "GrammarOutput", "IntermediateTensors"], | tuple["SchedulerOutput", "GrammarOutput", "IntermediateTensors"],
) -> Union[ ) -> Union[
"ModelRunnerOutput", "ModelRunnerOutput",
"AsyncModelRunnerOutput",
tuple["SchedulerOutput", "GrammarOutput", "IntermediateTensors"], tuple["SchedulerOutput", "GrammarOutput", "IntermediateTensors"],
]: ]:
# This method is used by Ray Compiled Graph to execute the model, # This method is used by Ray Compiled Graph to execute the model,
...@@ -107,7 +131,10 @@ try: ...@@ -107,7 +131,10 @@ try:
return scheduler_output, grammar_output, output return scheduler_output, grammar_output, output
if isinstance(output, AsyncModelRunnerOutput): if isinstance(output, AsyncModelRunnerOutput):
output = output.get_output() if not self.use_async_scheduling:
output = output.get_output()
else:
output = output.get_output_async()
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
# Case where there are no scheduled requests # Case where there are no scheduled requests
# but may still be finished requests. # but may still be finished requests.
...@@ -115,11 +142,22 @@ try: ...@@ -115,11 +142,22 @@ try:
output = scheduler_output, grammar_output, None output = scheduler_output, grammar_output, None
elif output is None: elif output is None:
output = self.worker.model_runner.sample_tokens(grammar_output) output = self.worker.model_runner.sample_tokens(grammar_output)
# Ensure outputs crossing Ray compiled DAG are serializable.
# AsyncModelRunnerOutput holds CUDA events and cannot be if self.use_async_scheduling:
# pickled. if self.output_rank == -1 or self.rpc_rank == self.output_rank:
if isinstance(output, AsyncModelRunnerOutput): self.handle_output(output)
output = output.get_output()
if isinstance(output, AsyncModelRunnerOutput):
output = output.get_output_async()
else:
# Ensure outputs crossing Ray compiled DAG are serializable.
# AsyncModelRunnerOutput holds CUDA events and cannot be
# pickled.
if isinstance(output, AsyncModelRunnerOutput):
output = output.get_output()
else:
if self.use_async_scheduling and (self.output_rank == -1 or self.rpc_rank == self.output_rank):
self.handle_output(output)
return output return output
def override_env_vars(self, vars: dict[str, str]): def override_env_vars(self, vars: dict[str, str]):
...@@ -127,6 +165,41 @@ try: ...@@ -127,6 +165,41 @@ try:
def _is_intermediate_tensors(self, output) -> bool: def _is_intermediate_tensors(self, output) -> bool:
return isinstance(output, IntermediateTensors) return isinstance(output, IntermediateTensors)
def enqueue_output(self, output: Any):
"""Prepares output from the worker and enqueues it to the
worker_response_mq. If the output is an Exception, it is
converted to a FAILURE response.
"""
import os
import threading
if isinstance(output, AsyncModelRunnerOutput):
output = output.get_output()
if isinstance(output, Exception):
result = (RayWorkerWrapper.ResponseStatus.FAILURE, str(output))
else:
result = (RayWorkerWrapper.ResponseStatus.SUCCESS, output)
if (response_mq := self.worker_response_mq) is not None:
response_mq.put(result)
def handle_output(self, output: Any):
"""Handles output from the worker. If async scheduling is enabled,
it is passed to the async_output_busy_loop thread. Otherwise, it is
enqueued directly to the worker_response_mq.
"""
if self.use_async_scheduling:
self.async_output_queue.put(output)
else:
self.enqueue_output(output)
def async_output_busy_loop(self):
"""Entrypoint for the thread which handles outputs asynchronously."""
while True:
output = self.async_output_queue.get()
self.enqueue_output(output)
ray_import_err = None ray_import_err = None
...@@ -159,6 +232,34 @@ class FutureWrapper(Future): ...@@ -159,6 +232,34 @@ class FutureWrapper(Future):
return self.aggregator.aggregate(outputs, output_rank=0) return self.aggregator.aggregate(outputs, output_rank=0)
class NonBlockFutureWrapper(Future):
def __init__(
self,
futures_queue: deque[tuple["FutureWrapper", Callable]],
aggregate: Callable = lambda x: x,
):
self.futures_queue = futures_queue
self.aggregate = aggregate
super().__init__()
def result(self, timeout=None):
if timeout is not None:
raise RuntimeError("timeout not implemented")
# Drain any futures ahead of us in the queue.
while not self.done():
future, get_response = self.futures_queue.pop()
future.wait_for_response(get_response)
return super().result()
def wait_for_response(self, get_response: Callable):
try:
response = self.aggregate(get_response())
with suppress(InvalidStateError):
self.set_result(response)
except Exception as e:
with suppress(InvalidStateError):
self.set_exception(e)
def ray_is_available() -> bool: def ray_is_available() -> bool:
"""Returns True if Ray is available.""" """Returns True if Ray is available."""
......
...@@ -200,6 +200,15 @@ class AsyncModelRunnerOutput(ABC): ...@@ -200,6 +200,15 @@ class AsyncModelRunnerOutput(ABC):
""" """
pass pass
@abstractmethod
def get_output_async(self) -> ModelRunnerOutput:
"""Get the ModelRunnerOutput for this async output.
This is a non blocking call, which return a fake out.
This method should only be called once per AsyncModelRunnerOutput.
"""
pass
@dataclass @dataclass
class DraftTokenIds: class DraftTokenIds:
......
...@@ -263,6 +263,9 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput): ...@@ -263,6 +263,9 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
output.logprobs = logprobs_lists output.logprobs = logprobs_lists
return output return output
def get_output_async(self) -> ModelRunnerOutput:
return self._model_runner_output
class AsyncGPUPoolingModelRunnerOutput(AsyncModelRunnerOutput): class AsyncGPUPoolingModelRunnerOutput(AsyncModelRunnerOutput):
def __init__( def __init__(
......
...@@ -198,6 +198,7 @@ class WorkerWrapperBase: ...@@ -198,6 +198,7 @@ class WorkerWrapperBase:
""" """
self.rpc_rank = rpc_rank self.rpc_rank = rpc_rank
self.global_rank = self.rpc_rank if global_rank is None else global_rank self.global_rank = self.rpc_rank if global_rank is None else global_rank
self.output_rank = -1
# Initialized after init_worker is called # Initialized after init_worker is called
self.worker: WorkerBase self.worker: WorkerBase
...@@ -207,7 +208,7 @@ class WorkerWrapperBase: ...@@ -207,7 +208,7 @@ class WorkerWrapperBase:
if self.worker is not None: if self.worker is not None:
self.worker.shutdown() self.worker.shutdown()
def adjust_rank(self, rank_mapping: dict[int, int]) -> None: def adjust_rank(self, rank_mapping: dict[int, int], output_rank: int=-1) -> None:
""" """
Adjust the rpc_rank based on the given mapping. Adjust the rpc_rank based on the given mapping.
It is only used during the initialization of the executor, It is only used during the initialization of the executor,
...@@ -216,6 +217,8 @@ class WorkerWrapperBase: ...@@ -216,6 +217,8 @@ class WorkerWrapperBase:
if self.rpc_rank in rank_mapping: if self.rpc_rank in rank_mapping:
self.rpc_rank = rank_mapping[self.rpc_rank] self.rpc_rank = rank_mapping[self.rpc_rank]
self.output_rank = output_rank
def update_environment_variables( def update_environment_variables(
self, self,
envs_list: list[dict[str, str]], envs_list: list[dict[str, str]],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment