mappings.py 8.23 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4

import torch

5
from .initialize import get_tensor_model_parallel_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank
6
7
8
9
from .utils import split_tensor_along_last_dim


def _reduce(input_):
Taebum Kim's avatar
Taebum Kim committed
10
    """All-reduce the input tensor across model parallel group."""
11
12

    # Bypass the function if we are using only 1 GPU.
13
    if get_tensor_model_parallel_world_size()==1:
14
15
16
        return input_

    # All-reduce.
17
    torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group())
18
19
20
21

    return input_


22
def _split_along_last_dim(input_):
23
24
25
    """Split the tensor along its last dimension and keep the
    corresponding slice."""

26
    world_size = get_tensor_model_parallel_world_size()
27
    # Bypass the function if we are using only 1 GPU.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
28
    if world_size == 1:
29
30
31
32
33
34
        return input_

    # Split along last dimension.
    input_list = split_tensor_along_last_dim(input_, world_size)

    # Note: torch.split does not create contiguous tensors by default.
35
    rank = get_tensor_model_parallel_rank()
36
37
38
39
    output = input_list[rank].contiguous()

    return output

40

41
42
43
44
45
46
def _split_along_first_dim(input_):
    """Split the tensor along its first dimension and keep the
    corresponding slice."""

    world_size = get_tensor_model_parallel_world_size()
    # Bypass the function if we are using only 1 GPU.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
47
    if world_size == 1:
48
49
50
51
        return input_

    # Split along first dimension.
    dim_size = input_.size()[0]
Vijay Korthikanti's avatar
Vijay Korthikanti committed
52
53
    assert dim_size % world_size == 0, \
        "First dimension of the tensor should be divisible by tensor parallel size"
54
55
    local_dim_size = dim_size // world_size
    rank = get_tensor_model_parallel_rank()
Vijay Korthikanti's avatar
Vijay Korthikanti committed
56
    dim_offset = rank * local_dim_size
57

Vijay Korthikanti's avatar
Vijay Korthikanti committed
58
    output = input_[dim_offset:dim_offset+local_dim_size].contiguous()
59
60

    return output
61

62
63

def _gather_along_last_dim(input_):
64
65
    """Gather tensors and concatinate along the last dimension."""

66
    world_size = get_tensor_model_parallel_world_size()
67
    # Bypass the function if we are using only 1 GPU.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
68
    if world_size == 1:
69
70
71
72
        return input_

    # Size and dimension.
    last_dim = input_.dim() - 1
73
    rank = get_tensor_model_parallel_rank()
74
75
76

    tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
    tensor_list[rank] = input_
77
    torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group())
78
79
80
81
82
83
84

    # Note: torch.cat already creates a contiguous tensor.
    output = torch.cat(tensor_list, dim=last_dim).contiguous()

    return output


85
86
87
88
89
def _gather_along_first_dim(input_):
    """Gather tensors and concatinate along the first dimension."""

    world_size = get_tensor_model_parallel_world_size()
    # Bypass the function if we are using only 1 GPU.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
90
    if world_size == 1:
91
92
93
94
95
96
        return input_

    dim_size = list(input_.size())
    dim_size[0] = dim_size[0] * world_size

    output = torch.empty(dim_size, dtype=input_.dtype,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
97
98
                         device=torch.cuda.current_device())
    torch.distributed._all_gather_base(output, input_.contiguous(),
99
100
101
102
103
104
105
106
                                       group=get_tensor_model_parallel_group())

    return output

def _reduce_scatter_along_first_dim(input_):
    """Reduce-scatter the input tensor across model parallel group."""
    world_size = get_tensor_model_parallel_world_size()
    # Bypass the function if we are using only 1 GPU.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
107
    if world_size == 1:
108
109
110
        return input_

    dim_size = list(input_.size())
Vijay Korthikanti's avatar
Vijay Korthikanti committed
111
112
113
    assert dim_size[0] % world_size == 0, \
        "First dimension of the tensor should be divisible by tensor parallel size"
    
Vijay Korthikanti's avatar
Vijay Korthikanti committed
114
115
    dim_size[0] = dim_size[0] // world_size
   
116
    output = torch.empty(dim_size, dtype=input_.dtype,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
117
118
                         device=torch.cuda.current_device())
    torch.distributed._reduce_scatter_base(output, input_.contiguous(), 
119
120
121
122
                                           group=get_tensor_model_parallel_group())
    return output


123
124
125
class _CopyToModelParallelRegion(torch.autograd.Function):
    """Pass the input to the model parallel region."""

126
127
128
129
    @staticmethod
    def symbolic(graph, input_):
        return input_
    
130
131
132
133
134
135
136
137
138
139
    @staticmethod
    def forward(ctx, input_):
        return input_

    @staticmethod
    def backward(ctx, grad_output):
        return _reduce(grad_output)


