Unverified Commit 29e7fb68 authored by Boris Bonev's avatar Boris Bonev Committed by GitHub
Browse files

Tkurth/cuda disco (#38)



* adding cuda kernels for disco conv

* making psi_idx an attribute

* adding license headers

* adding author files

* reorganizing files

* draft implementation

* added conditional installation to setup.py

* formatting changes

* removing triton kernel in DISCO convolution

* updated github actions

* updated Readme and changelog

* adding another guard for the cuda installation

* renaming the  cuda extension

* simplifying setup.py

* minor bugfix

* Bbonev/cuda disco cleanup (#32)

* cleanup of disco convolutions based on CUDA extension

* fixing unittest

* changing version to experimental 0.7.0a

* initial rewrite of the distributed convolution with CUDA

* fixing streams

* need to fix install options

* fixing streams

* undid setup.py changes

* reset setup.py

* including CUDAStream

* adjusted the precomputation of theta_cutoff. If you rely on this, your models will not be backwards-compatible.

* adjusting theta_cutoff in the unittest

* adding newly refactored kernels for faster compile

* Tkurth/cuda disco distributed fix (#34)

* attempt to make disco distributed

* working distributed convolutions

* fixing distributed conv

* working distributed disco

* removing irrelevant extra argument

* using stream functions from at instead of c10

* using stream functions from at instead of c10, small fix

* Bbonev/disc even filters (#35)

* initial working commit with new convention of counting collocation points across the diameter instead of across the radius

* fixed a bug in the computation of the even kernels

* changing heuristic for computing theta_cutoff

* Fixing unittest

* Readability improvements

* reworked normalization of filter basis functions

* implemented discrete normalization of disco filters

* relaxing tolerances in convolution unit test

* bugfix to correctly support unequal scale factors in latitudes and longitudes

* hotfix to a bug in the imports

* Bbonev/distributed disco refactor (#37)

* cleaned up normalization code in convolution

* formatting changes in distributed convolution

* Fixing default theta_cutoff to be the same in distributed and local case

* fixed distributed convolution to support the same normalization as non-distributed one

* readability improvements

* fixed initial scale of convolution parameter weights and fixed naming of the normalization routine

* Updated Readme.md

* added comment in Dockerfile regarding older architectures

---------
Co-authored-by: default avatarThorsten Kurth <tkurth@nvidia.com>
Co-authored-by: default avatarBoris Bonev <bbonev@nvidia.com>
parent 214fa40a
......@@ -31,6 +31,8 @@
import abc
from typing import List, Tuple, Union, Optional
from itertools import accumulate
from warnings import warn
import math
......@@ -40,28 +42,37 @@ import torch.nn as nn
from functools import partial
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes
from torch_harmonics._disco_convolution import (
_disco_s2_contraction_torch,
_disco_s2_transpose_contraction_torch,
_disco_s2_contraction_triton,
_disco_s2_transpose_contraction_triton,
)
from torch_harmonics._disco_convolution import _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch
from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda
from torch_harmonics.convolution import (
_compute_support_vals_isotropic,
_compute_support_vals_anisotropic,
_precompute_convolution_tensor_2d,
_normalize_convolution_tensor_s2,
DiscreteContinuousConv,
)
from torch_harmonics.distributed import polar_group_size, azimuth_group_size
from torch_harmonics.distributed import distributed_transpose_azimuth, distributed_transpose_polar
from torch_harmonics.distributed import reduce_from_polar_region, scatter_to_polar_region
from torch_harmonics.distributed import copy_to_polar_region, reduce_from_polar_region, scatter_to_polar_region, gather_from_polar_region
from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank
from torch_harmonics.distributed import compute_split_shapes, split_tensor_along_dim
def _precompute_distributed_convolution_tensor_s2(in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular",
theta_cutoff=0.01 * math.pi, distributed_mode="columns"):
# import custom C++/CUDA extensions
from disco_helpers import preprocess_psi
try:
import disco_cuda_extension
_cuda_extension_available = True
except ImportError as err:
disco_cuda_extension = None
_cuda_extension_available = False
def _precompute_distributed_convolution_tensor_s2(
in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi, transpose_normalization=False
):
"""
Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al.
......@@ -70,50 +81,36 @@ def _precompute_distributed_convolution_tensor_s2(in_shape, out_shape, kernel_sh
The rotation of the Euler angles uses the YZY convention, which applied to the northpole $(0,0,1)^T$ yields
$$
Y(\alpha) Z(\beta) Y(\gamma) n =
{\begin{bmatrix}
{\begin{bmatrix}
\cos(\gamma)\sin(\alpha) + \cos(\alpha)\cos(\beta)\sin(\gamma) \\
\sin(\beta)\sin(\gamma) \\
\cos(\alpha)\cos(\gamma)-\cos(\beta)\sin(\alpha)\sin(\gamma)
\end{bmatrix}}
$$
This is the distributed version: the matrix can either be split column- or row-wise. Column-wise seems better because the kernel has a lot of summation
atomics concerning the row reductions, which we can combine in a single allreduce.
"""
assert len(in_shape) == 2
assert len(out_shape) == 2
if len(kernel_shape) == 1:
kernel_handle = partial(_compute_support_vals_isotropic, nr=kernel_shape[0], r_cutoff=theta_cutoff, norm="s2")
kernel_handle = partial(_compute_support_vals_isotropic, nr=kernel_shape[0], r_cutoff=theta_cutoff)
elif len(kernel_shape) == 2:
kernel_handle = partial(_compute_support_vals_anisotropic, nr=kernel_shape[0], nphi=kernel_shape[1], r_cutoff=theta_cutoff, norm="s2")
kernel_handle = partial(_compute_support_vals_anisotropic, nr=kernel_shape[0], nphi=kernel_shape[1], r_cutoff=theta_cutoff)
else:
raise ValueError("kernel_shape should be either one- or two-dimensional.")
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
lats_in, _ = _precompute_latitudes(nlat_in, grid=grid_in)
lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in)
lats_in = torch.from_numpy(lats_in).float()
lats_out, _ = _precompute_latitudes(nlat_out, grid=grid_out)
lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out)
lats_out = torch.from_numpy(lats_out).float()
# split the latitude vector:
comm_size_polar = polar_group_size()
comm_rank_polar = polar_group_rank()
if distributed_mode == "columns":
lats_in = split_tensor_along_dim(lats_in, dim=0, num_chunks=comm_size_polar)[comm_rank_polar]
elif distributed_mode == "rows":
lats_out = split_tensor_along_dim(lats_out, dim=0, num_chunks=comm_size_polar)[comm_rank_polar]
nlat_out = lats_out.shape[0]
else:
raise NotImplementedError(f"Error, unknown distributed mode {distributed_mode}.")
# compute the phi differences
# It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1]
out_idx = []
out_vals = []
for t in range(nlat_out):
......@@ -151,11 +148,36 @@ def _precompute_distributed_convolution_tensor_s2(in_shape, out_shape, kernel_sh
out_vals.append(vals)
# concatenate the indices and values
out_idx = torch.cat(out_idx, dim=-1)
out_vals = torch.cat(out_vals, dim=-1)
out_idx = torch.cat(out_idx, dim=-1).to(torch.long).contiguous()
out_vals = torch.cat(out_vals, dim=-1).to(torch.float32).contiguous()
# perform the normalization over the entire psi matrix
if transpose_normalization:
quad_weights = 2.0 * torch.pi * torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in
else:
quad_weights = 2.0 * torch.pi * torch.from_numpy(win).float().reshape(-1, 1) / nlon_in
out_vals = _normalize_convolution_tensor_s2(out_idx, out_vals, in_shape, out_shape, kernel_shape, quad_weights, transpose_normalization=transpose_normalization)
# TODO: this part can be split off into it's own function
# split the latitude indices:
comm_size_polar = polar_group_size()
comm_rank_polar = polar_group_rank()
split_shapes = compute_split_shapes(nlat_in, num_chunks=comm_size_polar)
offsets = [0] + list(accumulate(split_shapes))
start_idx = offsets[comm_rank_polar]
end_idx = offsets[comm_rank_polar+1]
# once normalization is done we can throw away the entries which correspond to input latitudes we do not care about
lats = out_idx[2] // nlon_in
lons = out_idx[2] % nlon_in
ilats = torch.argwhere((lats < end_idx) & (lats >= start_idx)).squeeze()
out_vals = out_vals[ilats]
# for the indices we need to recompute them to refer to local indices of the input tenor
out_idx = torch.stack([out_idx[0, ilats], out_idx[1, ilats], (lats[ilats]-start_idx) * nlon_in + lons[ilats]], dim=0)
return out_idx, out_vals
class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
"""
Distributed version of Discrete-continuous convolutions (DISCO) on the 2-Sphere as described in [1].
......@@ -197,7 +219,7 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
# compute theta cutoff based on the bandlimit of the input field
if theta_cutoff is None:
theta_cutoff = (self.kernel_shape[0] + 1) * torch.pi / float(self.nlat_in - 1)
theta_cutoff = (self.kernel_shape[0] + 1) / 2 * torch.pi / float(self.nlat_out - 1)
if theta_cutoff <= 0.0:
raise ValueError("Error, theta_cutoff has to be positive.")
......@@ -205,74 +227,73 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
# integration weights
_, wgl = _precompute_latitudes(self.nlat_in, grid=grid_in)
quad_weights = 2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / float(self.nlon_in)
# Note that the psi matrix is of shape nlat_out x nlat_in * nlon_in. Since the contraction in nlon direction is a convolution,
# we will keep local to all nodes and split the computation up along nlat. We further split the input dim because this reduces the number
# of atomic reduction calls inside the actual kernel
distributed_mode = "columns"
# set local shapes according to distributed mode:
if distributed_mode == "columns":
self.nlat_in_local = self.lat_in_shapes[self.comm_rank_polar]
self.nlat_out_local = self.nlat_out
elif distributed_mode == "rows":
self.nlat_in_local = self.nlat_in
self.nlat_out_local = self.lat_out_shapes[self.comm_rank_polar]
else:
raise NotImplementedError(f"Error, unknown distributed mode {distributed_mode}.")
idx, vals = _precompute_distributed_convolution_tensor_s2(in_shape, out_shape, self.kernel_shape, grid_in=grid_in, grid_out=grid_out,
theta_cutoff=theta_cutoff, distributed_mode=distributed_mode)
# split the weight tensor as well
if distributed_mode == "columns":
quad_weights = split_tensor_along_dim(quad_weights, dim=0, num_chunks=self.comm_size_polar)[self.comm_rank_polar]
self.nlat_in_local = self.lat_in_shapes[self.comm_rank_polar]
self.nlat_out_local = self.nlat_out
idx, vals = _precompute_distributed_convolution_tensor_s2(
in_shape, out_shape, self.kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff, transpose_normalization=False
)
# split the weight tensor as well
quad_weights = split_tensor_along_dim(quad_weights, dim=0, num_chunks=self.comm_size_polar)[self.comm_rank_polar]
self.register_buffer("quad_weights", quad_weights, persistent=False)
self.register_buffer("psi_idx", idx, persistent=False)
# sort the values
ker_idx = idx[0, ...].contiguous()
row_idx = idx[1, ...].contiguous()
col_idx = idx[2, ...].contiguous()
roff_idx = preprocess_psi(self.kernel_size, self.nlat_out_local, ker_idx, row_idx, col_idx, vals)
# preprocessed data-structure for GPU kernel
self.register_buffer("psi_roff_idx", roff_idx, persistent=False)
self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
self.register_buffer("psi_row_idx", row_idx, persistent=False)
self.register_buffer("psi_col_idx", col_idx, persistent=False)
self.register_buffer("psi_vals", vals, persistent=False)
@property
def psi_idx(self):
return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()
def get_psi(self):
psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_out_local, self.nlat_in_local * self.nlon_in)).coalesce()
return psi
def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tensor:
def forward(self, x: torch.Tensor) -> torch.Tensor:
# store number of channels
num_chans = x.shape[1]
#print("input shape", x.shape)
# h and w is split. First we make w local by transposing into channel dim
if self.comm_size_azimuth > 1:
x = distributed_transpose_azimuth.apply(x, (1, -1), self.lon_in_shapes)
#print("transposed shape", x.shape)
# pre-multiply x with the quadrature weights
x = self.quad_weights * x
#print("multiplied shape", x.shape)
psi = self.get_psi()
if x.is_cuda and _cuda_extension_available:
x = _disco_s2_contraction_cuda(
x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out_local, self.nlon_out
)
else:
if x.is_cuda:
warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
#print("psi shape", psi.shape)
psi = self.get_psi()
if x.is_cuda and use_triton_kernel:
x = _disco_s2_contraction_triton(x, psi, self.nlon_out)
else:
x = _disco_s2_contraction_torch(x, psi, self.nlon_out)
#print("psi * x shape", x.shape)
# allreduce over latitudes: h is still local
x = reduce_from_polar_region(x)
#print("reduced shape", x.shape)
# split tensor along latitudes: h is split
x = scatter_to_polar_region(x, -2)
#print("scattered shape", x.shape)
# now we can transpose back the result, so that lon is split and channels are local
if self.comm_size_azimuth > 1:
chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
......@@ -284,7 +305,7 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
# do weight multiplication
out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
out = out.reshape(out.shape[0], -1, out.shape[-2], out.shape[-1])
out = out.reshape(out.shape[0], -1, H, W)
if self.bias is not None:
out = out + self.bias.reshape(1, -1, 1, 1)
......@@ -331,7 +352,7 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
# bandlimit
if theta_cutoff is None:
theta_cutoff = (self.kernel_shape[0] + 1) * torch.pi / float(self.nlat_in - 1)
theta_cutoff = (self.kernel_shape[0] + 1) / 2 * torch.pi / float(self.nlat_in - 1)
if theta_cutoff <= 0.0:
raise ValueError("Error, theta_cutoff has to be positive.")
......@@ -342,44 +363,41 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
# Note that the psi matrix is of shape nlat_out x nlat_in * nlon_in. Since the contraction in nlon direction is a convolution,
# we will keep local to all nodes and split the computation up along nlat. We further split the input dim because this reduces the number
# of atomic reduction calls inside the actual kernel
distributed_mode = "columns"
# of atomic reduction calls inside the actual kernel
# set local shapes according to distributed mode:
if distributed_mode == "columns":
self.nlat_in_local = self.lat_in_shapes[self.comm_rank_polar]
self.nlat_out_local = self.nlat_out
elif distributed_mode == "rows":
self.nlat_in_local = self.nlat_in
self.nlat_out_local = self.lat_out_shapes[self.comm_rank_polar]
else:
raise NotImplementedError(f"Error, unknown distributed mode {distributed_mode}.")
self.nlat_in_local = self.nlat_in
self.nlat_out_local = self.lat_out_shapes[self.comm_rank_polar]
# switch in_shape and out_shape since we want transpose conv
# distributed mode here is swapped because of the transpose
idx, vals = _precompute_distributed_convolution_tensor_s2(out_shape, in_shape, self.kernel_shape, grid_in=grid_out, grid_out=grid_in,
theta_cutoff=theta_cutoff, distributed_mode="rows" if distributed_mode else "columns")
## do partial transpose
## we do a semi-transposition to faciliate the computation
#tout = iidx[2] // self.nlon_out
#pout = iidx[2] % self.nlon_out
## flip the axis of longitudes
#pout = self.nlon_out - 1 - pout
#tin = iidx[1]
#idx = torch.stack([iidx[0], tout, tin*self.nlon_out + pout], dim=0)
# split the weight tensor as well
if distributed_mode == "columns":
quad_weights = split_tensor_along_dim(quad_weights, dim=0, num_chunks=self.comm_size_polar)[self.comm_rank_polar]
idx, vals = _precompute_distributed_convolution_tensor_s2(
out_shape, in_shape, self.kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff, transpose_normalization=True
)
# register all buffers
# split the weight tensor as well
quad_weights = split_tensor_along_dim(quad_weights, dim=0, num_chunks=self.comm_size_polar)[self.comm_rank_polar]
self.register_buffer("quad_weights", quad_weights, persistent=False)
self.register_buffer("psi_idx", idx, persistent=False)
# sort the values
ker_idx = idx[0, ...].contiguous()
row_idx = idx[1, ...].contiguous()
col_idx = idx[2, ...].contiguous()
roff_idx = preprocess_psi(self.kernel_size, self.nlat_in_local, ker_idx, row_idx, col_idx, vals)
# preprocessed data-structure for GPU kernel
self.register_buffer("psi_roff_idx", roff_idx, persistent=False)
self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
self.register_buffer("psi_row_idx", row_idx, persistent=False)
self.register_buffer("psi_col_idx", col_idx, persistent=False)
self.register_buffer("psi_vals", vals, persistent=False)
def get_psi(self, use_triton_kernel=True):
if not use_triton_kernel:
@property
def psi_idx(self):
return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()
def get_psi(self, semi_transposed: bool = False):
if semi_transposed:
# do partial transpose
# we do a semi-transposition to faciliate the computation
tout = self.psi_idx[2] // self.nlon_out
......@@ -387,41 +405,45 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
# flip the axis of longitudes
pout = self.nlon_out - 1 - pout
tin = self.psi_idx[1]
idx = torch.stack([self.psi_idx[0], tout, tin*self.nlon_out + pout], dim=0)
idx = torch.stack([self.psi_idx[0], tout, tin * self.nlon_out + pout], dim=0)
psi = torch.sparse_coo_tensor(idx, self.psi_vals, size=(self.kernel_size, self.nlat_out_local, self.nlat_in_local * self.nlon_out)).coalesce()
else:
psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_in_local, self.nlat_out_local * self.nlon_out)).coalesce()
return psi
def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tensor:
def forward(self, x: torch.Tensor) -> torch.Tensor:
# extract shape
B, C, H, W = x.shape
x = x.reshape(B, self.groups, self.groupsize, H, W)
# do weight multiplication
x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
x = x.reshape(x.shape[0], -1, x.shape[-3], x.shape[-2], x.shape[-1])
x = x.reshape(B, -1, x.shape[-3], H, W)
num_chans = x.shape[1]
# transpose such that lon is local, channels are split
if self.comm_size_azimuth > 1:
x = distributed_transpose_azimuth.apply(x, (1, -1), self.lon_in_shapes)
# pre-multiply x with the quadrature weights
# multiply weights
x = self.quad_weights * x
if x.is_cuda and use_triton_kernel:
psi = self.get_psi(True)
out = _disco_s2_transpose_contraction_triton(x, psi, self.nlon_out)
# we need to gather the input tensor
x = gather_from_polar_region(x, -2, self.lat_in_shapes)
# register allreduce for bwd pass
x = copy_to_polar_region(x)
if x.is_cuda and _cuda_extension_available:
out = _disco_s2_transpose_contraction_cuda(
x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out_local, self.nlon_out
)
else:
psi = self.get_psi(False)
if x.is_cuda:
warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
psi = self.get_psi(semi_transposed=True)
out = _disco_s2_transpose_contraction_torch(x, psi, self.nlon_out)
# allreduce over latitudes: h is still local
out = reduce_from_polar_region(out)
# split tensor along latitudes: h is split
out = scatter_to_polar_region(out, -2)
# now we can transpose back the result, so that lon is split and channels are local
if self.comm_size_azimuth > 1:
......@@ -432,4 +454,3 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
out = out + self.bias.reshape(1, -1, 1, 1)
return out
......@@ -230,7 +230,26 @@ def _gather(input_, dim_, shapes_, group=None):
output = torch.cat(input_list, dim=dim_).contiguous()
return output
class _CopyToPolarRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chunk to the rank."""
@staticmethod
def symbolic(graph, input_):
return input_
@staticmethod
def forward(ctx, input_):
return input_
@staticmethod
def backward(ctx, grad_output):
if is_distributed_polar():
return _reduce(grad_output, group=polar_group())
else:
return grad_output, None
class _ScatterToPolarRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chunk to the rank."""
......@@ -257,6 +276,29 @@ class _ScatterToPolarRegion(torch.autograd.Function):
else:
return grad_output, None
class _GatherFromPolarRegion(torch.autograd.Function):
"""Gather the input and keep it on the rank."""
@staticmethod
def symbolic(graph, input_, dim_, shapes_):
return _gather(input_, dim_, shapes_, polar_group())
@staticmethod
def forward(ctx, input_, dim_, shapes_):
if is_distributed_polar():
ctx.dim = dim_
return _gather(input_, dim_, shapes_, group=polar_group())
else:
return input_
@staticmethod
def backward(ctx, grad_output):
if is_distributed_polar():
return _split(grad_output, ctx.dim, group=polar_group()), None, None
else:
return grad_output, None, None
class _ReduceFromPolarRegion(torch.autograd.Function):
"""All-reduce the input from the polar region."""
......@@ -279,6 +321,10 @@ class _ReduceFromPolarRegion(torch.autograd.Function):
def backward(ctx, grad_output):
return grad_output
def copy_to_polar_region(input_):
return _CopyToPolarRegion.apply(input_)
def reduce_from_polar_region(input_):
return _ReduceFromPolarRegion.apply(input_)
......@@ -286,3 +332,7 @@ def reduce_from_polar_region(input_):
def scatter_to_polar_region(input_, dim_):
return _ScatterToPolarRegion.apply(input_, dim_)
def gather_from_polar_region(input_, dim_, shapes_):
return _GatherFromPolarRegion.apply(input_, dim_, shapes_)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment