Unverified Commit cc99baf1 authored by Jonathan Berkhahn's avatar Jonathan Berkhahn Committed by GitHub
Browse files

[Misc] Make timeout passable in init_distributed_environment (#24522)


Signed-off-by: default avatarjberkhahn <jaberkha@us.ibm.com>
parent dcb28a33
...@@ -29,6 +29,7 @@ import weakref ...@@ -29,6 +29,7 @@ import weakref
from collections import namedtuple from collections import namedtuple
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from dataclasses import dataclass from dataclasses import dataclass
from datetime import timedelta
from multiprocessing import shared_memory from multiprocessing import shared_memory
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Union
from unittest.mock import patch from unittest.mock import patch
...@@ -978,13 +979,12 @@ def set_custom_all_reduce(enable: bool): ...@@ -978,13 +979,12 @@ def set_custom_all_reduce(enable: bool):
_ENABLE_CUSTOM_ALL_REDUCE = enable _ENABLE_CUSTOM_ALL_REDUCE = enable
def init_distributed_environment( def init_distributed_environment(world_size: int = -1,
world_size: int = -1,
rank: int = -1, rank: int = -1,
distributed_init_method: str = "env://", distributed_init_method: str = "env://",
local_rank: int = -1, local_rank: int = -1,
backend: str = "nccl", backend: str = "nccl",
): timeout: Optional[timedelta] = None):
logger.debug( logger.debug(
"world_size=%d rank=%d local_rank=%d " "world_size=%d rank=%d local_rank=%d "
"distributed_init_method=%s backend=%s", world_size, rank, local_rank, "distributed_init_method=%s backend=%s", world_size, rank, local_rank,
...@@ -1020,7 +1020,8 @@ def init_distributed_environment( ...@@ -1020,7 +1020,8 @@ def init_distributed_environment(
backend=backend, backend=backend,
init_method=distributed_init_method, init_method=distributed_init_method,
world_size=world_size, world_size=world_size,
rank=rank) rank=rank,
timeout=timeout)
# set the local rank # set the local rank
# local_rank is not available in torch ProcessGroup, # local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816 # see https://github.com/pytorch/pytorch/issues/122816
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment