primitives.py 17.1 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# 
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
31
from typing import List
Boris Bonev's avatar
Boris Bonev committed
32
33
34

import torch
import torch.distributed as dist
35
from torch.amp import custom_fwd, custom_bwd
Boris Bonev's avatar
Boris Bonev committed
36

37
from .utils import polar_group, azimuth_group, polar_group_size
38
from .utils import is_initialized, is_distributed_polar, is_distributed_azimuth
Boris Bonev's avatar
Boris Bonev committed
39

40
41
# helper routine to compute uneven splitting in balanced way:
def compute_split_shapes(size: int, num_chunks: int) -> List[int]:
apaaris's avatar
apaaris committed
42
43
44
45
46
47
48
49
50
51
52
53
54
    """
    Compute the split shapes for a given size and number of chunks.
    
    Parameters
    ----------
    size: int
        The size of the tensor to split
    
    Returns
    -------
    List[int]
        The split shapes
    """
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    
    # treat trivial case first
    if num_chunks == 1:
        return [size]
    
    # first, check if we can split using div-up to balance the load: 
    chunk_size = (size + num_chunks - 1) // num_chunks
    last_chunk_size = max(0, size - chunk_size * (num_chunks - 1))
    if last_chunk_size == 0:
        # in this case, the last shard would be empty, split with floor instead:
        chunk_size = size // num_chunks
        last_chunk_size = size - chunk_size * (num_chunks-1)

    # generate sections list
    sections = [chunk_size for _ in range(num_chunks - 1)] + [last_chunk_size]

    return sections

    
Boris Bonev's avatar
Boris Bonev committed
74
def split_tensor_along_dim(tensor, dim, num_chunks):
apaaris's avatar
apaaris committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    """
    Split a tensor along a given dimension into a given number of chunks.
    
    Parameters
    ----------
    tensor: torch.Tensor
        The tensor to split
    dim: int
        The dimension to split along
    num_chunks: int
        The number of chunks to split into
        
    Returns
    -------
    tensor_list: List[torch.Tensor]  
        The split tensors
    """
    
Boris Bonev's avatar
Boris Bonev committed
93
    assert dim < tensor.dim(), f"Error, tensor dimension is {tensor.dim()} which cannot be split along {dim}"
94
95
96
97
98
99
    assert (tensor.shape[dim] >= num_chunks), f"Error, cannot split dim {dim} of size {tensor.shape[dim]} into \
                                              {num_chunks} chunks. Empty slices are currently not supported."
    
    # get split
    sections = compute_split_shapes(tensor.shape[dim], num_chunks)
    tensor_list = torch.split(tensor, sections, dim=dim)
Boris Bonev's avatar
Boris Bonev committed
100
101
102
    
    return tensor_list

103
104

def _transpose(tensor, dim0, dim1, dim1_split_sizes, group=None, async_op=False):
apaaris's avatar
apaaris committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    """
    Transpose a tensor along two dimensions.
    
    Parameters
    ----------
    tensor: torch.Tensor
        The tensor to transpose
    dim0: int
        The first dimension to transpose
    dim1: int
        The second dimension to transpose
    dim1_split_sizes: List[int]
        The split sizes for the second dimension

    Returns
    -------
    tensor_list: List[torch.Tensor]
        The split tensors
    """
    
Boris Bonev's avatar
Boris Bonev committed
125
    # get comm params
Boris Bonev's avatar
Boris Bonev committed
126
    comm_size = dist.get_world_size(group=group)
127
    comm_rank = dist.get_rank(group=group)
Boris Bonev's avatar
Boris Bonev committed
128

Boris Bonev's avatar
Boris Bonev committed
129
    # split and local transposition
130
    tsplit = split_tensor_along_dim(tensor, num_chunks=comm_size, dim=dim0)
Thorsten Kurth's avatar
Thorsten Kurth committed
131
    x_send = [y.contiguous() for y in tsplit]
132
133
134
135
136
    x_send_shapes = [x.shape for x in x_send]
    x_recv = []
    x_shape = list(x_send_shapes[comm_rank])
    for dim1_len in dim1_split_sizes:
        x_shape[dim1] = dim1_len
Thorsten Kurth's avatar
Thorsten Kurth committed
137
138
        x_recv.append(torch.empty(x_shape, dtype=tensor.dtype, device=tensor.device))
        
Boris Bonev's avatar
Boris Bonev committed
139
140
    # global transposition
    req = dist.all_to_all(x_recv, x_send, group=group, async_op=async_op)
141
142
143

    # get dim0 split sizes
    dim0_split_sizes = [x[dim0] for x in x_send_shapes]
Boris Bonev's avatar
Boris Bonev committed
144
    
145
    return x_recv, dim0_split_sizes, req
Boris Bonev's avatar
Boris Bonev committed
146
147


Boris Bonev's avatar
Boris Bonev committed
148
class distributed_transpose_azimuth(torch.autograd.Function):
Boris Bonev's avatar
Boris Bonev committed
149
150

    @staticmethod
151
    @custom_fwd(device_type="cuda")
152
    def forward(ctx, x, dims, dim1_split_sizes):
apaaris's avatar
apaaris committed
153

154
        # WAR for a potential contig check torch bug for channels last contig tensors
155
        xlist, dim0_split_sizes, _ = _transpose(x, dims[0], dims[1], dim1_split_sizes, group=azimuth_group())
156
        x = torch.cat(xlist, dim=dims[1])
157
158
        ctx.dims = dims
        ctx.dim0_split_sizes = dim0_split_sizes
Thorsten Kurth's avatar
Thorsten Kurth committed
159
        
Boris Bonev's avatar
Boris Bonev committed
160
        return x
Boris Bonev's avatar
Boris Bonev committed
161
162

    @staticmethod
163
    @custom_bwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
164
    def backward(ctx, go):
apaaris's avatar
apaaris committed
165
166
167
        r"""
        Backward pass for distributed azimuthal transpose.
        
apaaris's avatar
apaaris committed
168
169
170
171
        Parameters
        ----------
        go: torch.Tensor
            The gradient of the output
apaaris's avatar
apaaris committed
172
        
apaaris's avatar
apaaris committed
173
174
175
176
        Returns
        -------
        gi: torch.Tensor
            The gradient of the input
apaaris's avatar
apaaris committed
177
        """
178
179
        dims = ctx.dims
        dim0_split_sizes = ctx.dim0_split_sizes
180
        # WAR for a potential contig check torch bug for channels last contig tensors 
181
        gilist, _, _ = _transpose(go, dims[1], dims[0], dim0_split_sizes, group=azimuth_group())
182
        gi = torch.cat(gilist, dim=dims[0])
Thorsten Kurth's avatar
Thorsten Kurth committed
183
        
184
        return gi, None, None
Boris Bonev's avatar
Boris Bonev committed
185
186

    
Boris Bonev's avatar
Boris Bonev committed
187
class distributed_transpose_polar(torch.autograd.Function):
Boris Bonev's avatar
Boris Bonev committed
188
189

    @staticmethod
190
    @custom_fwd(device_type="cuda")
191
    def forward(ctx, x, dim, dim1_split_sizes):
apaaris's avatar
apaaris committed
192

193
        # WAR for a potential contig check torch bug for channels last contig tensors 
194
        xlist, dim0_split_sizes, _ = _transpose(x, dim[0], dim[1], dim1_split_sizes, group=polar_group())
195
        x = torch.cat(xlist, dim=dim[1])
Boris Bonev's avatar
Boris Bonev committed
196
        ctx.dim = dim
197
        ctx.dim0_split_sizes = dim0_split_sizes
Boris Bonev's avatar
Boris Bonev committed
198
        return x
Boris Bonev's avatar
Boris Bonev committed
199
200

    @staticmethod
201
    @custom_bwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
202
    def backward(ctx, go):
apaaris's avatar
apaaris committed
203
204
205
        r"""
        Backward pass for distributed polar transpose.
        
apaaris's avatar
apaaris committed
206
207
208
209
        Parameters
        ----------
        go: torch.Tensor
            The gradient of the output
apaaris's avatar
apaaris committed
210
        
apaaris's avatar
apaaris committed
211
212
213
214
        Returns
        -------
        gi: torch.Tensor
            The gradient of the input
apaaris's avatar
apaaris committed
215
        """
Boris Bonev's avatar
Boris Bonev committed
216
        dim = ctx.dim
217
        dim0_split_sizes = ctx.dim0_split_sizes
218
        # WAR for a potential contig check torch bug for channels last contig tensors 
219
        gilist, _, _ = _transpose(go, dim[1], dim[0], dim0_split_sizes, group=polar_group())
220
        gi = torch.cat(gilist, dim=dim[0])
221
        return gi, None, None
Boris Bonev's avatar
Boris Bonev committed
222

223
224
225
226
227
228
229
230
231
232
233
234
235
    
# we need those additional primitives for distributed matrix multiplications
def _reduce(input_, use_fp32=True, group=None):
    """All-reduce the input tensor across model parallel group."""

    # Bypass the function if we are using only 1 GPU.
    if dist.get_world_size(group=group) == 1:
        return input_
    
    # All-reduce.
    if use_fp32:
        dtype = input_.dtype
        inputf_ = input_.float()
236
        inputf_ = inputf_.contiguous()
237
238
239
        dist.all_reduce(inputf_, group=group)
        input_ = inputf_.to(dtype)
    else:
240
        input_ = input_.contiguous()
241
242
243
        dist.all_reduce(input_, group=group)
        
    return input_
Thorsten Kurth's avatar
Thorsten Kurth committed
244
    
245
246
247
248
249
250
251
252
253
254
255
256
257

def _split(input_, dim_, group=None):
    """Split the tensor along its last dimension and keep the corresponding slice."""
    # Bypass the function if we are using only 1 GPU.
    comm_size = dist.get_world_size(group=group)
    if comm_size == 1:
        return input_
    
    # Split along last dimension.
    input_list = split_tensor_along_dim(input_, dim_, comm_size)
    
    # Note: torch.split does not create contiguous tensors by default.
    rank = dist.get_rank(group=group)
258
    output = input_list[rank]
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
    
    return output


def _gather(input_, dim_, shapes_, group=None):
    """Gather unevenly split tensors across ranks"""
    
    comm_size = dist.get_world_size(group=group)

    if (shapes_ is not None) and (len(shapes_) != comm_size):
        raise ValueError()
    if dim_ >= input_.dim():
        raise ValueError()

    if comm_size == 1:
        return input_

    # make contiguous:
    input_ = input_.contiguous()
    input_shape = list(input_.shape)

    if shapes_ is not None:
281
        input_list = []
282
283
        for src in range(comm_size):
            input_shape[dim_] = shapes_[src]
284
            input_list.append(torch.empty(input_shape, dtype=input_.dtype, device=input_.device))
285
286
287
288
289
290
    else:
        # assume equal shape on all ranks
        input_list = [torch.empty_like(input_) for _ in range(comm_size)]

    dist.all_gather(input_list, input_, group=group)

291
    output = torch.cat(input_list, dim=dim_)
292
293

    return output
Boris Bonev's avatar
Boris Bonev committed
294
295


Thorsten Kurth's avatar
Thorsten Kurth committed
296
297
298
299
300
301
302
303
304
305
def _reduce_scatter(input_, dim_, use_fp32=True, group=None):
    """All-reduce the input tensor across model parallel group and scatter it back."""

    # Bypass the function if we are using only 1 GPU.
    if dist.get_world_size(group=group) == 1:
        return input_

    # make input contiguous
    comm_size = dist.get_world_size(group=group)
    comm_rank = dist.get_rank(group=group)
306
    input_list = split_tensor_along_dim(input_, dim_, comm_size)
Thorsten Kurth's avatar
Thorsten Kurth committed
307
308
309
310
311

    dtype = input_.dtype
    if (use_fp32 and (dtype != torch.float32)):
        input_list = [x.to(torch.float32) for x in input_list]

312
313
    input_list = [x.contiguous() for x in input_list]

Thorsten Kurth's avatar
Thorsten Kurth committed
314
315
316
317
318
319
320
321
322
323
324
    # perform reduce_scatter
    output = torch.empty_like(input_list[comm_rank])
    dist.reduce_scatter(output, input_list, group=group)

    # convert dtype if necessary
    if use_fp32:
        output = output.to(dtype=dtype)

    return output


Boris Bonev's avatar
Boris Bonev committed
325
326
327
328
329
330
331
class _CopyToPolarRegion(torch.autograd.Function):
    
    @staticmethod
    def symbolic(graph, input_):
        return input_
    
    @staticmethod
332
    @custom_fwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
333
    def forward(ctx, input_):
apaaris's avatar
apaaris committed
334
        
Boris Bonev's avatar
Boris Bonev committed
335
        return input_
336
    
Boris Bonev's avatar
Boris Bonev committed
337
    @staticmethod
338
    @custom_bwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
339
    def backward(ctx, grad_output):
apaaris's avatar
apaaris committed
340
341
342
        r"""
        Backward pass for copying to polar region.
        
apaaris's avatar
apaaris committed
343
344
345
346
        Parameters
        ----------
        grad_output: torch.Tensor
            The gradient of the output
apaaris's avatar
apaaris committed
347
        
apaaris's avatar
apaaris committed
348
349
350
351
        Returns
        -------
        grad_output: torch.Tensor
            The gradient of the output
apaaris's avatar
apaaris committed
352
        """
Boris Bonev's avatar
Boris Bonev committed
353
354
355
356
        if is_distributed_polar():
            return _reduce(grad_output, group=polar_group())
        else:
            return grad_output, None
357
358
359
360
361
362
363
364
365
366
367


class _CopyToAzimuthRegion(torch.autograd.Function):

    @staticmethod
    def symbolic(graph, input_):
        return input_

    @staticmethod
    @custom_fwd(device_type="cuda")
    def forward(ctx, input_):
368

369
370
371
372
373
        return input_

    @staticmethod
    @custom_bwd(device_type="cuda")
    def backward(ctx, grad_output):
apaaris's avatar
apaaris committed
374
375
376
        r"""
        Backward pass for copying to azimuth region.
        
apaaris's avatar
apaaris committed
377
378
379
380
        Parameters
        ----------
        grad_output: torch.Tensor
            The gradient of the output
apaaris's avatar
apaaris committed
381
        
apaaris's avatar
apaaris committed
382
383
384
385
        Returns
        -------
        grad_output: torch.Tensor
            The gradient of the output
apaaris's avatar
apaaris committed
386
        """
387
388
389
390
391
392
        if is_distributed_azimuth():
            return _reduce(grad_output, group=azimuth_group())
        else:
            return grad_output, None


393
394
395
396
397
398
399
class _ScatterToPolarRegion(torch.autograd.Function):

    @staticmethod
    def symbolic(graph, input_, dim_):
        return _split(input_, dim_, group=polar_group())

    @staticmethod
400
    @custom_fwd(device_type="cuda")
401
402
403
404
405
406
407
408
409
410
411
    def forward(ctx, input_, dim_):
        if is_distributed_polar():
            ctx.dim = dim_
            ctx.split_shapes = compute_split_shapes(
                input_.shape[dim_], polar_group_size()
            )
            return _split(input_, dim_, group=polar_group())
        else:
            return input_

    @staticmethod
412
    @custom_bwd(device_type="cuda")
413
414
415
416
417
418
    def backward(ctx, grad_output):
        if is_distributed_polar():
            return _gather(grad_output, ctx.dim, ctx.split_shapes, polar_group()), None
        else:
            return grad_output, None

Boris Bonev's avatar
Boris Bonev committed
419
420

class _GatherFromPolarRegion(torch.autograd.Function):
421

Boris Bonev's avatar
Boris Bonev committed
422
423
424
425
426
    @staticmethod
    def symbolic(graph, input_, dim_, shapes_):
        return _gather(input_, dim_, shapes_, polar_group())

    @staticmethod
427
    @custom_fwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
428
429
430
431
432
433
434
435
    def forward(ctx, input_, dim_, shapes_):
        if is_distributed_polar():
            ctx.dim = dim_
            return _gather(input_, dim_, shapes_, group=polar_group())
        else:
            return input_

    @staticmethod
436
    @custom_bwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
437
438
439
440
441
442
    def backward(ctx, grad_output):
        if is_distributed_polar():
            return _split(grad_output, ctx.dim, group=polar_group()), None, None
        else:
            return grad_output, None, None

443
444
445
446
447
448
449
450
451
452
453
    
class _ReduceFromPolarRegion(torch.autograd.Function):
    
    @staticmethod
    def symbolic(graph, input_):
        if is_distributed_polar():
            return _reduce(input_, group=polar_group())
        else:
            return input_

    @staticmethod
454
    @custom_fwd(device_type="cuda")
455
456
457
458
459
460
461
    def forward(ctx, input_):
        if is_distributed_polar():
            return _reduce(input_, group=polar_group())
        else:
            return input_

    @staticmethod
462
    @custom_bwd(device_type="cuda")
463
464
465
    def backward(ctx, grad_output):
        return grad_output

466
467
    
class _ReduceFromAzimuthRegion(torch.autograd.Function):
468
 
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
    @staticmethod
    def symbolic(graph, input_):
        if is_distributed_azimuth():
            return _reduce(input_, group=azimuth_group())
        else:
            return input_

    @staticmethod
    @custom_fwd(device_type="cuda")
    def forward(ctx, input_):
        if is_distributed_azimuth():
            return _reduce(input_, group=azimuth_group())
        else:
            return input_

    @staticmethod
    @custom_bwd(device_type="cuda")
    def backward(ctx, grad_output):
        return grad_output

Boris Bonev's avatar
Boris Bonev committed
489

Thorsten Kurth's avatar
Thorsten Kurth committed
490
class _ReduceFromScatterToPolarRegion(torch.autograd.Function):
491

Thorsten Kurth's avatar
Thorsten Kurth committed
492
493
494
495
496
497
498
499
    @staticmethod
    def symbolic(graph, input_, dim_):
        if is_distributed_polar():
            return _reduce_scatter(input_, dim_, group=polar_group())
        else:
            return input_

    @staticmethod
500
    @custom_fwd(device_type="cuda")
Thorsten Kurth's avatar
Thorsten Kurth committed
501
502
503
504
505
506
507
508
509
510
511
    def forward(ctx, input_, dim_):
        if is_distributed_polar():
            ctx.dim = dim_
            ctx.split_shapes = compute_split_shapes(
                input_.shape[dim_], polar_group_size()
            )
            return _reduce_scatter(input_, dim_, group=polar_group())
        else:
            return input_

    @staticmethod
512
    @custom_bwd(device_type="cuda")
Thorsten Kurth's avatar
Thorsten Kurth committed
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
    def backward(ctx, grad_output):
        if is_distributed_polar():
            return _gather(grad_output, ctx.dim, ctx.split_shapes, polar_group()), None
        else:
            return grad_output, None


class _GatherFromCopyToPolarRegion(torch.autograd.Function):

    @staticmethod
    def symbolic(graph, input_, dim_, shapes_):
        if is_distributed_polar():
            return _gather(input_, dim_, shapes_, polar_group())
        else:
            return input_

    @staticmethod
530
    @custom_fwd(device_type="cuda")
Thorsten Kurth's avatar
Thorsten Kurth committed
531
532
533
534
535
536
537
538
    def forward(ctx, input_, dim_, shapes_):
        if is_distributed_polar():
            ctx.dim = dim_
            return _gather(input_, dim_, shapes_, group=polar_group())
        else:
            return input_

    @staticmethod
539
    @custom_bwd(device_type="cuda")
Thorsten Kurth's avatar
Thorsten Kurth committed
540
541
542
543
544
545
546
547
    def backward(ctx, grad_output):
        if is_distributed_polar():
            return _reduce_scatter(grad_output, ctx.dim, use_fp32=True, group=polar_group()), None, None
        else:
            return grad_output, None, None
        

        
Boris Bonev's avatar
Boris Bonev committed
548
549
def copy_to_polar_region(input_):
    return _CopyToPolarRegion.apply(input_)
550
551
552
553

def copy_to_azimuth_region(input_):
    return _CopyToAzimuthRegion.apply(input_)
        
554
555
556
def reduce_from_polar_region(input_):
    return _ReduceFromPolarRegion.apply(input_)

557
558
def reduce_from_azimuth_region(input_):
    return _ReduceFromAzimuthRegion.apply(input_)
559
560
561

def scatter_to_polar_region(input_, dim_):
    return _ScatterToPolarRegion.apply(input_, dim_)
Boris Bonev's avatar
Boris Bonev committed
562
563
564

def gather_from_polar_region(input_, dim_, shapes_):
    return _GatherFromPolarRegion.apply(input_, dim_, shapes_)
Thorsten Kurth's avatar
Thorsten Kurth committed
565
566
567
568
569
570

def reduce_from_scatter_to_polar_region(input_, dim_):
    return _ReduceFromScatterToPolarRegion.apply(input_, dim_)

def gather_from_copy_to_polar_region(input_, dim_, shapes_):
    return _GatherFromCopyToPolarRegion.apply(input_, dim_, shapes_)