primitives.py 15.4 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# 
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
31
from typing import List
Boris Bonev's avatar
Boris Bonev committed
32
33
34

import torch
import torch.distributed as dist
35
from torch.amp import custom_fwd, custom_bwd
Boris Bonev's avatar
Boris Bonev committed
36

37
from .utils import polar_group, azimuth_group, polar_group_size
38
from .utils import is_initialized, is_distributed_polar, is_distributed_azimuth
Boris Bonev's avatar
Boris Bonev committed
39

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# helper routine to compute uneven splitting in balanced way:
def compute_split_shapes(size: int, num_chunks: int) -> List[int]:
    
    # treat trivial case first
    if num_chunks == 1:
        return [size]
    
    # first, check if we can split using div-up to balance the load: 
    chunk_size = (size + num_chunks - 1) // num_chunks
    last_chunk_size = max(0, size - chunk_size * (num_chunks - 1))
    if last_chunk_size == 0:
        # in this case, the last shard would be empty, split with floor instead:
        chunk_size = size // num_chunks
        last_chunk_size = size - chunk_size * (num_chunks-1)

    # generate sections list
    sections = [chunk_size for _ in range(num_chunks - 1)] + [last_chunk_size]

    return sections

    
Boris Bonev's avatar
Boris Bonev committed
61
62
def split_tensor_along_dim(tensor, dim, num_chunks):
    assert dim < tensor.dim(), f"Error, tensor dimension is {tensor.dim()} which cannot be split along {dim}"
63
64
65
66
67
68
    assert (tensor.shape[dim] >= num_chunks), f"Error, cannot split dim {dim} of size {tensor.shape[dim]} into \
                                              {num_chunks} chunks. Empty slices are currently not supported."
    
    # get split
    sections = compute_split_shapes(tensor.shape[dim], num_chunks)
    tensor_list = torch.split(tensor, sections, dim=dim)
Boris Bonev's avatar
Boris Bonev committed
69
70
71
    
    return tensor_list

72
73

def _transpose(tensor, dim0, dim1, dim1_split_sizes, group=None, async_op=False):
Boris Bonev's avatar
Boris Bonev committed
74
    # get comm params
Boris Bonev's avatar
Boris Bonev committed
75
    comm_size = dist.get_world_size(group=group)
76
    comm_rank = dist.get_rank(group=group)
Boris Bonev's avatar
Boris Bonev committed
77

Boris Bonev's avatar
Boris Bonev committed
78
    # split and local transposition
79
    tsplit = split_tensor_along_dim(tensor, num_chunks=comm_size, dim=dim0)
Thorsten Kurth's avatar
Thorsten Kurth committed
80
    x_send = [y.contiguous() for y in tsplit]
81
82
83
84
85
    x_send_shapes = [x.shape for x in x_send]
    x_recv = []
    x_shape = list(x_send_shapes[comm_rank])
    for dim1_len in dim1_split_sizes:
        x_shape[dim1] = dim1_len
Thorsten Kurth's avatar
Thorsten Kurth committed
86
87
        x_recv.append(torch.empty(x_shape, dtype=tensor.dtype, device=tensor.device))
        
Boris Bonev's avatar
Boris Bonev committed
88
89
    # global transposition
    req = dist.all_to_all(x_recv, x_send, group=group, async_op=async_op)
90
91
92

    # get dim0 split sizes
    dim0_split_sizes = [x[dim0] for x in x_send_shapes]
Boris Bonev's avatar
Boris Bonev committed
93
    
94
    return x_recv, dim0_split_sizes, req
Boris Bonev's avatar
Boris Bonev committed
95
96


Boris Bonev's avatar
Boris Bonev committed
97
class distributed_transpose_azimuth(torch.autograd.Function):
Boris Bonev's avatar
Boris Bonev committed
98
99

    @staticmethod
100
    @custom_fwd(device_type="cuda")
101
    def forward(ctx, x, dims, dim1_split_sizes):
102
        # WAR for a potential contig check torch bug for channels last contig tensors
103
        xlist, dim0_split_sizes, _ = _transpose(x, dims[0], dims[1], dim1_split_sizes, group=azimuth_group())
104
        x = torch.cat(xlist, dim=dims[1])
105
106
        ctx.dims = dims
        ctx.dim0_split_sizes = dim0_split_sizes
Thorsten Kurth's avatar
Thorsten Kurth committed
107
        
Boris Bonev's avatar
Boris Bonev committed
108
        return x
Boris Bonev's avatar
Boris Bonev committed
109
110

    @staticmethod
111
    @custom_bwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
112
    def backward(ctx, go):
113
114
        dims = ctx.dims
        dim0_split_sizes = ctx.dim0_split_sizes
115
        # WAR for a potential contig check torch bug for channels last contig tensors 
116
        gilist, _, _ = _transpose(go, dims[1], dims[0], dim0_split_sizes, group=azimuth_group())
117
        gi = torch.cat(gilist, dim=dims[0])
Thorsten Kurth's avatar
Thorsten Kurth committed
118
        
119
        return gi, None, None
Boris Bonev's avatar
Boris Bonev committed
120
121

    
Boris Bonev's avatar
Boris Bonev committed
122
class distributed_transpose_polar(torch.autograd.Function):
Boris Bonev's avatar
Boris Bonev committed
123
124

    @staticmethod
125
    @custom_fwd(device_type="cuda")
126
    def forward(ctx, x, dim, dim1_split_sizes):
127
        # WAR for a potential contig check torch bug for channels last contig tensors 
128
        xlist, dim0_split_sizes, _ = _transpose(x, dim[0], dim[1], dim1_split_sizes, group=polar_group())
129
        x = torch.cat(xlist, dim=dim[1])
Boris Bonev's avatar
Boris Bonev committed
130
        ctx.dim = dim
131
        ctx.dim0_split_sizes = dim0_split_sizes
Boris Bonev's avatar
Boris Bonev committed
132
        return x
Boris Bonev's avatar
Boris Bonev committed
133
134

    @staticmethod
135
    @custom_bwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
136
137
    def backward(ctx, go):
        dim = ctx.dim
138
        dim0_split_sizes = ctx.dim0_split_sizes
139
        # WAR for a potential contig check torch bug for channels last contig tensors 
140
        gilist, _, _ = _transpose(go, dim[1], dim[0], dim0_split_sizes, group=polar_group())
141
        gi = torch.cat(gilist, dim=dim[0])
142
        return gi, None, None
Boris Bonev's avatar
Boris Bonev committed
143

144
145
146
147
148
149
150
151
152
153
154
155
156
    
# we need those additional primitives for distributed matrix multiplications
def _reduce(input_, use_fp32=True, group=None):
    """All-reduce the input tensor across model parallel group."""

    # Bypass the function if we are using only 1 GPU.
    if dist.get_world_size(group=group) == 1:
        return input_
    
    # All-reduce.
    if use_fp32:
        dtype = input_.dtype
        inputf_ = input_.float()
157
        inputf_ = inputf_.contiguous()
158
159
160
        dist.all_reduce(inputf_, group=group)
        input_ = inputf_.to(dtype)
    else:
161
        input_ = input_.contiguous()
162
163
164
        dist.all_reduce(input_, group=group)
        
    return input_
Thorsten Kurth's avatar
Thorsten Kurth committed
165
    
166
167
168
169
170
171
172
173
174
175
176
177
178

def _split(input_, dim_, group=None):
    """Split the tensor along its last dimension and keep the corresponding slice."""
    # Bypass the function if we are using only 1 GPU.
    comm_size = dist.get_world_size(group=group)
    if comm_size == 1:
        return input_
    
    # Split along last dimension.
    input_list = split_tensor_along_dim(input_, dim_, comm_size)
    
    # Note: torch.split does not create contiguous tensors by default.
    rank = dist.get_rank(group=group)
179
    output = input_list[rank]
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    
    return output


def _gather(input_, dim_, shapes_, group=None):
    """Gather unevenly split tensors across ranks"""
    
    comm_size = dist.get_world_size(group=group)

    if (shapes_ is not None) and (len(shapes_) != comm_size):
        raise ValueError()
    if dim_ >= input_.dim():
        raise ValueError()

    if comm_size == 1:
        return input_

    # make contiguous:
    input_ = input_.contiguous()
    input_shape = list(input_.shape)

    if shapes_ is not None:
202
        input_list = []
203
204
        for src in range(comm_size):
            input_shape[dim_] = shapes_[src]
205
            input_list.append(torch.empty(input_shape, dtype=input_.dtype, device=input_.device))
206
207
208
209
210
211
    else:
        # assume equal shape on all ranks
        input_list = [torch.empty_like(input_) for _ in range(comm_size)]

    dist.all_gather(input_list, input_, group=group)

212
    output = torch.cat(input_list, dim=dim_)
213
214

    return output
Boris Bonev's avatar
Boris Bonev committed
215
216


Thorsten Kurth's avatar
Thorsten Kurth committed
217
218
219
220
221
222
223
224
225
226
def _reduce_scatter(input_, dim_, use_fp32=True, group=None):
    """All-reduce the input tensor across model parallel group and scatter it back."""

    # Bypass the function if we are using only 1 GPU.
    if dist.get_world_size(group=group) == 1:
        return input_

    # make input contiguous
    comm_size = dist.get_world_size(group=group)
    comm_rank = dist.get_rank(group=group)
227
    input_list = split_tensor_along_dim(input_, dim_, comm_size)
Thorsten Kurth's avatar
Thorsten Kurth committed
228
229
230
231
232

    dtype = input_.dtype
    if (use_fp32 and (dtype != torch.float32)):
        input_list = [x.to(torch.float32) for x in input_list]

233
234
    input_list = [x.contiguous() for x in input_list]

Thorsten Kurth's avatar
Thorsten Kurth committed
235
236
237
238
239
240
241
242
243
244
245
    # perform reduce_scatter
    output = torch.empty_like(input_list[comm_rank])
    dist.reduce_scatter(output, input_list, group=group)

    # convert dtype if necessary
    if use_fp32:
        output = output.to(dtype=dtype)

    return output


Boris Bonev's avatar
Boris Bonev committed
246
247
248
249
250
251
252
253
class _CopyToPolarRegion(torch.autograd.Function):
    """Split the input and keep only the corresponding chunk to the rank."""
    
    @staticmethod
    def symbolic(graph, input_):
        return input_
    
    @staticmethod
254
    @custom_fwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
255
256
    def forward(ctx, input_):
        return input_
257
    
Boris Bonev's avatar
Boris Bonev committed
258
    @staticmethod
259
    @custom_bwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
260
261
262
263
264
    def backward(ctx, grad_output):
        if is_distributed_polar():
            return _reduce(grad_output, group=polar_group())
        else:
            return grad_output, None
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287


class _CopyToAzimuthRegion(torch.autograd.Function):
    """Split the input and keep only the corresponding chunk to the rank."""

    @staticmethod
    def symbolic(graph, input_):
        return input_

    @staticmethod
    @custom_fwd(device_type="cuda")
    def forward(ctx, input_):
        return input_

    @staticmethod
    @custom_bwd(device_type="cuda")
    def backward(ctx, grad_output):
        if is_distributed_azimuth():
            return _reduce(grad_output, group=azimuth_group())
        else:
            return grad_output, None


288
289
290
291
292
293
294
295
class _ScatterToPolarRegion(torch.autograd.Function):
    """Split the input and keep only the corresponding chunk to the rank."""

    @staticmethod
    def symbolic(graph, input_, dim_):
        return _split(input_, dim_, group=polar_group())

    @staticmethod
296
    @custom_fwd(device_type="cuda")
297
298
299
300
301
302
303
304
305
306
307
    def forward(ctx, input_, dim_):
        if is_distributed_polar():
            ctx.dim = dim_
            ctx.split_shapes = compute_split_shapes(
                input_.shape[dim_], polar_group_size()
            )
            return _split(input_, dim_, group=polar_group())
        else:
            return input_

    @staticmethod
308
    @custom_bwd(device_type="cuda")
309
310
311
312
313
314
    def backward(ctx, grad_output):
        if is_distributed_polar():
            return _gather(grad_output, ctx.dim, ctx.split_shapes, polar_group()), None
        else:
            return grad_output, None

Boris Bonev's avatar
Boris Bonev committed
315
316
317
318
319
320
321
322
323

class _GatherFromPolarRegion(torch.autograd.Function):
    """Gather the input and keep it on the rank."""

    @staticmethod
    def symbolic(graph, input_, dim_, shapes_):
        return _gather(input_, dim_, shapes_, polar_group())

    @staticmethod
324
    @custom_fwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
325
326
327
328
329
330
331
332
    def forward(ctx, input_, dim_, shapes_):
        if is_distributed_polar():
            ctx.dim = dim_
            return _gather(input_, dim_, shapes_, group=polar_group())
        else:
            return input_

    @staticmethod
333
    @custom_bwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
334
335
336
337
338
339
    def backward(ctx, grad_output):
        if is_distributed_polar():
            return _split(grad_output, ctx.dim, group=polar_group()), None, None
        else:
            return grad_output, None, None

340
341
342
343
344
345
346
347
348
349
350
351
    
class _ReduceFromPolarRegion(torch.autograd.Function):
    """All-reduce the input from the polar region."""
    
    @staticmethod
    def symbolic(graph, input_):
        if is_distributed_polar():
            return _reduce(input_, group=polar_group())
        else:
            return input_

    @staticmethod
352
    @custom_fwd(device_type="cuda")
353
354
355
356
357
358
359
    def forward(ctx, input_):
        if is_distributed_polar():
            return _reduce(input_, group=polar_group())
        else:
            return input_

    @staticmethod
360
    @custom_bwd(device_type="cuda")
361
362
363
    def backward(ctx, grad_output):
        return grad_output

364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
    
class _ReduceFromAzimuthRegion(torch.autograd.Function):
    """All-reduce the input from the azimuth region."""

    @staticmethod
    def symbolic(graph, input_):
        if is_distributed_azimuth():
            return _reduce(input_, group=azimuth_group())
        else:
            return input_

    @staticmethod
    @custom_fwd(device_type="cuda")
    def forward(ctx, input_):
        if is_distributed_azimuth():
            return _reduce(input_, group=azimuth_group())
        else:
            return input_

    @staticmethod
    @custom_bwd(device_type="cuda")
    def backward(ctx, grad_output):
        return grad_output

Boris Bonev's avatar
Boris Bonev committed
388

Thorsten Kurth's avatar
Thorsten Kurth committed
389
390
391
392
393
394
395
396
397
398
399
class _ReduceFromScatterToPolarRegion(torch.autograd.Function):
    """All-reduce the input from the polar region and scatter back to polar region."""

    @staticmethod
    def symbolic(graph, input_, dim_):
        if is_distributed_polar():
            return _reduce_scatter(input_, dim_, group=polar_group())
        else:
            return input_

    @staticmethod
400
    @custom_fwd(device_type="cuda")
Thorsten Kurth's avatar
Thorsten Kurth committed
401
402
403
404
405
406
407
408
409
410
411
    def forward(ctx, input_, dim_):
        if is_distributed_polar():
            ctx.dim = dim_
            ctx.split_shapes = compute_split_shapes(
                input_.shape[dim_], polar_group_size()
            )
            return _reduce_scatter(input_, dim_, group=polar_group())
        else:
            return input_

    @staticmethod
412
    @custom_bwd(device_type="cuda")
Thorsten Kurth's avatar
Thorsten Kurth committed
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
    def backward(ctx, grad_output):
        if is_distributed_polar():
            return _gather(grad_output, ctx.dim, ctx.split_shapes, polar_group()), None
        else:
            return grad_output, None


class _GatherFromCopyToPolarRegion(torch.autograd.Function):
    """Gather the input from the polar region and register BWD AR, basically the inverse of reduce-scatter"""

    @staticmethod
    def symbolic(graph, input_, dim_, shapes_):
        if is_distributed_polar():
            return _gather(input_, dim_, shapes_, polar_group())
        else:
            return input_

    @staticmethod
431
    @custom_fwd(device_type="cuda")
Thorsten Kurth's avatar
Thorsten Kurth committed
432
433
434
435
436
437
438
439
    def forward(ctx, input_, dim_, shapes_):
        if is_distributed_polar():
            ctx.dim = dim_
            return _gather(input_, dim_, shapes_, group=polar_group())
        else:
            return input_

    @staticmethod
440
    @custom_bwd(device_type="cuda")
Thorsten Kurth's avatar
Thorsten Kurth committed
441
442
443
444
445
446
447
448
    def backward(ctx, grad_output):
        if is_distributed_polar():
            return _reduce_scatter(grad_output, ctx.dim, use_fp32=True, group=polar_group()), None, None
        else:
            return grad_output, None, None
        

        
Boris Bonev's avatar
Boris Bonev committed
449
450
def copy_to_polar_region(input_):
    return _CopyToPolarRegion.apply(input_)
451
452
453
454

def copy_to_azimuth_region(input_):
    return _CopyToAzimuthRegion.apply(input_)
        
455
456
457
def reduce_from_polar_region(input_):
    return _ReduceFromPolarRegion.apply(input_)

458
459
def reduce_from_azimuth_region(input_):
    return _ReduceFromAzimuthRegion.apply(input_)
460
461
462

def scatter_to_polar_region(input_, dim_):
    return _ScatterToPolarRegion.apply(input_, dim_)
Boris Bonev's avatar
Boris Bonev committed
463
464
465

def gather_from_polar_region(input_, dim_, shapes_):
    return _GatherFromPolarRegion.apply(input_, dim_, shapes_)
Thorsten Kurth's avatar
Thorsten Kurth committed
466
467
468
469
470
471

def reduce_from_scatter_to_polar_region(input_, dim_):
    return _ReduceFromScatterToPolarRegion.apply(input_, dim_)

def gather_from_copy_to_polar_region(input_, dim_, shapes_):
    return _GatherFromCopyToPolarRegion.apply(input_, dim_, shapes_)