primitives.py 13.5 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# 
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
31
from typing import List
Boris Bonev's avatar
Boris Bonev committed
32
33
34
35

import torch
import torch.distributed as dist

36
37
from .utils import polar_group, azimuth_group, polar_group_size
from .utils import is_initialized, is_distributed_polar
Boris Bonev's avatar
Boris Bonev committed
38

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# helper routine to compute uneven splitting in balanced way:
def compute_split_shapes(size: int, num_chunks: int) -> List[int]:
    
    # treat trivial case first
    if num_chunks == 1:
        return [size]
    
    # first, check if we can split using div-up to balance the load: 
    chunk_size = (size + num_chunks - 1) // num_chunks
    last_chunk_size = max(0, size - chunk_size * (num_chunks - 1))
    if last_chunk_size == 0:
        # in this case, the last shard would be empty, split with floor instead:
        chunk_size = size // num_chunks
        last_chunk_size = size - chunk_size * (num_chunks-1)

    # generate sections list
    sections = [chunk_size for _ in range(num_chunks - 1)] + [last_chunk_size]

    return sections

    
Boris Bonev's avatar
Boris Bonev committed
60
61
def split_tensor_along_dim(tensor, dim, num_chunks):
    assert dim < tensor.dim(), f"Error, tensor dimension is {tensor.dim()} which cannot be split along {dim}"
62
63
64
65
66
67
    assert (tensor.shape[dim] >= num_chunks), f"Error, cannot split dim {dim} of size {tensor.shape[dim]} into \
                                              {num_chunks} chunks. Empty slices are currently not supported."
    
    # get split
    sections = compute_split_shapes(tensor.shape[dim], num_chunks)
    tensor_list = torch.split(tensor, sections, dim=dim)
Boris Bonev's avatar
Boris Bonev committed
68
69
70
    
    return tensor_list

71
72

def _transpose(tensor, dim0, dim1, dim1_split_sizes, group=None, async_op=False):
Boris Bonev's avatar
Boris Bonev committed
73
    # get comm params
Boris Bonev's avatar
Boris Bonev committed
74
    comm_size = dist.get_world_size(group=group)
75
    comm_rank = dist.get_rank(group=group)
Boris Bonev's avatar
Boris Bonev committed
76

Boris Bonev's avatar
Boris Bonev committed
77
    # split and local transposition
78
    tsplit = split_tensor_along_dim(tensor, num_chunks=comm_size, dim=dim0)
Thorsten Kurth's avatar
Thorsten Kurth committed
79
    x_send = [y.contiguous() for y in tsplit]
80
81
82
83
84
    x_send_shapes = [x.shape for x in x_send]
    x_recv = []
    x_shape = list(x_send_shapes[comm_rank])
    for dim1_len in dim1_split_sizes:
        x_shape[dim1] = dim1_len
Thorsten Kurth's avatar
Thorsten Kurth committed
85
86
        x_recv.append(torch.empty(x_shape, dtype=tensor.dtype, device=tensor.device))
        
Boris Bonev's avatar
Boris Bonev committed
87
88
    # global transposition
    req = dist.all_to_all(x_recv, x_send, group=group, async_op=async_op)
89
90
91

    # get dim0 split sizes
    dim0_split_sizes = [x[dim0] for x in x_send_shapes]
Boris Bonev's avatar
Boris Bonev committed
92
    
93
    return x_recv, dim0_split_sizes, req
Boris Bonev's avatar
Boris Bonev committed
94
95


Boris Bonev's avatar
Boris Bonev committed
96
class distributed_transpose_azimuth(torch.autograd.Function):
Boris Bonev's avatar
Boris Bonev committed
97
98

    @staticmethod
99
    def forward(ctx, x, dims, dim1_split_sizes):
100
        # WAR for a potential contig check torch bug for channels last contig tensors
101
        xlist, dim0_split_sizes, _ = _transpose(x, dims[0], dims[1], dim1_split_sizes, group=azimuth_group())
102
        x = torch.cat(xlist, dim=dims[1])
103
104
        ctx.dims = dims
        ctx.dim0_split_sizes = dim0_split_sizes
Thorsten Kurth's avatar
Thorsten Kurth committed
105
        
Boris Bonev's avatar
Boris Bonev committed
106
        return x
Boris Bonev's avatar
Boris Bonev committed
107
108

    @staticmethod
Boris Bonev's avatar
Boris Bonev committed
109
    def backward(ctx, go):
110
111
        dims = ctx.dims
        dim0_split_sizes = ctx.dim0_split_sizes
112
        # WAR for a potential contig check torch bug for channels last contig tensors 
113
        gilist, _, _ = _transpose(go, dims[1], dims[0], dim0_split_sizes, group=azimuth_group())
114
        gi = torch.cat(gilist, dim=dims[0])
Thorsten Kurth's avatar
Thorsten Kurth committed
115
        
116
        return gi, None, None
Boris Bonev's avatar
Boris Bonev committed
117
118

    
Boris Bonev's avatar
Boris Bonev committed
119
class distributed_transpose_polar(torch.autograd.Function):
Boris Bonev's avatar
Boris Bonev committed
120
121

    @staticmethod
122
    def forward(ctx, x, dim, dim1_split_sizes):
123
        # WAR for a potential contig check torch bug for channels last contig tensors 
124
        xlist, dim0_split_sizes, _ = _transpose(x, dim[0], dim[1], dim1_split_sizes, group=polar_group())
125
        x = torch.cat(xlist, dim=dim[1])
Boris Bonev's avatar
Boris Bonev committed
126
        ctx.dim = dim
127
        ctx.dim0_split_sizes = dim0_split_sizes
Boris Bonev's avatar
Boris Bonev committed
128
        return x
Boris Bonev's avatar
Boris Bonev committed
129
130

    @staticmethod
Boris Bonev's avatar
Boris Bonev committed
131
132
    def backward(ctx, go):
        dim = ctx.dim
133
        dim0_split_sizes = ctx.dim0_split_sizes
134
        # WAR for a potential contig check torch bug for channels last contig tensors 
135
        gilist, _, _ = _transpose(go, dim[1], dim[0], dim0_split_sizes, group=polar_group())
136
        gi = torch.cat(gilist, dim=dim[0])
137
        return gi, None, None
Boris Bonev's avatar
Boris Bonev committed
138

139
140
141
142
143
144
145
146
147
148
149
150
151
    
# we need those additional primitives for distributed matrix multiplications
def _reduce(input_, use_fp32=True, group=None):
    """All-reduce the input tensor across model parallel group."""

    # Bypass the function if we are using only 1 GPU.
    if dist.get_world_size(group=group) == 1:
        return input_
    
    # All-reduce.
    if use_fp32:
        dtype = input_.dtype
        inputf_ = input_.float()
152
        inputf_ = inputf_.contiguous()
153
154
155
        dist.all_reduce(inputf_, group=group)
        input_ = inputf_.to(dtype)
    else:
156
        input_ = input_.contiguous()
157
158
159
        dist.all_reduce(input_, group=group)
        
    return input_
Thorsten Kurth's avatar
Thorsten Kurth committed
160
    
161
162
163
164
165
166
167
168
169
170
171
172
173

def _split(input_, dim_, group=None):
    """Split the tensor along its last dimension and keep the corresponding slice."""
    # Bypass the function if we are using only 1 GPU.
    comm_size = dist.get_world_size(group=group)
    if comm_size == 1:
        return input_
    
    # Split along last dimension.
    input_list = split_tensor_along_dim(input_, dim_, comm_size)
    
    # Note: torch.split does not create contiguous tensors by default.
    rank = dist.get_rank(group=group)
174
    output = input_list[rank]
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
    
    return output


def _gather(input_, dim_, shapes_, group=None):
    """Gather unevenly split tensors across ranks"""
    
    comm_size = dist.get_world_size(group=group)

    if (shapes_ is not None) and (len(shapes_) != comm_size):
        raise ValueError()
    if dim_ >= input_.dim():
        raise ValueError()

    if comm_size == 1:
        return input_

    # make contiguous:
    input_ = input_.contiguous()
    input_shape = list(input_.shape)

    if shapes_ is not None:
        input_list = [None] * comm_size

        for src in range(comm_size):
            input_shape[dim_] = shapes_[src]
            input_list[src] = torch.empty(
                input_shape,
                dtype=input_.dtype,
                device=input_.device,
            )
    else:
        # assume equal shape on all ranks
        input_list = [torch.empty_like(input_) for _ in range(comm_size)]

    dist.all_gather(input_list, input_, group=group)

212
    output = torch.cat(input_list, dim=dim_)
213
214

    return output
Boris Bonev's avatar
Boris Bonev committed
215
216


Thorsten Kurth's avatar
Thorsten Kurth committed
217
218
219
220
221
222
223
224
225
226
def _reduce_scatter(input_, dim_, use_fp32=True, group=None):
    """All-reduce the input tensor across model parallel group and scatter it back."""

    # Bypass the function if we are using only 1 GPU.
    if dist.get_world_size(group=group) == 1:
        return input_

    # make input contiguous
    comm_size = dist.get_world_size(group=group)
    comm_rank = dist.get_rank(group=group)
227
    input_list = split_tensor_along_dim(input_, dim_, comm_size)
Thorsten Kurth's avatar
Thorsten Kurth committed
228
229
230
231
232

    dtype = input_.dtype
    if (use_fp32 and (dtype != torch.float32)):
        input_list = [x.to(torch.float32) for x in input_list]

233
234
    input_list = [x.contiguous() for x in input_list]

Thorsten Kurth's avatar
Thorsten Kurth committed
235
236
237
238
239
240
241
242
243
244
245
    # perform reduce_scatter
    output = torch.empty_like(input_list[comm_rank])
    dist.reduce_scatter(output, input_list, group=group)

    # convert dtype if necessary
    if use_fp32:
        output = output.to(dtype=dtype)

    return output


Boris Bonev's avatar
Boris Bonev committed
246
247
248
249
250
251
252
253
254
255
class _CopyToPolarRegion(torch.autograd.Function):
    """Split the input and keep only the corresponding chunk to the rank."""
    
    @staticmethod
    def symbolic(graph, input_):
        return input_
    
    @staticmethod
    def forward(ctx, input_):
        return input_
256
    
Boris Bonev's avatar
Boris Bonev committed
257
258
259
260
261
262
263
    @staticmethod
    def backward(ctx, grad_output):
        if is_distributed_polar():
            return _reduce(grad_output, group=polar_group())
        else:
            return grad_output, None
        
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
    
class _ScatterToPolarRegion(torch.autograd.Function):
    """Split the input and keep only the corresponding chunk to the rank."""

    @staticmethod
    def symbolic(graph, input_, dim_):
        return _split(input_, dim_, group=polar_group())

    @staticmethod
    def forward(ctx, input_, dim_):
        if is_distributed_polar():
            ctx.dim = dim_
            ctx.split_shapes = compute_split_shapes(
                input_.shape[dim_], polar_group_size()
            )
            return _split(input_, dim_, group=polar_group())
        else:
            return input_

    @staticmethod
    def backward(ctx, grad_output):
        if is_distributed_polar():
            return _gather(grad_output, ctx.dim, ctx.split_shapes, polar_group()), None
        else:
            return grad_output, None

Boris Bonev's avatar
Boris Bonev committed
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312

