Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
89fb38df
Unverified
Commit
89fb38df
authored
Aug 20, 2024
by
Thorsten Kurth
Committed by
GitHub
Aug 20, 2024
Browse files
adding reduce_scatter (#40)
parent
3a3480b8
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
115 additions
and
38 deletions
+115
-38
tests/test_distributed_convolution.py
tests/test_distributed_convolution.py
+3
-0
torch_harmonics/distributed/__init__.py
torch_harmonics/distributed/__init__.py
+3
-1
torch_harmonics/distributed/distributed_convolution.py
torch_harmonics/distributed/distributed_convolution.py
+8
-14
torch_harmonics/distributed/primitives.py
torch_harmonics/distributed/primitives.py
+101
-23
No files found.
tests/test_distributed_convolution.py
View file @
89fb38df
...
@@ -130,6 +130,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
...
@@ -130,6 +130,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
lon_shapes
=
convolution_dist
.
lon_out_shapes
lon_shapes
=
convolution_dist
.
lon_out_shapes
# gather in W
# gather in W
tensor
=
tensor
.
contiguous
()
if
self
.
grid_size_w
>
1
:
if
self
.
grid_size_w
>
1
:
gather_shapes
=
[(
B
,
C
,
lat_shapes
[
self
.
hrank
],
w
)
for
w
in
lon_shapes
]
gather_shapes
=
[(
B
,
C
,
lat_shapes
[
self
.
hrank
],
w
)
for
w
in
lon_shapes
]
olist
=
[
torch
.
empty
(
shape
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
for
shape
in
gather_shapes
]
olist
=
[
torch
.
empty
(
shape
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
for
shape
in
gather_shapes
]
...
@@ -140,6 +141,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
...
@@ -140,6 +141,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
tensor_gather
=
tensor
tensor_gather
=
tensor
# gather in H
# gather in H
tensor_gather
=
tensor_gather
.
contiguous
()
if
self
.
grid_size_h
>
1
:
if
self
.
grid_size_h
>
1
:
gather_shapes
=
[(
B
,
C
,
h
,
convolution_dist
.
nlon_out
)
for
h
in
lat_shapes
]
gather_shapes
=
[(
B
,
C
,
h
,
convolution_dist
.
nlon_out
)
for
h
in
lat_shapes
]
olist
=
[
torch
.
empty
(
shape
,
dtype
=
tensor_gather
.
dtype
,
device
=
tensor_gather
.
device
)
for
shape
in
gather_shapes
]
olist
=
[
torch
.
empty
(
shape
,
dtype
=
tensor_gather
.
dtype
,
device
=
tensor_gather
.
device
)
for
shape
in
gather_shapes
]
...
@@ -268,6 +270,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
...
@@ -268,6 +270,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
#############################################################
#############################################################
with
torch
.
no_grad
():
with
torch
.
no_grad
():
igrad_gather_full
=
self
.
_gather_helper_bwd
(
igrad_local
,
B
,
C
,
conv_dist
)
igrad_gather_full
=
self
.
_gather_helper_bwd
(
igrad_local
,
B
,
C
,
conv_dist
)
err
=
torch
.
mean
(
torch
.
norm
(
igrad_full
-
igrad_gather_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
))
/
torch
.
norm
(
igrad_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
)))
err
=
torch
.
mean
(
torch
.
norm
(
igrad_full
-
igrad_gather_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
))
/
torch
.
norm
(
igrad_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
)))
if
self
.
world_rank
==
0
:
if
self
.
world_rank
==
0
:
print
(
f
"final relative error of gradients:
{
err
.
item
()
}
"
)
print
(
f
"final relative error of gradients:
{
err
.
item
()
}
"
)
...
...
torch_harmonics/distributed/__init__.py
View file @
89fb38df
...
@@ -39,7 +39,9 @@ from .primitives import (
...
@@ -39,7 +39,9 @@ from .primitives import (
reduce_from_polar_region
,
reduce_from_polar_region
,
scatter_to_polar_region
,
scatter_to_polar_region
,
gather_from_polar_region
,
gather_from_polar_region
,
copy_to_polar_region
copy_to_polar_region
,
reduce_from_scatter_to_polar_region
,
gather_from_copy_to_polar_region
)
)
# import the sht
# import the sht
...
...
torch_harmonics/distributed/distributed_convolution.py
View file @
89fb38df
...
@@ -54,7 +54,7 @@ from torch_harmonics.convolution import (
...
@@ -54,7 +54,7 @@ from torch_harmonics.convolution import (
from
torch_harmonics.distributed
import
polar_group_size
,
azimuth_group_size
from
torch_harmonics.distributed
import
polar_group_size
,
azimuth_group_size
from
torch_harmonics.distributed
import
distributed_transpose_azimuth
,
distributed_transpose_polar
from
torch_harmonics.distributed
import
distributed_transpose_azimuth
,
distributed_transpose_polar
from
torch_harmonics.distributed
import
copy_to_polar_region
,
reduce_from_polar_region
,
scatter_to_polar_region
,
gather_from_polar_region
from
torch_harmonics.distributed
import
reduce_from_
scatter_to_polar_region
,
gather_from_
copy_to_
polar_region
from
torch_harmonics.distributed
import
polar_group_rank
,
azimuth_group_rank
from
torch_harmonics.distributed
import
polar_group_rank
,
azimuth_group_rank
from
torch_harmonics.distributed
import
compute_split_shapes
,
split_tensor_along_dim
from
torch_harmonics.distributed
import
compute_split_shapes
,
split_tensor_along_dim
...
@@ -219,7 +219,7 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
...
@@ -219,7 +219,7 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
# compute theta cutoff based on the bandlimit of the input field
# compute theta cutoff based on the bandlimit of the input field
if
theta_cutoff
is
None
:
if
theta_cutoff
is
None
:
theta_cutoff
=
(
self
.
kernel_shape
[
0
]
+
1
)
/
2
*
torch
.
pi
/
float
(
self
.
nlat_out
-
1
)
theta_cutoff
=
torch
.
pi
/
float
(
self
.
nlat_out
-
1
)
if
theta_cutoff
<=
0.0
:
if
theta_cutoff
<=
0.0
:
raise
ValueError
(
"Error, theta_cutoff has to be positive."
)
raise
ValueError
(
"Error, theta_cutoff has to be positive."
)
...
@@ -268,7 +268,7 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
...
@@ -268,7 +268,7 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
# store number of channels
# store number of channels
num_chans
=
x
.
shape
[
1
]
num_chans
=
x
.
shape
[
1
]
# h and w is split. First we make w local by transposing into channel dim
# h and w is split. First we make w local by transposing into channel dim
if
self
.
comm_size_azimuth
>
1
:
if
self
.
comm_size_azimuth
>
1
:
x
=
distributed_transpose_azimuth
.
apply
(
x
,
(
1
,
-
1
),
self
.
lon_in_shapes
)
x
=
distributed_transpose_azimuth
.
apply
(
x
,
(
1
,
-
1
),
self
.
lon_in_shapes
)
...
@@ -288,11 +288,8 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
...
@@ -288,11 +288,8 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
x
=
_disco_s2_contraction_torch
(
x
,
psi
,
self
.
nlon_out
)
x
=
_disco_s2_contraction_torch
(
x
,
psi
,
self
.
nlon_out
)
# allreduce over latitudes: h is still local
# perform reduce scatter in polar region
x
=
reduce_from_polar_region
(
x
)
x
=
reduce_from_scatter_to_polar_region
(
x
,
-
2
)
# split tensor along latitudes: h is split
x
=
scatter_to_polar_region
(
x
,
-
2
)
# now we can transpose back the result, so that lon is split and channels are local
# now we can transpose back the result, so that lon is split and channels are local
if
self
.
comm_size_azimuth
>
1
:
if
self
.
comm_size_azimuth
>
1
:
...
@@ -352,7 +349,7 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
...
@@ -352,7 +349,7 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
# bandlimit
# bandlimit
if
theta_cutoff
is
None
:
if
theta_cutoff
is
None
:
theta_cutoff
=
(
self
.
kernel_shape
[
0
]
+
1
)
/
2
*
torch
.
pi
/
float
(
self
.
nlat_in
-
1
)
theta_cutoff
=
torch
.
pi
/
float
(
self
.
nlat_in
-
1
)
if
theta_cutoff
<=
0.0
:
if
theta_cutoff
<=
0.0
:
raise
ValueError
(
"Error, theta_cutoff has to be positive."
)
raise
ValueError
(
"Error, theta_cutoff has to be positive."
)
...
@@ -429,11 +426,8 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
...
@@ -429,11 +426,8 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
# multiply weights
# multiply weights
x
=
self
.
quad_weights
*
x
x
=
self
.
quad_weights
*
x
# we need to gather the input tensor
# gather input tensor and set up backward reduction hooks
x
=
gather_from_polar_region
(
x
,
-
2
,
self
.
lat_in_shapes
)
x
=
gather_from_copy_to_polar_region
(
x
,
-
2
,
self
.
lat_in_shapes
)
# register allreduce for bwd pass
x
=
copy_to_polar_region
(
x
)
if
x
.
is_cuda
and
_cuda_extension_available
:
if
x
.
is_cuda
and
_cuda_extension_available
:
out
=
_disco_s2_transpose_contraction_cuda
(
out
=
_disco_s2_transpose_contraction_cuda
(
...
...
torch_harmonics/distributed/primitives.py
View file @
89fb38df
...
@@ -56,14 +56,6 @@ def compute_split_shapes(size: int, num_chunks: int) -> List[int]:
...
@@ -56,14 +56,6 @@ def compute_split_shapes(size: int, num_chunks: int) -> List[int]:
return
sections
return
sections
# general helpers
def
get_memory_format
(
tensor
):
if
tensor
.
is_contiguous
(
memory_format
=
torch
.
channels_last
):
return
torch
.
channels_last
else
:
return
torch
.
contiguous_format
def
split_tensor_along_dim
(
tensor
,
dim
,
num_chunks
):
def
split_tensor_along_dim
(
tensor
,
dim
,
num_chunks
):
assert
dim
<
tensor
.
dim
(),
f
"Error, tensor dimension is
{
tensor
.
dim
()
}
which cannot be split along
{
dim
}
"
assert
dim
<
tensor
.
dim
(),
f
"Error, tensor dimension is
{
tensor
.
dim
()
}
which cannot be split along
{
dim
}
"
...
@@ -78,23 +70,20 @@ def split_tensor_along_dim(tensor, dim, num_chunks):
...
@@ -78,23 +70,20 @@ def split_tensor_along_dim(tensor, dim, num_chunks):
def
_transpose
(
tensor
,
dim0
,
dim1
,
dim1_split_sizes
,
group
=
None
,
async_op
=
False
):
def
_transpose
(
tensor
,
dim0
,
dim1
,
dim1_split_sizes
,
group
=
None
,
async_op
=
False
):
# get input format
input_format
=
get_memory_format
(
tensor
)
# get comm params
# get comm params
comm_size
=
dist
.
get_world_size
(
group
=
group
)
comm_size
=
dist
.
get_world_size
(
group
=
group
)
comm_rank
=
dist
.
get_rank
(
group
=
group
)
comm_rank
=
dist
.
get_rank
(
group
=
group
)
# split and local transposition
# split and local transposition
tsplit
=
split_tensor_along_dim
(
tensor
,
num_chunks
=
comm_size
,
dim
=
dim0
)
tsplit
=
split_tensor_along_dim
(
tensor
,
num_chunks
=
comm_size
,
dim
=
dim0
)
x_send
=
[
y
.
contiguous
(
memory_format
=
input_format
)
for
y
in
tsplit
]
x_send
=
[
y
.
contiguous
()
for
y
in
tsplit
]
x_send_shapes
=
[
x
.
shape
for
x
in
x_send
]
x_send_shapes
=
[
x
.
shape
for
x
in
x_send
]
x_recv
=
[]
x_recv
=
[]
x_shape
=
list
(
x_send_shapes
[
comm_rank
])
x_shape
=
list
(
x_send_shapes
[
comm_rank
])
for
dim1_len
in
dim1_split_sizes
:
for
dim1_len
in
dim1_split_sizes
:
x_shape
[
dim1
]
=
dim1_len
x_shape
[
dim1
]
=
dim1_len
x_recv
.
append
(
torch
.
empty
(
x_shape
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
,
memory_format
=
input_format
))
x_recv
.
append
(
torch
.
empty
(
x_shape
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
))
# global transposition
# global transposition
req
=
dist
.
all_to_all
(
x_recv
,
x_send
,
group
=
group
,
async_op
=
async_op
)
req
=
dist
.
all_to_all
(
x_recv
,
x_send
,
group
=
group
,
async_op
=
async_op
)
...
@@ -108,24 +97,24 @@ class distributed_transpose_azimuth(torch.autograd.Function):
...
@@ -108,24 +97,24 @@ class distributed_transpose_azimuth(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
x
,
dims
,
dim1_split_sizes
):
def
forward
(
ctx
,
x
,
dims
,
dim1_split_sizes
):
input_format
=
get_memory_format
(
x
)
# WAR for a potential contig check torch bug for channels last contig tensors
# WAR for a potential contig check torch bug for channels last contig tensors
x
=
x
.
contiguous
()
x
=
x
.
contiguous
()
xlist
,
dim0_split_sizes
,
_
=
_transpose
(
x
,
dims
[
0
],
dims
[
1
],
dim1_split_sizes
,
group
=
azimuth_group
())
xlist
,
dim0_split_sizes
,
_
=
_transpose
(
x
,
dims
[
0
],
dims
[
1
],
dim1_split_sizes
,
group
=
azimuth_group
())
x
=
torch
.
cat
(
xlist
,
dim
=
dims
[
1
]).
contiguous
(
memory_format
=
input_format
)
x
=
torch
.
cat
(
xlist
,
dim
=
dims
[
1
]).
contiguous
()
ctx
.
dims
=
dims
ctx
.
dims
=
dims
ctx
.
dim0_split_sizes
=
dim0_split_sizes
ctx
.
dim0_split_sizes
=
dim0_split_sizes
return
x
return
x
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
go
):
def
backward
(
ctx
,
go
):
input_format
=
get_memory_format
(
go
)
dims
=
ctx
.
dims
dims
=
ctx
.
dims
dim0_split_sizes
=
ctx
.
dim0_split_sizes
dim0_split_sizes
=
ctx
.
dim0_split_sizes
# WAR for a potential contig check torch bug for channels last contig tensors
# WAR for a potential contig check torch bug for channels last contig tensors
go
=
go
.
contiguous
()
go
=
go
.
contiguous
()
gilist
,
_
,
_
=
_transpose
(
go
,
dims
[
1
],
dims
[
0
],
dim0_split_sizes
,
group
=
azimuth_group
())
gilist
,
_
,
_
=
_transpose
(
go
,
dims
[
1
],
dims
[
0
],
dim0_split_sizes
,
group
=
azimuth_group
())
gi
=
torch
.
cat
(
gilist
,
dim
=
dims
[
0
]).
contiguous
(
memory_format
=
input_format
)
gi
=
torch
.
cat
(
gilist
,
dim
=
dims
[
0
]).
contiguous
()
return
gi
,
None
,
None
return
gi
,
None
,
None
...
@@ -133,24 +122,22 @@ class distributed_transpose_polar(torch.autograd.Function):
...
@@ -133,24 +122,22 @@ class distributed_transpose_polar(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
x
,
dim
,
dim1_split_sizes
):
def
forward
(
ctx
,
x
,
dim
,
dim1_split_sizes
):
input_format
=
get_memory_format
(
x
)
# WAR for a potential contig check torch bug for channels last contig tensors
# WAR for a potential contig check torch bug for channels last contig tensors
x
=
x
.
contiguous
()
x
=
x
.
contiguous
()
xlist
,
dim0_split_sizes
,
_
=
_transpose
(
x
,
dim
[
0
],
dim
[
1
],
dim1_split_sizes
,
group
=
polar_group
())
xlist
,
dim0_split_sizes
,
_
=
_transpose
(
x
,
dim
[
0
],
dim
[
1
],
dim1_split_sizes
,
group
=
polar_group
())
x
=
torch
.
cat
(
xlist
,
dim
=
dim
[
1
]).
contiguous
(
memory_format
=
input_format
)
x
=
torch
.
cat
(
xlist
,
dim
=
dim
[
1
]).
contiguous
()
ctx
.
dim
=
dim
ctx
.
dim
=
dim
ctx
.
dim0_split_sizes
=
dim0_split_sizes
ctx
.
dim0_split_sizes
=
dim0_split_sizes
return
x
return
x
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
go
):
def
backward
(
ctx
,
go
):
input_format
=
get_memory_format
(
go
)
dim
=
ctx
.
dim
dim
=
ctx
.
dim
dim0_split_sizes
=
ctx
.
dim0_split_sizes
dim0_split_sizes
=
ctx
.
dim0_split_sizes
# WAR for a potential contig check torch bug for channels last contig tensors
# WAR for a potential contig check torch bug for channels last contig tensors
go
=
go
.
contiguous
()
go
=
go
.
contiguous
()
gilist
,
_
,
_
=
_transpose
(
go
,
dim
[
1
],
dim
[
0
],
dim0_split_sizes
,
group
=
polar_group
())
gilist
,
_
,
_
=
_transpose
(
go
,
dim
[
1
],
dim
[
0
],
dim0_split_sizes
,
group
=
polar_group
())
gi
=
torch
.
cat
(
gilist
,
dim
=
dim
[
0
]).
contiguous
(
memory_format
=
input_format
)
gi
=
torch
.
cat
(
gilist
,
dim
=
dim
[
0
]).
contiguous
()
return
gi
,
None
,
None
return
gi
,
None
,
None
...
@@ -175,7 +162,7 @@ def _reduce(input_, use_fp32=True, group=None):
...
@@ -175,7 +162,7 @@ def _reduce(input_, use_fp32=True, group=None):
dist
.
all_reduce
(
input_
,
group
=
group
)
dist
.
all_reduce
(
input_
,
group
=
group
)
return
input_
return
input_
def
_split
(
input_
,
dim_
,
group
=
None
):
def
_split
(
input_
,
dim_
,
group
=
None
):
"""Split the tensor along its last dimension and keep the corresponding slice."""
"""Split the tensor along its last dimension and keep the corresponding slice."""
...
@@ -232,6 +219,33 @@ def _gather(input_, dim_, shapes_, group=None):
...
@@ -232,6 +219,33 @@ def _gather(input_, dim_, shapes_, group=None):
return
output
return
output
def
_reduce_scatter
(
input_
,
dim_
,
use_fp32
=
True
,
group
=
None
):
"""All-reduce the input tensor across model parallel group and scatter it back."""
# Bypass the function if we are using only 1 GPU.
if
dist
.
get_world_size
(
group
=
group
)
==
1
:
return
input_
# make input contiguous
comm_size
=
dist
.
get_world_size
(
group
=
group
)
comm_rank
=
dist
.
get_rank
(
group
=
group
)
input_list
=
[
x
.
contiguous
()
for
x
in
split_tensor_along_dim
(
input_
,
dim_
,
comm_size
)]
dtype
=
input_
.
dtype
if
(
use_fp32
and
(
dtype
!=
torch
.
float32
)):
input_list
=
[
x
.
to
(
torch
.
float32
)
for
x
in
input_list
]
# perform reduce_scatter
output
=
torch
.
empty_like
(
input_list
[
comm_rank
])
dist
.
reduce_scatter
(
output
,
input_list
,
group
=
group
)
# convert dtype if necessary
if
use_fp32
:
output
=
output
.
to
(
dtype
=
dtype
)
return
output
class
_CopyToPolarRegion
(
torch
.
autograd
.
Function
):
class
_CopyToPolarRegion
(
torch
.
autograd
.
Function
):
"""Split the input and keep only the corresponding chunk to the rank."""
"""Split the input and keep only the corresponding chunk to the rank."""
...
@@ -322,6 +336,62 @@ class _ReduceFromPolarRegion(torch.autograd.Function):
...
@@ -322,6 +336,62 @@ class _ReduceFromPolarRegion(torch.autograd.Function):
return
grad_output
return
grad_output
class
_ReduceFromScatterToPolarRegion
(
torch
.
autograd
.
Function
):
"""All-reduce the input from the polar region and scatter back to polar region."""
@
staticmethod
def
symbolic
(
graph
,
input_
,
dim_
):
if
is_distributed_polar
():
return
_reduce_scatter
(
input_
,
dim_
,
group
=
polar_group
())
else
:
return
input_
@
staticmethod
def
forward
(
ctx
,
input_
,
dim_
):
if
is_distributed_polar
():
ctx
.
dim
=
dim_
ctx
.
split_shapes
=
compute_split_shapes
(
input_
.
shape
[
dim_
],
polar_group_size
()
)
return
_reduce_scatter
(
input_
,
dim_
,
group
=
polar_group
())
else
:
return
input_
@
staticmethod
def
backward
(
ctx
,
grad_output
):
if
is_distributed_polar
():
return
_gather
(
grad_output
,
ctx
.
dim
,
ctx
.
split_shapes
,
polar_group
()),
None
else
:
return
grad_output
,
None
class
_GatherFromCopyToPolarRegion
(
torch
.
autograd
.
Function
):
"""Gather the input from the polar region and register BWD AR, basically the inverse of reduce-scatter"""
@
staticmethod
def
symbolic
(
graph
,
input_
,
dim_
,
shapes_
):
if
is_distributed_polar
():
return
_gather
(
input_
,
dim_
,
shapes_
,
polar_group
())
else
:
return
input_
@
staticmethod
def
forward
(
ctx
,
input_
,
dim_
,
shapes_
):
if
is_distributed_polar
():
ctx
.
dim
=
dim_
return
_gather
(
input_
,
dim_
,
shapes_
,
group
=
polar_group
())
else
:
return
input_
@
staticmethod
def
backward
(
ctx
,
grad_output
):
if
is_distributed_polar
():
return
_reduce_scatter
(
grad_output
,
ctx
.
dim
,
use_fp32
=
True
,
group
=
polar_group
()),
None
,
None
else
:
return
grad_output
,
None
,
None
def
copy_to_polar_region
(
input_
):
def
copy_to_polar_region
(
input_
):
return
_CopyToPolarRegion
.
apply
(
input_
)
return
_CopyToPolarRegion
.
apply
(
input_
)
...
@@ -336,3 +406,11 @@ def scatter_to_polar_region(input_, dim_):
...
@@ -336,3 +406,11 @@ def scatter_to_polar_region(input_, dim_):
def
gather_from_polar_region
(
input_
,
dim_
,
shapes_
):
def
gather_from_polar_region
(
input_
,
dim_
,
shapes_
):
return
_GatherFromPolarRegion
.
apply
(
input_
,
dim_
,
shapes_
)
return
_GatherFromPolarRegion
.
apply
(
input_
,
dim_
,
shapes_
)
def
reduce_from_scatter_to_polar_region
(
input_
,
dim_
):
return
_ReduceFromScatterToPolarRegion
.
apply
(
input_
,
dim_
)
def
gather_from_copy_to_polar_region
(
input_
,
dim_
,
shapes_
):
return
_GatherFromCopyToPolarRegion
.
apply
(
input_
,
dim_
,
shapes_
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment