Unverified Commit 29e7fb68 authored by Boris Bonev's avatar Boris Bonev Committed by GitHub
Browse files

Tkurth/cuda disco (#38)



* adding cuda kernels for disco conv

* making psi_idx an attribute

* adding license headers

* adding author files

* reorganizing files

* draft implementation

* added conditional installation to setup.py

* formatting changes

* removing triton kernel in DISCO convolution

* updated github actions

* updated Readme and changelog

* adding another guard for the cuda installation

* renaming the  cuda extension

* simplifying setup.py

* minor bugfix

* Bbonev/cuda disco cleanup (#32)

* cleanup of disco convolutions based on CUDA extension

* fixing unittest

* changing version to experimental 0.7.0a

* initial rewrite of the distributed convolution with CUDA

* fixing streams

* need to fix install options

* fixing streams

* undid setup.py changes

* reset setup.py

* including CUDAStream

* adjusted the precomputation of theta_cutoff. If you rely on this, your models will not be backwards-compatible.

* adjusting theta_cutoff in the unittest

* adding newly refactored kernels for faster compile

* Tkurth/cuda disco distributed fix (#34)

* attempt to make disco distributed

* working distributed convolutions

* fixing distributed conv

* working distributed disco

* removing irrelevant extra argument

* using stream functions from at instead of c10

* using stream functions from at instead of c10, small fix

* Bbonev/disc even filters (#35)

* initial working commit with new convention of counting collocation points across the diameter instead of across the radius

* fixed a bug in the computation of the even kernels

* changing heuristic for computing theta_cutoff

* Fixing unittest

* Readability improvements

* reworked normalization of filter basis functions

* implemented discrete normalization of disco filters

* relaxing tolerances in convolution unit test

* bugfix to correctly support unequal scale factors in latitudes and longitudes

* hotfix to a bug in the imports

* Bbonev/distributed disco refactor (#37)

* cleaned up normalization code in convolution

* formatting changes in distributed convolution

* Fixing default theta_cutoff to be the same in distributed and local case

* fixed distributed convolution to support the same normalization as non-distributed one

* readability improvements

* fixed initial scale of convolution parameter weights and fixed naming of the normalization routine

* Updated Readme.md

* added comment in Dockerfile regarding older architectures

---------
Co-authored-by: default avatarThorsten Kurth <tkurth@nvidia.com>
Co-authored-by: default avatarBoris Bonev <bbonev@nvidia.com>
parent 214fa40a
...@@ -31,6 +31,8 @@ ...@@ -31,6 +31,8 @@
import abc import abc
from typing import List, Tuple, Union, Optional from typing import List, Tuple, Union, Optional
from itertools import accumulate
from warnings import warn
import math import math
...@@ -40,28 +42,37 @@ import torch.nn as nn ...@@ -40,28 +42,37 @@ import torch.nn as nn
from functools import partial from functools import partial
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes
from torch_harmonics._disco_convolution import ( from torch_harmonics._disco_convolution import _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch
_disco_s2_contraction_torch, from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda
_disco_s2_transpose_contraction_torch,
_disco_s2_contraction_triton,
_disco_s2_transpose_contraction_triton,
)
from torch_harmonics.convolution import ( from torch_harmonics.convolution import (
_compute_support_vals_isotropic, _compute_support_vals_isotropic,
_compute_support_vals_anisotropic, _compute_support_vals_anisotropic,
_precompute_convolution_tensor_2d, _normalize_convolution_tensor_s2,
DiscreteContinuousConv, DiscreteContinuousConv,
) )
from torch_harmonics.distributed import polar_group_size, azimuth_group_size from torch_harmonics.distributed import polar_group_size, azimuth_group_size
from torch_harmonics.distributed import distributed_transpose_azimuth, distributed_transpose_polar from torch_harmonics.distributed import distributed_transpose_azimuth, distributed_transpose_polar
from torch_harmonics.distributed import reduce_from_polar_region, scatter_to_polar_region from torch_harmonics.distributed import copy_to_polar_region, reduce_from_polar_region, scatter_to_polar_region, gather_from_polar_region
from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank
from torch_harmonics.distributed import compute_split_shapes, split_tensor_along_dim from torch_harmonics.distributed import compute_split_shapes, split_tensor_along_dim
def _precompute_distributed_convolution_tensor_s2(in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", # import custom C++/CUDA extensions
theta_cutoff=0.01 * math.pi, distributed_mode="columns"): from disco_helpers import preprocess_psi
try:
import disco_cuda_extension
_cuda_extension_available = True
except ImportError as err:
disco_cuda_extension = None
_cuda_extension_available = False
def _precompute_distributed_convolution_tensor_s2(
in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi, transpose_normalization=False
):
""" """
Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$. Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al. Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al.
...@@ -76,40 +87,26 @@ def _precompute_distributed_convolution_tensor_s2(in_shape, out_shape, kernel_sh ...@@ -76,40 +87,26 @@ def _precompute_distributed_convolution_tensor_s2(in_shape, out_shape, kernel_sh
\cos(\alpha)\cos(\gamma)-\cos(\beta)\sin(\alpha)\sin(\gamma) \cos(\alpha)\cos(\gamma)-\cos(\beta)\sin(\alpha)\sin(\gamma)
\end{bmatrix}} \end{bmatrix}}
$$ $$
This is the distributed version: the matrix can either be split column- or row-wise. Column-wise seems better because the kernel has a lot of summation
atomics concerning the row reductions, which we can combine in a single allreduce.
""" """
assert len(in_shape) == 2 assert len(in_shape) == 2
assert len(out_shape) == 2 assert len(out_shape) == 2
if len(kernel_shape) == 1: if len(kernel_shape) == 1:
kernel_handle = partial(_compute_support_vals_isotropic, nr=kernel_shape[0], r_cutoff=theta_cutoff, norm="s2") kernel_handle = partial(_compute_support_vals_isotropic, nr=kernel_shape[0], r_cutoff=theta_cutoff)
elif len(kernel_shape) == 2: elif len(kernel_shape) == 2:
kernel_handle = partial(_compute_support_vals_anisotropic, nr=kernel_shape[0], nphi=kernel_shape[1], r_cutoff=theta_cutoff, norm="s2") kernel_handle = partial(_compute_support_vals_anisotropic, nr=kernel_shape[0], nphi=kernel_shape[1], r_cutoff=theta_cutoff)
else: else:
raise ValueError("kernel_shape should be either one- or two-dimensional.") raise ValueError("kernel_shape should be either one- or two-dimensional.")
nlat_in, nlon_in = in_shape nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape nlat_out, nlon_out = out_shape
lats_in, _ = _precompute_latitudes(nlat_in, grid=grid_in) lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in)
lats_in = torch.from_numpy(lats_in).float() lats_in = torch.from_numpy(lats_in).float()
lats_out, _ = _precompute_latitudes(nlat_out, grid=grid_out) lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out)
lats_out = torch.from_numpy(lats_out).float() lats_out = torch.from_numpy(lats_out).float()
# split the latitude vector:
comm_size_polar = polar_group_size()
comm_rank_polar = polar_group_rank()
if distributed_mode == "columns":
lats_in = split_tensor_along_dim(lats_in, dim=0, num_chunks=comm_size_polar)[comm_rank_polar]
elif distributed_mode == "rows":
lats_out = split_tensor_along_dim(lats_out, dim=0, num_chunks=comm_size_polar)[comm_rank_polar]
nlat_out = lats_out.shape[0]
else:
raise NotImplementedError(f"Error, unknown distributed mode {distributed_mode}.")
# compute the phi differences # compute the phi differences
# It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0 # It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1] lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1]
...@@ -151,11 +148,36 @@ def _precompute_distributed_convolution_tensor_s2(in_shape, out_shape, kernel_sh ...@@ -151,11 +148,36 @@ def _precompute_distributed_convolution_tensor_s2(in_shape, out_shape, kernel_sh
out_vals.append(vals) out_vals.append(vals)
# concatenate the indices and values # concatenate the indices and values
out_idx = torch.cat(out_idx, dim=-1) out_idx = torch.cat(out_idx, dim=-1).to(torch.long).contiguous()
out_vals = torch.cat(out_vals, dim=-1) out_vals = torch.cat(out_vals, dim=-1).to(torch.float32).contiguous()
# perform the normalization over the entire psi matrix
if transpose_normalization:
quad_weights = 2.0 * torch.pi * torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in
else:
quad_weights = 2.0 * torch.pi * torch.from_numpy(win).float().reshape(-1, 1) / nlon_in
out_vals = _normalize_convolution_tensor_s2(out_idx, out_vals, in_shape, out_shape, kernel_shape, quad_weights, transpose_normalization=transpose_normalization)
# TODO: this part can be split off into it's own function
# split the latitude indices:
comm_size_polar = polar_group_size()
comm_rank_polar = polar_group_rank()
split_shapes = compute_split_shapes(nlat_in, num_chunks=comm_size_polar)
offsets = [0] + list(accumulate(split_shapes))
start_idx = offsets[comm_rank_polar]
end_idx = offsets[comm_rank_polar+1]
# once normalization is done we can throw away the entries which correspond to input latitudes we do not care about
lats = out_idx[2] // nlon_in
lons = out_idx[2] % nlon_in
ilats = torch.argwhere((lats < end_idx) & (lats >= start_idx)).squeeze()
out_vals = out_vals[ilats]
# for the indices we need to recompute them to refer to local indices of the input tenor
out_idx = torch.stack([out_idx[0, ilats], out_idx[1, ilats], (lats[ilats]-start_idx) * nlon_in + lons[ilats]], dim=0)
return out_idx, out_vals return out_idx, out_vals
class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv): class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
""" """
Distributed version of Discrete-continuous convolutions (DISCO) on the 2-Sphere as described in [1]. Distributed version of Discrete-continuous convolutions (DISCO) on the 2-Sphere as described in [1].
...@@ -197,7 +219,7 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv): ...@@ -197,7 +219,7 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
# compute theta cutoff based on the bandlimit of the input field # compute theta cutoff based on the bandlimit of the input field
if theta_cutoff is None: if theta_cutoff is None:
theta_cutoff = (self.kernel_shape[0] + 1) * torch.pi / float(self.nlat_in - 1) theta_cutoff = (self.kernel_shape[0] + 1) / 2 * torch.pi / float(self.nlat_out - 1)
if theta_cutoff <= 0.0: if theta_cutoff <= 0.0:
raise ValueError("Error, theta_cutoff has to be positive.") raise ValueError("Error, theta_cutoff has to be positive.")
...@@ -209,70 +231,69 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv): ...@@ -209,70 +231,69 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
# Note that the psi matrix is of shape nlat_out x nlat_in * nlon_in. Since the contraction in nlon direction is a convolution, # Note that the psi matrix is of shape nlat_out x nlat_in * nlon_in. Since the contraction in nlon direction is a convolution,
# we will keep local to all nodes and split the computation up along nlat. We further split the input dim because this reduces the number # we will keep local to all nodes and split the computation up along nlat. We further split the input dim because this reduces the number
# of atomic reduction calls inside the actual kernel # of atomic reduction calls inside the actual kernel
distributed_mode = "columns"
# set local shapes according to distributed mode: # set local shapes according to distributed mode:
if distributed_mode == "columns":
self.nlat_in_local = self.lat_in_shapes[self.comm_rank_polar] self.nlat_in_local = self.lat_in_shapes[self.comm_rank_polar]
self.nlat_out_local = self.nlat_out self.nlat_out_local = self.nlat_out
elif distributed_mode == "rows": idx, vals = _precompute_distributed_convolution_tensor_s2(
self.nlat_in_local = self.nlat_in in_shape, out_shape, self.kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff, transpose_normalization=False
self.nlat_out_local = self.lat_out_shapes[self.comm_rank_polar] )
else:
raise NotImplementedError(f"Error, unknown distributed mode {distributed_mode}.")
idx, vals = _precompute_distributed_convolution_tensor_s2(in_shape, out_shape, self.kernel_shape, grid_in=grid_in, grid_out=grid_out,
theta_cutoff=theta_cutoff, distributed_mode=distributed_mode)
# split the weight tensor as well # split the weight tensor as well
if distributed_mode == "columns":
quad_weights = split_tensor_along_dim(quad_weights, dim=0, num_chunks=self.comm_size_polar)[self.comm_rank_polar] quad_weights = split_tensor_along_dim(quad_weights, dim=0, num_chunks=self.comm_size_polar)[self.comm_rank_polar]
self.register_buffer("quad_weights", quad_weights, persistent=False) self.register_buffer("quad_weights", quad_weights, persistent=False)
self.register_buffer("psi_idx", idx, persistent=False)
# sort the values
ker_idx = idx[0, ...].contiguous()
row_idx = idx[1, ...].contiguous()
col_idx = idx[2, ...].contiguous()
roff_idx = preprocess_psi(self.kernel_size, self.nlat_out_local, ker_idx, row_idx, col_idx, vals)
# preprocessed data-structure for GPU kernel
self.register_buffer("psi_roff_idx", roff_idx, persistent=False)
self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
self.register_buffer("psi_row_idx", row_idx, persistent=False)
self.register_buffer("psi_col_idx", col_idx, persistent=False)
self.register_buffer("psi_vals", vals, persistent=False) self.register_buffer("psi_vals", vals, persistent=False)
@property
def psi_idx(self):
return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()
def get_psi(self): def get_psi(self):
psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_out_local, self.nlat_in_local * self.nlon_in)).coalesce() psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_out_local, self.nlat_in_local * self.nlon_in)).coalesce()
return psi return psi
def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
# store number of channels # store number of channels
num_chans = x.shape[1] num_chans = x.shape[1]
#print("input shape", x.shape)
# h and w is split. First we make w local by transposing into channel dim # h and w is split. First we make w local by transposing into channel dim
if self.comm_size_azimuth > 1: if self.comm_size_azimuth > 1:
x = distributed_transpose_azimuth.apply(x, (1, -1), self.lon_in_shapes) x = distributed_transpose_azimuth.apply(x, (1, -1), self.lon_in_shapes)
#print("transposed shape", x.shape)
# pre-multiply x with the quadrature weights # pre-multiply x with the quadrature weights
x = self.quad_weights * x x = self.quad_weights * x
#print("multiplied shape", x.shape) if x.is_cuda and _cuda_extension_available:
x = _disco_s2_contraction_cuda(
x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out_local, self.nlon_out
)
else:
if x.is_cuda:
warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
psi = self.get_psi() psi = self.get_psi()
#print("psi shape", psi.shape)
if x.is_cuda and use_triton_kernel:
x = _disco_s2_contraction_triton(x, psi, self.nlon_out)
else:
x = _disco_s2_contraction_torch(x, psi, self.nlon_out) x = _disco_s2_contraction_torch(x, psi, self.nlon_out)
#print("psi * x shape", x.shape)
# allreduce over latitudes: h is still local # allreduce over latitudes: h is still local
x = reduce_from_polar_region(x) x = reduce_from_polar_region(x)
#print("reduced shape", x.shape)
# split tensor along latitudes: h is split # split tensor along latitudes: h is split
x = scatter_to_polar_region(x, -2) x = scatter_to_polar_region(x, -2)
#print("scattered shape", x.shape)
# now we can transpose back the result, so that lon is split and channels are local # now we can transpose back the result, so that lon is split and channels are local
if self.comm_size_azimuth > 1: if self.comm_size_azimuth > 1:
chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth) chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
...@@ -284,7 +305,7 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv): ...@@ -284,7 +305,7 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
# do weight multiplication # do weight multiplication
out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous() out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
out = out.reshape(out.shape[0], -1, out.shape[-2], out.shape[-1]) out = out.reshape(out.shape[0], -1, H, W)
if self.bias is not None: if self.bias is not None:
out = out + self.bias.reshape(1, -1, 1, 1) out = out + self.bias.reshape(1, -1, 1, 1)
...@@ -331,7 +352,7 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -331,7 +352,7 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
# bandlimit # bandlimit
if theta_cutoff is None: if theta_cutoff is None:
theta_cutoff = (self.kernel_shape[0] + 1) * torch.pi / float(self.nlat_in - 1) theta_cutoff = (self.kernel_shape[0] + 1) / 2 * torch.pi / float(self.nlat_in - 1)
if theta_cutoff <= 0.0: if theta_cutoff <= 0.0:
raise ValueError("Error, theta_cutoff has to be positive.") raise ValueError("Error, theta_cutoff has to be positive.")
...@@ -343,43 +364,40 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -343,43 +364,40 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
# Note that the psi matrix is of shape nlat_out x nlat_in * nlon_in. Since the contraction in nlon direction is a convolution, # Note that the psi matrix is of shape nlat_out x nlat_in * nlon_in. Since the contraction in nlon direction is a convolution,
# we will keep local to all nodes and split the computation up along nlat. We further split the input dim because this reduces the number # we will keep local to all nodes and split the computation up along nlat. We further split the input dim because this reduces the number
# of atomic reduction calls inside the actual kernel # of atomic reduction calls inside the actual kernel
distributed_mode = "columns"
# set local shapes according to distributed mode: # set local shapes according to distributed mode:
if distributed_mode == "columns":
self.nlat_in_local = self.lat_in_shapes[self.comm_rank_polar]
self.nlat_out_local = self.nlat_out
elif distributed_mode == "rows":
self.nlat_in_local = self.nlat_in self.nlat_in_local = self.nlat_in
self.nlat_out_local = self.lat_out_shapes[self.comm_rank_polar] self.nlat_out_local = self.lat_out_shapes[self.comm_rank_polar]
else:
raise NotImplementedError(f"Error, unknown distributed mode {distributed_mode}.")
# switch in_shape and out_shape since we want transpose conv # switch in_shape and out_shape since we want transpose conv
# distributed mode here is swapped because of the transpose # distributed mode here is swapped because of the transpose
idx, vals = _precompute_distributed_convolution_tensor_s2(out_shape, in_shape, self.kernel_shape, grid_in=grid_out, grid_out=grid_in, idx, vals = _precompute_distributed_convolution_tensor_s2(
theta_cutoff=theta_cutoff, distributed_mode="rows" if distributed_mode else "columns") out_shape, in_shape, self.kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff, transpose_normalization=True
)
## do partial transpose
## we do a semi-transposition to faciliate the computation
#tout = iidx[2] // self.nlon_out
#pout = iidx[2] % self.nlon_out
## flip the axis of longitudes
#pout = self.nlon_out - 1 - pout
#tin = iidx[1]
#idx = torch.stack([iidx[0], tout, tin*self.nlon_out + pout], dim=0)
# split the weight tensor as well # split the weight tensor as well
if distributed_mode == "columns":
quad_weights = split_tensor_along_dim(quad_weights, dim=0, num_chunks=self.comm_size_polar)[self.comm_rank_polar] quad_weights = split_tensor_along_dim(quad_weights, dim=0, num_chunks=self.comm_size_polar)[self.comm_rank_polar]
# register all buffers
self.register_buffer("quad_weights", quad_weights, persistent=False) self.register_buffer("quad_weights", quad_weights, persistent=False)
self.register_buffer("psi_idx", idx, persistent=False)
# sort the values
ker_idx = idx[0, ...].contiguous()
row_idx = idx[1, ...].contiguous()
col_idx = idx[2, ...].contiguous()
roff_idx = preprocess_psi(self.kernel_size, self.nlat_in_local, ker_idx, row_idx, col_idx, vals)
# preprocessed data-structure for GPU kernel
self.register_buffer("psi_roff_idx", roff_idx, persistent=False)
self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
self.register_buffer("psi_row_idx", row_idx, persistent=False)
self.register_buffer("psi_col_idx", col_idx, persistent=False)
self.register_buffer("psi_vals", vals, persistent=False) self.register_buffer("psi_vals", vals, persistent=False)
def get_psi(self, use_triton_kernel=True): @property
if not use_triton_kernel: def psi_idx(self):
return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()
def get_psi(self, semi_transposed: bool = False):
if semi_transposed:
# do partial transpose # do partial transpose
# we do a semi-transposition to faciliate the computation # we do a semi-transposition to faciliate the computation
tout = self.psi_idx[2] // self.nlon_out tout = self.psi_idx[2] // self.nlon_out
...@@ -387,41 +405,45 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -387,41 +405,45 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
# flip the axis of longitudes # flip the axis of longitudes
pout = self.nlon_out - 1 - pout pout = self.nlon_out - 1 - pout
tin = self.psi_idx[1] tin = self.psi_idx[1]
idx = torch.stack([self.psi_idx[0], tout, tin*self.nlon_out + pout], dim=0) idx = torch.stack([self.psi_idx[0], tout, tin * self.nlon_out + pout], dim=0)
psi = torch.sparse_coo_tensor(idx, self.psi_vals, size=(self.kernel_size, self.nlat_out_local, self.nlat_in_local * self.nlon_out)).coalesce() psi = torch.sparse_coo_tensor(idx, self.psi_vals, size=(self.kernel_size, self.nlat_out_local, self.nlat_in_local * self.nlon_out)).coalesce()
else: else:
psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_in_local, self.nlat_out_local * self.nlon_out)).coalesce() psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_in_local, self.nlat_out_local * self.nlon_out)).coalesce()
return psi return psi
def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
# extract shape # extract shape
B, C, H, W = x.shape B, C, H, W = x.shape
x = x.reshape(B, self.groups, self.groupsize, H, W) x = x.reshape(B, self.groups, self.groupsize, H, W)
# do weight multiplication # do weight multiplication
x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous() x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
x = x.reshape(x.shape[0], -1, x.shape[-3], x.shape[-2], x.shape[-1]) x = x.reshape(B, -1, x.shape[-3], H, W)
num_chans = x.shape[1] num_chans = x.shape[1]
# transpose such that lon is local, channels are split # transpose such that lon is local, channels are split
if self.comm_size_azimuth > 1: if self.comm_size_azimuth > 1:
x = distributed_transpose_azimuth.apply(x, (1, -1), self.lon_in_shapes) x = distributed_transpose_azimuth.apply(x, (1, -1), self.lon_in_shapes)
# pre-multiply x with the quadrature weights # multiply weights
x = self.quad_weights * x x = self.quad_weights * x
if x.is_cuda and use_triton_kernel: # we need to gather the input tensor
psi = self.get_psi(True) x = gather_from_polar_region(x, -2, self.lat_in_shapes)
out = _disco_s2_transpose_contraction_triton(x, psi, self.nlon_out)
else:
psi = self.get_psi(False)
out = _disco_s2_transpose_contraction_torch(x, psi, self.nlon_out)
# allreduce over latitudes: h is still local # register allreduce for bwd pass
out = reduce_from_polar_region(out) x = copy_to_polar_region(x)
# split tensor along latitudes: h is split if x.is_cuda and _cuda_extension_available:
out = scatter_to_polar_region(out, -2) out = _disco_s2_transpose_contraction_cuda(
x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out_local, self.nlon_out
)
else:
if x.is_cuda:
warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
psi = self.get_psi(semi_transposed=True)
out = _disco_s2_transpose_contraction_torch(x, psi, self.nlon_out)
# now we can transpose back the result, so that lon is split and channels are local # now we can transpose back the result, so that lon is split and channels are local
if self.comm_size_azimuth > 1: if self.comm_size_azimuth > 1:
...@@ -432,4 +454,3 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -432,4 +454,3 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
out = out + self.bias.reshape(1, -1, 1, 1) out = out + self.bias.reshape(1, -1, 1, 1)
return out return out
...@@ -232,6 +232,25 @@ def _gather(input_, dim_, shapes_, group=None): ...@@ -232,6 +232,25 @@ def _gather(input_, dim_, shapes_, group=None):
return output return output
class _CopyToPolarRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chunk to the rank."""
@staticmethod
def symbolic(graph, input_):
return input_
@staticmethod
def forward(ctx, input_):
return input_
@staticmethod
def backward(ctx, grad_output):
if is_distributed_polar():
return _reduce(grad_output, group=polar_group())
else:
return grad_output, None
class _ScatterToPolarRegion(torch.autograd.Function): class _ScatterToPolarRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chunk to the rank.""" """Split the input and keep only the corresponding chunk to the rank."""
...@@ -258,6 +277,29 @@ class _ScatterToPolarRegion(torch.autograd.Function): ...@@ -258,6 +277,29 @@ class _ScatterToPolarRegion(torch.autograd.Function):
return grad_output, None return grad_output, None
class _GatherFromPolarRegion(torch.autograd.Function):
"""Gather the input and keep it on the rank."""
@staticmethod
def symbolic(graph, input_, dim_, shapes_):
return _gather(input_, dim_, shapes_, polar_group())
@staticmethod
def forward(ctx, input_, dim_, shapes_):
if is_distributed_polar():
ctx.dim = dim_
return _gather(input_, dim_, shapes_, group=polar_group())
else:
return input_
@staticmethod
def backward(ctx, grad_output):
if is_distributed_polar():
return _split(grad_output, ctx.dim, group=polar_group()), None, None
else:
return grad_output, None, None
class _ReduceFromPolarRegion(torch.autograd.Function): class _ReduceFromPolarRegion(torch.autograd.Function):
"""All-reduce the input from the polar region.""" """All-reduce the input from the polar region."""
...@@ -280,9 +322,17 @@ class _ReduceFromPolarRegion(torch.autograd.Function): ...@@ -280,9 +322,17 @@ class _ReduceFromPolarRegion(torch.autograd.Function):
return grad_output return grad_output
def copy_to_polar_region(input_):
return _CopyToPolarRegion.apply(input_)
def reduce_from_polar_region(input_): def reduce_from_polar_region(input_):
return _ReduceFromPolarRegion.apply(input_) return _ReduceFromPolarRegion.apply(input_)
def scatter_to_polar_region(input_, dim_): def scatter_to_polar_region(input_, dim_):
return _ScatterToPolarRegion.apply(input_, dim_) return _ScatterToPolarRegion.apply(input_, dim_)
def gather_from_polar_region(input_, dim_, shapes_):
return _GatherFromPolarRegion.apply(input_, dim_, shapes_)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment