primitives.py 13.7 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# 
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
31
from typing import List
Boris Bonev's avatar
Boris Bonev committed
32
33
34
35

import torch
import torch.distributed as dist

36
37
from .utils import polar_group, azimuth_group, polar_group_size
from .utils import is_initialized, is_distributed_polar
Boris Bonev's avatar
Boris Bonev committed
38

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# helper routine to compute uneven splitting in balanced way:
def compute_split_shapes(size: int, num_chunks: int) -> List[int]:
    
    # treat trivial case first
    if num_chunks == 1:
        return [size]
    
    # first, check if we can split using div-up to balance the load: 
    chunk_size = (size + num_chunks - 1) // num_chunks
    last_chunk_size = max(0, size - chunk_size * (num_chunks - 1))
    if last_chunk_size == 0:
        # in this case, the last shard would be empty, split with floor instead:
        chunk_size = size // num_chunks
        last_chunk_size = size - chunk_size * (num_chunks-1)

    # generate sections list
    sections = [chunk_size for _ in range(num_chunks - 1)] + [last_chunk_size]

    return sections

    
Boris Bonev's avatar
Boris Bonev committed
60
61
def split_tensor_along_dim(tensor, dim, num_chunks):
    assert dim < tensor.dim(), f"Error, tensor dimension is {tensor.dim()} which cannot be split along {dim}"
62
63
64
65
66
67
    assert (tensor.shape[dim] >= num_chunks), f"Error, cannot split dim {dim} of size {tensor.shape[dim]} into \
                                              {num_chunks} chunks. Empty slices are currently not supported."
    
    # get split
    sections = compute_split_shapes(tensor.shape[dim], num_chunks)
    tensor_list = torch.split(tensor, sections, dim=dim)
Boris Bonev's avatar
Boris Bonev committed
68
69
70
    
    return tensor_list

71
72

def _transpose(tensor, dim0, dim1, dim1_split_sizes, group=None, async_op=False):
Boris Bonev's avatar
Boris Bonev committed
73
    # get comm params
Boris Bonev's avatar
Boris Bonev committed
74
    comm_size = dist.get_world_size(group=group)
75
    comm_rank = dist.get_rank(group=group)
Boris Bonev's avatar
Boris Bonev committed
76

Boris Bonev's avatar
Boris Bonev committed
77
    # split and local transposition
78
    tsplit = split_tensor_along_dim(tensor, num_chunks=comm_size, dim=dim0)
Thorsten Kurth's avatar
Thorsten Kurth committed
79
    x_send = [y.contiguous() for y in tsplit]
80
81
82
83
84
    x_send_shapes = [x.shape for x in x_send]
    x_recv = []
    x_shape = list(x_send_shapes[comm_rank])
    for dim1_len in dim1_split_sizes:
        x_shape[dim1] = dim1_len
Thorsten Kurth's avatar
Thorsten Kurth committed
85
86
        x_recv.append(torch.empty(x_shape, dtype=tensor.dtype, device=tensor.device))
        
Boris Bonev's avatar
Boris Bonev committed
87
88
    # global transposition
    req = dist.all_to_all(x_recv, x_send, group=group, async_op=async_op)
89
90
91

    # get dim0 split sizes
    dim0_split_sizes = [x[dim0] for x in x_send_shapes]
Boris Bonev's avatar
Boris Bonev committed
92
    
93
    return x_recv, dim0_split_sizes, req
Boris Bonev's avatar
Boris Bonev committed
94
95


Boris Bonev's avatar
Boris Bonev committed
96
class distributed_transpose_azimuth(torch.autograd.Function):
Boris Bonev's avatar
Boris Bonev committed
97
98

    @staticmethod
99
    def forward(ctx, x, dims, dim1_split_sizes):
100
101
        # WAR for a potential contig check torch bug for channels last contig tensors
        x = x.contiguous()
102
        xlist, dim0_split_sizes, _ = _transpose(x, dims[0], dims[1], dim1_split_sizes, group=azimuth_group())
Thorsten Kurth's avatar
Thorsten Kurth committed
103
        x = torch.cat(xlist, dim=dims[1]).contiguous()
104
105
        ctx.dims = dims
        ctx.dim0_split_sizes = dim0_split_sizes
Thorsten Kurth's avatar
Thorsten Kurth committed
106
        
Boris Bonev's avatar
Boris Bonev committed
107
        return x
Boris Bonev's avatar
Boris Bonev committed
108
109

    @staticmethod
Boris Bonev's avatar
Boris Bonev committed
110
    def backward(ctx, go):
111
112
        dims = ctx.dims
        dim0_split_sizes = ctx.dim0_split_sizes
113
114
        # WAR for a potential contig check torch bug for channels last contig tensors 
        go = go.contiguous()
115
        gilist, _, _ = _transpose(go, dims[1], dims[0], dim0_split_sizes, group=azimuth_group())
Thorsten Kurth's avatar
Thorsten Kurth committed
116
117
        gi = torch.cat(gilist, dim=dims[0]).contiguous()
        
118
        return gi, None, None
Boris Bonev's avatar
Boris Bonev committed
119
120

    
Boris Bonev's avatar
Boris Bonev committed
121
class distributed_transpose_polar(torch.autograd.Function):
Boris Bonev's avatar
Boris Bonev committed
122
123

    @staticmethod
124
    def forward(ctx, x, dim, dim1_split_sizes):
125
126
        # WAR for a potential contig check torch bug for channels last contig tensors 
        x = x.contiguous()
127
        xlist, dim0_split_sizes, _ = _transpose(x, dim[0], dim[1], dim1_split_sizes, group=polar_group())
Thorsten Kurth's avatar
Thorsten Kurth committed
128
        x = torch.cat(xlist, dim=dim[1]).contiguous()
Boris Bonev's avatar
Boris Bonev committed
129
        ctx.dim = dim
130
        ctx.dim0_split_sizes = dim0_split_sizes
Boris Bonev's avatar
Boris Bonev committed
131
        return x
Boris Bonev's avatar
Boris Bonev committed
132
133

    @staticmethod
Boris Bonev's avatar
Boris Bonev committed
134
135
    def backward(ctx, go):
        dim = ctx.dim
136
        dim0_split_sizes = ctx.dim0_split_sizes
137
138
        # WAR for a potential contig check torch bug for channels last contig tensors 
        go = go.contiguous()
139
        gilist, _, _ = _transpose(go, dim[1], dim[0], dim0_split_sizes, group=polar_group())
Thorsten Kurth's avatar
Thorsten Kurth committed
140
        gi = torch.cat(gilist, dim=dim[0]).contiguous()
141
        return gi, None, None
Boris Bonev's avatar
Boris Bonev committed
142

143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
    
# we need those additional primitives for distributed matrix multiplications
def _reduce(input_, use_fp32=True, group=None):
    """All-reduce the input tensor across model parallel group."""

    # Bypass the function if we are using only 1 GPU.
    if dist.get_world_size(group=group) == 1:
        return input_

    # make input contiguous
    input_ = input_.contiguous()
    
    # All-reduce.
    if use_fp32:
        dtype = input_.dtype
        inputf_ = input_.float()
        dist.all_reduce(inputf_, group=group)
        input_ = inputf_.to(dtype)
    else:
        dist.all_reduce(input_, group=group)
        
    return input_
Thorsten Kurth's avatar
Thorsten Kurth committed
165
    
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219

def _split(input_, dim_, group=None):
    """Split the tensor along its last dimension and keep the corresponding slice."""
    # Bypass the function if we are using only 1 GPU.
    comm_size = dist.get_world_size(group=group)
    if comm_size == 1:
        return input_
    
    # Split along last dimension.
    input_list = split_tensor_along_dim(input_, dim_, comm_size)
    
    # Note: torch.split does not create contiguous tensors by default.
    rank = dist.get_rank(group=group)
    output = input_list[rank].contiguous()
    
    return output


def _gather(input_, dim_, shapes_, group=None):
    """Gather unevenly split tensors across ranks"""
    
    comm_size = dist.get_world_size(group=group)

    if (shapes_ is not None) and (len(shapes_) != comm_size):
        raise ValueError()
    if dim_ >= input_.dim():
        raise ValueError()

    if comm_size == 1:
        return input_

    # make contiguous:
    input_ = input_.contiguous()
    input_shape = list(input_.shape)

    if shapes_ is not None:
        input_list = [None] * comm_size

        for src in range(comm_size):
            input_shape[dim_] = shapes_[src]
            input_list[src] = torch.empty(
                input_shape,
                dtype=input_.dtype,
                device=input_.device,
            )
    else:
        # assume equal shape on all ranks
        input_list = [torch.empty_like(input_) for _ in range(comm_size)]

    dist.all_gather(input_list, input_, group=group)

    output = torch.cat(input_list, dim=dim_).contiguous()

    return output
Boris Bonev's avatar
Boris Bonev committed
220
221


Thorsten Kurth's avatar
Thorsten Kurth committed
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
def _reduce_scatter(input_, dim_, use_fp32=True, group=None):
    """All-reduce the input tensor across model parallel group and scatter it back."""

    # Bypass the function if we are using only 1 GPU.
    if dist.get_world_size(group=group) == 1:
        return input_

    # make input contiguous
    comm_size = dist.get_world_size(group=group)
    comm_rank = dist.get_rank(group=group)
    input_list = [x.contiguous() for x in split_tensor_along_dim(input_, dim_, comm_size)]

    dtype = input_.dtype
    if (use_fp32 and (dtype != torch.float32)):
        input_list = [x.to(torch.float32) for x in input_list]

    # perform reduce_scatter
    output = torch.empty_like(input_list[comm_rank])
    dist.reduce_scatter(output, input_list, group=group)

    # convert dtype if necessary
    if use_fp32:
        output = output.to(dtype=dtype)

    return output


Boris Bonev's avatar
Boris Bonev committed
249
250
251
252
253
254
255
256
257
258
class _CopyToPolarRegion(torch.autograd.Function):
    """Split the input and keep only the corresponding chunk to the rank."""
    
    @staticmethod
    def symbolic(graph, input_):
        return input_
    
    @staticmethod
    def forward(ctx, input_):
        return input_
259
    
Boris Bonev's avatar
Boris Bonev committed
260
261
262
263
264
265
266
    @staticmethod
    def backward(ctx, grad_output):
        if is_distributed_polar():
            return _reduce(grad_output, group=polar_group())
        else:
            return grad_output, None
        
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
    
class _ScatterToPolarRegion(torch.autograd.Function):
    """Split the input and keep only the corresponding chunk to the rank."""

    @staticmethod
    def symbolic(graph, input_, dim_):
        return _split(input_, dim_, group=polar_group())

    @staticmethod
    def forward(ctx, input_, dim_):
        if is_distributed_polar():
            ctx.dim = dim_
            ctx.split_shapes = compute_split_shapes(
                input_.shape[dim_], polar_group_size()
            )
            return _split(input_, dim_, group=polar_group())
        else:
            return input_

    @staticmethod
    def backward(ctx, grad_output):
        if is_distributed_polar():
            return _gather(grad_output, ctx.dim, ctx.split_shapes, polar_group()), None
        else:
            return grad_output, None

Boris Bonev's avatar
Boris Bonev committed
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315

class _GatherFromPolarRegion(torch.autograd.Function):
    """Gather the input and keep it on the rank."""

    @staticmethod
    def symbolic(graph, input_, dim_, shapes_):
        return _gather(input_, dim_, shapes_, polar_group())

    @staticmethod
    def forward(ctx, input_, dim_, shapes_):
        if is_distributed_polar():
            ctx.dim = dim_
            return _gather(input_, dim_, shapes_, group=polar_group())
        else:
            return input_

    @staticmethod
    def backward(ctx, grad_output):
        if is_distributed_polar():
            return _split(grad_output, ctx.dim, group=polar_group()), None, None
        else:
            return grad_output, None, None

316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
    
class _ReduceFromPolarRegion(torch.autograd.Function):
    """All-reduce the input from the polar region."""
    
    @staticmethod
    def symbolic(graph, input_):
        if is_distributed_polar():
            return _reduce(input_, group=polar_group())
        else:
            return input_

    @staticmethod
    def forward(ctx, input_):
        if is_distributed_polar():
            return _reduce(input_, group=polar_group())
        else:
            return input_

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output

Boris Bonev's avatar
Boris Bonev committed
338

Thorsten Kurth's avatar
Thorsten Kurth committed
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
class _ReduceFromScatterToPolarRegion(torch.autograd.Function):
    """All-reduce the input from the polar region and scatter back to polar region."""

    @staticmethod
    def symbolic(graph, input_, dim_):
        if is_distributed_polar():
            return _reduce_scatter(input_, dim_, group=polar_group())
        else:
            return input_

    @staticmethod
    def forward(ctx, input_, dim_):
        if is_distributed_polar():
            ctx.dim = dim_
            ctx.split_shapes = compute_split_shapes(
                input_.shape[dim_], polar_group_size()
            )
            return _reduce_scatter(input_, dim_, group=polar_group())
        else:
            return input_

    @staticmethod
    def backward(ctx, grad_output):
        if is_distributed_polar():
            return _gather(grad_output, ctx.dim, ctx.split_shapes, polar_group()), None
        else:
            return grad_output, None


class _GatherFromCopyToPolarRegion(torch.autograd.Function):
    """Gather the input from the polar region and register BWD AR, basically the inverse of reduce-scatter"""

    @staticmethod
    def symbolic(graph, input_, dim_, shapes_):
        if is_distributed_polar():
            return _gather(input_, dim_, shapes_, polar_group())
        else:
            return input_

    @staticmethod
    def forward(ctx, input_, dim_, shapes_):
        if is_distributed_polar():
            ctx.dim = dim_
            return _gather(input_, dim_, shapes_, group=polar_group())
        else:
            return input_

    @staticmethod
    def backward(ctx, grad_output):
        if is_distributed_polar():
            return _reduce_scatter(grad_output, ctx.dim, use_fp32=True, group=polar_group()), None, None
        else:
            return grad_output, None, None
        

        
Boris Bonev's avatar
Boris Bonev committed
395
396
397
def copy_to_polar_region(input_):
    return _CopyToPolarRegion.apply(input_)
    
398
399
400
401
402
403
404
    
def reduce_from_polar_region(input_):
    return _ReduceFromPolarRegion.apply(input_)


def scatter_to_polar_region(input_, dim_):
    return _ScatterToPolarRegion.apply(input_, dim_)
Boris Bonev's avatar
Boris Bonev committed
405
406
407
408


def gather_from_polar_region(input_, dim_, shapes_):
    return _GatherFromPolarRegion.apply(input_, dim_, shapes_)
Thorsten Kurth's avatar
Thorsten Kurth committed
409
410
411
412
413
414
415
416


def reduce_from_scatter_to_polar_region(input_, dim_):
    return _ReduceFromScatterToPolarRegion.apply(input_, dim_)


def gather_from_copy_to_polar_region(input_, dim_, shapes_):
    return _GatherFromCopyToPolarRegion.apply(input_, dim_, shapes_)