primitives.py 15.2 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# 
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
31
from typing import List
Boris Bonev's avatar
Boris Bonev committed
32
33
34

import torch
import torch.distributed as dist
35
from torch.amp import custom_fwd, custom_bwd
Boris Bonev's avatar
Boris Bonev committed
36

37
from .utils import polar_group, azimuth_group, polar_group_size
38
from .utils import is_initialized, is_distributed_polar, is_distributed_azimuth
Boris Bonev's avatar
Boris Bonev committed
39

40
41
# helper routine to compute uneven splitting in balanced way:
def compute_split_shapes(size: int, num_chunks: int) -> List[int]:
apaaris's avatar
apaaris committed
42
43
44
45
46
47
48
49
50
51
52
53
54
    """
    Compute the split shapes for a given size and number of chunks.
    
    Parameters
    ----------
    size: int
        The size of the tensor to split
    
    Returns
    -------
    List[int]
        The split shapes
    """
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    
    # treat trivial case first
    if num_chunks == 1:
        return [size]
    
    # first, check if we can split using div-up to balance the load: 
    chunk_size = (size + num_chunks - 1) // num_chunks
    last_chunk_size = max(0, size - chunk_size * (num_chunks - 1))
    if last_chunk_size == 0:
        # in this case, the last shard would be empty, split with floor instead:
        chunk_size = size // num_chunks
        last_chunk_size = size - chunk_size * (num_chunks-1)

    # generate sections list
    sections = [chunk_size for _ in range(num_chunks - 1)] + [last_chunk_size]

    return sections

    
Boris Bonev's avatar
Boris Bonev committed
74
def split_tensor_along_dim(tensor, dim, num_chunks):
apaaris's avatar
apaaris committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    """
    Split a tensor along a given dimension into a given number of chunks.
    
    Parameters
    ----------
    tensor: torch.Tensor
        The tensor to split
    dim: int
        The dimension to split along
    num_chunks: int
        The number of chunks to split into
        
    Returns
    -------
    tensor_list: List[torch.Tensor]  
        The split tensors
    """
    
Boris Bonev's avatar
Boris Bonev committed
93
    assert dim < tensor.dim(), f"Error, tensor dimension is {tensor.dim()} which cannot be split along {dim}"
94
95
96
97
98
99
    assert (tensor.shape[dim] >= num_chunks), f"Error, cannot split dim {dim} of size {tensor.shape[dim]} into \
                                              {num_chunks} chunks. Empty slices are currently not supported."
    
    # get split
    sections = compute_split_shapes(tensor.shape[dim], num_chunks)
    tensor_list = torch.split(tensor, sections, dim=dim)
Boris Bonev's avatar
Boris Bonev committed
100
101
102
    
    return tensor_list

103
104

def _transpose(tensor, dim0, dim1, dim1_split_sizes, group=None, async_op=False):
apaaris's avatar
apaaris committed
105
    
Boris Bonev's avatar
Boris Bonev committed
106
    # get comm params
Boris Bonev's avatar
Boris Bonev committed
107
    comm_size = dist.get_world_size(group=group)
108
    comm_rank = dist.get_rank(group=group)
Boris Bonev's avatar
Boris Bonev committed
109

Boris Bonev's avatar
Boris Bonev committed
110
    # split and local transposition
111
    tsplit = split_tensor_along_dim(tensor, num_chunks=comm_size, dim=dim0)
Thorsten Kurth's avatar
Thorsten Kurth committed
112
    x_send = [y.contiguous() for y in tsplit]
113
114
115
116
117
    x_send_shapes = [x.shape for x in x_send]
    x_recv = []
    x_shape = list(x_send_shapes[comm_rank])
    for dim1_len in dim1_split_sizes:
        x_shape[dim1] = dim1_len
Thorsten Kurth's avatar
Thorsten Kurth committed
118
119
        x_recv.append(torch.empty(x_shape, dtype=tensor.dtype, device=tensor.device))
        
Boris Bonev's avatar
Boris Bonev committed
120
121
    # global transposition
    req = dist.all_to_all(x_recv, x_send, group=group, async_op=async_op)
122
123
124

    # get dim0 split sizes
    dim0_split_sizes = [x[dim0] for x in x_send_shapes]
Boris Bonev's avatar
Boris Bonev committed
125
    
126
    return x_recv, dim0_split_sizes, req
Boris Bonev's avatar
Boris Bonev committed
127
128


Boris Bonev's avatar
Boris Bonev committed
129
class distributed_transpose_azimuth(torch.autograd.Function):
Boris Bonev's avatar
Boris Bonev committed
130
131

    @staticmethod
132
    @custom_fwd(device_type="cuda")
133
    def forward(ctx, x, dims, dim1_split_sizes):
apaaris's avatar
apaaris committed
134

135
        # WAR for a potential contig check torch bug for channels last contig tensors
136
        xlist, dim0_split_sizes, _ = _transpose(x, dims[0], dims[1], dim1_split_sizes, group=azimuth_group())
137
        x = torch.cat(xlist, dim=dims[1])
138
139
        ctx.dims = dims
        ctx.dim0_split_sizes = dim0_split_sizes
Thorsten Kurth's avatar
Thorsten Kurth committed
140
        
Boris Bonev's avatar
Boris Bonev committed
141
        return x
Boris Bonev's avatar
Boris Bonev committed
142
143

    @staticmethod
144
    @custom_bwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
145
    def backward(ctx, go):
146
147
        dims = ctx.dims
        dim0_split_sizes = ctx.dim0_split_sizes
148
        # WAR for a potential contig check torch bug for channels last contig tensors 
149
        gilist, _, _ = _transpose(go, dims[1], dims[0], dim0_split_sizes, group=azimuth_group())
150
        gi = torch.cat(gilist, dim=dims[0])
Thorsten Kurth's avatar
Thorsten Kurth committed
151
        
152
        return gi, None, None
Boris Bonev's avatar
Boris Bonev committed
153
154

    
Boris Bonev's avatar
Boris Bonev committed
155
class distributed_transpose_polar(torch.autograd.Function):
Boris Bonev's avatar
Boris Bonev committed
156
157

    @staticmethod
158
    @custom_fwd(device_type="cuda")
159
    def forward(ctx, x, dim, dim1_split_sizes):
apaaris's avatar
apaaris committed
160

161
        # WAR for a potential contig check torch bug for channels last contig tensors 
162
        xlist, dim0_split_sizes, _ = _transpose(x, dim[0], dim[1], dim1_split_sizes, group=polar_group())
163
        x = torch.cat(xlist, dim=dim[1])
Boris Bonev's avatar
Boris Bonev committed
164
        ctx.dim = dim
165
        ctx.dim0_split_sizes = dim0_split_sizes
Boris Bonev's avatar
Boris Bonev committed
166
        return x
Boris Bonev's avatar
Boris Bonev committed
167
168

    @staticmethod
169
    @custom_bwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
170
    def backward(ctx, go):
171

Boris Bonev's avatar
Boris Bonev committed
172
        dim = ctx.dim
173
        dim0_split_sizes = ctx.dim0_split_sizes
174
        # WAR for a potential contig check torch bug for channels last contig tensors 
175
        gilist, _, _ = _transpose(go, dim[1], dim[0], dim0_split_sizes, group=polar_group())
176
        gi = torch.cat(gilist, dim=dim[0])
177
        return gi, None, None
Boris Bonev's avatar
Boris Bonev committed
178

179
180
181
182
183
184
185
186
187
188
189
190
    
# we need those additional primitives for distributed matrix multiplications
def _reduce(input_, use_fp32=True, group=None):

    # Bypass the function if we are using only 1 GPU.
    if dist.get_world_size(group=group) == 1:
        return input_
    
    # All-reduce.
    if use_fp32:
        dtype = input_.dtype
        inputf_ = input_.float()
191
        inputf_ = inputf_.contiguous()
192
193
194
        dist.all_reduce(inputf_, group=group)
        input_ = inputf_.to(dtype)
    else:
195
        input_ = input_.contiguous()
196
197
198
        dist.all_reduce(input_, group=group)
        
    return input_
Thorsten Kurth's avatar
Thorsten Kurth committed
199
    
200
201
202
203
204
205
206
207
208
209
210
211

def _split(input_, dim_, group=None):
    # Bypass the function if we are using only 1 GPU.
    comm_size = dist.get_world_size(group=group)
    if comm_size == 1:
        return input_
    
    # Split along last dimension.
    input_list = split_tensor_along_dim(input_, dim_, comm_size)
    
    # Note: torch.split does not create contiguous tensors by default.
    rank = dist.get_rank(group=group)
212
    output = input_list[rank]
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
    
    return output


def _gather(input_, dim_, shapes_, group=None):
    
    comm_size = dist.get_world_size(group=group)

    if (shapes_ is not None) and (len(shapes_) != comm_size):
        raise ValueError()
    if dim_ >= input_.dim():
        raise ValueError()

    if comm_size == 1:
        return input_

    # make contiguous:
    input_ = input_.contiguous()
    input_shape = list(input_.shape)

    if shapes_ is not None:
234
        input_list = []
235
236
        for src in range(comm_size):
            input_shape[dim_] = shapes_[src]
237
            input_list.append(torch.empty(input_shape, dtype=input_.dtype, device=input_.device))
238
239
240
241
242
243
    else:
        # assume equal shape on all ranks
        input_list = [torch.empty_like(input_) for _ in range(comm_size)]

    dist.all_gather(input_list, input_, group=group)

244
    output = torch.cat(input_list, dim=dim_)
245
246

    return output
Boris Bonev's avatar
Boris Bonev committed
247
248


Thorsten Kurth's avatar
Thorsten Kurth committed
249
250
251
252
253
254
255
256
257
def _reduce_scatter(input_, dim_, use_fp32=True, group=None):

    # Bypass the function if we are using only 1 GPU.
    if dist.get_world_size(group=group) == 1:
        return input_

    # make input contiguous
    comm_size = dist.get_world_size(group=group)
    comm_rank = dist.get_rank(group=group)
258
    input_list = split_tensor_along_dim(input_, dim_, comm_size)
Thorsten Kurth's avatar
Thorsten Kurth committed
259
260
261
262
263

    dtype = input_.dtype
    if (use_fp32 and (dtype != torch.float32)):
        input_list = [x.to(torch.float32) for x in input_list]

264
265
    input_list = [x.contiguous() for x in input_list]

Thorsten Kurth's avatar
Thorsten Kurth committed
266
267
268
269
270
271
272
273
274
275
276
    # perform reduce_scatter
    output = torch.empty_like(input_list[comm_rank])
    dist.reduce_scatter(output, input_list, group=group)

    # convert dtype if necessary
    if use_fp32:
        output = output.to(dtype=dtype)

    return output


Boris Bonev's avatar
Boris Bonev committed
277
278
279
280
281
282
283
class _CopyToPolarRegion(torch.autograd.Function):
    
    @staticmethod
    def symbolic(graph, input_):
        return input_
    
    @staticmethod
284
    @custom_fwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
285
    def forward(ctx, input_):
apaaris's avatar
apaaris committed
286
        
Boris Bonev's avatar
Boris Bonev committed
287
        return input_
288
    
Boris Bonev's avatar
Boris Bonev committed
289
    @staticmethod
290
    @custom_bwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
291
    def backward(ctx, grad_output):
292

Boris Bonev's avatar
Boris Bonev committed
293
294
295
296
        if is_distributed_polar():
            return _reduce(grad_output, group=polar_group())
        else:
            return grad_output, None
297
298
299
300
301
302
303
304
305
306
307


class _CopyToAzimuthRegion(torch.autograd.Function):

    @staticmethod
    def symbolic(graph, input_):
        return input_

    @staticmethod
    @custom_fwd(device_type="cuda")
    def forward(ctx, input_):
308

309
310
311
312
313
        return input_

    @staticmethod
    @custom_bwd(device_type="cuda")
    def backward(ctx, grad_output):
314

315
316
317
318
319
320
        if is_distributed_azimuth():
            return _reduce(grad_output, group=azimuth_group())
        else:
            return grad_output, None


321
322
323
324
325
326
327
class _ScatterToPolarRegion(torch.autograd.Function):

    @staticmethod
    def symbolic(graph, input_, dim_):
        return _split(input_, dim_, group=polar_group())

    @staticmethod
328
    @custom_fwd(device_type="cuda")
329
330
331
332
333
334
335
336
337
338
339
    def forward(ctx, input_, dim_):
        if is_distributed_polar():
            ctx.dim = dim_
            ctx.split_shapes = compute_split_shapes(
                input_.shape[dim_], polar_group_size()
            )
            return _split(input_, dim_, group=polar_group())
        else:
            return input_

    @staticmethod
340
    @custom_bwd(device_type="cuda")
341
342
343
344
345
346
    def backward(ctx, grad_output):
        if is_distributed_polar():
            return _gather(grad_output, ctx.dim, ctx.split_shapes, polar_group()), None
        else:
            return grad_output, None

Boris Bonev's avatar
Boris Bonev committed
347
348

class _GatherFromPolarRegion(torch.autograd.Function):
349

Boris Bonev's avatar
Boris Bonev committed
350
351
352
353
354
    @staticmethod
    def symbolic(graph, input_, dim_, shapes_):
        return _gather(input_, dim_, shapes_, polar_group())

    @staticmethod
355
    @custom_fwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
356
357
358
359
360
361
362
363
    def forward(ctx, input_, dim_, shapes_):
        if is_distributed_polar():
            ctx.dim = dim_
            return _gather(input_, dim_, shapes_, group=polar_group())
        else:
            return input_

    @staticmethod
364
    @custom_bwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
365
366
367
368
369
370
    def backward(ctx, grad_output):
        if is_distributed_polar():
            return _split(grad_output, ctx.dim, group=polar_group()), None, None
        else:
            return grad_output, None, None

371
372
373
374
375
376
377
378
379
380
381
    
class _ReduceFromPolarRegion(torch.autograd.Function):
    
    @staticmethod
    def symbolic(graph, input_):
        if is_distributed_polar():
            return _reduce(input_, group=polar_group())
        else:
            return input_

    @staticmethod
382
    @custom_fwd(device_type="cuda")
383
384
385
386
387
388
389
    def forward(ctx, input_):
        if is_distributed_polar():
            return _reduce(input_, group=polar_group())
        else:
            return input_

    @staticmethod
390
    @custom_bwd(device_type="cuda")
391
392
393
    def backward(ctx, grad_output):
        return grad_output

394
395
    
class _ReduceFromAzimuthRegion(torch.autograd.Function):
396
 
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
    @staticmethod
    def symbolic(graph, input_):
        if is_distributed_azimuth():
            return _reduce(input_, group=azimuth_group())
        else:
            return input_

    @staticmethod
    @custom_fwd(device_type="cuda")
    def forward(ctx, input_):
        if is_distributed_azimuth():
            return _reduce(input_, group=azimuth_group())
        else:
            return input_

    @staticmethod
    @custom_bwd(device_type="cuda")
    def backward(ctx, grad_output):
        return grad_output

Boris Bonev's avatar
Boris Bonev committed
417

Thorsten Kurth's avatar
Thorsten Kurth committed
418
class _ReduceFromScatterToPolarRegion(torch.autograd.Function):
419

Thorsten Kurth's avatar
Thorsten Kurth committed
420
421
422
423
424
425
426
427
    @staticmethod
    def symbolic(graph, input_, dim_):
        if is_distributed_polar():
            return _reduce_scatter(input_, dim_, group=polar_group())
        else:
            return input_

    @staticmethod
428
    @custom_fwd(device_type="cuda")
Thorsten Kurth's avatar
Thorsten Kurth committed
429
430
431
432
433
434
435
436
437
438
439
    def forward(ctx, input_, dim_):
        if is_distributed_polar():
            ctx.dim = dim_
            ctx.split_shapes = compute_split_shapes(
                input_.shape[dim_], polar_group_size()
            )
            return _reduce_scatter(input_, dim_, group=polar_group())
        else:
            return input_

    @staticmethod
440
    @custom_bwd(device_type="cuda")
Thorsten Kurth's avatar
Thorsten Kurth committed
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
    def backward(ctx, grad_output):
        if is_distributed_polar():
            return _gather(grad_output, ctx.dim, ctx.split_shapes, polar_group()), None
        else:
            return grad_output, None


class _GatherFromCopyToPolarRegion(torch.autograd.Function):

    @staticmethod
    def symbolic(graph, input_, dim_, shapes_):
        if is_distributed_polar():
            return _gather(input_, dim_, shapes_, polar_group())
        else:
            return input_

    @staticmethod
458
    @custom_fwd(device_type="cuda")
Thorsten Kurth's avatar
Thorsten Kurth committed
459
460
461
462
463
464
465
466
    def forward(ctx, input_, dim_, shapes_):
        if is_distributed_polar():
            ctx.dim = dim_
            return _gather(input_, dim_, shapes_, group=polar_group())
        else:
            return input_

    @staticmethod
467
    @custom_bwd(device_type="cuda")
Thorsten Kurth's avatar
Thorsten Kurth committed
468
469
470
471
472
473
474
475
    def backward(ctx, grad_output):
        if is_distributed_polar():
            return _reduce_scatter(grad_output, ctx.dim, use_fp32=True, group=polar_group()), None, None
        else:
            return grad_output, None, None
        

        
Boris Bonev's avatar
Boris Bonev committed
476
477
def copy_to_polar_region(input_):
    return _CopyToPolarRegion.apply(input_)
478
479
480
481

def copy_to_azimuth_region(input_):
    return _CopyToAzimuthRegion.apply(input_)
        
482
483
484
def reduce_from_polar_region(input_):
    return _ReduceFromPolarRegion.apply(input_)

485
486
def reduce_from_azimuth_region(input_):
    return _ReduceFromAzimuthRegion.apply(input_)
487
488
489

def scatter_to_polar_region(input_, dim_):
    return _ScatterToPolarRegion.apply(input_, dim_)
Boris Bonev's avatar
Boris Bonev committed
490
491
492

def gather_from_polar_region(input_, dim_, shapes_):
    return _GatherFromPolarRegion.apply(input_, dim_, shapes_)
Thorsten Kurth's avatar
Thorsten Kurth committed
493
494
495
496
497
498

def reduce_from_scatter_to_polar_region(input_, dim_):
    return _ReduceFromScatterToPolarRegion.apply(input_, dim_)

def gather_from_copy_to_polar_region(input_, dim_, shapes_):
    return _GatherFromCopyToPolarRegion.apply(input_, dim_, shapes_)