Unverified Commit 214fa40a authored by Thorsten Kurth's avatar Thorsten Kurth Committed by GitHub
Browse files

Tkurth/distributed disco (#30)



* initial implementation of distributed DISCO layer

* working distributed convolution

* working refactored serial conv transpose with torch kernel

* working distributed conv and transposed conv when using the python kernel

* working distributed convolution with torch kernel

* fixed triton kernel tests

* adding print statement to debug CI

* adjusting tolerances in local convolution unittest

---------
Co-authored-by: default avatarBoris Bonev <bbonev@nvidia.com>
parent 54502a17
......@@ -155,30 +155,32 @@ def _precompute_convolution_tensor_dense(
class TestDiscreteContinuousConvolution(unittest.TestCase):
def setUp(self):
if torch.cuda.is_available():
self.device = torch.device("cuda")
self.device = torch.device("cuda:0")
torch.cuda.set_device(self.device.index)
torch.cuda.manual_seed(333)
else:
self.device = torch.device("cpu")
self.device = torch.device("cpu")
torch.manual_seed(333)
@parameterized.expand(
[
# regular convolution
[8, 4, 2, (16, 32), (16, 32), [2], "equiangular", "equiangular", False, 1e-5],
[8, 4, 2, (16, 32), (8, 16), [3], "equiangular", "equiangular", False, 1e-5],
[8, 4, 2, (16, 32), (8, 16), [2, 3], "equiangular", "equiangular", False, 1e-5],
[8, 4, 2, (18, 36), (6, 12), [4], "equiangular", "equiangular", False, 1e-5],
[8, 4, 2, (16, 32), (8, 16), [3], "equiangular", "legendre-gauss", False, 1e-5],
[8, 4, 2, (16, 32), (8, 16), [3], "legendre-gauss", "equiangular", False, 1e-5],
[8, 4, 2, (16, 32), (8, 16), [3], "legendre-gauss", "legendre-gauss", False, 1e-5],
[8, 4, 2, (16, 32), (16, 32), [2 ], "equiangular", "equiangular", False, 5e-5],
[8, 4, 2, (16, 32), ( 8, 16), [3 ], "equiangular", "equiangular", False, 5e-5],
[8, 4, 2, (16, 32), ( 8, 16), [2, 3], "equiangular", "equiangular", False, 5e-5],
[8, 4, 2, (18, 36), ( 6, 12), [4 ], "equiangular", "equiangular", False, 5e-5],
[8, 4, 2, (16, 32), ( 8, 16), [3 ], "equiangular", "legendre-gauss", False, 5e-5],
[8, 4, 2, (16, 32), ( 8, 16), [3 ], "legendre-gauss", "equiangular", False, 5e-5],
[8, 4, 2, (16, 32), ( 8, 16), [3 ], "legendre-gauss", "legendre-gauss", False, 5e-5],
# transpose convolution
[8, 4, 2, (16, 32), (16, 32), [2], "equiangular", "equiangular", True, 1e-5],
[8, 4, 2, (8, 16), (16, 32), [3], "equiangular", "equiangular", True, 1e-5],
[8, 4, 2, (8, 16), (16, 32), [2, 3], "equiangular", "equiangular", True, 1e-5],
[8, 4, 2, (6, 12), (18, 36), [4], "equiangular", "equiangular", True, 1e-5],
[8, 4, 2, (8, 16), (16, 32), [3], "equiangular", "legendre-gauss", True, 1e-5],
[8, 4, 2, (8, 16), (16, 32), [3], "legendre-gauss", "equiangular", True, 1e-5],
[8, 4, 2, (8, 16), (16, 32), [3], "legendre-gauss", "legendre-gauss", True, 1e-5],
[8, 4, 2, (16, 32), (16, 32), [2 ], "equiangular", "equiangular", True, 5e-5],
[8, 4, 2, ( 8, 16), (16, 32), [3 ], "equiangular", "equiangular", True, 5e-5],
[8, 4, 2, ( 8, 16), (16, 32), [2, 3], "equiangular", "equiangular", True, 5e-5],
[8, 4, 2, ( 6, 12), (18, 36), [4 ], "equiangular", "equiangular", True, 5e-5],
[8, 4, 2, ( 8, 16), (16, 32), [3 ], "equiangular", "legendre-gauss", True, 5e-5],
[8, 4, 2, ( 8, 16), (16, 32), [3 ], "legendre-gauss", "equiangular", True, 5e-5],
[8, 4, 2, ( 8, 16), (16, 32), [3 ], "legendre-gauss", "legendre-gauss", True, 5e-5],
]
)
def test_disco_convolution(
......@@ -228,36 +230,38 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
)
# create a copy of the weight
w_ref = conv.weight.detach().clone()
w_ref.requires_grad_(True)
w_ref = torch.empty_like(conv.weight)
with torch.no_grad():
w_ref.copy_(conv.weight)
w_ref.requires_grad = True
# create an input signal
torch.manual_seed(333)
x = torch.randn(batch_size, in_channels, *in_shape, requires_grad=True).to(self.device)
x = torch.randn(batch_size, in_channels, *in_shape, device=self.device)
# FWD and BWD pass
x.requires_grad = True
y = conv(x)
grad_input = torch.randn_like(y)
y.backward(grad_input)
x_grad = x.grad.clone()
# perform the reference computation
x_ref = x.clone().detach()
x_ref.requires_grad_(True)
x_ref.requires_grad = True
if transpose:
y_ref = torch.einsum("oif,biqr->bofqr", w_ref, x_ref)
y_ref = torch.einsum("fqrtp,bofqr->botp", psi_dense, y_ref * conv.quad_weights)
else:
y_ref = torch.einsum("ftpqr,bcqr->bcftp", psi_dense, x_ref * conv.quad_weights)
y_ref = torch.einsum("oif,biftp->botp", w_ref, y_ref)
# use the convolution module
y = conv(x)
y_ref.backward(grad_input)
x_ref_grad = x_ref.grad.clone()
# compare results
self.assertTrue(torch.allclose(y, y_ref, rtol=tol, atol=tol))
# compute gradients and compare results
grad_input = torch.randn_like(y)
y_ref.backward(grad_input)
y.backward(grad_input)
# compare
self.assertTrue(torch.allclose(x.grad, x_ref.grad, rtol=tol, atol=tol))
self.assertTrue(torch.allclose(x_grad, x_ref_grad, rtol=tol, atol=tol))
self.assertTrue(torch.allclose(conv.weight.grad, w_ref.grad, rtol=tol, atol=tol))
if __name__ == "__main__":
......
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import os
import unittest
from parameterized import parameterized
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch_harmonics as harmonics
import torch_harmonics.distributed as thd
class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
@classmethod
def setUpClass(cls):
# set up distributed
cls.world_rank = int(os.getenv('WORLD_RANK', 0))
cls.grid_size_h = int(os.getenv('GRID_H', 1))
cls.grid_size_w = int(os.getenv('GRID_W', 1))
port = int(os.getenv('MASTER_PORT', '29501'))
master_address = os.getenv('MASTER_ADDR', 'localhost')
cls.world_size = cls.grid_size_h * cls.grid_size_w
if torch.cuda.is_available():
if cls.world_rank == 0:
print("Running test on GPU")
local_rank = cls.world_rank % torch.cuda.device_count()
cls.device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(local_rank)
torch.cuda.manual_seed(333)
proc_backend = 'nccl'
else:
if cls.world_rank == 0:
print("Running test on CPU")
cls.device = torch.device('cpu')
proc_backend = 'gloo'
torch.manual_seed(333)
dist.init_process_group(backend = proc_backend,
init_method = f"tcp://{master_address}:{port}",
rank = cls.world_rank,
world_size = cls.world_size)
cls.wrank = cls.world_rank % cls.grid_size_w
cls.hrank = cls.world_rank // cls.grid_size_w
# now set up the comm groups:
#set default
cls.w_group = None
cls.h_group = None
# do the init
wgroups = []
for w in range(0, cls.world_size, cls.grid_size_w):
start = w
end = w + cls.grid_size_w
wgroups.append(list(range(start, end)))
if cls.world_rank == 0:
print("w-groups:", wgroups)
for grp in wgroups:
if len(grp) == 1:
continue
tmp_group = dist.new_group(ranks=grp)
if cls.world_rank in grp:
cls.w_group = tmp_group
# transpose:
hgroups = [sorted(list(i)) for i in zip(*wgroups)]
if cls.world_rank == 0:
print("h-groups:", hgroups)
for grp in hgroups:
if len(grp) == 1:
continue
tmp_group = dist.new_group(ranks=grp)
if cls.world_rank in grp:
cls.h_group = tmp_group
if cls.world_rank == 0:
print(f"Running distributed tests on grid H x W = {cls.grid_size_h} x {cls.grid_size_w}")
# initializing sht
thd.init(cls.h_group, cls.w_group)
def _split_helper(self, tensor):
with torch.no_grad():
# split in W
tensor_list_local = thd.split_tensor_along_dim(tensor, dim=-1, num_chunks=self.grid_size_w)
tensor_local = tensor_list_local[self.wrank]
# split in H
tensor_list_local = thd.split_tensor_along_dim(tensor_local, dim=-2, num_chunks=self.grid_size_h)
tensor_local = tensor_list_local[self.hrank]
return tensor_local
def _gather_helper_fwd(self, tensor, B, C, convolution_dist):
# we need the shapes
lat_shapes = convolution_dist.lat_out_shapes
lon_shapes = convolution_dist.lon_out_shapes
#print("tensor before gather shape", tensor.shape)
# gather in W
if self.grid_size_w > 1:
gather_shapes = [(B, C, lat_shapes[self.hrank], w) for w in lon_shapes]
olist = [torch.empty(shape, dtype=tensor.dtype, device=tensor.device) for shape in gather_shapes]
olist[self.wrank] = tensor
dist.all_gather(olist, tensor, group=self.w_group)
tensor_gather = torch.cat(olist, dim=-1)
else:
tensor_gather = tensor
#print("tensor_gather shape", tensor_gather.shape)
# gather in H
if self.grid_size_h > 1:
gather_shapes = [(B, C, h, convolution_dist.nlon_out) for h in lat_shapes]
olist = [torch.empty(shape, dtype=tensor_gather.dtype, device=tensor_gather.device) for shape in gather_shapes]
olist[self.hrank] = tensor_gather
dist.all_gather(olist, tensor_gather, group=self.h_group)
tensor_gather = torch.cat(olist, dim=-2)
return tensor_gather
def _gather_helper_bwd(self, tensor, B, C, convolution_dist):
# we need the shapes
lat_shapes = convolution_dist.lat_in_shapes
lon_shapes = convolution_dist.lon_in_shapes
# gather in W
if self.grid_size_w > 1:
gather_shapes = [(B, C, lat_shapes[self.hrank], w) for w in lon_shapes]
olist = [torch.empty(shape, dtype=tensor.dtype, device=tensor.device) for shape in gather_shapes]
olist[self.wrank] = tensor
dist.all_gather(olist, tensor, group=self.w_group)
tensor_gather = torch.cat(olist, dim=-1)
else:
tensor_gather = tensor
# gather in H
if self.grid_size_h > 1:
gather_shapes = [(B, C, h, convolution_dist.nlon_in) for h in lat_shapes]
olist = [torch.empty(shape, dtype=tensor_gather.dtype, device=tensor_gather.device) for shape in gather_shapes]
olist[self.hrank] = tensor_gather
dist.all_gather(olist, tensor_gather, group=self.h_group)
tensor_gather = torch.cat(olist, dim=-2)
return tensor_gather
@parameterized.expand([
[128, 256, 128, 256, 32, 8, [3 ], 1, "equiangular", "equiangular", False, 1e-6],
[129, 256, 128, 256, 32, 8, [3 ], 1, "equiangular", "equiangular", False, 1e-6],
[128, 256, 128, 256, 32, 8, [3, 2], 1, "equiangular", "equiangular", False, 1e-6],
[128, 256, 64, 128, 32, 8, [3 ], 1, "equiangular", "equiangular", False, 1e-6],
[128, 256, 128, 256, 32, 8, [3 ], 2, "equiangular", "equiangular", False, 1e-6],
[128, 256, 128, 256, 32, 5, [3 ], 1, "equiangular", "equiangular", False, 1e-6],
[128, 256, 128, 256, 32, 8, [3 ], 1, "equiangular", "equiangular", True, 1e-6],
[129, 256, 128, 256, 32, 8, [3 ], 1, "equiangular", "equiangular", True, 1e-6],
[128, 256, 128, 256, 32, 8, [3, 2], 1, "equiangular", "equiangular", True, 1e-6],
[ 64, 128, 128, 256, 32, 8, [3 ], 1, "equiangular", "equiangular", True, 1e-6],
[128, 256, 128, 256, 32, 8, [3 ], 2, "equiangular", "equiangular", True, 1e-6],
[128, 256, 128, 256, 32, 5, [3 ], 1, "equiangular", "equiangular", True, 1e-6],
])
def test_distributed_disco_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan,
kernel_shape, groups, grid_in, grid_out, transpose, tol):
B, C, H, W = batch_size, num_chan, nlat_in, nlon_in
disco_args = dict(in_channels=C, out_channels=C,
in_shape=(nlat_in, nlon_in), out_shape=(nlat_out, nlon_out),
kernel_shape=kernel_shape, groups=groups,
grid_in=grid_in, grid_out=grid_out, bias=True)
# set up handles
if transpose:
conv_local = harmonics.DiscreteContinuousConvTransposeS2(**disco_args).to(self.device)
conv_dist = thd.DistributedDiscreteContinuousConvTransposeS2(**disco_args).to(self.device)
else:
conv_local = harmonics.DiscreteContinuousConvS2(**disco_args).to(self.device)
conv_dist = thd.DistributedDiscreteContinuousConvS2(**disco_args).to(self.device)
# copy the weights from the local conv into the dist conv
with torch.no_grad():
conv_dist.weight.copy_(conv_local.weight)
conv_dist.bias.copy_(conv_local.bias)
# create tensors
inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device)
#############################################################
# local conv
#############################################################
# FWD pass
inp_full.requires_grad = True
out_full = conv_local(inp_full, use_triton_kernel=True)
# create grad for backward
with torch.no_grad():
# create full grad
ograd_full = torch.randn_like(out_full)
# BWD pass
out_full.backward(ograd_full)
igrad_full = inp_full.grad.clone()
#############################################################
# distributed conv
#############################################################
# FWD pass
inp_local = self._split_helper(inp_full)
inp_local.requires_grad = True
out_local = conv_dist(inp_local, use_triton_kernel=True)
# BWD pass
ograd_local = self._split_helper(ograd_full)
out_local = conv_dist(inp_local, use_triton_kernel=True)
out_local.backward(ograd_local)
igrad_local = inp_local.grad.clone()
#############################################################
# evaluate FWD pass
#############################################################
with torch.no_grad():
out_gather_full = self._gather_helper_fwd(out_local, B, C, conv_dist)
err = torch.mean(torch.norm(out_full-out_gather_full, p='fro', dim=(-1,-2)) / torch.norm(out_full, p='fro', dim=(-1,-2)) )
if self.world_rank == 0:
print(f"final relative error of output: {err.item()}")
self.assertTrue(err.item() <= tol)
#############################################################
# evaluate BWD pass
#############################################################
with torch.no_grad():
igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, conv_dist)
err = torch.mean(torch.norm(igrad_full-igrad_gather_full, p='fro', dim=(-1,-2)) / torch.norm(igrad_full, p='fro', dim=(-1,-2)) )
if self.world_rank == 0:
print(f"final relative error of gradients: {err.item()}")
self.assertTrue(err.item() <= tol)
if __name__ == '__main__':
unittest.main()
......@@ -264,7 +264,7 @@ def _disco_s2_contraction_bwd(grad_y: torch.Tensor, psi: torch.Tensor, nlon_in:
# make sure that the grid-points of the output grid fall onto the grid points of the input grid
assert nlon_in % nlon_out == 0
pscale = nlon_in // nlon_out
# to simplify things, we merge batch and channel dimensions
grad_y = grad_y.reshape(batch_size * n_chans, kernel_size, nlat_out, nlon_out)
......@@ -409,30 +409,21 @@ def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nl
psi = psi.to(x.device)
batch_size, n_chans, kernel_size, nlat_in, nlon_in = x.shape
kernel_size, _, n_out = psi.shape
kernel_size, nlat_out, n_out = psi.shape
assert psi.shape[-2] == nlat_in
assert n_out % nlon_out == 0
nlat_out = n_out // nlon_out
assert nlon_out >= nlat_in
pscale = nlon_out // nlon_in
# we do a semi-transposition to faciliate the computation
inz = psi.indices()
tout = inz[2] // nlon_out
pout = inz[2] % nlon_out
# flip the axis of longitudes
pout = nlon_out - 1 - pout
tin = inz[1]
inz = torch.stack([inz[0], tout, tin*nlon_out + pout], dim=0)
psi_mod = torch.sparse_coo_tensor(inz, psi.values(), size=(kernel_size, nlat_out, nlat_in*nlon_out))
# interleave zeros along the longitude dimension to allow for fractional offsets to be considered
x_ext = torch.zeros(kernel_size, nlat_in, nlon_out, batch_size * n_chans, device=x.device, dtype=x.dtype)
x_ext[:, :, ::pscale, :] = x.reshape(batch_size * n_chans, kernel_size, nlat_in, nlon_in).permute(1, 2, 3, 0)
# we need to go backwards through the vector, so we flip the axis
x_ext = x_ext.contiguous()
x = x.reshape(batch_size * n_chans, kernel_size, nlat_in, nlon_in).permute(1, 2, 3, 0)
# x has shape kernel_size x nlat_in x nlon_in x batch_size * n_chans
# we only need to apoply the nlon stride here, since nlat stride is taken care of by the kernel
x_ext[:, :, ::pscale, :] = x[...]
# create output tensor
y = torch.zeros(kernel_size, nlon_out, nlat_out, batch_size * n_chans, device=x.device, dtype=x.dtype)
for pout in range(nlon_out):
......@@ -440,10 +431,10 @@ def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nl
# TODO: double-check why this has to happen first
x_ext = torch.roll(x_ext, -1, dims=2)
# sparse contraction with the modified psi
y[:, pout, :, :] = torch.bmm(psi_mod, x_ext.reshape(kernel_size, nlat_in * nlon_out, -1))
y[:, pout, :, :] = torch.bmm(psi, x_ext.reshape(kernel_size, nlat_in * nlon_out, -1))
# sum over the kernel dimension and reshape to the correct output size
y = y.sum(dim=0).permute(2, 1, 0).reshape(batch_size, n_chans, nlat_out, nlon_out)
y = y.sum(dim=0).permute(2, 1, 0).reshape(batch_size, n_chans, nlat_out, nlon_out).contiguous()
return y
......@@ -143,14 +143,12 @@ def _precompute_convolution_tensor_s2(in_shape, out_shape, kernel_shape, grid_in
lats_out, _ = _precompute_latitudes(nlat_out, grid=grid_out)
lats_out = torch.from_numpy(lats_out).float()
# array for accumulating non-zero indices
out_idx = torch.empty([3, 0], dtype=torch.long)
out_vals = torch.empty([0], dtype=torch.long)
# compute the phi differences
# It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1]
out_idx = []
out_vals = []
for t in range(nlat_out):
# the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis
alpha = -lats_out[t]
......@@ -182,8 +180,12 @@ def _precompute_convolution_tensor_s2(in_shape, out_shape, kernel_shape, grid_in
idx = torch.stack([iidx[:, 0], t * torch.ones_like(iidx[:, 0]), iidx[:, 1] * nlon_in + iidx[:, 2]], dim=0)
# append indices and values to the COO datastructure
out_idx = torch.cat([out_idx, idx], dim=-1)
out_vals = torch.cat([out_vals, vals], dim=-1)
out_idx.append(idx)
out_vals.append(vals)
# concatenate the indices and values
out_idx = torch.cat(out_idx, dim=-1).to(torch.long).contiguous()
out_vals = torch.cat(out_vals, dim=-1).to(torch.float32).contiguous()
return out_idx, out_vals
......@@ -328,7 +330,7 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
x = self.quad_weights * x
psi = self.get_psi()
if x.is_cuda and use_triton_kernel:
x = _disco_s2_contraction_triton(x, psi, self.nlon_out)
else:
......@@ -339,7 +341,7 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
x = x.reshape(B, self.groups, self.groupsize, K, H, W)
# do weight multiplication
out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2]))
out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
out = out.reshape(out.shape[0], -1, out.shape[-2], out.shape[-1])
if self.bias is not None:
......@@ -391,8 +393,18 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
self.register_buffer("psi_idx", idx, persistent=False)
self.register_buffer("psi_vals", vals, persistent=False)
def get_psi(self):
psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out)).coalesce()
def get_psi(self, use_triton_kernel=True):
if not use_triton_kernel:
# we do a semi-transposition to faciliate the computation
tout = self.psi_idx[2] // self.nlon_out
pout = self.psi_idx[2] % self.nlon_out
# flip the axis of longitudes
pout = self.nlon_out - 1 - pout
tin = self.psi_idx[1]
idx = torch.stack([self.psi_idx[0], tout, tin*self.nlon_out + pout], dim=0)
psi = torch.sparse_coo_tensor(idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_out)).coalesce()
else:
psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out)).coalesce()
return psi
def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tensor:
......@@ -401,13 +413,13 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
x = x.reshape(B, self.groups, self.groupsize, H, W)
# do weight multiplication
x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2]))
x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
x = x.reshape(x.shape[0], -1, x.shape[-3], x.shape[-2], x.shape[-1])
# pre-multiply x with the quadrature weights
x = self.quad_weights * x
psi = self.get_psi()
psi = self.get_psi(x.is_cuda and use_triton_kernel)
if x.is_cuda and use_triton_kernel:
out = _disco_s2_transpose_contraction_triton(x, psi, self.nlon_out)
......
......@@ -32,8 +32,13 @@
# we need this in order to enable distributed
from .utils import init, is_initialized, polar_group, azimuth_group
from .utils import polar_group_size, azimuth_group_size, polar_group_rank, azimuth_group_rank
from .primitives import compute_split_shapes, split_tensor_along_dim, distributed_transpose_azimuth, distributed_transpose_polar
from .primitives import compute_split_shapes, split_tensor_along_dim
from .primitives import distributed_transpose_azimuth, distributed_transpose_polar, reduce_from_polar_region, scatter_to_polar_region
# import the sht stuff
# import the sht
from .distributed_sht import DistributedRealSHT, DistributedInverseRealSHT
from .distributed_sht import DistributedRealVectorSHT, DistributedInverseRealVectorSHT
# import DISCO
from .distributed_convolution import DistributedDiscreteContinuousConvS2
from .distributed_convolution import DistributedDiscreteContinuousConvTransposeS2
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import abc
from typing import List, Tuple, Union, Optional
import math
import torch
import torch.nn as nn
from functools import partial
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes
from torch_harmonics._disco_convolution import (
_disco_s2_contraction_torch,
_disco_s2_transpose_contraction_torch,
_disco_s2_contraction_triton,
_disco_s2_transpose_contraction_triton,
)
from torch_harmonics.convolution import (
_compute_support_vals_isotropic,
_compute_support_vals_anisotropic,
_precompute_convolution_tensor_2d,
DiscreteContinuousConv,
)
from torch_harmonics.distributed import polar_group_size, azimuth_group_size
from torch_harmonics.distributed import distributed_transpose_azimuth, distributed_transpose_polar
from torch_harmonics.distributed import reduce_from_polar_region, scatter_to_polar_region
from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank
from torch_harmonics.distributed import compute_split_shapes, split_tensor_along_dim
def _precompute_distributed_convolution_tensor_s2(in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular",
theta_cutoff=0.01 * math.pi, distributed_mode="columns"):
"""
Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al.
The output tensor has shape kernel_shape x nlat_out x (nlat_in * nlon_in).
The rotation of the Euler angles uses the YZY convention, which applied to the northpole $(0,0,1)^T$ yields
$$
Y(\alpha) Z(\beta) Y(\gamma) n =
{\begin{bmatrix}
\cos(\gamma)\sin(\alpha) + \cos(\alpha)\cos(\beta)\sin(\gamma) \\
\sin(\beta)\sin(\gamma) \\
\cos(\alpha)\cos(\gamma)-\cos(\beta)\sin(\alpha)\sin(\gamma)
\end{bmatrix}}
$$
This is the distributed version: the matrix can either be split column- or row-wise. Column-wise seems better because the kernel has a lot of summation
atomics concerning the row reductions, which we can combine in a single allreduce.
"""
assert len(in_shape) == 2
assert len(out_shape) == 2
if len(kernel_shape) == 1:
kernel_handle = partial(_compute_support_vals_isotropic, nr=kernel_shape[0], r_cutoff=theta_cutoff, norm="s2")
elif len(kernel_shape) == 2:
kernel_handle = partial(_compute_support_vals_anisotropic, nr=kernel_shape[0], nphi=kernel_shape[1], r_cutoff=theta_cutoff, norm="s2")
else:
raise ValueError("kernel_shape should be either one- or two-dimensional.")
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
lats_in, _ = _precompute_latitudes(nlat_in, grid=grid_in)
lats_in = torch.from_numpy(lats_in).float()
lats_out, _ = _precompute_latitudes(nlat_out, grid=grid_out)
lats_out = torch.from_numpy(lats_out).float()
# split the latitude vector:
comm_size_polar = polar_group_size()
comm_rank_polar = polar_group_rank()
if distributed_mode == "columns":
lats_in = split_tensor_along_dim(lats_in, dim=0, num_chunks=comm_size_polar)[comm_rank_polar]
elif distributed_mode == "rows":
lats_out = split_tensor_along_dim(lats_out, dim=0, num_chunks=comm_size_polar)[comm_rank_polar]
nlat_out = lats_out.shape[0]
else:
raise NotImplementedError(f"Error, unknown distributed mode {distributed_mode}.")
# compute the phi differences
# It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1]
out_idx = []
out_vals = []
for t in range(nlat_out):
# the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis
alpha = -lats_out[t]
beta = lons_in
gamma = lats_in.reshape(-1, 1)
# compute cartesian coordinates of the rotated position
# This uses the YZY convention of Euler angles, where the last angle (alpha) is a passive rotation,
# and therefore applied with a negative sign
z = -torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(alpha) * torch.cos(gamma)
x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(gamma) * torch.sin(alpha)
y = torch.sin(beta) * torch.sin(gamma)
# normalization is emportant to avoid NaNs when arccos and atan are applied
# this can otherwise lead to spurious artifacts in the solution
norm = torch.sqrt(x * x + y * y + z * z)
x = x / norm
y = y / norm
z = z / norm
# compute spherical coordinates, where phi needs to fall into the [0, 2pi) range
theta = torch.arccos(z)
phi = torch.arctan2(y, x) + torch.pi
# find the indices where the rotated position falls into the support of the kernel
iidx, vals = kernel_handle(theta, phi)
# add the output latitude and reshape such that psi has dimensions kernel_shape x nlat_out x (nlat_in*nlon_in)
idx = torch.stack([iidx[:, 0], t * torch.ones_like(iidx[:, 0]), iidx[:, 1] * nlon_in + iidx[:, 2]], dim=0)
# append indices and values to the COO datastructure
out_idx.append(idx)
out_vals.append(vals)
# concatenate the indices and values
out_idx = torch.cat(out_idx, dim=-1)
out_vals = torch.cat(out_vals, dim=-1)
return out_idx, out_vals
class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
"""
Distributed version of Discrete-continuous convolutions (DISCO) on the 2-Sphere as described in [1].
[1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
We assume the data can be splitted in polar and azimuthal directions.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
in_shape: Tuple[int],
out_shape: Tuple[int],
kernel_shape: Union[int, List[int]],
groups: Optional[int] = 1,
grid_in: Optional[str] = "equiangular",
grid_out: Optional[str] = "equiangular",
bias: Optional[bool] = True,
theta_cutoff: Optional[float] = None,
):
super().__init__(in_channels, out_channels, kernel_shape, groups, bias)
self.nlat_in, self.nlon_in = in_shape
self.nlat_out, self.nlon_out = out_shape
# get the comms grid:
self.comm_size_polar = polar_group_size()
self.comm_rank_polar = polar_group_rank()
self.comm_size_azimuth = azimuth_group_size()
self.comm_rank_azimuth = azimuth_group_rank()
# we need those shapes:
self.lat_in_shapes = compute_split_shapes(self.nlat_in, self.comm_size_polar)
self.lon_in_shapes = compute_split_shapes(self.nlon_in, self.comm_size_azimuth)
self.lat_out_shapes = compute_split_shapes(self.nlat_out, self.comm_size_polar)
self.lon_out_shapes = compute_split_shapes(self.nlon_out, self.comm_size_azimuth)
# compute theta cutoff based on the bandlimit of the input field
if theta_cutoff is None:
theta_cutoff = (self.kernel_shape[0] + 1) * torch.pi / float(self.nlat_in - 1)
if theta_cutoff <= 0.0:
raise ValueError("Error, theta_cutoff has to be positive.")
# integration weights
_, wgl = _precompute_latitudes(self.nlat_in, grid=grid_in)
quad_weights = 2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / float(self.nlon_in)
# Note that the psi matrix is of shape nlat_out x nlat_in * nlon_in. Since the contraction in nlon direction is a convolution,
# we will keep local to all nodes and split the computation up along nlat. We further split the input dim because this reduces the number
# of atomic reduction calls inside the actual kernel
distributed_mode = "columns"
# set local shapes according to distributed mode:
if distributed_mode == "columns":
self.nlat_in_local = self.lat_in_shapes[self.comm_rank_polar]
self.nlat_out_local = self.nlat_out
elif distributed_mode == "rows":
self.nlat_in_local = self.nlat_in
self.nlat_out_local = self.lat_out_shapes[self.comm_rank_polar]
else:
raise NotImplementedError(f"Error, unknown distributed mode {distributed_mode}.")
idx, vals = _precompute_distributed_convolution_tensor_s2(in_shape, out_shape, self.kernel_shape, grid_in=grid_in, grid_out=grid_out,
theta_cutoff=theta_cutoff, distributed_mode=distributed_mode)
# split the weight tensor as well
if distributed_mode == "columns":
quad_weights = split_tensor_along_dim(quad_weights, dim=0, num_chunks=self.comm_size_polar)[self.comm_rank_polar]
self.register_buffer("quad_weights", quad_weights, persistent=False)
self.register_buffer("psi_idx", idx, persistent=False)
self.register_buffer("psi_vals", vals, persistent=False)
def get_psi(self):
psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_out_local, self.nlat_in_local * self.nlon_in)).coalesce()
return psi
def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tensor:
# store number of channels
num_chans = x.shape[1]
#print("input shape", x.shape)
# h and w is split. First we make w local by transposing into channel dim
if self.comm_size_azimuth > 1:
x = distributed_transpose_azimuth.apply(x, (1, -1), self.lon_in_shapes)
#print("transposed shape", x.shape)
# pre-multiply x with the quadrature weights
x = self.quad_weights * x
#print("multiplied shape", x.shape)
psi = self.get_psi()
#print("psi shape", psi.shape)
if x.is_cuda and use_triton_kernel:
x = _disco_s2_contraction_triton(x, psi, self.nlon_out)
else:
x = _disco_s2_contraction_torch(x, psi, self.nlon_out)
#print("psi * x shape", x.shape)
# allreduce over latitudes: h is still local
x = reduce_from_polar_region(x)
#print("reduced shape", x.shape)
# split tensor along latitudes: h is split
x = scatter_to_polar_region(x, -2)
#print("scattered shape", x.shape)
# now we can transpose back the result, so that lon is split and channels are local
if self.comm_size_azimuth > 1:
chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
x = distributed_transpose_azimuth.apply(x, (-1, 1), chan_shapes)
# extract shape
B, C, K, H, W = x.shape
x = x.reshape(B, self.groups, self.groupsize, K, H, W)
# do weight multiplication
out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
out = out.reshape(out.shape[0], -1, out.shape[-2], out.shape[-1])
if self.bias is not None:
out = out + self.bias.reshape(1, -1, 1, 1)
return out
class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
"""
Discrete-continuous transpose convolutions (DISCO) on the 2-Sphere as described in [1].
[1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
"""
def __init__(
self,
in_channels: int,
out_channels: int,
in_shape: Tuple[int],
out_shape: Tuple[int],
kernel_shape: Union[int, List[int]],
groups: Optional[int] = 1,
grid_in: Optional[str] = "equiangular",
grid_out: Optional[str] = "equiangular",
bias: Optional[bool] = True,
theta_cutoff: Optional[float] = None,
):
super().__init__(in_channels, out_channels, kernel_shape, groups, bias)
self.nlat_in, self.nlon_in = in_shape
self.nlat_out, self.nlon_out = out_shape
# get the comms grid:
self.comm_size_polar = polar_group_size()
self.comm_rank_polar = polar_group_rank()
self.comm_size_azimuth = azimuth_group_size()
self.comm_rank_azimuth = azimuth_group_rank()
# we need those shapes:
self.lat_in_shapes = compute_split_shapes(self.nlat_in, self.comm_size_polar)
self.lon_in_shapes = compute_split_shapes(self.nlon_in, self.comm_size_azimuth)
self.lat_out_shapes = compute_split_shapes(self.nlat_out, self.comm_size_polar)
self.lon_out_shapes = compute_split_shapes(self.nlon_out, self.comm_size_azimuth)
# bandlimit
if theta_cutoff is None:
theta_cutoff = (self.kernel_shape[0] + 1) * torch.pi / float(self.nlat_in - 1)
if theta_cutoff <= 0.0:
raise ValueError("Error, theta_cutoff has to be positive.")
# integration weights
_, wgl = _precompute_latitudes(self.nlat_in, grid=grid_in)
quad_weights = 2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / self.nlon_in
# Note that the psi matrix is of shape nlat_out x nlat_in * nlon_in. Since the contraction in nlon direction is a convolution,
# we will keep local to all nodes and split the computation up along nlat. We further split the input dim because this reduces the number
# of atomic reduction calls inside the actual kernel
distributed_mode = "columns"
# set local shapes according to distributed mode:
if distributed_mode == "columns":
self.nlat_in_local = self.lat_in_shapes[self.comm_rank_polar]
self.nlat_out_local = self.nlat_out
elif distributed_mode == "rows":
self.nlat_in_local = self.nlat_in
self.nlat_out_local = self.lat_out_shapes[self.comm_rank_polar]
else:
raise NotImplementedError(f"Error, unknown distributed mode {distributed_mode}.")
# switch in_shape and out_shape since we want transpose conv
# distributed mode here is swapped because of the transpose
idx, vals = _precompute_distributed_convolution_tensor_s2(out_shape, in_shape, self.kernel_shape, grid_in=grid_out, grid_out=grid_in,
theta_cutoff=theta_cutoff, distributed_mode="rows" if distributed_mode else "columns")
## do partial transpose
## we do a semi-transposition to faciliate the computation
#tout = iidx[2] // self.nlon_out
#pout = iidx[2] % self.nlon_out
## flip the axis of longitudes
#pout = self.nlon_out - 1 - pout
#tin = iidx[1]
#idx = torch.stack([iidx[0], tout, tin*self.nlon_out + pout], dim=0)
# split the weight tensor as well
if distributed_mode == "columns":
quad_weights = split_tensor_along_dim(quad_weights, dim=0, num_chunks=self.comm_size_polar)[self.comm_rank_polar]
# register all buffers
self.register_buffer("quad_weights", quad_weights, persistent=False)
self.register_buffer("psi_idx", idx, persistent=False)
self.register_buffer("psi_vals", vals, persistent=False)
def get_psi(self, use_triton_kernel=True):
if not use_triton_kernel:
# do partial transpose
# we do a semi-transposition to faciliate the computation
tout = self.psi_idx[2] // self.nlon_out
pout = self.psi_idx[2] % self.nlon_out
# flip the axis of longitudes
pout = self.nlon_out - 1 - pout
tin = self.psi_idx[1]
idx = torch.stack([self.psi_idx[0], tout, tin*self.nlon_out + pout], dim=0)
psi = torch.sparse_coo_tensor(idx, self.psi_vals, size=(self.kernel_size, self.nlat_out_local, self.nlat_in_local * self.nlon_out)).coalesce()
else:
psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_in_local, self.nlat_out_local * self.nlon_out)).coalesce()
return psi
def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tensor:
# extract shape
B, C, H, W = x.shape
x = x.reshape(B, self.groups, self.groupsize, H, W)
# do weight multiplication
x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
x = x.reshape(x.shape[0], -1, x.shape[-3], x.shape[-2], x.shape[-1])
num_chans = x.shape[1]
# transpose such that lon is local, channels are split
if self.comm_size_azimuth > 1:
x = distributed_transpose_azimuth.apply(x, (1, -1), self.lon_in_shapes)
# pre-multiply x with the quadrature weights
x = self.quad_weights * x
if x.is_cuda and use_triton_kernel:
psi = self.get_psi(True)
out = _disco_s2_transpose_contraction_triton(x, psi, self.nlon_out)
else:
psi = self.get_psi(False)
out = _disco_s2_transpose_contraction_torch(x, psi, self.nlon_out)
# allreduce over latitudes: h is still local
out = reduce_from_polar_region(out)
# split tensor along latitudes: h is split
out = scatter_to_polar_region(out, -2)
# now we can transpose back the result, so that lon is split and channels are local
if self.comm_size_azimuth > 1:
chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
out = distributed_transpose_azimuth.apply(out, (-1, 1), chan_shapes)
if self.bias is not None:
out = out + self.bias.reshape(1, -1, 1, 1)
return out
......@@ -33,7 +33,8 @@ from typing import List
import torch
import torch.distributed as dist
from .utils import polar_group, azimuth_group, is_initialized
from .utils import polar_group, azimuth_group, polar_group_size
from .utils import is_initialized, is_distributed_polar
# helper routine to compute uneven splitting in balanced way:
def compute_split_shapes(size: int, num_chunks: int) -> List[int]:
......@@ -152,3 +153,136 @@ class distributed_transpose_polar(torch.autograd.Function):
gi = torch.cat(gilist, dim=dim[0]).contiguous(memory_format=input_format)
return gi, None, None
# we need those additional primitives for distributed matrix multiplications
def _reduce(input_, use_fp32=True, group=None):
"""All-reduce the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
if dist.get_world_size(group=group) == 1:
return input_
# make input contiguous
input_ = input_.contiguous()
# All-reduce.
if use_fp32:
dtype = input_.dtype
inputf_ = input_.float()
dist.all_reduce(inputf_, group=group)
input_ = inputf_.to(dtype)
else:
dist.all_reduce(input_, group=group)
return input_
def _split(input_, dim_, group=None):
"""Split the tensor along its last dimension and keep the corresponding slice."""
# Bypass the function if we are using only 1 GPU.
comm_size = dist.get_world_size(group=group)
if comm_size == 1:
return input_
# Split along last dimension.
input_list = split_tensor_along_dim(input_, dim_, comm_size)
# Note: torch.split does not create contiguous tensors by default.
rank = dist.get_rank(group=group)
output = input_list[rank].contiguous()
return output
def _gather(input_, dim_, shapes_, group=None):
"""Gather unevenly split tensors across ranks"""
comm_size = dist.get_world_size(group=group)
if (shapes_ is not None) and (len(shapes_) != comm_size):
raise ValueError()
if dim_ >= input_.dim():
raise ValueError()
if comm_size == 1:
return input_
# make contiguous:
input_ = input_.contiguous()
input_shape = list(input_.shape)
if shapes_ is not None:
input_list = [None] * comm_size
for src in range(comm_size):
input_shape[dim_] = shapes_[src]
input_list[src] = torch.empty(
input_shape,
dtype=input_.dtype,
device=input_.device,
)
else:
# assume equal shape on all ranks
input_list = [torch.empty_like(input_) for _ in range(comm_size)]
dist.all_gather(input_list, input_, group=group)
output = torch.cat(input_list, dim=dim_).contiguous()
return output
class _ScatterToPolarRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chunk to the rank."""
@staticmethod
def symbolic(graph, input_, dim_):
return _split(input_, dim_, group=polar_group())
@staticmethod
def forward(ctx, input_, dim_):
if is_distributed_polar():
ctx.dim = dim_
ctx.split_shapes = compute_split_shapes(
input_.shape[dim_], polar_group_size()
)
return _split(input_, dim_, group=polar_group())
else:
return input_
@staticmethod
def backward(ctx, grad_output):
if is_distributed_polar():
return _gather(grad_output, ctx.dim, ctx.split_shapes, polar_group()), None
else:
return grad_output, None
class _ReduceFromPolarRegion(torch.autograd.Function):
"""All-reduce the input from the polar region."""
@staticmethod
def symbolic(graph, input_):
if is_distributed_polar():
return _reduce(input_, group=polar_group())
else:
return input_
@staticmethod
def forward(ctx, input_):
if is_distributed_polar():
return _reduce(input_, group=polar_group())
else:
return input_
@staticmethod
def backward(ctx, grad_output):
return grad_output
def reduce_from_polar_region(input_):
return _ReduceFromPolarRegion.apply(input_)
def scatter_to_polar_region(input_, dim_):
return _ScatterToPolarRegion.apply(input_, dim_)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment