Commit 50ebe96f authored by apaaris's avatar apaaris Committed by Boris Bonev
Browse files

Improved docstrings in distributed

parent 10cf65f6
...@@ -74,6 +74,57 @@ def _split_distributed_convolution_tensor_s2( ...@@ -74,6 +74,57 @@ def _split_distributed_convolution_tensor_s2(
in_shape: Tuple[int], in_shape: Tuple[int],
out_shape: Tuple[int], out_shape: Tuple[int],
): ):
"""
Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al.
The output tensor has shape kernel_shape x nlat_out x (nlat_in * nlon_in).
The rotation of the Euler angles uses the YZY convention, which applied to the northpole $(0,0,1)^T$ yields
$$
Y(\alpha) Z(\beta) Y(\gamma) n =
{\begin{bmatrix}
\cos(\gamma)\sin(\alpha) + \cos(\alpha)\cos(\beta)\sin(\gamma) \\
\sin(\beta)\sin(\gamma) \\
\cos(\alpha)\cos(\gamma)-\cos(\beta)\sin(\alpha)\sin(\gamma)
\end{bmatrix}}
$$
Parameters
----------
in_shape: Tuple[int]
Shape of the input tensor
out_shape: Tuple[int]
Shape of the output tensor
filter_basis: FilterBasis
Filter basis to use
grid_in: str
Grid type for the input tensor
grid_out: str
Grid type for the output tensor
theta_cutoff: float
Theta cutoff for the filter basis
theta_eps: float
Epsilon for the theta cutoff
transpose_normalization: bool
Whether to transpose the normalization
basis_norm_mode: str
Normalization mode for the filter basis
merge_quadrature: bool
Whether to merge the quadrature weights
Returns
-------
out_idx: torch.Tensor
Indices of the output tensor
out_vals: torch.Tensor
Values of the output tensor
"""
assert len(in_shape) == 2
assert len(out_shape) == 2
kernel_size = filter_basis.kernel_size
nlat_in, nlon_in = in_shape nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape nlat_out, nlon_out = out_shape
...@@ -102,10 +153,43 @@ def _split_distributed_convolution_tensor_s2( ...@@ -102,10 +153,43 @@ def _split_distributed_convolution_tensor_s2(
class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv): class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
""" """
Distributed version of Discrete-continuous convolutions (DISCO) on the 2-Sphere as described in [1]. Distributed version of Discrete-continuous convolutions (DISCO) on the 2-Sphere as described in [1].
We assume the data can be splitted in polar and azimuthal directions.
Parameters
----------
in_channels: int
Number of input channels
out_channels: int
Number of output channels
in_shape: Tuple[int]
Shape of the input tensor
out_shape: Tuple[int]
Shape of the output tensor
kernel_shape: Union[int, Tuple[int], Tuple[int, int]]
Shape of the kernel
basis_type: Optional[str]
Type of basis to use
basis_norm_mode: Optional[str]
Normalization mode for the filter basis
groups: Optional[int]
Number of groups
grid_in: Optional[str]
Grid type for the input tensor
grid_out: Optional[str]
Grid type for the output tensor
bias: Optional[bool]
Whether to use bias
theta_cutoff: Optional[float]
Theta cutoff for the filter basis
Returns
-------
out: torch.Tensor
Output tensor
References
----------
[1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603 [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
We assume the data can be splitted in polar and azimuthal directions.
""" """
def __init__( def __init__(
...@@ -247,6 +331,40 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -247,6 +331,40 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
""" """
Discrete-continuous transpose convolutions (DISCO) on the 2-Sphere as described in [1]. Discrete-continuous transpose convolutions (DISCO) on the 2-Sphere as described in [1].
Parameters
----------
in_channels: int
Number of input channels
out_channels: int
Number of output channels
in_shape: Tuple[int]
Shape of the input tensor
out_shape: Tuple[int]
Shape of the output tensor
kernel_shape: Union[int, Tuple[int], Tuple[int, int]]
Shape of the kernel
basis_type: Optional[str]
Type of basis to use
basis_norm_mode: Optional[str]
Normalization mode for the filter basis
groups: Optional[int]
Number of groups
grid_in: Optional[str]
Grid type for the input tensor
grid_out: Optional[str]
Grid type for the output tensor
bias: Optional[bool]
Whether to use bias
theta_cutoff: Optional[float]
Theta cutoff for the filter basis
Returns
-------
out: torch.Tensor
Output tensor
References
----------
[1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603 [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
We assume the data can be splitted in polar and azimuthal directions. We assume the data can be splitted in polar and azimuthal directions.
......
...@@ -48,6 +48,30 @@ class DistributedRealSHT(nn.Module): ...@@ -48,6 +48,30 @@ class DistributedRealSHT(nn.Module):
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes. Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
The SHT is applied to the last two dimensions of the input The SHT is applied to the last two dimensions of the input
Parameters
----------
nlat: int
Number of latitude points
nlon: int
Number of longitude points
lmax: int
Maximum spherical harmonic degree
mmax: int
Maximum spherical harmonic order
grid: str
Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
norm: str
Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
csphase: bool
Whether to apply the Condon-Shortley phase factor, by default True
Returns
-------
x: torch.Tensor
Tensor of shape (..., lmax, mmax)
References
----------
[1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems. [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math. [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
""" """
...@@ -56,10 +80,22 @@ class DistributedRealSHT(nn.Module): ...@@ -56,10 +80,22 @@ class DistributedRealSHT(nn.Module):
""" """
Distribtued SHT layer. Expects the last 3 dimensions of the input tensor to be channels, latitude, longitude. Distribtued SHT layer. Expects the last 3 dimensions of the input tensor to be channels, latitude, longitude.
Parameters: Parameters
nlat: input grid resolution in the latitudinal direction ----------
nlon: input grid resolution in the longitudinal direction nlat: int
grid: grid in the latitude direction (for now only tensor product grids are supported) Number of latitude points
nlon: int
Number of longitude points
lmax: int
Maximum spherical harmonic degree
mmax: int
Maximum spherical harmonic order
grid: str
Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
norm: str
Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
csphase: bool
Whether to apply the Condon-Shortley phase factor, by default True
""" """
super().__init__() super().__init__()
...@@ -168,9 +204,31 @@ class DistributedInverseRealSHT(nn.Module): ...@@ -168,9 +204,31 @@ class DistributedInverseRealSHT(nn.Module):
""" """
Defines a module for computing the inverse (real-valued) SHT. Defines a module for computing the inverse (real-valued) SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes. Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
nlat, nlon: Output dimensions
lmax, mmax: Input dimensions (spherical coefficients). For convenience, these are inferred from the output dimensions
Parameters
----------
nlat: int
Number of latitude points
nlon: int
Number of longitude points
lmax: int
Maximum spherical harmonic degree
mmax: int
Maximum spherical harmonic order
grid: str
Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
norm: str
Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
csphase: bool
Whether to apply the Condon-Shortley phase factor, by default True
Returns
-------
x: torch.Tensor
Tensor of shape (..., lmax, mmax)
References
----------
[1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems. [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math. [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
""" """
...@@ -282,6 +340,30 @@ class DistributedRealVectorSHT(nn.Module): ...@@ -282,6 +340,30 @@ class DistributedRealVectorSHT(nn.Module):
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes. Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
The SHT is applied to the last three dimensions of the input. The SHT is applied to the last three dimensions of the input.
Parameters
----------
nlat: int
Number of latitude points
nlon: int
Number of longitude points
lmax: int
Maximum spherical harmonic degree
mmax: int
Maximum spherical harmonic order
grid: str
Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
norm: str
Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
csphase: bool
Whether to apply the Condon-Shortley phase factor, by default True
Returns
-------
x: torch.Tensor
Tensor of shape (..., lmax, mmax)
References
----------
[1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems. [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math. [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
""" """
...@@ -290,10 +372,18 @@ class DistributedRealVectorSHT(nn.Module): ...@@ -290,10 +372,18 @@ class DistributedRealVectorSHT(nn.Module):
""" """
Initializes the vector SHT Layer, precomputing the necessary quadrature weights Initializes the vector SHT Layer, precomputing the necessary quadrature weights
Parameters: Parameters
nlat: input grid resolution in the latitudinal direction ----------
nlon: input grid resolution in the longitudinal direction nlat: int
grid: type of grid the data lives on Number of latitude points
nlon: int
Number of longitude points
grid: str
Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
norm: str
Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
csphase: bool
Whether to apply the Condon-Shortley phase factor, by default True
""" """
super().__init__() super().__init__()
...@@ -425,6 +515,30 @@ class DistributedInverseRealVectorSHT(nn.Module): ...@@ -425,6 +515,30 @@ class DistributedInverseRealVectorSHT(nn.Module):
Defines a module for computing the inverse (real-valued) vector SHT. Defines a module for computing the inverse (real-valued) vector SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes. Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
Parameters
----------
nlat: int
Number of latitude points
nlon: int
Number of longitude points
lmax: int
Maximum spherical harmonic degree
mmax: int
Maximum spherical harmonic order
grid: str
Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
norm: str
Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
csphase: bool
Whether to apply the Condon-Shortley phase factor, by default True
Returns
-------
x: torch.Tensor
Tensor of shape (..., lmax, mmax)
References
----------
[1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems. [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math. [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
""" """
......
...@@ -39,6 +39,19 @@ from .utils import is_initialized, is_distributed_polar, is_distributed_azimuth ...@@ -39,6 +39,19 @@ from .utils import is_initialized, is_distributed_polar, is_distributed_azimuth
# helper routine to compute uneven splitting in balanced way: # helper routine to compute uneven splitting in balanced way:
def compute_split_shapes(size: int, num_chunks: int) -> List[int]: def compute_split_shapes(size: int, num_chunks: int) -> List[int]:
"""
Compute the split shapes for a given size and number of chunks.
Parameters
----------
size: int
The size of the tensor to split
Returns
-------
List[int]
The split shapes
"""
# treat trivial case first # treat trivial case first
if num_chunks == 1: if num_chunks == 1:
...@@ -59,6 +72,24 @@ def compute_split_shapes(size: int, num_chunks: int) -> List[int]: ...@@ -59,6 +72,24 @@ def compute_split_shapes(size: int, num_chunks: int) -> List[int]:
def split_tensor_along_dim(tensor, dim, num_chunks): def split_tensor_along_dim(tensor, dim, num_chunks):
"""
Split a tensor along a given dimension into a given number of chunks.
Parameters
----------
tensor: torch.Tensor
The tensor to split
dim: int
The dimension to split along
num_chunks: int
The number of chunks to split into
Returns
-------
tensor_list: List[torch.Tensor]
The split tensors
"""
assert dim < tensor.dim(), f"Error, tensor dimension is {tensor.dim()} which cannot be split along {dim}" assert dim < tensor.dim(), f"Error, tensor dimension is {tensor.dim()} which cannot be split along {dim}"
assert (tensor.shape[dim] >= num_chunks), f"Error, cannot split dim {dim} of size {tensor.shape[dim]} into \ assert (tensor.shape[dim] >= num_chunks), f"Error, cannot split dim {dim} of size {tensor.shape[dim]} into \
{num_chunks} chunks. Empty slices are currently not supported." {num_chunks} chunks. Empty slices are currently not supported."
...@@ -71,6 +102,26 @@ def split_tensor_along_dim(tensor, dim, num_chunks): ...@@ -71,6 +102,26 @@ def split_tensor_along_dim(tensor, dim, num_chunks):
def _transpose(tensor, dim0, dim1, dim1_split_sizes, group=None, async_op=False): def _transpose(tensor, dim0, dim1, dim1_split_sizes, group=None, async_op=False):
"""
Transpose a tensor along two dimensions.
Parameters
----------
tensor: torch.Tensor
The tensor to transpose
dim0: int
The first dimension to transpose
dim1: int
The second dimension to transpose
dim1_split_sizes: List[int]
The split sizes for the second dimension
Returns
-------
tensor_list: List[torch.Tensor]
The split tensors
"""
# get comm params # get comm params
comm_size = dist.get_world_size(group=group) comm_size = dist.get_world_size(group=group)
comm_rank = dist.get_rank(group=group) comm_rank = dist.get_rank(group=group)
...@@ -99,6 +150,26 @@ class distributed_transpose_azimuth(torch.autograd.Function): ...@@ -99,6 +150,26 @@ class distributed_transpose_azimuth(torch.autograd.Function):
Distributed transpose operation for azimuthal dimension. Distributed transpose operation for azimuthal dimension.
This class provides the forward and backward passes for distributed This class provides the forward and backward passes for distributed
tensor transposition along the azimuthal dimension. tensor transposition along the azimuthal dimension.
Parameters
----------
tensor: torch.Tensor
The tensor to transpose
dim0: int
The first dimension to transpose
dim1: int
The second dimension to transpose
dim1_split_sizes: List[int]
The split sizes for the second dimension
Returns
-------
x_recv: List[torch.Tensor]
The split tensors
dim0_split_sizes: List[int]
The split sizes for the first dimension
req: dist.Request
The request object
""" """
@staticmethod @staticmethod
...@@ -107,10 +178,19 @@ class distributed_transpose_azimuth(torch.autograd.Function): ...@@ -107,10 +178,19 @@ class distributed_transpose_azimuth(torch.autograd.Function):
r""" r"""
Forward pass for distributed azimuthal transpose. Forward pass for distributed azimuthal transpose.
Parameters: Parameters
x: input tensor ----------
dims: dimensions to transpose x: torch.Tensor
dim1_split_sizes: split sizes for dimension 1 The tensor to transpose
dims: List[int]
The dimensions to transpose
dim1_split_sizes: List[int]
The split sizes for the second dimension
Returns
-------
x: torch.Tensor
The transposed tensor
""" """
# WAR for a potential contig check torch bug for channels last contig tensors # WAR for a potential contig check torch bug for channels last contig tensors
xlist, dim0_split_sizes, _ = _transpose(x, dims[0], dims[1], dim1_split_sizes, group=azimuth_group()) xlist, dim0_split_sizes, _ = _transpose(x, dims[0], dims[1], dim1_split_sizes, group=azimuth_group())
...@@ -126,11 +206,15 @@ class distributed_transpose_azimuth(torch.autograd.Function): ...@@ -126,11 +206,15 @@ class distributed_transpose_azimuth(torch.autograd.Function):
r""" r"""
Backward pass for distributed azimuthal transpose. Backward pass for distributed azimuthal transpose.
Parameters: Parameters
go: gradient of the output ----------
go: torch.Tensor
The gradient of the output
Returns: Returns
gradient of the input -------
gi: torch.Tensor
The gradient of the input
""" """
dims = ctx.dims dims = ctx.dims
dim0_split_sizes = ctx.dim0_split_sizes dim0_split_sizes = ctx.dim0_split_sizes
...@@ -146,6 +230,24 @@ class distributed_transpose_polar(torch.autograd.Function): ...@@ -146,6 +230,24 @@ class distributed_transpose_polar(torch.autograd.Function):
Distributed transpose operation for polar dimension. Distributed transpose operation for polar dimension.
This class provides the forward and backward passes for distributed This class provides the forward and backward passes for distributed
tensor transposition along the polar dimension. tensor transposition along the polar dimension.
Parameters
----------
x: torch.Tensor
The tensor to transpose
dims: List[int]
The dimensions to transpose
dim1_split_sizes: List[int]
The split sizes for the second dimension
Returns
-------
x: torch.Tensor
The transposed tensor
dim0_split_sizes: List[int]
The split sizes for the first dimension
req: dist.Request
The request object
""" """
@staticmethod @staticmethod
...@@ -154,10 +256,19 @@ class distributed_transpose_polar(torch.autograd.Function): ...@@ -154,10 +256,19 @@ class distributed_transpose_polar(torch.autograd.Function):
r""" r"""
Forward pass for distributed polar transpose. Forward pass for distributed polar transpose.
Parameters: Parameters
x: input tensor ----------
dim: dimensions to transpose x: torch.Tensor
dim1_split_sizes: split sizes for dimension 1 The tensor to transpose
dim: List[int]
The dimensions to transpose
dim1_split_sizes: List[int]
The split sizes for the second dimension
Returns
-------
x: torch.Tensor
The transposed tensor
""" """
# WAR for a potential contig check torch bug for channels last contig tensors # WAR for a potential contig check torch bug for channels last contig tensors
xlist, dim0_split_sizes, _ = _transpose(x, dim[0], dim[1], dim1_split_sizes, group=polar_group()) xlist, dim0_split_sizes, _ = _transpose(x, dim[0], dim[1], dim1_split_sizes, group=polar_group())
...@@ -172,11 +283,15 @@ class distributed_transpose_polar(torch.autograd.Function): ...@@ -172,11 +283,15 @@ class distributed_transpose_polar(torch.autograd.Function):
r""" r"""
Backward pass for distributed polar transpose. Backward pass for distributed polar transpose.
Parameters: Parameters
go: gradient of the output ----------
go: torch.Tensor
The gradient of the output
Returns: Returns
gradient of the input -------
gi: torch.Tensor
The gradient of the input
""" """
dim = ctx.dim dim = ctx.dim
dim0_split_sizes = ctx.dim0_split_sizes dim0_split_sizes = ctx.dim0_split_sizes
...@@ -292,6 +407,16 @@ class _CopyToPolarRegion(torch.autograd.Function): ...@@ -292,6 +407,16 @@ class _CopyToPolarRegion(torch.autograd.Function):
Copy tensor to polar region for distributed computation. Copy tensor to polar region for distributed computation.
This class provides the forward and backward passes for copying This class provides the forward and backward passes for copying
tensors to the polar region in distributed settings. tensors to the polar region in distributed settings.
Parameters
----------
input_: torch.Tensor
The tensor to copy
Returns
-------
output: torch.Tensor
The reduced and scattered tensor
""" """
@staticmethod @staticmethod
...@@ -304,11 +429,15 @@ class _CopyToPolarRegion(torch.autograd.Function): ...@@ -304,11 +429,15 @@ class _CopyToPolarRegion(torch.autograd.Function):
r""" r"""
Forward pass for copying to polar region. Forward pass for copying to polar region.
Parameters: Parameters
input_: input tensor ----------
input_: torch.Tensor
The tensor to copy
Returns: Returns
input tensor (no-op in forward pass) -------
input_: torch.Tensor
The tensor to copy
""" """
return input_ return input_
...@@ -318,11 +447,15 @@ class _CopyToPolarRegion(torch.autograd.Function): ...@@ -318,11 +447,15 @@ class _CopyToPolarRegion(torch.autograd.Function):
r""" r"""
Backward pass for copying to polar region. Backward pass for copying to polar region.
Parameters: Parameters
grad_output: gradient of the output ----------
grad_output: torch.Tensor
The gradient of the output
Returns: Returns
gradient of the input -------
grad_output: torch.Tensor
The gradient of the output
""" """
if is_distributed_polar(): if is_distributed_polar():
return _reduce(grad_output, group=polar_group()) return _reduce(grad_output, group=polar_group())
...@@ -335,6 +468,7 @@ class _CopyToAzimuthRegion(torch.autograd.Function): ...@@ -335,6 +468,7 @@ class _CopyToAzimuthRegion(torch.autograd.Function):
Copy tensor to azimuth region for distributed computation. Copy tensor to azimuth region for distributed computation.
This class provides the forward and backward passes for copying This class provides the forward and backward passes for copying
tensors to the azimuth region in distributed settings. tensors to the azimuth region in distributed settings.
""" """
@staticmethod @staticmethod
...@@ -347,11 +481,15 @@ class _CopyToAzimuthRegion(torch.autograd.Function): ...@@ -347,11 +481,15 @@ class _CopyToAzimuthRegion(torch.autograd.Function):
r""" r"""
Forward pass for copying to azimuth region. Forward pass for copying to azimuth region.
Parameters: Parameters
input_: input tensor ----------
input_: torch.Tensor
The tensor to copy
Returns: Returns
input tensor (no-op in forward pass) -------
input_: torch.Tensor
The tensor to copy
""" """
return input_ return input_
...@@ -361,11 +499,15 @@ class _CopyToAzimuthRegion(torch.autograd.Function): ...@@ -361,11 +499,15 @@ class _CopyToAzimuthRegion(torch.autograd.Function):
r""" r"""
Backward pass for copying to azimuth region. Backward pass for copying to azimuth region.
Parameters: Parameters
grad_output: gradient of the output ----------
grad_output: torch.Tensor
The gradient of the output
Returns: Returns
gradient of the input -------
grad_output: torch.Tensor
The gradient of the output
""" """
if is_distributed_azimuth(): if is_distributed_azimuth():
return _reduce(grad_output, group=azimuth_group()) return _reduce(grad_output, group=azimuth_group())
...@@ -378,6 +520,18 @@ class _ScatterToPolarRegion(torch.autograd.Function): ...@@ -378,6 +520,18 @@ class _ScatterToPolarRegion(torch.autograd.Function):
Scatter tensor to polar region for distributed computation. Scatter tensor to polar region for distributed computation.
This class provides the forward and backward passes for scattering This class provides the forward and backward passes for scattering
tensors to the polar region in distributed settings. tensors to the polar region in distributed settings.
Parameters
----------
input_: torch.Tensor
The tensor to scatter
dim_: int
The dimension to scatter along
Returns
-------
output: torch.Tensor
The scattered tensor
""" """
@staticmethod @staticmethod
...@@ -406,8 +560,23 @@ class _ScatterToPolarRegion(torch.autograd.Function): ...@@ -406,8 +560,23 @@ class _ScatterToPolarRegion(torch.autograd.Function):
class _GatherFromPolarRegion(torch.autograd.Function): class _GatherFromPolarRegion(torch.autograd.Function):
"""Gather the input and keep it on the rank.""" r"""
Gather the input and keep it on the rank.
Parameters
----------
input_: torch.Tensor
The tensor to gather
dim_: int
The dimension to gather along
shapes_: List[int]
The split sizes for the dimension to gather along
Returns
-------
output: torch.Tensor
The gathered tensor
"""
@staticmethod @staticmethod
def symbolic(graph, input_, dim_, shapes_): def symbolic(graph, input_, dim_, shapes_):
return _gather(input_, dim_, shapes_, polar_group()) return _gather(input_, dim_, shapes_, polar_group())
...@@ -431,7 +600,19 @@ class _GatherFromPolarRegion(torch.autograd.Function): ...@@ -431,7 +600,19 @@ class _GatherFromPolarRegion(torch.autograd.Function):
class _ReduceFromPolarRegion(torch.autograd.Function): class _ReduceFromPolarRegion(torch.autograd.Function):
"""All-reduce the input from the polar region.""" r"""
All-reduce the input from the polar region.
Parameters
----------
input_: torch.Tensor
The tensor to reduce
Returns
-------
output: torch.Tensor
The reduced tensor
"""
@staticmethod @staticmethod
def symbolic(graph, input_): def symbolic(graph, input_):
...@@ -455,8 +636,19 @@ class _ReduceFromPolarRegion(torch.autograd.Function): ...@@ -455,8 +636,19 @@ class _ReduceFromPolarRegion(torch.autograd.Function):
class _ReduceFromAzimuthRegion(torch.autograd.Function): class _ReduceFromAzimuthRegion(torch.autograd.Function):
"""All-reduce the input from the azimuth region.""" r"""
All-reduce the input from the azimuth region.
Parameters
----------
input_: torch.Tensor
The tensor to reduce
Returns
-------
output: torch.Tensor
The reduced tensor
"""
@staticmethod @staticmethod
def symbolic(graph, input_): def symbolic(graph, input_):
if is_distributed_azimuth(): if is_distributed_azimuth():
...@@ -479,8 +671,21 @@ class _ReduceFromAzimuthRegion(torch.autograd.Function): ...@@ -479,8 +671,21 @@ class _ReduceFromAzimuthRegion(torch.autograd.Function):
class _ReduceFromScatterToPolarRegion(torch.autograd.Function): class _ReduceFromScatterToPolarRegion(torch.autograd.Function):
"""All-reduce the input from the polar region and scatter back to polar region.""" r"""
All-reduce the input from the polar region and scatter back to polar region.
Parameters
----------
input_: torch.Tensor
The tensor to reduce
dim_: int
The dimension to reduce along
Returns
-------
output: torch.Tensor
The reduced tensor
"""
@staticmethod @staticmethod
def symbolic(graph, input_, dim_): def symbolic(graph, input_, dim_):
if is_distributed_polar(): if is_distributed_polar():
...@@ -510,7 +715,23 @@ class _ReduceFromScatterToPolarRegion(torch.autograd.Function): ...@@ -510,7 +715,23 @@ class _ReduceFromScatterToPolarRegion(torch.autograd.Function):
class _GatherFromCopyToPolarRegion(torch.autograd.Function): class _GatherFromCopyToPolarRegion(torch.autograd.Function):
"""Gather the input from the polar region and register BWD AR, basically the inverse of reduce-scatter""" r"""
Gather the input from the polar region and register BWD AR, basically the inverse of reduce-scatter
Parameters
----------
input_: torch.Tensor
The tensor to gather
dim_: int
The dimension to gather along
shapes_: List[int]
The split sizes for the dimension to gather along
Returns
-------
output: torch.Tensor
The gathered tensor
"""
@staticmethod @staticmethod
def symbolic(graph, input_, dim_, shapes_): def symbolic(graph, input_, dim_, shapes_):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment