primitives.py 22.3 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# 
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
31
from typing import List
Boris Bonev's avatar
Boris Bonev committed
32
33
34

import torch
import torch.distributed as dist
35
from torch.amp import custom_fwd, custom_bwd
Boris Bonev's avatar
Boris Bonev committed
36

37
from .utils import polar_group, azimuth_group, polar_group_size
38
from .utils import is_initialized, is_distributed_polar, is_distributed_azimuth
Boris Bonev's avatar
Boris Bonev committed
39

40
41
# helper routine to compute uneven splitting in balanced way:
def compute_split_shapes(size: int, num_chunks: int) -> List[int]:
apaaris's avatar
apaaris committed
42
43
44
45
46
47
48
49
50
51
52
53
54
    """
    Compute the split shapes for a given size and number of chunks.
    
    Parameters
    ----------
    size: int
        The size of the tensor to split
    
    Returns
    -------
    List[int]
        The split shapes
    """
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    
    # treat trivial case first
    if num_chunks == 1:
        return [size]
    
    # first, check if we can split using div-up to balance the load: 
    chunk_size = (size + num_chunks - 1) // num_chunks
    last_chunk_size = max(0, size - chunk_size * (num_chunks - 1))
    if last_chunk_size == 0:
        # in this case, the last shard would be empty, split with floor instead:
        chunk_size = size // num_chunks
        last_chunk_size = size - chunk_size * (num_chunks-1)

    # generate sections list
    sections = [chunk_size for _ in range(num_chunks - 1)] + [last_chunk_size]

    return sections

    
Boris Bonev's avatar
Boris Bonev committed
74
def split_tensor_along_dim(tensor, dim, num_chunks):
apaaris's avatar
apaaris committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    """
    Split a tensor along a given dimension into a given number of chunks.
    
    Parameters
    ----------
    tensor: torch.Tensor
        The tensor to split
    dim: int
        The dimension to split along
    num_chunks: int
        The number of chunks to split into
        
    Returns
    -------
    tensor_list: List[torch.Tensor]  
        The split tensors
    """
    
Boris Bonev's avatar
Boris Bonev committed
93
    assert dim < tensor.dim(), f"Error, tensor dimension is {tensor.dim()} which cannot be split along {dim}"
94
95
96
97
98
99
    assert (tensor.shape[dim] >= num_chunks), f"Error, cannot split dim {dim} of size {tensor.shape[dim]} into \
                                              {num_chunks} chunks. Empty slices are currently not supported."
    
    # get split
    sections = compute_split_shapes(tensor.shape[dim], num_chunks)
    tensor_list = torch.split(tensor, sections, dim=dim)
Boris Bonev's avatar
Boris Bonev committed
100
101
102
    
    return tensor_list

103
104

def _transpose(tensor, dim0, dim1, dim1_split_sizes, group=None, async_op=False):
apaaris's avatar
apaaris committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    """
    Transpose a tensor along two dimensions.
    
    Parameters
    ----------
    tensor: torch.Tensor
        The tensor to transpose
    dim0: int
        The first dimension to transpose
    dim1: int
        The second dimension to transpose
    dim1_split_sizes: List[int]
        The split sizes for the second dimension

    Returns
    -------
    tensor_list: List[torch.Tensor]
        The split tensors
    """
    
Boris Bonev's avatar
Boris Bonev committed
125
    # get comm params
Boris Bonev's avatar
Boris Bonev committed
126
    comm_size = dist.get_world_size(group=group)
127
    comm_rank = dist.get_rank(group=group)
Boris Bonev's avatar
Boris Bonev committed
128

Boris Bonev's avatar
Boris Bonev committed
129
    # split and local transposition
130
    tsplit = split_tensor_along_dim(tensor, num_chunks=comm_size, dim=dim0)
Thorsten Kurth's avatar
Thorsten Kurth committed
131
    x_send = [y.contiguous() for y in tsplit]
132
133
134
135
136
    x_send_shapes = [x.shape for x in x_send]
    x_recv = []
    x_shape = list(x_send_shapes[comm_rank])
    for dim1_len in dim1_split_sizes:
        x_shape[dim1] = dim1_len
Thorsten Kurth's avatar
Thorsten Kurth committed
137
138
        x_recv.append(torch.empty(x_shape, dtype=tensor.dtype, device=tensor.device))
        
Boris Bonev's avatar
Boris Bonev committed
139
140
    # global transposition
    req = dist.all_to_all(x_recv, x_send, group=group, async_op=async_op)
141
142
143

    # get dim0 split sizes
    dim0_split_sizes = [x[dim0] for x in x_send_shapes]
Boris Bonev's avatar
Boris Bonev committed
144
    
145
    return x_recv, dim0_split_sizes, req
Boris Bonev's avatar
Boris Bonev committed
146
147


Boris Bonev's avatar
Boris Bonev committed
148
class distributed_transpose_azimuth(torch.autograd.Function):
apaaris's avatar
apaaris committed
149
150
151
152
    r"""
    Distributed transpose operation for azimuthal dimension.
    This class provides the forward and backward passes for distributed
    tensor transposition along the azimuthal dimension.
apaaris's avatar
apaaris committed
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
    
    Parameters
    ----------
    tensor: torch.Tensor
        The tensor to transpose
    dim0: int
        The first dimension to transpose
    dim1: int
        The second dimension to transpose
    dim1_split_sizes: List[int]
        The split sizes for the second dimension

    Returns
    -------
    x_recv: List[torch.Tensor]
        The split tensors
    dim0_split_sizes: List[int]
        The split sizes for the first dimension
    req: dist.Request
        The request object
apaaris's avatar
apaaris committed
173
    """
Boris Bonev's avatar
Boris Bonev committed
174
175

    @staticmethod
176
    @custom_fwd(device_type="cuda")
177
    def forward(ctx, x, dims, dim1_split_sizes):
apaaris's avatar
apaaris committed
178
179
180
        r"""
        Forward pass for distributed azimuthal transpose.
        
apaaris's avatar
apaaris committed
181
182
183
184
185
186
187
188
189
190
191
192
193
        Parameters
        ----------
        x: torch.Tensor
            The tensor to transpose
        dims: List[int]
            The dimensions to transpose
        dim1_split_sizes: List[int]
            The split sizes for the second dimension

        Returns
        -------
        x: torch.Tensor
            The transposed tensor
apaaris's avatar
apaaris committed
194
        """
195
        # WAR for a potential contig check torch bug for channels last contig tensors
196
        xlist, dim0_split_sizes, _ = _transpose(x, dims[0], dims[1], dim1_split_sizes, group=azimuth_group())
197
        x = torch.cat(xlist, dim=dims[1])
198
199
        ctx.dims = dims
        ctx.dim0_split_sizes = dim0_split_sizes
Thorsten Kurth's avatar
Thorsten Kurth committed
200
        
Boris Bonev's avatar
Boris Bonev committed
201
        return x
Boris Bonev's avatar
Boris Bonev committed
202
203

    @staticmethod
204
    @custom_bwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
205
    def backward(ctx, go):
apaaris's avatar
apaaris committed
206
207
208
        r"""
        Backward pass for distributed azimuthal transpose.
        
apaaris's avatar
apaaris committed
209
210
211
212
        Parameters
        ----------
        go: torch.Tensor
            The gradient of the output
apaaris's avatar
apaaris committed
213
        
apaaris's avatar
apaaris committed
214
215
216
217
        Returns
        -------
        gi: torch.Tensor
            The gradient of the input
apaaris's avatar
apaaris committed
218
        """
219
220
        dims = ctx.dims
        dim0_split_sizes = ctx.dim0_split_sizes
221
        # WAR for a potential contig check torch bug for channels last contig tensors 
222
        gilist, _, _ = _transpose(go, dims[1], dims[0], dim0_split_sizes, group=azimuth_group())
223
        gi = torch.cat(gilist, dim=dims[0])
Thorsten Kurth's avatar
Thorsten Kurth committed
224
        
225
        return gi, None, None
Boris Bonev's avatar
Boris Bonev committed
226
227

    
Boris Bonev's avatar
Boris Bonev committed
228
class distributed_transpose_polar(torch.autograd.Function):
apaaris's avatar
apaaris committed
229
230
231
232
    r"""
    Distributed transpose operation for polar dimension.
    This class provides the forward and backward passes for distributed
    tensor transposition along the polar dimension.
apaaris's avatar
apaaris committed
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250

    Parameters
    ----------
    x: torch.Tensor
        The tensor to transpose
    dims: List[int]
        The dimensions to transpose
    dim1_split_sizes: List[int]
        The split sizes for the second dimension

    Returns
    -------
    x: torch.Tensor
        The transposed tensor
    dim0_split_sizes: List[int]
        The split sizes for the first dimension
    req: dist.Request
        The request object
apaaris's avatar
apaaris committed
251
    """
Boris Bonev's avatar
Boris Bonev committed
252
253

    @staticmethod
254
    @custom_fwd(device_type="cuda")
255
    def forward(ctx, x, dim, dim1_split_sizes):
apaaris's avatar
apaaris committed
256
257
258
        r"""
        Forward pass for distributed polar transpose.
        
apaaris's avatar
apaaris committed
259
260
261
262
263
264
265
266
267
268
269
270
271
        Parameters
        ----------
        x: torch.Tensor
            The tensor to transpose
        dim: List[int]
            The dimensions to transpose
        dim1_split_sizes: List[int]
            The split sizes for the second dimension

        Returns
        -------
        x: torch.Tensor
            The transposed tensor
apaaris's avatar
apaaris committed
272
        """
273
        # WAR for a potential contig check torch bug for channels last contig tensors 
274
        xlist, dim0_split_sizes, _ = _transpose(x, dim[0], dim[1], dim1_split_sizes, group=polar_group())
275
        x = torch.cat(xlist, dim=dim[1])
Boris Bonev's avatar
Boris Bonev committed
276
        ctx.dim = dim
277
        ctx.dim0_split_sizes = dim0_split_sizes
Boris Bonev's avatar
Boris Bonev committed
278
        return x
Boris Bonev's avatar
Boris Bonev committed
279
280

    @staticmethod
281
    @custom_bwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
282
    def backward(ctx, go):
apaaris's avatar
apaaris committed
283
284
285
        r"""
        Backward pass for distributed polar transpose.
        
apaaris's avatar
apaaris committed
286
287
288
289
        Parameters
        ----------
        go: torch.Tensor
            The gradient of the output
apaaris's avatar
apaaris committed
290
        
apaaris's avatar
apaaris committed
291
292
293
294
        Returns
        -------
        gi: torch.Tensor
            The gradient of the input
apaaris's avatar
apaaris committed
295
        """
Boris Bonev's avatar
Boris Bonev committed
296
        dim = ctx.dim
297
        dim0_split_sizes = ctx.dim0_split_sizes
298
        # WAR for a potential contig check torch bug for channels last contig tensors 
299
        gilist, _, _ = _transpose(go, dim[1], dim[0], dim0_split_sizes, group=polar_group())
300
        gi = torch.cat(gilist, dim=dim[0])
301
        return gi, None, None
Boris Bonev's avatar
Boris Bonev committed
302

303
304
305
306
307
308
309
310
311
312
313
314
315
    
# we need those additional primitives for distributed matrix multiplications
def _reduce(input_, use_fp32=True, group=None):
    """All-reduce the input tensor across model parallel group."""

    # Bypass the function if we are using only 1 GPU.
    if dist.get_world_size(group=group) == 1:
        return input_
    
    # All-reduce.
    if use_fp32:
        dtype = input_.dtype
        inputf_ = input_.float()
316
        inputf_ = inputf_.contiguous()
317
318
319
        dist.all_reduce(inputf_, group=group)
        input_ = inputf_.to(dtype)
    else:
320
        input_ = input_.contiguous()
321
322
323
        dist.all_reduce(input_, group=group)
        
    return input_
Thorsten Kurth's avatar
Thorsten Kurth committed
324
    
325
326
327
328
329
330
331
332
333
334
335
336
337

def _split(input_, dim_, group=None):
    """Split the tensor along its last dimension and keep the corresponding slice."""
    # Bypass the function if we are using only 1 GPU.
    comm_size = dist.get_world_size(group=group)
    if comm_size == 1:
        return input_
    
    # Split along last dimension.
    input_list = split_tensor_along_dim(input_, dim_, comm_size)
    
    # Note: torch.split does not create contiguous tensors by default.
    rank = dist.get_rank(group=group)
338
    output = input_list[rank]
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
    
    return output


def _gather(input_, dim_, shapes_, group=None):
    """Gather unevenly split tensors across ranks"""
    
    comm_size = dist.get_world_size(group=group)

    if (shapes_ is not None) and (len(shapes_) != comm_size):
        raise ValueError()
    if dim_ >= input_.dim():
        raise ValueError()

    if comm_size == 1:
        return input_

    # make contiguous:
    input_ = input_.contiguous()
    input_shape = list(input_.shape)

    if shapes_ is not None:
361
        input_list = []
362
363
        for src in range(comm_size):
            input_shape[dim_] = shapes_[src]
364
            input_list.append(torch.empty(input_shape, dtype=input_.dtype, device=input_.device))
365
366
367
368
369
370
    else:
        # assume equal shape on all ranks
        input_list = [torch.empty_like(input_) for _ in range(comm_size)]

    dist.all_gather(input_list, input_, group=group)

371
    output = torch.cat(input_list, dim=dim_)
372
373

    return output
Boris Bonev's avatar
Boris Bonev committed
374
375


Thorsten Kurth's avatar
Thorsten Kurth committed
376
377
378
379
380
381
382
383
384
385
def _reduce_scatter(input_, dim_, use_fp32=True, group=None):
    """All-reduce the input tensor across model parallel group and scatter it back."""

    # Bypass the function if we are using only 1 GPU.
    if dist.get_world_size(group=group) == 1:
        return input_

    # make input contiguous
    comm_size = dist.get_world_size(group=group)
    comm_rank = dist.get_rank(group=group)
386
    input_list = split_tensor_along_dim(input_, dim_, comm_size)
Thorsten Kurth's avatar
Thorsten Kurth committed
387
388
389
390
391

    dtype = input_.dtype
    if (use_fp32 and (dtype != torch.float32)):
        input_list = [x.to(torch.float32) for x in input_list]

392
393
    input_list = [x.contiguous() for x in input_list]

Thorsten Kurth's avatar
Thorsten Kurth committed
394
395
396
397
398
399
400
401
402
403
404
    # perform reduce_scatter
    output = torch.empty_like(input_list[comm_rank])
    dist.reduce_scatter(output, input_list, group=group)

    # convert dtype if necessary
    if use_fp32:
        output = output.to(dtype=dtype)

    return output


Boris Bonev's avatar
Boris Bonev committed
405
class _CopyToPolarRegion(torch.autograd.Function):
apaaris's avatar
apaaris committed
406
407
408
409
    r"""
    Copy tensor to polar region for distributed computation.
    This class provides the forward and backward passes for copying
    tensors to the polar region in distributed settings.
apaaris's avatar
apaaris committed
410
411
412
413
414
415
416
417
418
419

    Parameters
    ----------
    input_: torch.Tensor
        The tensor to copy
        
    Returns
    -------
    output: torch.Tensor
        The reduced and scattered tensor
apaaris's avatar
apaaris committed
420
    """
Boris Bonev's avatar
Boris Bonev committed
421
422
423
424
425
426
    
    @staticmethod
    def symbolic(graph, input_):
        return input_
    
    @staticmethod
427
    @custom_fwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
428
    def forward(ctx, input_):
apaaris's avatar
apaaris committed
429
430
431
        r"""
        Forward pass for copying to polar region.
        
apaaris's avatar
apaaris committed
432
433
434
435
        Parameters
        ----------
        input_: torch.Tensor
            The tensor to copy
apaaris's avatar
apaaris committed
436
        
apaaris's avatar
apaaris committed
437
438
439
440
        Returns
        -------
        input_: torch.Tensor
            The tensor to copy
apaaris's avatar
apaaris committed
441
        """
Boris Bonev's avatar
Boris Bonev committed
442
        return input_
443
    
Boris Bonev's avatar
Boris Bonev committed
444
    @staticmethod
445
    @custom_bwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
446
    def backward(ctx, grad_output):
apaaris's avatar
apaaris committed
447
448
449
        r"""
        Backward pass for copying to polar region.
        
apaaris's avatar
apaaris committed
450
451
452
453
        Parameters
        ----------
        grad_output: torch.Tensor
            The gradient of the output
apaaris's avatar
apaaris committed
454
        
apaaris's avatar
apaaris committed
455
456
457
458
        Returns
        -------
        grad_output: torch.Tensor
            The gradient of the output
apaaris's avatar
apaaris committed
459
        """
Boris Bonev's avatar
Boris Bonev committed
460
461
462
463
        if is_distributed_polar():
            return _reduce(grad_output, group=polar_group())
        else:
            return grad_output, None
464
465
466


class _CopyToAzimuthRegion(torch.autograd.Function):
apaaris's avatar
apaaris committed
467
468
469
470
    r"""
    Copy tensor to azimuth region for distributed computation.
    This class provides the forward and backward passes for copying
    tensors to the azimuth region in distributed settings.
apaaris's avatar
apaaris committed
471
    
apaaris's avatar
apaaris committed
472
    """
473
474
475
476
477
478
479
480

    @staticmethod
    def symbolic(graph, input_):
        return input_

    @staticmethod
    @custom_fwd(device_type="cuda")
    def forward(ctx, input_):
apaaris's avatar
apaaris committed
481
482
483
        r"""
        Forward pass for copying to azimuth region.
        
apaaris's avatar
apaaris committed
484
485
486
487
        Parameters
        ----------
        input_: torch.Tensor
            The tensor to copy
apaaris's avatar
apaaris committed
488
        
apaaris's avatar
apaaris committed
489
490
491
492
        Returns
        -------
        input_: torch.Tensor
            The tensor to copy
apaaris's avatar
apaaris committed
493
        """
494
495
496
497
498
        return input_

    @staticmethod
    @custom_bwd(device_type="cuda")
    def backward(ctx, grad_output):
apaaris's avatar
apaaris committed
499
500
501
        r"""
        Backward pass for copying to azimuth region.
        
apaaris's avatar
apaaris committed
502
503
504
505
        Parameters
        ----------
        grad_output: torch.Tensor
            The gradient of the output
apaaris's avatar
apaaris committed
506
        
apaaris's avatar
apaaris committed
507
508
509
510
        Returns
        -------
        grad_output: torch.Tensor
            The gradient of the output
apaaris's avatar
apaaris committed
511
        """
512
513
514
515
516
517
        if is_distributed_azimuth():
            return _reduce(grad_output, group=azimuth_group())
        else:
            return grad_output, None


518
class _ScatterToPolarRegion(torch.autograd.Function):
apaaris's avatar
apaaris committed
519
520
521
522
    r"""
    Scatter tensor to polar region for distributed computation.
    This class provides the forward and backward passes for scattering
    tensors to the polar region in distributed settings.
apaaris's avatar
apaaris committed
523
524
525
526
527
528
529
530
531
532
533
534
    
    Parameters
    ----------
    input_: torch.Tensor
        The tensor to scatter
    dim_: int
        The dimension to scatter along
            
    Returns
    -------
    output: torch.Tensor
        The scattered tensor
apaaris's avatar
apaaris committed
535
    """
536
537
538
539
540
541

    @staticmethod
    def symbolic(graph, input_, dim_):
        return _split(input_, dim_, group=polar_group())

    @staticmethod
542
    @custom_fwd(device_type="cuda")
543
544
545
546
547
548
549
550
551
552
553
    def forward(ctx, input_, dim_):
        if is_distributed_polar():
            ctx.dim = dim_
            ctx.split_shapes = compute_split_shapes(
                input_.shape[dim_], polar_group_size()
            )
            return _split(input_, dim_, group=polar_group())
        else:
            return input_

    @staticmethod
554
    @custom_bwd(device_type="cuda")
555
556
557
558
559
560
    def backward(ctx, grad_output):
        if is_distributed_polar():
            return _gather(grad_output, ctx.dim, ctx.split_shapes, polar_group()), None
        else:
            return grad_output, None

Boris Bonev's avatar
Boris Bonev committed
561
562

class _GatherFromPolarRegion(torch.autograd.Function):
apaaris's avatar
apaaris committed
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
    r"""
    Gather the input and keep it on the rank.
    
    Parameters
    ----------
    input_: torch.Tensor
        The tensor to gather
    dim_: int
        The dimension to gather along
    shapes_: List[int]
        The split sizes for the dimension to gather along
        
    Returns
    -------
    output: torch.Tensor
        The gathered tensor
    """
Boris Bonev's avatar
Boris Bonev committed
580
581
582
583
584
    @staticmethod
    def symbolic(graph, input_, dim_, shapes_):
        return _gather(input_, dim_, shapes_, polar_group())

    @staticmethod
585
    @custom_fwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
586
587
588
589
590
591
592
593
    def forward(ctx, input_, dim_, shapes_):
        if is_distributed_polar():
            ctx.dim = dim_
            return _gather(input_, dim_, shapes_, group=polar_group())
        else:
            return input_

    @staticmethod
594
    @custom_bwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
595
596
597
598
599
600
    def backward(ctx, grad_output):
        if is_distributed_polar():
            return _split(grad_output, ctx.dim, group=polar_group()), None, None
        else:
            return grad_output, None, None

601
602
    
class _ReduceFromPolarRegion(torch.autograd.Function):
apaaris's avatar
apaaris committed
603
604
605
606
607
608
609
610
611
612
613
614
615
    r"""
    All-reduce the input from the polar region.
    
    Parameters
    ----------
    input_: torch.Tensor
        The tensor to reduce
        
    Returns
    -------
    output: torch.Tensor
        The reduced tensor
    """
616
617
618
619
620
621
622
623
624
    
    @staticmethod
    def symbolic(graph, input_):
        if is_distributed_polar():
            return _reduce(input_, group=polar_group())
        else:
            return input_

    @staticmethod
625
    @custom_fwd(device_type="cuda")
626
627
628
629
630
631
632
    def forward(ctx, input_):
        if is_distributed_polar():
            return _reduce(input_, group=polar_group())
        else:
            return input_

    @staticmethod
633
    @custom_bwd(device_type="cuda")
634
635
636
    def backward(ctx, grad_output):
        return grad_output

637
638
    
class _ReduceFromAzimuthRegion(torch.autograd.Function):
apaaris's avatar
apaaris committed
639
640
641
642
643
644
645
646
647
648
649
650
651
    r"""
    All-reduce the input from the azimuth region.
    
    Parameters
    ----------
    input_: torch.Tensor
        The tensor to reduce
        
    Returns
    -------
    output: torch.Tensor
        The reduced tensor
    """
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
    @staticmethod
    def symbolic(graph, input_):
        if is_distributed_azimuth():
            return _reduce(input_, group=azimuth_group())
        else:
            return input_

    @staticmethod
    @custom_fwd(device_type="cuda")
    def forward(ctx, input_):
        if is_distributed_azimuth():
            return _reduce(input_, group=azimuth_group())
        else:
            return input_

    @staticmethod
    @custom_bwd(device_type="cuda")
    def backward(ctx, grad_output):
        return grad_output

Boris Bonev's avatar
Boris Bonev committed
672

Thorsten Kurth's avatar
Thorsten Kurth committed
673
class _ReduceFromScatterToPolarRegion(torch.autograd.Function):
apaaris's avatar
apaaris committed
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
    r"""
    All-reduce the input from the polar region and scatter back to polar region.
    
    Parameters
    ----------
    input_: torch.Tensor
        The tensor to reduce
    dim_: int
        The dimension to reduce along
        
    Returns
    -------
    output: torch.Tensor
        The reduced tensor
    """
Thorsten Kurth's avatar
Thorsten Kurth committed
689
690
691
692
693
694
695
696
    @staticmethod
    def symbolic(graph, input_, dim_):
        if is_distributed_polar():
            return _reduce_scatter(input_, dim_, group=polar_group())
        else:
            return input_

    @staticmethod
697
    @custom_fwd(device_type="cuda")
Thorsten Kurth's avatar
Thorsten Kurth committed
698
699
700
701
702
703
704
705
706
707
708
    def forward(ctx, input_, dim_):
        if is_distributed_polar():
            ctx.dim = dim_
            ctx.split_shapes = compute_split_shapes(
                input_.shape[dim_], polar_group_size()
            )
            return _reduce_scatter(input_, dim_, group=polar_group())
        else:
            return input_

    @staticmethod
709
    @custom_bwd(device_type="cuda")
Thorsten Kurth's avatar
Thorsten Kurth committed
710
711
712
713
714
715
716
717
    def backward(ctx, grad_output):
        if is_distributed_polar():
            return _gather(grad_output, ctx.dim, ctx.split_shapes, polar_group()), None
        else:
            return grad_output, None


class _GatherFromCopyToPolarRegion(torch.autograd.Function):
apaaris's avatar
apaaris committed
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
    r"""
    Gather the input from the polar region and register BWD AR, basically the inverse of reduce-scatter
    
    Parameters
    ----------
    input_: torch.Tensor
        The tensor to gather
    dim_: int
        The dimension to gather along
    shapes_: List[int]
        The split sizes for the dimension to gather along
        
    Returns
    -------
    output: torch.Tensor
        The gathered tensor
    """
Thorsten Kurth's avatar
Thorsten Kurth committed
735
736
737
738
739
740
741
742
743

    @staticmethod
    def symbolic(graph, input_, dim_, shapes_):
        if is_distributed_polar():
            return _gather(input_, dim_, shapes_, polar_group())
        else:
            return input_

    @staticmethod
744
    @custom_fwd(device_type="cuda")
Thorsten Kurth's avatar
Thorsten Kurth committed
745
746
747
748
749
750
751
752
    def forward(ctx, input_, dim_, shapes_):
        if is_distributed_polar():
            ctx.dim = dim_
            return _gather(input_, dim_, shapes_, group=polar_group())
        else:
            return input_

    @staticmethod
753
    @custom_bwd(device_type="cuda")
Thorsten Kurth's avatar
Thorsten Kurth committed
754
755
756
757
758
759
760
761
    def backward(ctx, grad_output):
        if is_distributed_polar():
            return _reduce_scatter(grad_output, ctx.dim, use_fp32=True, group=polar_group()), None, None
        else:
            return grad_output, None, None
        

        
Boris Bonev's avatar
Boris Bonev committed
762
763
def copy_to_polar_region(input_):
    return _CopyToPolarRegion.apply(input_)
764
765
766
767

def copy_to_azimuth_region(input_):
    return _CopyToAzimuthRegion.apply(input_)
        
768
769
770
def reduce_from_polar_region(input_):
    return _ReduceFromPolarRegion.apply(input_)

771
772
def reduce_from_azimuth_region(input_):
    return _ReduceFromAzimuthRegion.apply(input_)
773
774
775

def scatter_to_polar_region(input_, dim_):
    return _ScatterToPolarRegion.apply(input_, dim_)
Boris Bonev's avatar
Boris Bonev committed
776
777
778

def gather_from_polar_region(input_, dim_, shapes_):
    return _GatherFromPolarRegion.apply(input_, dim_, shapes_)
Thorsten Kurth's avatar
Thorsten Kurth committed
779
780
781
782
783
784

def reduce_from_scatter_to_polar_region(input_, dim_):
    return _ReduceFromScatterToPolarRegion.apply(input_, dim_)

def gather_from_copy_to_polar_region(input_, dim_, shapes_):
    return _GatherFromCopyToPolarRegion.apply(input_, dim_, shapes_)