primitives.py 18.4 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# 
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
31
from typing import List
Boris Bonev's avatar
Boris Bonev committed
32
33
34

import torch
import torch.distributed as dist
35
from torch.amp import custom_fwd, custom_bwd
Boris Bonev's avatar
Boris Bonev committed
36

37
from .utils import polar_group, azimuth_group, polar_group_size
38
from .utils import is_initialized, is_distributed_polar, is_distributed_azimuth
Boris Bonev's avatar
Boris Bonev committed
39

40
41
# helper routine to compute uneven splitting in balanced way:
def compute_split_shapes(size: int, num_chunks: int) -> List[int]:
apaaris's avatar
apaaris committed
42
43
44
45
46
47
48
49
50
51
52
53
54
    """
    Compute the split shapes for a given size and number of chunks.
    
    Parameters
    ----------
    size: int
        The size of the tensor to split
    
    Returns
    -------
    List[int]
        The split shapes
    """
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    
    # treat trivial case first
    if num_chunks == 1:
        return [size]
    
    # first, check if we can split using div-up to balance the load: 
    chunk_size = (size + num_chunks - 1) // num_chunks
    last_chunk_size = max(0, size - chunk_size * (num_chunks - 1))
    if last_chunk_size == 0:
        # in this case, the last shard would be empty, split with floor instead:
        chunk_size = size // num_chunks
        last_chunk_size = size - chunk_size * (num_chunks-1)

    # generate sections list
    sections = [chunk_size for _ in range(num_chunks - 1)] + [last_chunk_size]

    return sections

    
Boris Bonev's avatar
Boris Bonev committed
74
def split_tensor_along_dim(tensor, dim, num_chunks):
apaaris's avatar
apaaris committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    """
    Split a tensor along a given dimension into a given number of chunks.
    
    Parameters
    ----------
    tensor: torch.Tensor
        The tensor to split
    dim: int
        The dimension to split along
    num_chunks: int
        The number of chunks to split into
        
    Returns
    -------
    tensor_list: List[torch.Tensor]  
        The split tensors
    """
    
Boris Bonev's avatar
Boris Bonev committed
93
    assert dim < tensor.dim(), f"Error, tensor dimension is {tensor.dim()} which cannot be split along {dim}"
94
95
96
97
98
99
    assert (tensor.shape[dim] >= num_chunks), f"Error, cannot split dim {dim} of size {tensor.shape[dim]} into \
                                              {num_chunks} chunks. Empty slices are currently not supported."
    
    # get split
    sections = compute_split_shapes(tensor.shape[dim], num_chunks)
    tensor_list = torch.split(tensor, sections, dim=dim)
Boris Bonev's avatar
Boris Bonev committed
100
101
102
    
    return tensor_list

103
104

def _transpose(tensor, dim0, dim1, dim1_split_sizes, group=None, async_op=False):
apaaris's avatar
apaaris committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    """
    Transpose a tensor along two dimensions.
    
    Parameters
    ----------
    tensor: torch.Tensor
        The tensor to transpose
    dim0: int
        The first dimension to transpose
    dim1: int
        The second dimension to transpose
    dim1_split_sizes: List[int]
        The split sizes for the second dimension

    Returns
    -------
    tensor_list: List[torch.Tensor]
        The split tensors
    """
    
Boris Bonev's avatar
Boris Bonev committed
125
    # get comm params
Boris Bonev's avatar
Boris Bonev committed
126
    comm_size = dist.get_world_size(group=group)
127
    comm_rank = dist.get_rank(group=group)
Boris Bonev's avatar
Boris Bonev committed
128

Boris Bonev's avatar
Boris Bonev committed
129
    # split and local transposition
130
    tsplit = split_tensor_along_dim(tensor, num_chunks=comm_size, dim=dim0)
Thorsten Kurth's avatar
Thorsten Kurth committed
131
    x_send = [y.contiguous() for y in tsplit]
132
133
134
135
136
    x_send_shapes = [x.shape for x in x_send]
    x_recv = []
    x_shape = list(x_send_shapes[comm_rank])
    for dim1_len in dim1_split_sizes:
        x_shape[dim1] = dim1_len
Thorsten Kurth's avatar
Thorsten Kurth committed
137
138
        x_recv.append(torch.empty(x_shape, dtype=tensor.dtype, device=tensor.device))
        
Boris Bonev's avatar
Boris Bonev committed
139
140
    # global transposition
    req = dist.all_to_all(x_recv, x_send, group=group, async_op=async_op)
141
142
143

    # get dim0 split sizes
    dim0_split_sizes = [x[dim0] for x in x_send_shapes]
Boris Bonev's avatar
Boris Bonev committed
144
    
145
    return x_recv, dim0_split_sizes, req
Boris Bonev's avatar
Boris Bonev committed
146
147


Boris Bonev's avatar
Boris Bonev committed
148
class distributed_transpose_azimuth(torch.autograd.Function):
Boris Bonev's avatar
Boris Bonev committed
149
150

    @staticmethod
151
    @custom_fwd(device_type="cuda")
152
    def forward(ctx, x, dims, dim1_split_sizes):
apaaris's avatar
apaaris committed
153
154
155
        r"""
        Forward pass for distributed azimuthal transpose.
        
apaaris's avatar
apaaris committed
156
157
158
159
160
161
162
163
164
165
166
167
168
        Parameters
        ----------
        x: torch.Tensor
            The tensor to transpose
        dims: List[int]
            The dimensions to transpose
        dim1_split_sizes: List[int]
            The split sizes for the second dimension

        Returns
        -------
        x: torch.Tensor
            The transposed tensor
apaaris's avatar
apaaris committed
169
        """
170
        # WAR for a potential contig check torch bug for channels last contig tensors
171
        xlist, dim0_split_sizes, _ = _transpose(x, dims[0], dims[1], dim1_split_sizes, group=azimuth_group())
172
        x = torch.cat(xlist, dim=dims[1])
173
174
        ctx.dims = dims
        ctx.dim0_split_sizes = dim0_split_sizes
Thorsten Kurth's avatar
Thorsten Kurth committed
175
        
Boris Bonev's avatar
Boris Bonev committed
176
        return x
Boris Bonev's avatar
Boris Bonev committed
177
178

    @staticmethod
179
    @custom_bwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
180
    def backward(ctx, go):
apaaris's avatar
apaaris committed
181
182
183
        r"""
        Backward pass for distributed azimuthal transpose.
        
apaaris's avatar
apaaris committed
184
185
186
187
        Parameters
        ----------
        go: torch.Tensor
            The gradient of the output
apaaris's avatar
apaaris committed
188
        
apaaris's avatar
apaaris committed
189
190
191
192
        Returns
        -------
        gi: torch.Tensor
            The gradient of the input
apaaris's avatar
apaaris committed
193
        """
194
195
        dims = ctx.dims
        dim0_split_sizes = ctx.dim0_split_sizes
196
        # WAR for a potential contig check torch bug for channels last contig tensors 
197
        gilist, _, _ = _transpose(go, dims[1], dims[0], dim0_split_sizes, group=azimuth_group())
198
        gi = torch.cat(gilist, dim=dims[0])
Thorsten Kurth's avatar
Thorsten Kurth committed
199
        
200
        return gi, None, None
Boris Bonev's avatar
Boris Bonev committed
201
202

    
Boris Bonev's avatar
Boris Bonev committed
203
class distributed_transpose_polar(torch.autograd.Function):
Boris Bonev's avatar
Boris Bonev committed
204
205

    @staticmethod
206
    @custom_fwd(device_type="cuda")
207
    def forward(ctx, x, dim, dim1_split_sizes):
apaaris's avatar
apaaris committed
208
209
210
        r"""
        Forward pass for distributed polar transpose.
        
apaaris's avatar
apaaris committed
211
212
213
214
215
216
217
218
219
220
221
222
223
        Parameters
        ----------
        x: torch.Tensor
            The tensor to transpose
        dim: List[int]
            The dimensions to transpose
        dim1_split_sizes: List[int]
            The split sizes for the second dimension

        Returns
        -------
        x: torch.Tensor
            The transposed tensor
apaaris's avatar
apaaris committed
224
        """
225
        # WAR for a potential contig check torch bug for channels last contig tensors 
226
        xlist, dim0_split_sizes, _ = _transpose(x, dim[0], dim[1], dim1_split_sizes, group=polar_group())
227
        x = torch.cat(xlist, dim=dim[1])
Boris Bonev's avatar
Boris Bonev committed
228
        ctx.dim = dim
229
        ctx.dim0_split_sizes = dim0_split_sizes
Boris Bonev's avatar
Boris Bonev committed
230
        return x
Boris Bonev's avatar
Boris Bonev committed
231
232

    @staticmethod
233
    @custom_bwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
234
    def backward(ctx, go):
apaaris's avatar
apaaris committed
235
236
237
        r"""
        Backward pass for distributed polar transpose.
        
apaaris's avatar
apaaris committed
238
239
240
241
        Parameters
        ----------
        go: torch.Tensor
            The gradient of the output
apaaris's avatar
apaaris committed
242
        
apaaris's avatar
apaaris committed
243
244
245
246
        Returns
        -------
        gi: torch.Tensor
            The gradient of the input
apaaris's avatar
apaaris committed
247
        """
Boris Bonev's avatar
Boris Bonev committed
248
        dim = ctx.dim
249
        dim0_split_sizes = ctx.dim0_split_sizes
250
        # WAR for a potential contig check torch bug for channels last contig tensors 
251
        gilist, _, _ = _transpose(go, dim[1], dim[0], dim0_split_sizes, group=polar_group())
252
        gi = torch.cat(gilist, dim=dim[0])
253
        return gi, None, None
Boris Bonev's avatar
Boris Bonev committed
254

255
256
257
258
259
260
261
262
263
264
265
266
267
    
# we need those additional primitives for distributed matrix multiplications
def _reduce(input_, use_fp32=True, group=None):
    """All-reduce the input tensor across model parallel group."""

    # Bypass the function if we are using only 1 GPU.
    if dist.get_world_size(group=group) == 1:
        return input_
    
    # All-reduce.
    if use_fp32:
        dtype = input_.dtype
        inputf_ = input_.float()
268
        inputf_ = inputf_.contiguous()
269
270
271
        dist.all_reduce(inputf_, group=group)
        input_ = inputf_.to(dtype)
    else:
272
        input_ = input_.contiguous()
273
274
275
        dist.all_reduce(input_, group=group)
        
    return input_
Thorsten Kurth's avatar
Thorsten Kurth committed
276
    
277
278
279
280
281
282
283
284
285
286
287
288
289

def _split(input_, dim_, group=None):
    """Split the tensor along its last dimension and keep the corresponding slice."""
    # Bypass the function if we are using only 1 GPU.
    comm_size = dist.get_world_size(group=group)
    if comm_size == 1:
        return input_
    
    # Split along last dimension.
    input_list = split_tensor_along_dim(input_, dim_, comm_size)
    
    # Note: torch.split does not create contiguous tensors by default.
    rank = dist.get_rank(group=group)
290
    output = input_list[rank]
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
    
    return output


def _gather(input_, dim_, shapes_, group=None):
    """Gather unevenly split tensors across ranks"""
    
    comm_size = dist.get_world_size(group=group)

    if (shapes_ is not None) and (len(shapes_) != comm_size):
        raise ValueError()
    if dim_ >= input_.dim():
        raise ValueError()

    if comm_size == 1:
        return input_

    # make contiguous:
    input_ = input_.contiguous()
    input_shape = list(input_.shape)

    if shapes_ is not None:
313
        input_list = []
314
315
        for src in range(comm_size):
            input_shape[dim_] = shapes_[src]
316
            input_list.append(torch.empty(input_shape, dtype=input_.dtype, device=input_.device))
317
318
319
320
321
322
    else:
        # assume equal shape on all ranks
        input_list = [torch.empty_like(input_) for _ in range(comm_size)]

    dist.all_gather(input_list, input_, group=group)

323
    output = torch.cat(input_list, dim=dim_)
324
325

    return output
Boris Bonev's avatar
Boris Bonev committed
326
327


Thorsten Kurth's avatar
Thorsten Kurth committed
328
329
330
331
332
333
334
335
336
337
def _reduce_scatter(input_, dim_, use_fp32=True, group=None):
    """All-reduce the input tensor across model parallel group and scatter it back."""

    # Bypass the function if we are using only 1 GPU.
    if dist.get_world_size(group=group) == 1:
        return input_

    # make input contiguous
    comm_size = dist.get_world_size(group=group)
    comm_rank = dist.get_rank(group=group)
338
    input_list = split_tensor_along_dim(input_, dim_, comm_size)
Thorsten Kurth's avatar
Thorsten Kurth committed
339
340
341
342
343

    dtype = input_.dtype
    if (use_fp32 and (dtype != torch.float32)):
        input_list = [x.to(torch.float32) for x in input_list]

344
345
    input_list = [x.contiguous() for x in input_list]

Thorsten Kurth's avatar
Thorsten Kurth committed
346
347
348
349
350
351
352
353
354
355
356
    # perform reduce_scatter
    output = torch.empty_like(input_list[comm_rank])
    dist.reduce_scatter(output, input_list, group=group)

    # convert dtype if necessary
    if use_fp32:
        output = output.to(dtype=dtype)

    return output


Boris Bonev's avatar
Boris Bonev committed
357
358
359
360
361
362
363
class _CopyToPolarRegion(torch.autograd.Function):
    
    @staticmethod
    def symbolic(graph, input_):
        return input_
    
    @staticmethod
364
    @custom_fwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
365
    def forward(ctx, input_):
apaaris's avatar
apaaris committed
366
367
368
        r"""
        Forward pass for copying to polar region.
        
apaaris's avatar
apaaris committed
369
370
371
372
        Parameters
        ----------
        input_: torch.Tensor
            The tensor to copy
apaaris's avatar
apaaris committed
373
        
apaaris's avatar
apaaris committed
374
375
376
377
        Returns
        -------
        input_: torch.Tensor
            The tensor to copy
apaaris's avatar
apaaris committed
378
        """
Boris Bonev's avatar
Boris Bonev committed
379
        return input_
380
    
Boris Bonev's avatar
Boris Bonev committed
381
    @staticmethod
382
    @custom_bwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
383
    def backward(ctx, grad_output):
apaaris's avatar
apaaris committed
384
385
386
        r"""
        Backward pass for copying to polar region.
        
apaaris's avatar
apaaris committed
387
388
389
390
        Parameters
        ----------
        grad_output: torch.Tensor
            The gradient of the output
apaaris's avatar
apaaris committed
391
        
apaaris's avatar
apaaris committed
392
393
394
395
        Returns
        -------
        grad_output: torch.Tensor
            The gradient of the output
apaaris's avatar
apaaris committed
396
        """
Boris Bonev's avatar
Boris Bonev committed
397
398
399
400
        if is_distributed_polar():
            return _reduce(grad_output, group=polar_group())
        else:
            return grad_output, None
401
402
403
404
405
406
407
408
409
410
411


class _CopyToAzimuthRegion(torch.autograd.Function):

    @staticmethod
    def symbolic(graph, input_):
        return input_

    @staticmethod
    @custom_fwd(device_type="cuda")
    def forward(ctx, input_):
apaaris's avatar
apaaris committed
412
413
414
        r"""
        Forward pass for copying to azimuth region.
        
apaaris's avatar
apaaris committed
415
416
417
418
        Parameters
        ----------
        input_: torch.Tensor
            The tensor to copy
apaaris's avatar
apaaris committed
419
        
apaaris's avatar
apaaris committed
420
421
422
423
        Returns
        -------
        input_: torch.Tensor
            The tensor to copy
apaaris's avatar
apaaris committed
424
        """
425
426
427
428
429
        return input_

    @staticmethod
    @custom_bwd(device_type="cuda")
    def backward(ctx, grad_output):
apaaris's avatar
apaaris committed
430
431
432
        r"""
        Backward pass for copying to azimuth region.
        
apaaris's avatar
apaaris committed
433
434
435
436
        Parameters
        ----------
        grad_output: torch.Tensor
            The gradient of the output
apaaris's avatar
apaaris committed
437
        
apaaris's avatar
apaaris committed
438
439
440
441
        Returns
        -------
        grad_output: torch.Tensor
            The gradient of the output
apaaris's avatar
apaaris committed
442
        """
443
444
445
446
447
448
        if is_distributed_azimuth():
            return _reduce(grad_output, group=azimuth_group())
        else:
            return grad_output, None


449
450
451
452
453
454
455
class _ScatterToPolarRegion(torch.autograd.Function):

    @staticmethod
    def symbolic(graph, input_, dim_):
        return _split(input_, dim_, group=polar_group())

    @staticmethod
456
    @custom_fwd(device_type="cuda")
457
458
459
460
461
462
463
464
465
466
467
    def forward(ctx, input_, dim_):
        if is_distributed_polar():
            ctx.dim = dim_
            ctx.split_shapes = compute_split_shapes(
                input_.shape[dim_], polar_group_size()
            )
            return _split(input_, dim_, group=polar_group())
        else:
            return input_

    @staticmethod
468
    @custom_bwd(device_type="cuda")
469
470
471
472
473
474
    def backward(ctx, grad_output):
        if is_distributed_polar():
            return _gather(grad_output, ctx.dim, ctx.split_shapes, polar_group()), None
        else:
            return grad_output, None

Boris Bonev's avatar
Boris Bonev committed
475
476

class _GatherFromPolarRegion(torch.autograd.Function):
477

Boris Bonev's avatar
Boris Bonev committed
478
479
480
481
482
    @staticmethod
    def symbolic(graph, input_, dim_, shapes_):
        return _gather(input_, dim_, shapes_, polar_group())

    @staticmethod
483
    @custom_fwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
484
485
486
487
488
489
490
491
    def forward(ctx, input_, dim_, shapes_):
        if is_distributed_polar():
            ctx.dim = dim_
            return _gather(input_, dim_, shapes_, group=polar_group())
        else:
            return input_

    @staticmethod
492
    @custom_bwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
493
494
495
496
497
498
    def backward(ctx, grad_output):
        if is_distributed_polar():
            return _split(grad_output, ctx.dim, group=polar_group()), None, None
        else:
            return grad_output, None, None

499
500
501
502
503
504
505
506
507
508
509
    
class _ReduceFromPolarRegion(torch.autograd.Function):
    
    @staticmethod
    def symbolic(graph, input_):
        if is_distributed_polar():
            return _reduce(input_, group=polar_group())
        else:
            return input_

    @staticmethod
510
    @custom_fwd(device_type="cuda")
511
512
513
514
515
516
517
    def forward(ctx, input_):
        if is_distributed_polar():
            return _reduce(input_, group=polar_group())
        else:
            return input_

    @staticmethod
518
    @custom_bwd(device_type="cuda")
519
520
521
    def backward(ctx, grad_output):
        return grad_output

522
523
    
class _ReduceFromAzimuthRegion(torch.autograd.Function):
524
 
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
    @staticmethod
    def symbolic(graph, input_):
        if is_distributed_azimuth():
            return _reduce(input_, group=azimuth_group())
        else:
            return input_

    @staticmethod
    @custom_fwd(device_type="cuda")
    def forward(ctx, input_):
        if is_distributed_azimuth():
            return _reduce(input_, group=azimuth_group())
        else:
            return input_

    @staticmethod
    @custom_bwd(device_type="cuda")
    def backward(ctx, grad_output):
        return grad_output

Boris Bonev's avatar
Boris Bonev committed
545

Thorsten Kurth's avatar
Thorsten Kurth committed
546
class _ReduceFromScatterToPolarRegion(torch.autograd.Function):
547

Thorsten Kurth's avatar
Thorsten Kurth committed
548
549
550
551
552
553
554
555
    @staticmethod
    def symbolic(graph, input_, dim_):
        if is_distributed_polar():
            return _reduce_scatter(input_, dim_, group=polar_group())
        else:
            return input_

    @staticmethod
556
    @custom_fwd(device_type="cuda")
Thorsten Kurth's avatar
Thorsten Kurth committed
557
558
559
560
561
562
563
564
565
566
567
    def forward(ctx, input_, dim_):
        if is_distributed_polar():
            ctx.dim = dim_
            ctx.split_shapes = compute_split_shapes(
                input_.shape[dim_], polar_group_size()
            )
            return _reduce_scatter(input_, dim_, group=polar_group())
        else:
            return input_

    @staticmethod
568
    @custom_bwd(device_type="cuda")
Thorsten Kurth's avatar
Thorsten Kurth committed
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
    def backward(ctx, grad_output):
        if is_distributed_polar():
            return _gather(grad_output, ctx.dim, ctx.split_shapes, polar_group()), None
        else:
            return grad_output, None


class _GatherFromCopyToPolarRegion(torch.autograd.Function):

    @staticmethod
    def symbolic(graph, input_, dim_, shapes_):
        if is_distributed_polar():
            return _gather(input_, dim_, shapes_, polar_group())
        else:
            return input_

    @staticmethod
586
    @custom_fwd(device_type="cuda")
Thorsten Kurth's avatar
Thorsten Kurth committed
587
588
589
590
591
592
593
594
    def forward(ctx, input_, dim_, shapes_):
        if is_distributed_polar():
            ctx.dim = dim_
            return _gather(input_, dim_, shapes_, group=polar_group())
        else:
            return input_

    @staticmethod
595
    @custom_bwd(device_type="cuda")
Thorsten Kurth's avatar
Thorsten Kurth committed
596
597
598
599
600
601
602
603
    def backward(ctx, grad_output):
        if is_distributed_polar():
            return _reduce_scatter(grad_output, ctx.dim, use_fp32=True, group=polar_group()), None, None
        else:
            return grad_output, None, None
        

        
Boris Bonev's avatar
Boris Bonev committed
604
605
def copy_to_polar_region(input_):
    return _CopyToPolarRegion.apply(input_)
606
607
608
609

def copy_to_azimuth_region(input_):
    return _CopyToAzimuthRegion.apply(input_)
        
610
611
612
def reduce_from_polar_region(input_):
    return _ReduceFromPolarRegion.apply(input_)

613
614
def reduce_from_azimuth_region(input_):
    return _ReduceFromAzimuthRegion.apply(input_)
615
616
617

def scatter_to_polar_region(input_, dim_):
    return _ScatterToPolarRegion.apply(input_, dim_)
Boris Bonev's avatar
Boris Bonev committed
618
619
620

def gather_from_polar_region(input_, dim_, shapes_):
    return _GatherFromPolarRegion.apply(input_, dim_, shapes_)
Thorsten Kurth's avatar
Thorsten Kurth committed
621
622
623
624
625
626

def reduce_from_scatter_to_polar_region(input_, dim_):
    return _ReduceFromScatterToPolarRegion.apply(input_, dim_)

def gather_from_copy_to_polar_region(input_, dim_, shapes_):
    return _GatherFromCopyToPolarRegion.apply(input_, dim_, shapes_)