mappings.py 8.2 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

18
from .initialize import get_tensor_model_parallel_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank
19
20
21
22
from .utils import split_tensor_along_last_dim


def _reduce(input_):
Taebum Kim's avatar
Taebum Kim committed
23
    """All-reduce the input tensor across model parallel group."""
24
25

    # Bypass the function if we are using only 1 GPU.
26
    if get_tensor_model_parallel_world_size()==1:
27
28
29
        return input_

    # All-reduce.
30
    torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group())
31
32
33
34

    return input_


35
def _split_along_last_dim(input_):
36
37
38
    """Split the tensor along its last dimension and keep the
    corresponding slice."""

39
    world_size = get_tensor_model_parallel_world_size()
40
    # Bypass the function if we are using only 1 GPU.
41
    if world_size==1:
42
43
44
45
46
47
        return input_

    # Split along last dimension.
    input_list = split_tensor_along_last_dim(input_, world_size)

    # Note: torch.split does not create contiguous tensors by default.
48
    rank = get_tensor_model_parallel_rank()
49
50
51
52
    output = input_list[rank].contiguous()

    return output

53

54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def _split_along_first_dim(input_):
    """Split the tensor along its first dimension and keep the
    corresponding slice."""

    world_size = get_tensor_model_parallel_world_size()
    # Bypass the function if we are using only 1 GPU.
    if world_size==1:
        return input_

    # Split along first dimension.
    dim_size = input_.size()[0]
    assert dim_size % world_size == 0
    local_dim_size = dim_size // world_size
    rank = get_tensor_model_parallel_rank()
    dim_offset = rank * (local_dim_size)

    output = input_[dim_offset:dim_offset+local_dim_size]

    return output
73

74
75

def _gather_along_last_dim(input_):
76
77
    """Gather tensors and concatinate along the last dimension."""

78
    world_size = get_tensor_model_parallel_world_size()
79
    # Bypass the function if we are using only 1 GPU.
80
    if world_size==1:
81
82
83
84
        return input_

    # Size and dimension.
    last_dim = input_.dim() - 1
85
    rank = get_tensor_model_parallel_rank()
86
87
88

    tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
    tensor_list[rank] = input_
89
    torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group())
90
91
92
93
94
95
96

    # Note: torch.cat already creates a contiguous tensor.
    output = torch.cat(tensor_list, dim=last_dim).contiguous()

    return output


97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
def _gather_along_first_dim(input_):
    """Gather tensors and concatinate along the first dimension."""

    world_size = get_tensor_model_parallel_world_size()
    # Bypass the function if we are using only 1 GPU.
    if world_size==1:
        return input_

    dim_size = list(input_.size())
    dim_size[0] = dim_size[0] * world_size

    output = torch.empty(dim_size, dtype=input_.dtype,
                         device=torch.cuda.current_device(),
                         requires_grad=False)
    torch.distributed._all_gather_base(output, input_,
                                       group=get_tensor_model_parallel_group())

    return output

def _reduce_scatter_along_first_dim(input_):
    """Reduce-scatter the input tensor across model parallel group."""

    world_size = get_tensor_model_parallel_world_size()
    # Bypass the function if we are using only 1 GPU.
    if get_tensor_model_parallel_world_size()==1:
        return input_

    dim_size = list(input_.size())
    assert dim_size[0] % world_size == 0
    dim_size[0]= dim_size[0] // world_size
    
    output = torch.empty(dim_size, dtype=input_.dtype,
                         device=torch.cuda.current_device(),
                         requires_grad=False)

    # reduce_scatter
    torch.distributed._reduce_scatter_base(output, input_, 
                                           group=get_tensor_model_parallel_group())

    return output


def _reduce_scatter_along_last_dim(input_):
    output = _reduce(input_)
    output = _split_along_last_dim(output)
    return output


145
146
147
class _CopyToModelParallelRegion(torch.autograd.Function):
    """Pass the input to the model parallel region."""

148
149
150
151
    @staticmethod
    def symbolic(graph, input_):
        return input_
    
152
153
154
155
156
157
158
159
160
161
    @staticmethod
    def forward(ctx, input_):
        return input_

    @staticmethod
    def backward(ctx, grad_output):
        return _reduce(grad_output)


class _ReduceFromModelParallelRegion(torch.autograd.Function):
Nako Sung's avatar
Nako Sung committed
162
    """All-reduce the input from the model parallel region."""
163

164
165
166
167
    @staticmethod
    def symbolic(graph, input_):
        return _reduce(input_)
    
168
169
170
171
172
173
174
175
176
    @staticmethod
    def forward(ctx, input_):
        return _reduce(input_)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output


177
class _ScatterToModelParallelRegion(torch.autograd.Function):
178
179
    """Split the input and keep only the corresponding chuck to the rank."""

180
181
    @staticmethod
    def symbolic(graph, input_):
182
        return _split_along_last_dim(input_)
183

184
185
    @staticmethod
    def forward(ctx, input_):
186
        return _split_along_last_dim(input_)
187
188
189

    @staticmethod
    def backward(ctx, grad_output):
190
        return _gather_along_last_dim(grad_output)
191
192


193
class _GatherFromModelParallelRegion(torch.autograd.Function):
194
195
    """Gather the input from model parallel region and concatinate."""

196
197
    @staticmethod
    def symbolic(graph, input_):
198
        return _gather_along_last_dim(input_)
199
    
200
201
    @staticmethod
    def forward(ctx, input_):
202
        return _gather_along_last_dim(input_)
203
204
205

    @staticmethod
    def backward(ctx, grad_output):
206
        return _split_along_last_dim(grad_output)
207
208


209
class _ScatterToSequenceParallelRegion(torch.autograd.Function):
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
    """Split the input and keep only the corresponding chuck to the rank."""

    @staticmethod
    def symbolic(graph, input_):
        return _split_along_first_dim(input_)

    @staticmethod
    def forward(ctx, input_):
        return _split_along_first_dim(input_)

    @staticmethod
    def backward(ctx, grad_output):
        return _gather_along_first_dim(grad_output)


225
class _GatherFromSequenceParallelRegion(torch.autograd.Function):
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
    """Gather the input from model parallel region and concatinate.""" #TODO

    @staticmethod
    def symbolic(graph, input_):
        return _gather_along_first_dim(input_)
    
    @staticmethod
    def forward(ctx, input_):
        return _gather_along_first_dim(input_)

    @staticmethod
    def backward(ctx, grad_output):
        return _reduce_scatter_along_first_dim(grad_output)


241
class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function):
242
243
244
245
246
247
248
249
250
251
252
253
254
    """Reduce scatter the input from the model parallel region."""

    @staticmethod
    def symbolic(graph, input_):
        return _reduce_scatter_along_first_dim(input_)
    
    @staticmethod
    def forward(ctx, input_):
        return _reduce_scatter_along_first_dim(input_)

    @staticmethod
    def backward(ctx, grad_output):
        return _gather_along_first_dim(grad_output)
255
256
257
258
259
260


# -----------------
# Helper functions.
# -----------------

261
def copy_to_tensor_model_parallel_region(input_):
262
263
    return _CopyToModelParallelRegion.apply(input_)

Neel Kant's avatar
Neel Kant committed
264

265
def reduce_from_tensor_model_parallel_region(input_):
266
267
    return _ReduceFromModelParallelRegion.apply(input_)

Neel Kant's avatar
Neel Kant committed
268

269
270
def scatter_to_tensor_model_parallel_region(input_):
    return _ScatterToModelParallelRegion.apply(input_)
271
272


273
274
def gather_from_tensor_model_parallel_region(input_):
    return _GatherFromModelParallelRegion.apply(input_)
275
276


277
278
def scatter_to_sequence_parallel_region(input_):
    return _ScatterToSequenceParallelRegion.apply(input_)
279
280


281
def gather_from_sequence_parallel_region(input_):
282
    return _GatherFromSequenceParallelRegion.apply(input_)
283

284

285
286
def reduce_scatter_to_sequence_parallel_region(input_):
    return _ReduceScatterToSequenceParallelRegion.apply(input_)
Neel Kant's avatar
Neel Kant committed
287