distributed.py 5.69 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from typing import Optional

import torch
from torch import Tensor
from torch.distributed import ProcessGroup

# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
# version of PyTorch. The following 4 lines are for backward compatibility with
# older PyTorch.
if "all_gather_into_tensor" not in dir(torch.distributed):
    torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
if "reduce_scatter_tensor" not in dir(torch.distributed):
    torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base


17
# Raw operation, does not support autograd, but does support async
18
19
def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
    world_size = torch.distributed.get_world_size(process_group)
Tri Dao's avatar
Tri Dao committed
20
21
22
23
24
25
    output = torch.empty(
        world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device
    )
    handle = torch.distributed.all_gather_into_tensor(
        output, input_.contiguous(), group=process_group, async_op=async_op
    )
26
27
28
    return output, handle


29
# Raw operation, does not support autograd, but does support async
30
31
32
def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
    world_size = torch.distributed.get_world_size(process_group)
    assert input_.shape[0] % world_size == 0
Tri Dao's avatar
Tri Dao committed
33
34
35
36
37
38
    output = torch.empty(
        input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device
    )
    handle = torch.distributed.reduce_scatter_tensor(
        output, input_.contiguous(), group=process_group, async_op=async_op
    )
39
40
41
    return output, handle


42
43
44
45
46
47
48
# Raw operation, does not support autograd, but does support async
def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
    input_ = input_.contiguous()
    handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op)
    return input_, handle


49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
class AllGatherFunc(torch.autograd.Function):
    """Gather the input from sequence parallel region and concatenate."""

    @staticmethod
    def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
        ctx.process_group = process_group
        output, _ = all_gather_raw(input_, process_group)
        return output

    @staticmethod
    def backward(ctx, grad_output: Tensor):
        grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group)
        return grad_input, None


# Supports autograd, but does not support async
all_gather = AllGatherFunc.apply


class ReduceScatterFunc(torch.autograd.Function):
    """Reduce scatter the input from the sequence parallel region and concatenate."""

    @staticmethod
    def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
        ctx.process_group = process_group
        output, _ = reduce_scatter_raw(input_, process_group)
        return output

    @staticmethod
    def backward(ctx, grad_output: Tensor):
        grad_input, _ = all_gather_raw(grad_output, ctx.process_group)
        return grad_input, None


# Supports autograd, but does not support async
reduce_scatter = ReduceScatterFunc.apply
85
86


87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
class AllReduceFunc(torch.autograd.Function):
    """Gather the input from sequence parallel region and concatenate."""

    @staticmethod
    def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
        ctx.process_group = process_group
        output, _ = all_reduce_raw(input_, process_group)
        return output

    @staticmethod
    def backward(ctx, grad_output: Tensor):
        return grad_output, None


# Supports autograd, but does not support async
all_reduce = AllReduceFunc.apply


def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup):
    # We want to iterate over parameters with _shared_params=True in the same order,
107
    # as different ranks might have different number of parameters (e.g., only rank 0 has bias).
Tri Dao's avatar
Tri Dao committed
108
109
110
    pamams_shared = {
        name: p for name, p in model.named_parameters() if getattr(p, "_shared_params", False)
    }
111
    for _, p in sorted(pamams_shared.items()):
112
113
114
115
116
117
118
        with torch.no_grad():
            # Broadcast needs src to be global rank, not group rank
            torch.distributed.broadcast(
                p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group
            )


119
# Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256
120
121
122
def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup):
    # We want to iterate over parameters with _sequence_parallel=True in the same order,
    # as different ranks might have different number of parameters (e.g., only rank 0 has bias).
Tri Dao's avatar
Tri Dao committed
123
124
125
    params_seqparallel = {
        name: p for name, p in model.named_parameters() if getattr(p, "_sequence_parallel", False)
    }
126
    grads = [p.grad for _, p in sorted(params_seqparallel.items())]
127
128
129
130
131
132
    if grads:
        with torch.no_grad():
            coalesced = torch._utils._flatten_dense_tensors(grads)
            torch.distributed.all_reduce(coalesced, group=process_group)
            for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
                buf.copy_(synced)
133
134
135
136
137
138
139
140
141
142
143
144


def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int:
    """Get the dim for the local rank derived from splitting dim on world_size processes.

    The split may not be even across the world_size processes.
    """
    multiple = dim // multiple_of
    div = multiple // world_size
    mod = multiple % world_size
    local_multiple = div + int(local_rank < mod)
    return local_multiple * multiple_of