class _GatherFromPolarRegion(torch.autograd.Function):
    """Gather the input and keep it on the rank."""

    @staticmethod
    def symbolic(graph, input_, dim_, shapes_):
        return _gather(input_, dim_, shapes_, polar_group())

    @staticmethod
    def forward(ctx, input_, dim_, shapes_):
        if is_distributed_polar():
            ctx.dim = dim_
            return _gather(input_, dim_, shapes_, group=polar_group())
        else:
            return input_

    @staticmethod
    def backward(ctx, grad_output):
        if is_distributed_polar():
            return _split(grad_output, ctx.dim, group=polar_group()), None, None
        else:
            return grad_output, None, None

313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
    
class _ReduceFromPolarRegion(torch.autograd.Function):
    """All-reduce the input from the polar region."""
    
    @staticmethod
    def symbolic(graph, input_):
        if is_distributed_polar():
            return _reduce(input_, group=polar_group())
        else:
            return input_

    @staticmethod
    def forward(ctx, input_):
        if is_distributed_polar():
            return _reduce(input_, group=polar_group())
        else:
            return input_

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output

Boris Bonev's avatar
Boris Bonev committed
335

Thorsten Kurth's avatar
Thorsten Kurth committed
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
class _ReduceFromScatterToPolarRegion(torch.autograd.Function):
    """All-reduce the input from the polar region and scatter back to polar region."""

    @staticmethod
    def symbolic(graph, input_, dim_):
        if is_distributed_polar():
            return _reduce_scatter(input_, dim_, group=polar_group())
        else:
            return input_

    @staticmethod
    def forward(ctx, input_, dim_):
        if is_distributed_polar():
            ctx.dim = dim_
            ctx.split_shapes = compute_split_shapes(
                input_.shape[dim_], polar_group_size()
            )
            return _reduce_scatter(input_, dim_, group=polar_group())
        else:
            return input_

    @staticmethod
    def backward(ctx, grad_output):
        if is_distributed_polar():
            return _gather(grad_output, ctx.dim, ctx.split_shapes, polar_group()), None
        else:
            return grad_output, None


class _GatherFromCopyToPolarRegion(torch.autograd.Function):
    """Gather the input from the polar region and register BWD AR, basically the inverse of reduce-scatter"""

    @staticmethod
    def symbolic(graph, input_, dim_, shapes_):
        if is_distributed_polar():
            return _gather(input_, dim_, shapes_, polar_group())
        else:
            return input_

    @staticmethod
    def forward(ctx, input_, dim_, shapes_):
        if is_distributed_polar():
            ctx.dim = dim_
            return _gather(input_, dim_, shapes_, group=polar_group())
        else:
            return input_

    @staticmethod
    def backward(ctx, grad_output):
        if is_distributed_polar():
            return _reduce_scatter(grad_output, ctx.dim, use_fp32=True, group=polar_group()), None, None
        else:
            return grad_output, None, None
        

        
Boris Bonev's avatar
Boris Bonev committed
392
393
394
def copy_to_polar_region(input_):
    return _CopyToPolarRegion.apply(input_)
    
395
396
397
398
399
400
401
    
def reduce_from_polar_region(input_):
    return _ReduceFromPolarRegion.apply(input_)


def scatter_to_polar_region(input_, dim_):
    return _ScatterToPolarRegion.apply(input_, dim_)
Boris Bonev's avatar
Boris Bonev committed
402
403
404
405


def gather_from_polar_region(input_, dim_, shapes_):
    return _GatherFromPolarRegion.apply(input_, dim_, shapes_)
Thorsten Kurth's avatar
Thorsten Kurth committed
406
407
408
409
410
411
412
413


def reduce_from_scatter_to_polar_region(input_, dim_):
    return _ReduceFromScatterToPolarRegion.apply(input_, dim_)


def gather_from_copy_to_polar_region(input_, dim_, shapes_):
    return _GatherFromCopyToPolarRegion.apply(input_, dim_, shapes_)