Unverified Commit 21464e05 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat] Gossip/SlowMo (#378)



Add SlowMo Distributed Data Parallel for clusters with slow interconnects
Co-authored-by: default avatarVinayak Tantia <tantia.vinayak1@gmail.com>
parent 8347c1a2
......@@ -18,4 +18,16 @@ def gather(tensors: Iterable[Tensor],
destination: Optional[int] = None,
) -> Tensor: ...
def broadcast_coalesced(tensors: Iterable[Tensor],
devices: Iterable[int],
buffer_size: int = 10485760,
) -> Tuple[Tensor, ...]: ...
def reduce_add_coalesced(inputs: Iterable[Iterable[Tensor]],
destination: Optional[int] = None,
buffer_size: int = 10485760,
) -> Tuple[Tensor, ...]: ...
#END
......@@ -16,6 +16,9 @@ class ProcessGroup:
def size(self) -> int: ...
def rank(self) -> int: ...
class Work:
def wait(self) -> None: ...
class ReduceOp:
SUM: ReduceOp
PRODUCT: ReduceOp
......@@ -26,15 +29,27 @@ class ReduceOp:
BXOR: ReduceOp
def get_rank(group: Any = None) -> int: ...
def get_world_size(group: Any = None) -> int: ...
def get_backend(group: Optional[Any] = None) -> Any: ...
def broadcast(tensor: Tensor, src: Any, group: Any, async_op: Any = False): ...
def gather(tensor: Tensor, gather_list: Optional[List[Tensor]], dst: Any, group:Optional[ProcessGroup] = None, async_op: Optional[bool] = False): ...
def reduce(tensor: Tensor, dst: Any, op: Optional[Any]=ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: Optional[bool] = False): ...
def broadcast_object_list(object_list: List[Any], src: int, group:Optional[ProcessGroup] = None): ...
def broadcast(tensor: Tensor, src: Any, group: Optional[Any] = None, async_op: Any = False): ...
def gather(
tensor: Tensor,
gather_list: Optional[List[Tensor]],
dst: Any,
group: Optional[ProcessGroup] = None,
async_op: Optional[bool] = False,
): ...
def reduce(
tensor: Tensor,
dst: Any,
op: Optional[Any] = ReduceOp.SUM,
group: Optional[ProcessGroup] = None,
async_op: Optional[bool] = False,
): ...
def broadcast_object_list(object_list: List[Any], src: int, group: Optional[ProcessGroup] = None): ...
def is_available() -> bool: ...
def is_initialized() -> bool: ...
def is_nccl_available() -> bool: ...
def init_process_group(backend: Union[str, Backend], init_method: Optional[str] = None, timeout: datetime.timedelta = datetime.timedelta(0, 1800), rank: Optional[int] = None, world_size: Optional[int] = None): ...
def new_group(ranks: Optional[Sequence[int]] = None,
......@@ -51,11 +66,15 @@ def _all_gather_base(input_tensor: Tensor, output_tensor: Tensor, group:Optional
def _reduce_scatter_base(output_tensor: Tensor, input_tensor: Tensor, group:Optional[ProcessGroup] = None): ...
def destroy_process_group() -> None: ...
def send(tensor: Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: Optional[int] = None) -> None: ...
def isend(tensor: Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: Optional[int] = None) -> None: ...
def recv(tensor: Tensor, src: Optional[int] = None, group: Optional[ProcessGroup] = None, tag: Optional[int] = None) -> int: ...
def irecv(tensor: Tensor, src: Optional[int] = None, group: Optional[ProcessGroup] = None, tag: Optional[int] = None) -> int: ...
def recv(
tensor: Tensor, src: Optional[int] = None, group: Optional[ProcessGroup] = None, tag: Optional[int] = None
) -> int: ...
def irecv(
tensor: Tensor, src: Optional[int] = None, group: Optional[ProcessGroup] = None, tag: Optional[int] = None
) -> int: ...
def _broadcast_coalesced(process_group: ProcessGroup, tensors: List[Tensor], buffer_size: int) -> None: ...
class group(object):
WORLD: Any
......
......@@ -5,3 +5,5 @@ from typing import Any, List, Union, Optional
from . import ProcessGroup
def _get_global_rank(group: ProcessGroup, rank: int) -> int: ...
def _get_default_group() -> ProcessGroup: ...
\ No newline at end of file
......@@ -19,3 +19,4 @@ tests/optim/test_adam.py
tests/optim/test_oss.py
tests/optim/test_oss_adascale.py
tests/optim/test_ddp_adascale.py
tests/experimental/nn/data_parallel/test_gossip.py
This diff is collapsed.
......@@ -298,22 +298,18 @@ def run_test_row_parallel_linear(rank, model_parallel_size, filename, filename_r
print(" >> passed the test :-)")
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def test_affine_weight():
spawn_for_all_world_sizes(run_test_initialize_affine_weight)
spawn_for_all_world_sizes(run_test_initialize_affine_weight, deterministic=True)
def test_embedding():
spawn_for_all_world_sizes(run_test_parallel_embedding)
spawn_for_all_world_sizes(run_test_parallel_embedding, deterministic=True)
def test_column_parallel():
spawn_for_all_world_sizes(run_test_column_parallel_linear)
spawn_for_all_world_sizes(run_test_column_parallel_linear, deterministic=True)
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" not in os.environ, reason="only works on mpi")
def test_row_parallel():
spawn_for_all_world_sizes(run_test_row_parallel_linear)
spawn_for_all_world_sizes(run_test_row_parallel_linear, deterministic=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment