Commit 1ef713bb authored by Andrea Paris's avatar Andrea Paris Committed by Boris Bonev
Browse files

converted many docstrings to oneliners

parent a8f2af6c
...@@ -256,9 +256,6 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv): ...@@ -256,9 +256,6 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
self.psi = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, self.nlat_in_local, self.nlat_out_local) self.psi = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, self.nlat_in_local, self.nlat_out_local)
def extra_repr(self): def extra_repr(self):
r"""
Pretty print module
"""
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}" return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
@property @property
...@@ -437,9 +434,6 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -437,9 +434,6 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
self.psi_st = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, self.nlat_in_local, self.nlat_out_local, semi_transposed=True) self.psi_st = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, self.nlat_in_local, self.nlat_out_local, semi_transposed=True)
def extra_repr(self): def extra_repr(self):
r"""
Pretty print module
"""
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}" return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
@property @property
......
...@@ -43,7 +43,7 @@ from torch_harmonics.distributed import compute_split_shapes ...@@ -43,7 +43,7 @@ from torch_harmonics.distributed import compute_split_shapes
class DistributedResampleS2(nn.Module): class DistributedResampleS2(nn.Module):
r""" """
Distributed resampling module for spherical data on the 2-sphere. Distributed resampling module for spherical data on the 2-sphere.
This module performs distributed resampling of spherical data across multiple processes, This module performs distributed resampling of spherical data across multiple processes,
...@@ -156,19 +156,7 @@ class DistributedResampleS2(nn.Module): ...@@ -156,19 +156,7 @@ class DistributedResampleS2(nn.Module):
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}" return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}"
def _upscale_longitudes(self, x: torch.Tensor): def _upscale_longitudes(self, x: torch.Tensor):
""" """Upscale the longitude dimension using interpolation."""
Upscale the longitude dimension using interpolation.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (..., nlat, nlon)
Returns
-------
torch.Tensor
Upscaled tensor in the longitude dimension
"""
# do the interpolation # do the interpolation
lwgt = self.lon_weights.to(x.dtype) lwgt = self.lon_weights.to(x.dtype)
if self.mode == "bilinear": if self.mode == "bilinear":
...@@ -183,19 +171,7 @@ class DistributedResampleS2(nn.Module): ...@@ -183,19 +171,7 @@ class DistributedResampleS2(nn.Module):
return x return x
def _expand_poles(self, x: torch.Tensor): def _expand_poles(self, x: torch.Tensor):
""" """Expand the data to include pole values for interpolation."""
Expand the data to include pole values for interpolation.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (..., nlat, nlon)
Returns
-------
torch.Tensor
Tensor with expanded pole values
"""
x_north = x[..., 0, :].sum(dim=-1, keepdims=True) x_north = x[..., 0, :].sum(dim=-1, keepdims=True)
x_south = x[..., -1, :].sum(dim=-1, keepdims=True) x_south = x[..., -1, :].sum(dim=-1, keepdims=True)
x_count = torch.tensor([x.shape[-1]], dtype=torch.long, device=x.device, requires_grad=False) x_count = torch.tensor([x.shape[-1]], dtype=torch.long, device=x.device, requires_grad=False)
...@@ -218,19 +194,7 @@ class DistributedResampleS2(nn.Module): ...@@ -218,19 +194,7 @@ class DistributedResampleS2(nn.Module):
return x return x
def _upscale_latitudes(self, x: torch.Tensor): def _upscale_latitudes(self, x: torch.Tensor):
""" """Upscale the latitude dimension using interpolation."""
Upscale the latitude dimension using interpolation.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (..., nlat, nlon)
Returns
-------
torch.Tensor
Upscaled tensor in the latitude dimension
"""
# do the interpolation # do the interpolation
lwgt = self.lat_weights.to(x.dtype) lwgt = self.lat_weights.to(x.dtype)
if self.mode == "bilinear": if self.mode == "bilinear":
......
...@@ -131,9 +131,6 @@ class DistributedRealSHT(nn.Module): ...@@ -131,9 +131,6 @@ class DistributedRealSHT(nn.Module):
self.register_buffer('weights', weights, persistent=False) self.register_buffer('weights', weights, persistent=False)
def extra_repr(self): def extra_repr(self):
"""
Pretty print module
"""
return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}' return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
...@@ -264,9 +261,6 @@ class DistributedInverseRealSHT(nn.Module): ...@@ -264,9 +261,6 @@ class DistributedInverseRealSHT(nn.Module):
self.register_buffer('pct', pct, persistent=False) self.register_buffer('pct', pct, persistent=False)
def extra_repr(self): def extra_repr(self):
"""
Pretty print module
"""
return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}' return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
...@@ -409,9 +403,6 @@ class DistributedRealVectorSHT(nn.Module): ...@@ -409,9 +403,6 @@ class DistributedRealVectorSHT(nn.Module):
def extra_repr(self): def extra_repr(self):
"""
Pretty print module
"""
return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}' return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
...@@ -556,9 +547,6 @@ class DistributedInverseRealVectorSHT(nn.Module): ...@@ -556,9 +547,6 @@ class DistributedInverseRealVectorSHT(nn.Module):
self.register_buffer('dpct', dpct, persistent=False) self.register_buffer('dpct', dpct, persistent=False)
def extra_repr(self): def extra_repr(self):
"""
Pretty print module
"""
return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}' return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
......
...@@ -39,19 +39,7 @@ from .utils import is_initialized, is_distributed_polar, is_distributed_azimuth ...@@ -39,19 +39,7 @@ from .utils import is_initialized, is_distributed_polar, is_distributed_azimuth
# helper routine to compute uneven splitting in balanced way: # helper routine to compute uneven splitting in balanced way:
def compute_split_shapes(size: int, num_chunks: int) -> List[int]: def compute_split_shapes(size: int, num_chunks: int) -> List[int]:
""" """Compute the split shapes for a given size and number of chunks."""
Compute the split shapes for a given size and number of chunks.
Parameters
----------
size: int
The size of the tensor to split
Returns
-------
List[int]
The split shapes
"""
# treat trivial case first # treat trivial case first
if num_chunks == 1: if num_chunks == 1:
...@@ -72,23 +60,7 @@ def compute_split_shapes(size: int, num_chunks: int) -> List[int]: ...@@ -72,23 +60,7 @@ def compute_split_shapes(size: int, num_chunks: int) -> List[int]:
def split_tensor_along_dim(tensor, dim, num_chunks): def split_tensor_along_dim(tensor, dim, num_chunks):
""" """Split a tensor along a given dimension into a given number of chunks."""
Split a tensor along a given dimension into a given number of chunks.
Parameters
----------
tensor: torch.Tensor
The tensor to split
dim: int
The dimension to split along
num_chunks: int
The number of chunks to split into
Returns
-------
tensor_list: List[torch.Tensor]
The split tensors
"""
assert dim < tensor.dim(), f"Error, tensor dimension is {tensor.dim()} which cannot be split along {dim}" assert dim < tensor.dim(), f"Error, tensor dimension is {tensor.dim()} which cannot be split along {dim}"
assert (tensor.shape[dim] >= num_chunks), f"Error, cannot split dim {dim} of size {tensor.shape[dim]} into \ assert (tensor.shape[dim] >= num_chunks), f"Error, cannot split dim {dim} of size {tensor.shape[dim]} into \
......
...@@ -293,11 +293,7 @@ class SphericalLossBase(nn.Module, ABC): ...@@ -293,11 +293,7 @@ class SphericalLossBase(nn.Module, ABC):
class SquaredL2LossS2(SphericalLossBase): class SquaredL2LossS2(SphericalLossBase):
""" """Squared L2 loss for spherical regression tasks."""
Squared L2 loss for spherical regression tasks.
Computes the squared difference between prediction and target tensors.
"""
def _compute_loss_term(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor: def _compute_loss_term(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor:
...@@ -305,11 +301,7 @@ class SquaredL2LossS2(SphericalLossBase): ...@@ -305,11 +301,7 @@ class SquaredL2LossS2(SphericalLossBase):
class L1LossS2(SphericalLossBase): class L1LossS2(SphericalLossBase):
""" """L1 loss for spherical regression tasks."""
L1 loss for spherical regression tasks.
Computes the absolute difference between prediction and target tensors.
"""
def _compute_loss_term(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor: def _compute_loss_term(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor:
...@@ -317,11 +309,7 @@ class L1LossS2(SphericalLossBase): ...@@ -317,11 +309,7 @@ class L1LossS2(SphericalLossBase):
class L2LossS2(SquaredL2LossS2): class L2LossS2(SquaredL2LossS2):
""" """L2 loss for spherical regression tasks."""
L2 loss for spherical regression tasks.
Computes the square root of the squared L2 loss.
"""
def _post_integration_hook(self, loss: torch.Tensor) -> torch.Tensor: def _post_integration_hook(self, loss: torch.Tensor) -> torch.Tensor:
...@@ -329,11 +317,7 @@ class L2LossS2(SquaredL2LossS2): ...@@ -329,11 +317,7 @@ class L2LossS2(SquaredL2LossS2):
class W11LossS2(SphericalLossBase): class W11LossS2(SphericalLossBase):
""" """W11 loss for spherical regression tasks."""
W11 loss for spherical regression tasks.
Computes the L1 norm of the gradient differences between prediction and target.
"""
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular"): def __init__(self, nlat: int, nlon: int, grid: str = "equiangular"):
......
...@@ -117,58 +117,15 @@ class SphereSolver(nn.Module): ...@@ -117,58 +117,15 @@ class SphereSolver(nn.Module):
self.register_buffer('invlap', invlap) self.register_buffer('invlap', invlap)
def grid2spec(self, u): def grid2spec(self, u):
"""
Convert spatial data to spectral coefficients.
Parameters
-----------
u : torch.Tensor
Spatial data tensor
Returns
-------
torch.Tensor
Spectral coefficients
"""
return self.sht(u) return self.sht(u)
def spec2grid(self, uspec): def spec2grid(self, uspec):
""" """Convert spectral coefficients to spatial data."""
Convert spectral coefficients to spatial data.
Parameters
-----------
uspec : torch.Tensor
Spectral coefficients tensor
Returns
-------
torch.Tensor
Spatial data
"""
return self.isht(uspec) return self.isht(uspec)
def dudtspec(self, uspec, pde='allen-cahn'): def dudtspec(self, uspec, pde='allen-cahn'):
""" """Compute the time derivative of spectral coefficients for different PDEs."""
Compute the time derivative of spectral coefficients for different PDEs.
Parameters
-----------
uspec : torch.Tensor
Spectral coefficients
pde : str, optional
PDE type ("allen-cahn", "ginzburg-landau"), by default "allen-cahn"
Returns
-------
torch.Tensor
Time derivative of spectral coefficients
Raises
------
NotImplementedError
If PDE type is not supported
"""
if pde == 'allen-cahn': if pde == 'allen-cahn':
ugrid = self.spec2grid(uspec) ugrid = self.spec2grid(uspec)
u3spec = self.grid2spec(ugrid**3) u3spec = self.grid2spec(ugrid**3)
...@@ -183,14 +140,7 @@ class SphereSolver(nn.Module): ...@@ -183,14 +140,7 @@ class SphereSolver(nn.Module):
return dudtspec return dudtspec
def randspec(self): def randspec(self):
""" """Generate random spectral data on the sphere."""
Generate random spectral data on the sphere.
Returns
-------
torch.Tensor
Random spectral coefficients
"""
rspec = torch.randn_like(self.lap) / 4 / torch.pi rspec = torch.randn_like(self.lap) / 4 / torch.pi
return rspec return rspec
...@@ -273,21 +223,5 @@ class SphereSolver(nn.Module): ...@@ -273,21 +223,5 @@ class SphereSolver(nn.Module):
return im return im
def plot_specdata(self, data, fig, **kwargs): def plot_specdata(self, data, fig, **kwargs):
""" """Plot spectral data by converting to spatial data first."""
Plot spectral data by converting to spatial data first.
Parameters
-----------
data : torch.Tensor
Spectral data to plot
fig : matplotlib.figure.Figure
Figure to plot on
**kwargs
Additional arguments passed to plot_griddata
Returns
-------
matplotlib.collections.QuadMesh
The plotted image object
"""
return self.plot_griddata(self.isht(data), fig, **kwargs) return self.plot_griddata(self.isht(data), fig, **kwargs)
...@@ -704,7 +704,8 @@ class StanfordDepthDataset(Dataset): ...@@ -704,7 +704,8 @@ class StanfordDepthDataset(Dataset):
def compute_stats_s2(dataset: Dataset, normalize_target: bool = False): def compute_stats_s2(dataset: Dataset, normalize_target: bool = False):
""" """
Compute stats using parallel welford reduction and quadrature on the sphere. The parallel welford reduction follows this article (parallel algorithm): https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance Compute stats using parallel welford reduction and quadrature on the sphere.
The parallel welford reduction follows this article (parallel algorithm): https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
""" """
nexamples = len(dataset) nexamples = len(dataset)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment