primitives.py 17.8 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# 
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
31
from typing import List
Boris Bonev's avatar
Boris Bonev committed
32
33
34

import torch
import torch.distributed as dist
35
from torch.amp import custom_fwd, custom_bwd
Boris Bonev's avatar
Boris Bonev committed
36

37
from .utils import polar_group, azimuth_group, polar_group_size
38
from .utils import is_initialized, is_distributed_polar, is_distributed_azimuth
Boris Bonev's avatar
Boris Bonev committed
39

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# helper routine to compute uneven splitting in balanced way:
def compute_split_shapes(size: int, num_chunks: int) -> List[int]:
    
    # treat trivial case first
    if num_chunks == 1:
        return [size]
    
    # first, check if we can split using div-up to balance the load: 
    chunk_size = (size + num_chunks - 1) // num_chunks
    last_chunk_size = max(0, size - chunk_size * (num_chunks - 1))
    if last_chunk_size == 0:
        # in this case, the last shard would be empty, split with floor instead:
        chunk_size = size // num_chunks
        last_chunk_size = size - chunk_size * (num_chunks-1)

    # generate sections list
    sections = [chunk_size for _ in range(num_chunks - 1)] + [last_chunk_size]

    return sections

    
Boris Bonev's avatar
Boris Bonev committed
61
62
def split_tensor_along_dim(tensor, dim, num_chunks):
    assert dim < tensor.dim(), f"Error, tensor dimension is {tensor.dim()} which cannot be split along {dim}"
63
64
65
66
67
68
    assert (tensor.shape[dim] >= num_chunks), f"Error, cannot split dim {dim} of size {tensor.shape[dim]} into \
                                              {num_chunks} chunks. Empty slices are currently not supported."
    
    # get split
    sections = compute_split_shapes(tensor.shape[dim], num_chunks)
    tensor_list = torch.split(tensor, sections, dim=dim)
Boris Bonev's avatar
Boris Bonev committed
69
70
71
    
    return tensor_list

72
73

def _transpose(tensor, dim0, dim1, dim1_split_sizes, group=None, async_op=False):
Boris Bonev's avatar
Boris Bonev committed
74
    # get comm params
Boris Bonev's avatar
Boris Bonev committed
75
    comm_size = dist.get_world_size(group=group)
76
    comm_rank = dist.get_rank(group=group)
Boris Bonev's avatar
Boris Bonev committed
77

Boris Bonev's avatar
Boris Bonev committed
78
    # split and local transposition
79
    tsplit = split_tensor_along_dim(tensor, num_chunks=comm_size, dim=dim0)
Thorsten Kurth's avatar
Thorsten Kurth committed
80
    x_send = [y.contiguous() for y in tsplit]
81
82
83
84
85
    x_send_shapes = [x.shape for x in x_send]
    x_recv = []
    x_shape = list(x_send_shapes[comm_rank])
    for dim1_len in dim1_split_sizes:
        x_shape[dim1] = dim1_len
Thorsten Kurth's avatar
Thorsten Kurth committed
86
87
        x_recv.append(torch.empty(x_shape, dtype=tensor.dtype, device=tensor.device))
        
Boris Bonev's avatar
Boris Bonev committed
88
89
    # global transposition
    req = dist.all_to_all(x_recv, x_send, group=group, async_op=async_op)
90
91
92

    # get dim0 split sizes
    dim0_split_sizes = [x[dim0] for x in x_send_shapes]
Boris Bonev's avatar
Boris Bonev committed
93
    
94
    return x_recv, dim0_split_sizes, req
Boris Bonev's avatar
Boris Bonev committed
95
96


Boris Bonev's avatar
Boris Bonev committed
97
class distributed_transpose_azimuth(torch.autograd.Function):
apaaris's avatar
apaaris committed
98
99
100
101
102
    r"""
    Distributed transpose operation for azimuthal dimension.
    This class provides the forward and backward passes for distributed
    tensor transposition along the azimuthal dimension.
    """
Boris Bonev's avatar
Boris Bonev committed
103
104

    @staticmethod
105
    @custom_fwd(device_type="cuda")
106
    def forward(ctx, x, dims, dim1_split_sizes):
apaaris's avatar
apaaris committed
107
108
109
110
111
112
113
114
        r"""
        Forward pass for distributed azimuthal transpose.
        
        Parameters:
        x: input tensor
        dims: dimensions to transpose
        dim1_split_sizes: split sizes for dimension 1
        """
115
        # WAR for a potential contig check torch bug for channels last contig tensors
116
        xlist, dim0_split_sizes, _ = _transpose(x, dims[0], dims[1], dim1_split_sizes, group=azimuth_group())
117
        x = torch.cat(xlist, dim=dims[1])
118
119
        ctx.dims = dims
        ctx.dim0_split_sizes = dim0_split_sizes
Thorsten Kurth's avatar
Thorsten Kurth committed
120
        
Boris Bonev's avatar
Boris Bonev committed
121
        return x
Boris Bonev's avatar
Boris Bonev committed
122
123

    @staticmethod
124
    @custom_bwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
125
    def backward(ctx, go):
apaaris's avatar
apaaris committed
126
127
128
129
130
131
132
133
134
        r"""
        Backward pass for distributed azimuthal transpose.
        
        Parameters:
        go: gradient of the output
        
        Returns:
        gradient of the input
        """
135
136
        dims = ctx.dims
        dim0_split_sizes = ctx.dim0_split_sizes
137
        # WAR for a potential contig check torch bug for channels last contig tensors 
138
        gilist, _, _ = _transpose(go, dims[1], dims[0], dim0_split_sizes, group=azimuth_group())
139
        gi = torch.cat(gilist, dim=dims[0])
Thorsten Kurth's avatar
Thorsten Kurth committed
140
        
141
        return gi, None, None
Boris Bonev's avatar
Boris Bonev committed
142
143

    
Boris Bonev's avatar
Boris Bonev committed
144
class distributed_transpose_polar(torch.autograd.Function):
apaaris's avatar
apaaris committed
145
146
147
148
149
    r"""
    Distributed transpose operation for polar dimension.
    This class provides the forward and backward passes for distributed
    tensor transposition along the polar dimension.
    """
Boris Bonev's avatar
Boris Bonev committed
150
151

    @staticmethod
152
    @custom_fwd(device_type="cuda")
153
    def forward(ctx, x, dim, dim1_split_sizes):
apaaris's avatar
apaaris committed
154
155
156
157
158
159
160
161
        r"""
        Forward pass for distributed polar transpose.
        
        Parameters:
        x: input tensor
        dim: dimensions to transpose
        dim1_split_sizes: split sizes for dimension 1
        """
162
        # WAR for a potential contig check torch bug for channels last contig tensors 
163
        xlist, dim0_split_sizes, _ = _transpose(x, dim[0], dim[1], dim1_split_sizes, group=polar_group())
164
        x = torch.cat(xlist, dim=dim[1])
Boris Bonev's avatar
Boris Bonev committed
165
        ctx.dim = dim
166
        ctx.dim0_split_sizes = dim0_split_sizes
Boris Bonev's avatar
Boris Bonev committed
167
        return x
Boris Bonev's avatar
Boris Bonev committed
168
169

    @staticmethod
170
    @custom_bwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
171
    def backward(ctx, go):
apaaris's avatar
apaaris committed
172
173
174
175
176
177
178
179
180
        r"""
        Backward pass for distributed polar transpose.
        
        Parameters:
        go: gradient of the output
        
        Returns:
        gradient of the input
        """
Boris Bonev's avatar
Boris Bonev committed
181
        dim = ctx.dim
182
        dim0_split_sizes = ctx.dim0_split_sizes
183
        # WAR for a potential contig check torch bug for channels last contig tensors 
184
        gilist, _, _ = _transpose(go, dim[1], dim[0], dim0_split_sizes, group=polar_group())
185
        gi = torch.cat(gilist, dim=dim[0])
186
        return gi, None, None
Boris Bonev's avatar
Boris Bonev committed
187

188
189
190
191
192
193
194
195
196
197
198
199
200
    
# we need those additional primitives for distributed matrix multiplications
def _reduce(input_, use_fp32=True, group=None):
    """All-reduce the input tensor across model parallel group."""

    # Bypass the function if we are using only 1 GPU.
    if dist.get_world_size(group=group) == 1:
        return input_
    
    # All-reduce.
    if use_fp32:
        dtype = input_.dtype
        inputf_ = input_.float()
201
        inputf_ = inputf_.contiguous()
202
203
204
        dist.all_reduce(inputf_, group=group)
        input_ = inputf_.to(dtype)
    else:
205
        input_ = input_.contiguous()
206
207
208
        dist.all_reduce(input_, group=group)
        
    return input_
Thorsten Kurth's avatar
Thorsten Kurth committed
209
    
210
211
212
213
214
215
216
217
218
219
220
221
222

def _split(input_, dim_, group=None):
    """Split the tensor along its last dimension and keep the corresponding slice."""
    # Bypass the function if we are using only 1 GPU.
    comm_size = dist.get_world_size(group=group)
    if comm_size == 1:
        return input_
    
    # Split along last dimension.
    input_list = split_tensor_along_dim(input_, dim_, comm_size)
    
    # Note: torch.split does not create contiguous tensors by default.
    rank = dist.get_rank(group=group)
223
    output = input_list[rank]
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
    
    return output


def _gather(input_, dim_, shapes_, group=None):
    """Gather unevenly split tensors across ranks"""
    
    comm_size = dist.get_world_size(group=group)

    if (shapes_ is not None) and (len(shapes_) != comm_size):
        raise ValueError()
    if dim_ >= input_.dim():
        raise ValueError()

    if comm_size == 1:
        return input_

    # make contiguous:
    input_ = input_.contiguous()
    input_shape = list(input_.shape)

    if shapes_ is not None:
246
        input_list = []
247
248
        for src in range(comm_size):
            input_shape[dim_] = shapes_[src]
249
            input_list.append(torch.empty(input_shape, dtype=input_.dtype, device=input_.device))
250
251
252
253
254
255
    else:
        # assume equal shape on all ranks
        input_list = [torch.empty_like(input_) for _ in range(comm_size)]

    dist.all_gather(input_list, input_, group=group)

256
    output = torch.cat(input_list, dim=dim_)
257
258

    return output
Boris Bonev's avatar
Boris Bonev committed
259
260


Thorsten Kurth's avatar
Thorsten Kurth committed
261
262
263
264
265
266
267
268
269
270
def _reduce_scatter(input_, dim_, use_fp32=True, group=None):
    """All-reduce the input tensor across model parallel group and scatter it back."""

    # Bypass the function if we are using only 1 GPU.
    if dist.get_world_size(group=group) == 1:
        return input_

    # make input contiguous
    comm_size = dist.get_world_size(group=group)
    comm_rank = dist.get_rank(group=group)
271
    input_list = split_tensor_along_dim(input_, dim_, comm_size)
Thorsten Kurth's avatar
Thorsten Kurth committed
272
273
274
275
276

    dtype = input_.dtype
    if (use_fp32 and (dtype != torch.float32)):
        input_list = [x.to(torch.float32) for x in input_list]

277
278
    input_list = [x.contiguous() for x in input_list]

Thorsten Kurth's avatar
Thorsten Kurth committed
279
280
281
282
283
284
285
286
287
288
289
    # perform reduce_scatter
    output = torch.empty_like(input_list[comm_rank])
    dist.reduce_scatter(output, input_list, group=group)

    # convert dtype if necessary
    if use_fp32:
        output = output.to(dtype=dtype)

    return output


Boris Bonev's avatar
Boris Bonev committed
290
class _CopyToPolarRegion(torch.autograd.Function):
apaaris's avatar
apaaris committed
291
292
293
294
295
    r"""
    Copy tensor to polar region for distributed computation.
    This class provides the forward and backward passes for copying
    tensors to the polar region in distributed settings.
    """
Boris Bonev's avatar
Boris Bonev committed
296
297
298
299
300
301
    
    @staticmethod
    def symbolic(graph, input_):
        return input_
    
    @staticmethod
302
    @custom_fwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
303
    def forward(ctx, input_):
apaaris's avatar
apaaris committed
304
305
306
307
308
309
310
311
312
        r"""
        Forward pass for copying to polar region.
        
        Parameters:
        input_: input tensor
        
        Returns:
        input tensor (no-op in forward pass)
        """
Boris Bonev's avatar
Boris Bonev committed
313
        return input_
314
    
Boris Bonev's avatar
Boris Bonev committed
315
    @staticmethod
316
    @custom_bwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
317
    def backward(ctx, grad_output):
apaaris's avatar
apaaris committed
318
319
320
321
322
323
324
325
326
        r"""
        Backward pass for copying to polar region.
        
        Parameters:
        grad_output: gradient of the output
        
        Returns:
        gradient of the input
        """
Boris Bonev's avatar
Boris Bonev committed
327
328
329
330
        if is_distributed_polar():
            return _reduce(grad_output, group=polar_group())
        else:
            return grad_output, None
331
332
333


class _CopyToAzimuthRegion(torch.autograd.Function):
apaaris's avatar
apaaris committed
334
335
336
337
338
    r"""
    Copy tensor to azimuth region for distributed computation.
    This class provides the forward and backward passes for copying
    tensors to the azimuth region in distributed settings.
    """
339
340
341
342
343
344
345
346

    @staticmethod
    def symbolic(graph, input_):
        return input_

    @staticmethod
    @custom_fwd(device_type="cuda")
    def forward(ctx, input_):
apaaris's avatar
apaaris committed
347
348
349
350
351
352
353
354
355
        r"""
        Forward pass for copying to azimuth region.
        
        Parameters:
        input_: input tensor
        
        Returns:
        input tensor (no-op in forward pass)
        """
356
357
358
359
360
        return input_

    @staticmethod
    @custom_bwd(device_type="cuda")
    def backward(ctx, grad_output):
apaaris's avatar
apaaris committed
361
362
363
364
365
366
367
368
369
        r"""
        Backward pass for copying to azimuth region.
        
        Parameters:
        grad_output: gradient of the output
        
        Returns:
        gradient of the input
        """
370
371
372
373
374
375
        if is_distributed_azimuth():
            return _reduce(grad_output, group=azimuth_group())
        else:
            return grad_output, None


376
class _ScatterToPolarRegion(torch.autograd.Function):
apaaris's avatar
apaaris committed
377
378
379
380
381
    r"""
    Scatter tensor to polar region for distributed computation.
    This class provides the forward and backward passes for scattering
    tensors to the polar region in distributed settings.
    """
382
383
384
385
386
387

    @staticmethod
    def symbolic(graph, input_, dim_):
        return _split(input_, dim_, group=polar_group())

    @staticmethod
388
    @custom_fwd(device_type="cuda")
389
390
391
392
393
394
395
396
397
398
399
    def forward(ctx, input_, dim_):
        if is_distributed_polar():
            ctx.dim = dim_
            ctx.split_shapes = compute_split_shapes(
                input_.shape[dim_], polar_group_size()
            )
            return _split(input_, dim_, group=polar_group())
        else:
            return input_

    @staticmethod
400
    @custom_bwd(device_type="cuda")
401
402
403
404
405
406
    def backward(ctx, grad_output):
        if is_distributed_polar():
            return _gather(grad_output, ctx.dim, ctx.split_shapes, polar_group()), None
        else:
            return grad_output, None

Boris Bonev's avatar
Boris Bonev committed
407
408
409
410
411
412
413
414
415

class _GatherFromPolarRegion(torch.autograd.Function):
    """Gather the input and keep it on the rank."""

    @staticmethod
    def symbolic(graph, input_, dim_, shapes_):
        return _gather(input_, dim_, shapes_, polar_group())

    @staticmethod
416
    @custom_fwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
417
418
419
420
421
422
423
424
    def forward(ctx, input_, dim_, shapes_):
        if is_distributed_polar():
            ctx.dim = dim_
            return _gather(input_, dim_, shapes_, group=polar_group())
        else:
            return input_

    @staticmethod
425
    @custom_bwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
426
427
428
429
430
431
    def backward(ctx, grad_output):
        if is_distributed_polar():
            return _split(grad_output, ctx.dim, group=polar_group()), None, None
        else:
            return grad_output, None, None

432
433
434
435
436
437
438
439
440
441
442
443
    
class _ReduceFromPolarRegion(torch.autograd.Function):
    """All-reduce the input from the polar region."""
    
    @staticmethod
    def symbolic(graph, input_):
        if is_distributed_polar():
            return _reduce(input_, group=polar_group())
        else:
            return input_

    @staticmethod
444
    @custom_fwd(device_type="cuda")
445
446
447
448
449
450
451
    def forward(ctx, input_):
        if is_distributed_polar():
            return _reduce(input_, group=polar_group())
        else:
            return input_

    @staticmethod
452
    @custom_bwd(device_type="cuda")
453
454
455
    def backward(ctx, grad_output):
        return grad_output

456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
    
class _ReduceFromAzimuthRegion(torch.autograd.Function):
    """All-reduce the input from the azimuth region."""

    @staticmethod
    def symbolic(graph, input_):
        if is_distributed_azimuth():
            return _reduce(input_, group=azimuth_group())
        else:
            return input_

    @staticmethod
    @custom_fwd(device_type="cuda")
    def forward(ctx, input_):
        if is_distributed_azimuth():
            return _reduce(input_, group=azimuth_group())
        else:
            return input_

    @staticmethod
    @custom_bwd(device_type="cuda")
    def backward(ctx, grad_output):
        return grad_output

Boris Bonev's avatar
Boris Bonev committed
480

Thorsten Kurth's avatar
Thorsten Kurth committed
481
482
483
484
485
486
487
488
489
490
491
class _ReduceFromScatterToPolarRegion(torch.autograd.Function):
    """All-reduce the input from the polar region and scatter back to polar region."""

    @staticmethod
    def symbolic(graph, input_, dim_):
        if is_distributed_polar():
            return _reduce_scatter(input_, dim_, group=polar_group())
        else:
            return input_

    @staticmethod
492
    @custom_fwd(device_type="cuda")
Thorsten Kurth's avatar
Thorsten Kurth committed
493
494
495
496
497
498
499
500
501
502
503
    def forward(ctx, input_, dim_):
        if is_distributed_polar():
            ctx.dim = dim_
            ctx.split_shapes = compute_split_shapes(
                input_.shape[dim_], polar_group_size()
            )
            return _reduce_scatter(input_, dim_, group=polar_group())
        else:
            return input_

    @staticmethod
504
    @custom_bwd(device_type="cuda")
Thorsten Kurth's avatar
Thorsten Kurth committed
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
    def backward(ctx, grad_output):
        if is_distributed_polar():
            return _gather(grad_output, ctx.dim, ctx.split_shapes, polar_group()), None
        else:
            return grad_output, None


class _GatherFromCopyToPolarRegion(torch.autograd.Function):
    """Gather the input from the polar region and register BWD AR, basically the inverse of reduce-scatter"""

    @staticmethod
    def symbolic(graph, input_, dim_, shapes_):
        if is_distributed_polar():
            return _gather(input_, dim_, shapes_, polar_group())
        else:
            return input_

    @staticmethod
523
    @custom_fwd(device_type="cuda")
Thorsten Kurth's avatar
Thorsten Kurth committed
524
525
526
527
528
529
530
531
    def forward(ctx, input_, dim_, shapes_):
        if is_distributed_polar():
            ctx.dim = dim_
            return _gather(input_, dim_, shapes_, group=polar_group())
        else:
            return input_

    @staticmethod
532
    @custom_bwd(device_type="cuda")
Thorsten Kurth's avatar
Thorsten Kurth committed
533
534
535
536
537
538
539
540
    def backward(ctx, grad_output):
        if is_distributed_polar():
            return _reduce_scatter(grad_output, ctx.dim, use_fp32=True, group=polar_group()), None, None
        else:
            return grad_output, None, None
        

        
Boris Bonev's avatar
Boris Bonev committed
541
542
def copy_to_polar_region(input_):
    return _CopyToPolarRegion.apply(input_)
543
544
545
546

def copy_to_azimuth_region(input_):
    return _CopyToAzimuthRegion.apply(input_)
        
547
548
549
def reduce_from_polar_region(input_):
    return _ReduceFromPolarRegion.apply(input_)

550
551
def reduce_from_azimuth_region(input_):
    return _ReduceFromAzimuthRegion.apply(input_)
552
553
554

def scatter_to_polar_region(input_, dim_):
    return _ScatterToPolarRegion.apply(input_, dim_)
Boris Bonev's avatar
Boris Bonev committed
555
556
557

def gather_from_polar_region(input_, dim_, shapes_):
    return _GatherFromPolarRegion.apply(input_, dim_, shapes_)
Thorsten Kurth's avatar
Thorsten Kurth committed
558
559
560
561
562
563

def reduce_from_scatter_to_polar_region(input_, dim_):
    return _ReduceFromScatterToPolarRegion.apply(input_, dim_)

def gather_from_copy_to_polar_region(input_, dim_, shapes_):
    return _GatherFromCopyToPolarRegion.apply(input_, dim_, shapes_)