class _ReduceFromModelParallelRegion(torch.autograd.Function):
Nako Sung's avatar
Nako Sung committed
140
    """All-reduce the input from the model parallel region."""
141

142
143
144
145
    @staticmethod
    def symbolic(graph, input_):
        return _reduce(input_)
    
146
147
148
149
150
151
152
153
154
    @staticmethod
    def forward(ctx, input_):
        return _reduce(input_)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output


155
class _ScatterToModelParallelRegion(torch.autograd.Function):
156
157
    """Split the input and keep only the corresponding chuck to the rank."""

158
159
    @staticmethod
    def symbolic(graph, input_):
160
        return _split_along_last_dim(input_)
161

162
163
    @staticmethod
    def forward(ctx, input_):
164
        return _split_along_last_dim(input_)
165
166
167

    @staticmethod
    def backward(ctx, grad_output):
168
        return _gather_along_last_dim(grad_output)
169
170


171
class _GatherFromModelParallelRegion(torch.autograd.Function):
172
173
    """Gather the input from model parallel region and concatinate."""

174
175
    @staticmethod
    def symbolic(graph, input_):
176
        return _gather_along_last_dim(input_)
177
    
178
179
    @staticmethod
    def forward(ctx, input_):
180
        return _gather_along_last_dim(input_)
181
182
183

    @staticmethod
    def backward(ctx, grad_output):
184
        return _split_along_last_dim(grad_output)
185
186


187
class _ScatterToSequenceParallelRegion(torch.autograd.Function):
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    """Split the input and keep only the corresponding chuck to the rank."""

    @staticmethod
    def symbolic(graph, input_):
        return _split_along_first_dim(input_)

    @staticmethod
    def forward(ctx, input_):
        return _split_along_first_dim(input_)

    @staticmethod
    def backward(ctx, grad_output):
        return _gather_along_first_dim(grad_output)


203
class _GatherFromSequenceParallelRegion(torch.autograd.Function):
204
    """Gather the input from sequence parallel region and concatinate.""" 
205
206

    @staticmethod
207
    def symbolic(graph, input_, tensor_parallel_output_grad=True):
208
209
210
        return _gather_along_first_dim(input_)
    
    @staticmethod
211
212
    def forward(ctx, input_, tensor_parallel_output_grad=True):
        ctx.tensor_parallel_output_grad = tensor_parallel_output_grad
213
214
215
216
        return _gather_along_first_dim(input_)

    @staticmethod
    def backward(ctx, grad_output):
217
        tensor_parallel_output_grad = ctx.tensor_parallel_output_grad
218

219
220
221
222
223
        # If the computation graph after the gather operation is
        # in the tensor parallel mode, output gradients need to reduce 
        # scattered and whereas if the computation is duplicated, 
        # output gradients need to be scattered.
        if tensor_parallel_output_grad:
224
225
226
            return _reduce_scatter_along_first_dim(grad_output), None
        else:
            return _split_along_first_dim(grad_output), None
227
228


229
class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function):
230
231
232
233
234
235
236
237
238
239
240
241
242
    """Reduce scatter the input from the model parallel region."""

    @staticmethod
    def symbolic(graph, input_):
        return _reduce_scatter_along_first_dim(input_)
    
    @staticmethod
    def forward(ctx, input_):
        return _reduce_scatter_along_first_dim(input_)

    @staticmethod
    def backward(ctx, grad_output):
        return _gather_along_first_dim(grad_output)
243
244
245
246
247
248


# -----------------
# Helper functions.
# -----------------

249
def copy_to_tensor_model_parallel_region(input_):
250
251
    return _CopyToModelParallelRegion.apply(input_)

Neel Kant's avatar
Neel Kant committed
252

253
def reduce_from_tensor_model_parallel_region(input_):
254
255
    return _ReduceFromModelParallelRegion.apply(input_)

Neel Kant's avatar
Neel Kant committed
256

257
258
def scatter_to_tensor_model_parallel_region(input_):
    return _ScatterToModelParallelRegion.apply(input_)
259
260


261
262
def gather_from_tensor_model_parallel_region(input_):
    return _GatherFromModelParallelRegion.apply(input_)
263
264


265
266
def scatter_to_sequence_parallel_region(input_):
    return _ScatterToSequenceParallelRegion.apply(input_)
267
268


269
270
def gather_from_sequence_parallel_region(input_, tensor_parallel_output_grad=True):
    return _GatherFromSequenceParallelRegion.apply(input_, tensor_parallel_output_grad)
271

272

273
274
def reduce_scatter_to_sequence_parallel_region(input_):
    return _ReduceScatterToSequenceParallelRegion.apply(input_)
Neel Kant's avatar
Neel Kant committed
275