communication_op.py 1.16 KB
Newer Older
1
from typing import Any, Dict, Optional, Union
2

3
import torch
4
import torch.distributed
5

6
from .parallel_state import get_tp_group
7
8


9
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
10
11
    """All-reduce the input tensor across model parallel group."""
    return get_tp_group().all_reduce(input_)
12
13


14
15
def tensor_model_parallel_all_gather(input_: torch.Tensor,
                                     dim: int = -1) -> torch.Tensor:
16
    """All-gather the input tensor across model parallel group."""
17
    return get_tp_group().all_gather(input_, dim)
18
19


20
21
22
def tensor_model_parallel_gather(input_: torch.Tensor,
                                 dst: int = 0,
                                 dim: int = -1) -> torch.Tensor:
23
24
    """Gather the input tensor across model parallel group."""
    return get_tp_group().gather(input_, dst, dim)
25
26


27
28
29
30
def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor,
                                                                Any]]] = None,
                          src: int = 0):
    if not torch.distributed.is_initialized():
31
        return tensor_dict
32
    return get_tp_group().broadcast_tensor_dict(tensor_dict, src)