Commit eeda67aa authored by apaaris's avatar apaaris Committed by Boris Bonev
Browse files

Improved docstrings in convolution

parent ffee67f9
......@@ -65,6 +65,34 @@ def _normalize_convolution_tensor_s2(
- "none": No normalization is applied.
- "individual": for each output latitude and filter basis function the filter is numerically integrated over the sphere and normalized so that it yields 1.
- "mean": the norm is computed for each output latitude and then averaged over the output latitudes. Each basis function is then normalized by this mean.
Parameters
-----------
psi_idx: torch.Tensor
Index tensor of the convolution tensor
psi_vals: torch.Tensor
Values tensor of the convolution tensor
in_shape: Tuple[int]
Input shape of the convolution tensor
out_shape: Tuple[int]
Output shape of the convolution tensor
kernel_size: int
Size of the kernel
quad_weights: torch.Tensor
Quadrature weights
transpose_normalization: bool
Whether to normalize the convolution tensor in the transpose direction
basis_norm_mode: str
Mode for basis normalization
merge_quadrature: bool
Whether to merge the quadrature weights into the convolution tensor
eps: float
Small epsilon to avoid division by zero
Returns
-------
psi_vals: torch.Tensor
Normalized convolution tensor
"""
# exit here if no normalization is needed
......@@ -166,6 +194,37 @@ def _precompute_convolution_tensor_s2(
\cos(\alpha)\cos(\gamma)-\cos(\beta)\sin(\alpha)\sin(\gamma)
\end{bmatrix}}
$$
Parameters
-----------
in_shape: Tuple[int]
Input shape of the convolution tensor
out_shape: Tuple[int]
Output shape of the convolution tensor
filter_basis: FilterBasis
Filter basis functions
grid_in: str
Input grid type
grid_out: str
Output grid type
theta_cutoff: float
Theta cutoff for the filter basis functions
theta_eps: float
Epsilon for the theta cutoff
transpose_normalization: bool
Whether to normalize the convolution tensor in the transpose direction
basis_norm_mode: str
Mode for basis normalization
merge_quadrature: bool
Whether to merge the quadrature weights into the convolution tensor
Returns
-------
out_idx: torch.Tensor
Index tensor of the convolution tensor
out_vals: torch.Tensor
Values tensor of the convolution tensor
"""
assert len(in_shape) == 2
......@@ -268,6 +327,26 @@ def _precompute_convolution_tensor_s2(
class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta):
"""
Abstract base class for discrete-continuous convolutions
Parameters
-----------
in_channels: int
Number of input channels
out_channels: int
Number of output channels
kernel_shape: Union[int, Tuple[int], Tuple[int, int]]
Shape of the kernel
basis_type: Optional[str]
Type of the basis functions
groups: Optional[int]
Number of groups
bias: Optional[bool]
Whether to use bias
Returns
-------
out: torch.Tensor
Output tensor
"""
def __init__(
......@@ -316,6 +395,40 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
"""
Discrete-continuous (DISCO) convolutions on the 2-Sphere as described in [1].
Parameters
-----------
in_channels: int
Number of input channels
out_channels: int
Number of output channels
in_shape: Tuple[int]
Input shape of the convolution tensor
out_shape: Tuple[int]
Output shape of the convolution tensor
kernel_shape: Union[int, Tuple[int], Tuple[int, int]]
Shape of the kernel
basis_type: Optional[str]
Type of the basis functions
basis_norm_mode: Optional[str]
Mode for basis normalization
groups: Optional[int]
Number of groups
grid_in: Optional[str]
Input grid type
grid_out: Optional[str]
Output grid type
bias: Optional[bool]
Whether to use bias
theta_cutoff: Optional[float]
Theta cutoff for the filter basis functions
Returns
-------
out: torch.Tensor
Output tensor
References
----------
[1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
"""
......@@ -389,8 +502,28 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
@property
def psi_idx(self):
"""
Get the convolution tensor index
Returns
-------
psi_idx: torch.Tensor
Convolution tensor index
"""
return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()
def get_psi(self):
"""
Get the convolution tensor
Returns
-------
psi: torch.Tensor
Convolution tensor
"""
psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in)).coalesce()
return psi
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.is_cuda and _cuda_extension_available:
......@@ -420,6 +553,40 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
"""
Discrete-continuous (DISCO) transpose convolutions on the 2-Sphere as described in [1].
Parameters
-----------
in_channels: int
Number of input channels
out_channels: int
Number of output channels
in_shape: Tuple[int]
Input shape of the convolution tensor
out_shape: Tuple[int]
Output shape of the convolution tensor
kernel_shape: Union[int, Tuple[int], Tuple[int, int]]
Shape of the kernel
basis_type: Optional[str]
Type of the basis functions
basis_norm_mode: Optional[str]
Mode for basis normalization
groups: Optional[int]
Number of groups
grid_in: Optional[str]
Input grid type
grid_out: Optional[str]
Output grid type
bias: Optional[bool]
Whether to use bias
theta_cutoff: Optional[float]
Theta cutoff for the filter basis functions
Returns
--------
out: torch.Tensor
Output tensor
References
----------
[1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
"""
......@@ -496,7 +663,52 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
def psi_idx(self):
return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()
<<<<<<< HEAD
=======
def get_psi(self, semi_transposed: bool = False):
"""
Get the convolution tensor
Parameters
-----------
semi_transposed: bool
Whether to semi-transpose the convolution tensor
Returns
-------
psi: torch.Tensor
Convolution tensor
"""
if semi_transposed:
# we do a semi-transposition to faciliate the computation
tout = self.psi_idx[2] // self.nlon_out
pout = self.psi_idx[2] % self.nlon_out
# flip the axis of longitudes
pout = self.nlon_out - 1 - pout
tin = self.psi_idx[1]
idx = torch.stack([self.psi_idx[0], tout, tin * self.nlon_out + pout], dim=0)
psi = torch.sparse_coo_tensor(idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_out)).coalesce()
else:
psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out)).coalesce()
return psi
>>>>>>> 4578beb (Improved docstrings in convolution)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass
Parameters
-----------
x: torch.Tensor
Input tensor
Returns
-------
out: torch.Tensor
Output tensor
"""
# extract shape
B, C, H, W = x.shape
x = x.reshape(B, self.groups, self.groupsize, H, W)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment