Unverified Commit 03a49bb8 authored by Shiyan Deng's avatar Shiyan Deng Committed by GitHub
Browse files

[Feature] Add --distributed-timeout-seconds CLI option (#36047)


Signed-off-by: default avatarShiyan Deng <dsy842974287@meta.com>
Co-authored-by: default avatarLu Fang <30275821+houseroad@users.noreply.github.com>
parent 8e87cc57
...@@ -237,6 +237,12 @@ class ParallelConfig: ...@@ -237,6 +237,12 @@ class ParallelConfig:
"""num of nodes for multi-node distributed """num of nodes for multi-node distributed
inference when distributed_executor_backend is mp.""" inference when distributed_executor_backend is mp."""
distributed_timeout_seconds: int | None = None
"""Timeout in seconds for distributed operations (e.g., init_process_group).
If set, this value is passed to torch.distributed.init_process_group as the
timeout parameter. If None, PyTorch's default timeout is used (600s for NCCL).
Increase this for multi-node setups where model downloads may be slow."""
world_size: int = Field(init=False) world_size: int = Field(init=False)
"""world_size is TPxPP, it affects the number of workers we create.""" """world_size is TPxPP, it affects the number of workers we create."""
......
...@@ -403,6 +403,7 @@ class EngineArgs: ...@@ -403,6 +403,7 @@ class EngineArgs:
master_port: int = ParallelConfig.master_port master_port: int = ParallelConfig.master_port
nnodes: int = ParallelConfig.nnodes nnodes: int = ParallelConfig.nnodes
node_rank: int = ParallelConfig.node_rank node_rank: int = ParallelConfig.node_rank
distributed_timeout_seconds: int | None = ParallelConfig.distributed_timeout_seconds
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
prefill_context_parallel_size: int = ParallelConfig.prefill_context_parallel_size prefill_context_parallel_size: int = ParallelConfig.prefill_context_parallel_size
decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
...@@ -814,6 +815,10 @@ class EngineArgs: ...@@ -814,6 +815,10 @@ class EngineArgs:
parallel_group.add_argument("--master-port", **parallel_kwargs["master_port"]) parallel_group.add_argument("--master-port", **parallel_kwargs["master_port"])
parallel_group.add_argument("--nnodes", "-n", **parallel_kwargs["nnodes"]) parallel_group.add_argument("--nnodes", "-n", **parallel_kwargs["nnodes"])
parallel_group.add_argument("--node-rank", "-r", **parallel_kwargs["node_rank"]) parallel_group.add_argument("--node-rank", "-r", **parallel_kwargs["node_rank"])
parallel_group.add_argument(
"--distributed-timeout-seconds",
**parallel_kwargs["distributed_timeout_seconds"],
)
parallel_group.add_argument( parallel_group.add_argument(
"--tensor-parallel-size", "-tp", **parallel_kwargs["tensor_parallel_size"] "--tensor-parallel-size", "-tp", **parallel_kwargs["tensor_parallel_size"]
) )
...@@ -1701,6 +1706,7 @@ class EngineArgs: ...@@ -1701,6 +1706,7 @@ class EngineArgs:
master_port=self.master_port, master_port=self.master_port,
nnodes=self.nnodes, nnodes=self.nnodes,
node_rank=self.node_rank, node_rank=self.node_rank,
distributed_timeout_seconds=self.distributed_timeout_seconds,
data_parallel_master_ip=data_parallel_address, data_parallel_master_ip=data_parallel_address,
data_parallel_rpc_port=data_parallel_rpc_port, data_parallel_rpc_port=data_parallel_rpc_port,
data_parallel_backend=self.data_parallel_backend, data_parallel_backend=self.data_parallel_backend,
......
...@@ -6,6 +6,7 @@ import gc ...@@ -6,6 +6,7 @@ import gc
import os import os
from collections.abc import Callable from collections.abc import Callable
from contextlib import AbstractContextManager, nullcontext from contextlib import AbstractContextManager, nullcontext
from datetime import timedelta
from types import NoneType from types import NoneType
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
...@@ -942,8 +943,18 @@ def init_worker_distributed_environment( ...@@ -942,8 +943,18 @@ def init_worker_distributed_environment(
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
init_method = distributed_init_method or "env://" init_method = distributed_init_method or "env://"
timeout = None
if parallel_config.distributed_timeout_seconds is not None:
timeout = timedelta(seconds=parallel_config.distributed_timeout_seconds)
init_distributed_environment( init_distributed_environment(
parallel_config.world_size, rank, init_method, local_rank, backend parallel_config.world_size,
rank,
init_method,
local_rank,
backend,
timeout,
) )
ensure_model_parallel_initialized( ensure_model_parallel_initialized(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment