Commit 50ebe96f authored by apaaris's avatar apaaris Committed by Boris Bonev
Browse files

Improved docstrings in distributed

parent 10cf65f6
......@@ -74,6 +74,57 @@ def _split_distributed_convolution_tensor_s2(
in_shape: Tuple[int],
out_shape: Tuple[int],
):
"""
Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al.
The output tensor has shape kernel_shape x nlat_out x (nlat_in * nlon_in).
The rotation of the Euler angles uses the YZY convention, which applied to the northpole $(0,0,1)^T$ yields
$$
Y(\alpha) Z(\beta) Y(\gamma) n =
{\begin{bmatrix}
\cos(\gamma)\sin(\alpha) + \cos(\alpha)\cos(\beta)\sin(\gamma) \\
\sin(\beta)\sin(\gamma) \\
\cos(\alpha)\cos(\gamma)-\cos(\beta)\sin(\alpha)\sin(\gamma)
\end{bmatrix}}
$$
Parameters
----------
in_shape: Tuple[int]
Shape of the input tensor
out_shape: Tuple[int]
Shape of the output tensor
filter_basis: FilterBasis
Filter basis to use
grid_in: str
Grid type for the input tensor
grid_out: str
Grid type for the output tensor
theta_cutoff: float
Theta cutoff for the filter basis
theta_eps: float
Epsilon for the theta cutoff
transpose_normalization: bool
Whether to transpose the normalization
basis_norm_mode: str
Normalization mode for the filter basis
merge_quadrature: bool
Whether to merge the quadrature weights
Returns
-------
out_idx: torch.Tensor
Indices of the output tensor
out_vals: torch.Tensor
Values of the output tensor
"""
assert len(in_shape) == 2
assert len(out_shape) == 2
kernel_size = filter_basis.kernel_size
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
......@@ -102,10 +153,43 @@ def _split_distributed_convolution_tensor_s2(
class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
"""
Distributed version of Discrete-continuous convolutions (DISCO) on the 2-Sphere as described in [1].
We assume the data can be splitted in polar and azimuthal directions.
Parameters
----------
in_channels: int
Number of input channels
out_channels: int
Number of output channels
in_shape: Tuple[int]
Shape of the input tensor
out_shape: Tuple[int]
Shape of the output tensor
kernel_shape: Union[int, Tuple[int], Tuple[int, int]]
Shape of the kernel
basis_type: Optional[str]
Type of basis to use
basis_norm_mode: Optional[str]
Normalization mode for the filter basis
groups: Optional[int]
Number of groups
grid_in: Optional[str]
Grid type for the input tensor
grid_out: Optional[str]
Grid type for the output tensor
bias: Optional[bool]
Whether to use bias
theta_cutoff: Optional[float]
Theta cutoff for the filter basis
Returns
-------
out: torch.Tensor
Output tensor
References
----------
[1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
We assume the data can be splitted in polar and azimuthal directions.
"""
def __init__(
......@@ -247,6 +331,40 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
"""
Discrete-continuous transpose convolutions (DISCO) on the 2-Sphere as described in [1].
Parameters
----------
in_channels: int
Number of input channels
out_channels: int
Number of output channels
in_shape: Tuple[int]
Shape of the input tensor
out_shape: Tuple[int]
Shape of the output tensor
kernel_shape: Union[int, Tuple[int], Tuple[int, int]]
Shape of the kernel
basis_type: Optional[str]
Type of basis to use
basis_norm_mode: Optional[str]
Normalization mode for the filter basis
groups: Optional[int]
Number of groups
grid_in: Optional[str]
Grid type for the input tensor
grid_out: Optional[str]
Grid type for the output tensor
bias: Optional[bool]
Whether to use bias
theta_cutoff: Optional[float]
Theta cutoff for the filter basis
Returns
-------
out: torch.Tensor
Output tensor
References
----------
[1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
We assume the data can be splitted in polar and azimuthal directions.
......
......@@ -48,6 +48,30 @@ class DistributedRealSHT(nn.Module):
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
The SHT is applied to the last two dimensions of the input
Parameters
----------
nlat: int
Number of latitude points
nlon: int
Number of longitude points
lmax: int
Maximum spherical harmonic degree
mmax: int
Maximum spherical harmonic order
grid: str
Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
norm: str
Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
csphase: bool
Whether to apply the Condon-Shortley phase factor, by default True
Returns
-------
x: torch.Tensor
Tensor of shape (..., lmax, mmax)
References
----------
[1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
"""
......@@ -56,10 +80,22 @@ class DistributedRealSHT(nn.Module):
"""
Distribtued SHT layer. Expects the last 3 dimensions of the input tensor to be channels, latitude, longitude.
Parameters:
nlat: input grid resolution in the latitudinal direction
nlon: input grid resolution in the longitudinal direction
grid: grid in the latitude direction (for now only tensor product grids are supported)
Parameters
----------
nlat: int
Number of latitude points
nlon: int
Number of longitude points
lmax: int
Maximum spherical harmonic degree
mmax: int
Maximum spherical harmonic order
grid: str
Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
norm: str
Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
csphase: bool
Whether to apply the Condon-Shortley phase factor, by default True
"""
super().__init__()
......@@ -168,9 +204,31 @@ class DistributedInverseRealSHT(nn.Module):
"""
Defines a module for computing the inverse (real-valued) SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
nlat, nlon: Output dimensions
lmax, mmax: Input dimensions (spherical coefficients). For convenience, these are inferred from the output dimensions
Parameters
----------
nlat: int
Number of latitude points
nlon: int
Number of longitude points
lmax: int
Maximum spherical harmonic degree
mmax: int
Maximum spherical harmonic order
grid: str
Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
norm: str
Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
csphase: bool
Whether to apply the Condon-Shortley phase factor, by default True
Returns
-------
x: torch.Tensor
Tensor of shape (..., lmax, mmax)
References
----------
[1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
"""
......@@ -282,6 +340,30 @@ class DistributedRealVectorSHT(nn.Module):
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
The SHT is applied to the last three dimensions of the input.
Parameters
----------
nlat: int
Number of latitude points
nlon: int
Number of longitude points
lmax: int
Maximum spherical harmonic degree
mmax: int
Maximum spherical harmonic order
grid: str
Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
norm: str
Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
csphase: bool
Whether to apply the Condon-Shortley phase factor, by default True
Returns
-------
x: torch.Tensor
Tensor of shape (..., lmax, mmax)
References
----------
[1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
"""
......@@ -290,10 +372,18 @@ class DistributedRealVectorSHT(nn.Module):
"""
Initializes the vector SHT Layer, precomputing the necessary quadrature weights
Parameters:
nlat: input grid resolution in the latitudinal direction
nlon: input grid resolution in the longitudinal direction
grid: type of grid the data lives on
Parameters
----------
nlat: int
Number of latitude points
nlon: int
Number of longitude points
grid: str
Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
norm: str
Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
csphase: bool
Whether to apply the Condon-Shortley phase factor, by default True
"""
super().__init__()
......@@ -425,6 +515,30 @@ class DistributedInverseRealVectorSHT(nn.Module):
Defines a module for computing the inverse (real-valued) vector SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
Parameters
----------
nlat: int
Number of latitude points
nlon: int
Number of longitude points
lmax: int
Maximum spherical harmonic degree
mmax: int
Maximum spherical harmonic order
grid: str
Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
norm: str
Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
csphase: bool
Whether to apply the Condon-Shortley phase factor, by default True
Returns
-------
x: torch.Tensor
Tensor of shape (..., lmax, mmax)
References
----------
[1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
"""
......
......@@ -39,6 +39,19 @@ from .utils import is_initialized, is_distributed_polar, is_distributed_azimuth
# helper routine to compute uneven splitting in balanced way:
def compute_split_shapes(size: int, num_chunks: int) -> List[int]:
"""
Compute the split shapes for a given size and number of chunks.
Parameters
----------
size: int
The size of the tensor to split
Returns
-------
List[int]
The split shapes
"""
# treat trivial case first
if num_chunks == 1:
......@@ -59,6 +72,24 @@ def compute_split_shapes(size: int, num_chunks: int) -> List[int]:
def split_tensor_along_dim(tensor, dim, num_chunks):
"""
Split a tensor along a given dimension into a given number of chunks.
Parameters
----------
tensor: torch.Tensor
The tensor to split
dim: int
The dimension to split along
num_chunks: int
The number of chunks to split into
Returns
-------
tensor_list: List[torch.Tensor]
The split tensors
"""
assert dim < tensor.dim(), f"Error, tensor dimension is {tensor.dim()} which cannot be split along {dim}"
assert (tensor.shape[dim] >= num_chunks), f"Error, cannot split dim {dim} of size {tensor.shape[dim]} into \
{num_chunks} chunks. Empty slices are currently not supported."
......@@ -71,6 +102,26 @@ def split_tensor_along_dim(tensor, dim, num_chunks):
def _transpose(tensor, dim0, dim1, dim1_split_sizes, group=None, async_op=False):
"""
Transpose a tensor along two dimensions.
Parameters
----------
tensor: torch.Tensor
The tensor to transpose
dim0: int
The first dimension to transpose
dim1: int
The second dimension to transpose
dim1_split_sizes: List[int]
The split sizes for the second dimension
Returns
-------
tensor_list: List[torch.Tensor]
The split tensors
"""
# get comm params
comm_size = dist.get_world_size(group=group)
comm_rank = dist.get_rank(group=group)
......@@ -99,6 +150,26 @@ class distributed_transpose_azimuth(torch.autograd.Function):
Distributed transpose operation for azimuthal dimension.
This class provides the forward and backward passes for distributed
tensor transposition along the azimuthal dimension.
Parameters
----------
tensor: torch.Tensor
The tensor to transpose
dim0: int
The first dimension to transpose
dim1: int
The second dimension to transpose
dim1_split_sizes: List[int]
The split sizes for the second dimension
Returns
-------
x_recv: List[torch.Tensor]
The split tensors
dim0_split_sizes: List[int]
The split sizes for the first dimension
req: dist.Request
The request object
"""
@staticmethod
......@@ -107,10 +178,19 @@ class distributed_transpose_azimuth(torch.autograd.Function):
r"""
Forward pass for distributed azimuthal transpose.
Parameters:
x: input tensor
dims: dimensions to transpose
dim1_split_sizes: split sizes for dimension 1
Parameters
----------
x: torch.Tensor
The tensor to transpose
dims: List[int]
The dimensions to transpose
dim1_split_sizes: List[int]
The split sizes for the second dimension
Returns
-------
x: torch.Tensor
The transposed tensor
"""
# WAR for a potential contig check torch bug for channels last contig tensors
xlist, dim0_split_sizes, _ = _transpose(x, dims[0], dims[1], dim1_split_sizes, group=azimuth_group())
......@@ -126,11 +206,15 @@ class distributed_transpose_azimuth(torch.autograd.Function):
r"""
Backward pass for distributed azimuthal transpose.
Parameters:
go: gradient of the output
Parameters
----------
go: torch.Tensor
The gradient of the output
Returns:
gradient of the input
Returns
-------
gi: torch.Tensor
The gradient of the input
"""
dims = ctx.dims
dim0_split_sizes = ctx.dim0_split_sizes
......@@ -146,6 +230,24 @@ class distributed_transpose_polar(torch.autograd.Function):
Distributed transpose operation for polar dimension.
This class provides the forward and backward passes for distributed
tensor transposition along the polar dimension.
Parameters
----------
x: torch.Tensor
The tensor to transpose
dims: List[int]
The dimensions to transpose
dim1_split_sizes: List[int]
The split sizes for the second dimension
Returns
-------
x: torch.Tensor
The transposed tensor
dim0_split_sizes: List[int]
The split sizes for the first dimension
req: dist.Request
The request object
"""
@staticmethod
......@@ -154,10 +256,19 @@ class distributed_transpose_polar(torch.autograd.Function):
r"""
Forward pass for distributed polar transpose.
Parameters:
x: input tensor
dim: dimensions to transpose
dim1_split_sizes: split sizes for dimension 1
Parameters
----------
x: torch.Tensor
The tensor to transpose
dim: List[int]
The dimensions to transpose
dim1_split_sizes: List[int]
The split sizes for the second dimension
Returns
-------
x: torch.Tensor
The transposed tensor
"""
# WAR for a potential contig check torch bug for channels last contig tensors
xlist, dim0_split_sizes, _ = _transpose(x, dim[0], dim[1], dim1_split_sizes, group=polar_group())
......@@ -172,11 +283,15 @@ class distributed_transpose_polar(torch.autograd.Function):
r"""
Backward pass for distributed polar transpose.
Parameters:
go: gradient of the output
Parameters
----------
go: torch.Tensor
The gradient of the output
Returns:
gradient of the input
Returns
-------
gi: torch.Tensor
The gradient of the input
"""
dim = ctx.dim
dim0_split_sizes = ctx.dim0_split_sizes
......@@ -292,6 +407,16 @@ class _CopyToPolarRegion(torch.autograd.Function):
Copy tensor to polar region for distributed computation.
This class provides the forward and backward passes for copying
tensors to the polar region in distributed settings.
Parameters
----------
input_: torch.Tensor
The tensor to copy
Returns
-------
output: torch.Tensor
The reduced and scattered tensor
"""
@staticmethod
......@@ -304,11 +429,15 @@ class _CopyToPolarRegion(torch.autograd.Function):
r"""
Forward pass for copying to polar region.
Parameters:
input_: input tensor
Parameters
----------
input_: torch.Tensor
The tensor to copy
Returns:
input tensor (no-op in forward pass)
Returns
-------
input_: torch.Tensor
The tensor to copy
"""
return input_
......@@ -318,11 +447,15 @@ class _CopyToPolarRegion(torch.autograd.Function):
r"""
Backward pass for copying to polar region.
Parameters:
grad_output: gradient of the output
Parameters
----------
grad_output: torch.Tensor
The gradient of the output
Returns:
gradient of the input
Returns
-------
grad_output: torch.Tensor
The gradient of the output
"""
if is_distributed_polar():
return _reduce(grad_output, group=polar_group())
......@@ -335,6 +468,7 @@ class _CopyToAzimuthRegion(torch.autograd.Function):
Copy tensor to azimuth region for distributed computation.
This class provides the forward and backward passes for copying
tensors to the azimuth region in distributed settings.
"""
@staticmethod
......@@ -347,11 +481,15 @@ class _CopyToAzimuthRegion(torch.autograd.Function):
r"""
Forward pass for copying to azimuth region.
Parameters:
input_: input tensor
Parameters
----------
input_: torch.Tensor
The tensor to copy
Returns:
input tensor (no-op in forward pass)
Returns
-------
input_: torch.Tensor
The tensor to copy
"""
return input_
......@@ -361,11 +499,15 @@ class _CopyToAzimuthRegion(torch.autograd.Function):
r"""
Backward pass for copying to azimuth region.
Parameters:
grad_output: gradient of the output
Parameters
----------
grad_output: torch.Tensor
The gradient of the output
Returns:
gradient of the input
Returns
-------
grad_output: torch.Tensor
The gradient of the output
"""
if is_distributed_azimuth():
return _reduce(grad_output, group=azimuth_group())
......@@ -378,6 +520,18 @@ class _ScatterToPolarRegion(torch.autograd.Function):
Scatter tensor to polar region for distributed computation.
This class provides the forward and backward passes for scattering
tensors to the polar region in distributed settings.
Parameters
----------
input_: torch.Tensor
The tensor to scatter
dim_: int
The dimension to scatter along
Returns
-------
output: torch.Tensor
The scattered tensor
"""
@staticmethod
......@@ -406,8 +560,23 @@ class _ScatterToPolarRegion(torch.autograd.Function):
class _GatherFromPolarRegion(torch.autograd.Function):
"""Gather the input and keep it on the rank."""
r"""
Gather the input and keep it on the rank.
Parameters
----------
input_: torch.Tensor
The tensor to gather
dim_: int
The dimension to gather along
shapes_: List[int]
The split sizes for the dimension to gather along
Returns
-------
output: torch.Tensor
The gathered tensor
"""
@staticmethod
def symbolic(graph, input_, dim_, shapes_):
return _gather(input_, dim_, shapes_, polar_group())
......@@ -431,7 +600,19 @@ class _GatherFromPolarRegion(torch.autograd.Function):
class _ReduceFromPolarRegion(torch.autograd.Function):
"""All-reduce the input from the polar region."""
r"""
All-reduce the input from the polar region.
Parameters
----------
input_: torch.Tensor
The tensor to reduce
Returns
-------
output: torch.Tensor
The reduced tensor
"""
@staticmethod
def symbolic(graph, input_):
......@@ -455,8 +636,19 @@ class _ReduceFromPolarRegion(torch.autograd.Function):
class _ReduceFromAzimuthRegion(torch.autograd.Function):
"""All-reduce the input from the azimuth region."""
r"""
All-reduce the input from the azimuth region.
Parameters
----------
input_: torch.Tensor
The tensor to reduce
Returns
-------
output: torch.Tensor
The reduced tensor
"""
@staticmethod
def symbolic(graph, input_):
if is_distributed_azimuth():
......@@ -479,8 +671,21 @@ class _ReduceFromAzimuthRegion(torch.autograd.Function):
class _ReduceFromScatterToPolarRegion(torch.autograd.Function):
"""All-reduce the input from the polar region and scatter back to polar region."""
r"""
All-reduce the input from the polar region and scatter back to polar region.
Parameters
----------
input_: torch.Tensor
The tensor to reduce
dim_: int
The dimension to reduce along
Returns
-------
output: torch.Tensor
The reduced tensor
"""
@staticmethod
def symbolic(graph, input_, dim_):
if is_distributed_polar():
......@@ -510,7 +715,23 @@ class _ReduceFromScatterToPolarRegion(torch.autograd.Function):
class _GatherFromCopyToPolarRegion(torch.autograd.Function):
"""Gather the input from the polar region and register BWD AR, basically the inverse of reduce-scatter"""
r"""
Gather the input from the polar region and register BWD AR, basically the inverse of reduce-scatter
Parameters
----------
input_: torch.Tensor
The tensor to gather
dim_: int
The dimension to gather along
shapes_: List[int]
The split sizes for the dimension to gather along
Returns
-------
output: torch.Tensor
The gathered tensor
"""
@staticmethod
def symbolic(graph, input_, dim_, shapes_):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment