Unverified Commit c0bb9eb3 authored by Shenggui Li's avatar Shenggui Li Committed by GitHub
Browse files

[improve] made timeout configurable (#3803)

parent 7036d6fc
...@@ -81,3 +81,9 @@ Overall, with these optimizations, we have achieved up to a 7x acceleration in o ...@@ -81,3 +81,9 @@ Overall, with these optimizations, we have achieved up to a 7x acceleration in o
- **Weight**: Per-128x128-block quantization for better numerical stability. - **Weight**: Per-128x128-block quantization for better numerical stability.
**Usage**: turn on by default for DeepSeek V3 models. **Usage**: turn on by default for DeepSeek V3 models.
## FAQ
**Question**: What should I do if model loading takes too long and NCCL timeout occurs?
Answer: You can try to add `--dist-timeout 3600` when launching the model, this allows for 1-hour timeout.i
...@@ -30,6 +30,7 @@ import weakref ...@@ -30,6 +30,7 @@ import weakref
from collections import namedtuple from collections import namedtuple
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from dataclasses import dataclass from dataclasses import dataclass
from datetime import timedelta
from multiprocessing import shared_memory from multiprocessing import shared_memory
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from unittest.mock import patch from unittest.mock import patch
...@@ -960,6 +961,7 @@ def init_distributed_environment( ...@@ -960,6 +961,7 @@ def init_distributed_environment(
distributed_init_method: str = "env://", distributed_init_method: str = "env://",
local_rank: int = -1, local_rank: int = -1,
backend: str = "nccl", backend: str = "nccl",
timeout: Optional[int] = None,
): ):
logger.debug( logger.debug(
"world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s", "world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s",
...@@ -974,13 +976,20 @@ def init_distributed_environment( ...@@ -974,13 +976,20 @@ def init_distributed_environment(
"distributed_init_method must be provided when initializing " "distributed_init_method must be provided when initializing "
"distributed environment" "distributed environment"
) )
if timeout is not None:
assert isinstance(timeout, (int)), "timeout must be a number"
assert timeout > 0, "timeout must be positive"
timeout = timedelta(seconds=timeout)
# this backend is used for WORLD # this backend is used for WORLD
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend=backend, backend=backend,
init_method=distributed_init_method, init_method=distributed_init_method,
world_size=world_size, world_size=world_size,
rank=rank, rank=rank,
timeout=timeout,
) )
# set the local rank # set the local rank
# local_rank is not available in torch ProcessGroup, # local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816 # see https://github.com/pytorch/pytorch/issues/122816
......
...@@ -259,6 +259,7 @@ class ModelRunner: ...@@ -259,6 +259,7 @@ class ModelRunner:
rank=self.tp_rank, rank=self.tp_rank,
local_rank=self.gpu_id, local_rank=self.gpu_id,
distributed_init_method=dist_init_method, distributed_init_method=dist_init_method,
timeout=self.server_args.dist_timeout,
) )
initialize_model_parallel(tensor_model_parallel_size=self.tp_size) initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
initialize_dp_attention( initialize_dp_attention(
......
...@@ -79,6 +79,7 @@ class ServerArgs: ...@@ -79,6 +79,7 @@ class ServerArgs:
random_seed: Optional[int] = None random_seed: Optional[int] = None
constrained_json_whitespace_pattern: Optional[str] = None constrained_json_whitespace_pattern: Optional[str] = None
watchdog_timeout: float = 300 watchdog_timeout: float = 300
dist_timeout: Optional[int] = None # timeout for torch.distributed
download_dir: Optional[str] = None download_dir: Optional[str] = None
base_gpu_id: int = 0 base_gpu_id: int = 0
...@@ -534,6 +535,12 @@ class ServerArgs: ...@@ -534,6 +535,12 @@ class ServerArgs:
default=ServerArgs.watchdog_timeout, default=ServerArgs.watchdog_timeout,
help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.", help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
) )
parser.add_argument(
"--dist-timeout",
type=int,
default=ServerArgs.dist_timeout,
help="Set timeout for torch.distributed initialization.",
)
parser.add_argument( parser.add_argument(
"--download-dir", "--download-dir",
type=str, type=str,
......
...@@ -503,7 +503,9 @@ def run_unittest_files(files: List[str], timeout_per_file: float): ...@@ -503,7 +503,9 @@ def run_unittest_files(files: List[str], timeout_per_file: float):
ret_code = run_with_timeout( ret_code = run_with_timeout(
run_one_file, args=(filename,), timeout=timeout_per_file run_one_file, args=(filename,), timeout=timeout_per_file
) )
assert ret_code == 0 assert (
ret_code == 0
), f"expected return code 0, but {filename} returned {ret_code}"
except TimeoutError: except TimeoutError:
kill_process_tree(process.pid) kill_process_tree(process.pid)
time.sleep(5) time.sleep(5)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment