mappings.py 8.12 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

18
from .initialize import get_tensor_model_parallel_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank
19
20
21
22
from .utils import split_tensor_along_last_dim


def _reduce(input_):
Taebum Kim's avatar
Taebum Kim committed
23
    """All-reduce the input tensor across model parallel group."""
24
25

    # Bypass the function if we are using only 1 GPU.
26
    if get_tensor_model_parallel_world_size()==1:
27
28
29
        return input_

    # All-reduce.
30
    torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group())
31
32
33
34

    return input_


35
def _split_along_last_dim(input_):
36
37
38
    """Split the tensor along its last dimension and keep the
    corresponding slice."""

39
    world_size = get_tensor_model_parallel_world_size()
40
    # Bypass the function if we are using only 1 GPU.
41
    if world_size==1:
42
43
44
45
46
47
        return input_

    # Split along last dimension.
    input_list = split_tensor_along_last_dim(input_, world_size)

    # Note: torch.split does not create contiguous tensors by default.
48
    rank = get_tensor_model_parallel_rank()
49
50
51
52
    output = input_list[rank].contiguous()

    return output

53

54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def _split_along_first_dim(input_):
    """Split the tensor along its first dimension and keep the
    corresponding slice."""

    world_size = get_tensor_model_parallel_world_size()
    # Bypass the function if we are using only 1 GPU.
    if world_size==1:
        return input_

    # Split along first dimension.
    dim_size = input_.size()[0]
    assert dim_size % world_size == 0
    local_dim_size = dim_size // world_size
    rank = get_tensor_model_parallel_rank()
    dim_offset = rank * (local_dim_size)

Vijay Korthikanti's avatar
Vijay Korthikanti committed
70
    output = input_[dim_offset:dim_offset+local_dim_size].contiguous()
71
72

    return output
73

74
75

def _gather_along_last_dim(input_):
76
77
    """Gather tensors and concatinate along the last dimension."""

78
    world_size = get_tensor_model_parallel_world_size()
79
    # Bypass the function if we are using only 1 GPU.
80
    if world_size==1:
81
82
83
84
        return input_

    # Size and dimension.
    last_dim = input_.dim() - 1
85
    rank = get_tensor_model_parallel_rank()
86
87
88

    tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
    tensor_list[rank] = input_
89
    torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group())
90
91
92
93
94
95
96

    # Note: torch.cat already creates a contiguous tensor.
    output = torch.cat(tensor_list, dim=last_dim).contiguous()

    return output


97
98
99
100
101
102
103
104
105
106
107
108
def _gather_along_first_dim(input_):
    """Gather tensors and concatinate along the first dimension."""

    world_size = get_tensor_model_parallel_world_size()
    # Bypass the function if we are using only 1 GPU.
    if world_size==1:
        return input_

    dim_size = list(input_.size())
    dim_size[0] = dim_size[0] * world_size

    output = torch.empty(dim_size, dtype=input_.dtype,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
109
110
                         device=torch.cuda.current_device())
    torch.distributed._all_gather_base(output, input_.contiguous(),
111
112
113
114
115
116
117
118
                                       group=get_tensor_model_parallel_group())

    return output

def _reduce_scatter_along_first_dim(input_):
    """Reduce-scatter the input tensor across model parallel group."""
    world_size = get_tensor_model_parallel_world_size()
    # Bypass the function if we are using only 1 GPU.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
119
    if get_tensor_model_parallel_world_size() == 1:
120
121
122
123
        return input_

    dim_size = list(input_.size())
    assert dim_size[0] % world_size == 0
Vijay Korthikanti's avatar
Vijay Korthikanti committed
124
125
    dim_size[0] = dim_size[0] // world_size
   
126
    output = torch.empty(dim_size, dtype=input_.dtype,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
127
128
                         device=torch.cuda.current_device())
    torch.distributed._reduce_scatter_base(output, input_.contiguous(), 
129
130
131
132
133
134
135
136
137
138
                                           group=get_tensor_model_parallel_group())
    return output


def _reduce_scatter_along_last_dim(input_):
    output = _reduce(input_)
    output = _split_along_last_dim(output)
    return output


139
140
141
class _CopyToModelParallelRegion(torch.autograd.Function):
    """Pass the input to the model parallel region."""

142
143
144
145
    @staticmethod
    def symbolic(graph, input_):
        return input_
    
146
147
148
149
150
151
152
153
154
155
    @staticmethod
    def forward(ctx, input_):
        return input_

    @staticmethod
    def backward(ctx, grad_output):
        return _reduce(grad_output)


class _ReduceFromModelParallelRegion(torch.autograd.Function):
Nako Sung's avatar
Nako Sung committed
156
    """All-reduce the input from the model parallel region."""
157

158
159
160
161
    @staticmethod
    def symbolic(graph, input_):
        return _reduce(input_)
    
162
163
164
165
166
167
168
169
170
    @staticmethod
    def forward(ctx, input_):
        return _reduce(input_)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output


171
class _ScatterToModelParallelRegion(torch.autograd.Function):
172
173
    """Split the input and keep only the corresponding chuck to the rank."""

174
175
    @staticmethod
    def symbolic(graph, input_):
176
        return _split_along_last_dim(input_)
177

178
179
    @staticmethod
    def forward(ctx, input_):
180
        return _split_along_last_dim(input_)
181
182
183

    @staticmethod
    def backward(ctx, grad_output):
184
        return _gather_along_last_dim(grad_output)
185
186


187
class _GatherFromModelParallelRegion(torch.autograd.Function):
188
189
    """Gather the input from model parallel region and concatinate."""

190
191
    @staticmethod
    def symbolic(graph, input_):
192
        return _gather_along_last_dim(input_)
193
    
194
195
    @staticmethod
    def forward(ctx, input_):
196
        return _gather_along_last_dim(input_)
197
198
199

    @staticmethod
    def backward(ctx, grad_output):
200
        return _split_along_last_dim(grad_output)
201
202


203
class _ScatterToSequenceParallelRegion(torch.autograd.Function):
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
    """Split the input and keep only the corresponding chuck to the rank."""

    @staticmethod
    def symbolic(graph, input_):
        return _split_along_first_dim(input_)

    @staticmethod
    def forward(ctx, input_):
        return _split_along_first_dim(input_)

    @staticmethod
    def backward(ctx, grad_output):
        return _gather_along_first_dim(grad_output)


219
class _GatherFromSequenceParallelRegion(torch.autograd.Function):
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
    """Gather the input from model parallel region and concatinate.""" #TODO

    @staticmethod
    def symbolic(graph, input_):
        return _gather_along_first_dim(input_)
    
    @staticmethod
    def forward(ctx, input_):
        return _gather_along_first_dim(input_)

    @staticmethod
    def backward(ctx, grad_output):
        return _reduce_scatter_along_first_dim(grad_output)


235
class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function):
236
237
238
239
240
241
242
243
244
245
246
247
248
    """Reduce scatter the input from the model parallel region."""

    @staticmethod
    def symbolic(graph, input_):
        return _reduce_scatter_along_first_dim(input_)
    
    @staticmethod
    def forward(ctx, input_):
        return _reduce_scatter_along_first_dim(input_)

    @staticmethod
    def backward(ctx, grad_output):
        return _gather_along_first_dim(grad_output)
249
250
251
252
253
254


# -----------------
# Helper functions.
# -----------------

255
def copy_to_tensor_model_parallel_region(input_):
256
257
    return _CopyToModelParallelRegion.apply(input_)

Neel Kant's avatar
Neel Kant committed
258

259
def reduce_from_tensor_model_parallel_region(input_):
260
261
    return _ReduceFromModelParallelRegion.apply(input_)

Neel Kant's avatar
Neel Kant committed
262

263
264
def scatter_to_tensor_model_parallel_region(input_):
    return _ScatterToModelParallelRegion.apply(input_)
265
266


267
268
def gather_from_tensor_model_parallel_region(input_):
    return _GatherFromModelParallelRegion.apply(input_)
269
270


271
272
def scatter_to_sequence_parallel_region(input_):
    return _ScatterToSequenceParallelRegion.apply(input_)
273
274


275
def gather_from_sequence_parallel_region(input_):
276
    return _GatherFromSequenceParallelRegion.apply(input_)
277

278

279
280
def reduce_scatter_to_sequence_parallel_region(input_):
    return _ReduceScatterToSequenceParallelRegion.apply(input_)
Neel Kant's avatar
Neel Kant committed
281