Commit ec53e666 authored by Andrea Paris's avatar Andrea Paris Committed by Boris Bonev
Browse files

further cleanup

parent 1ef713bb
......@@ -42,35 +42,7 @@ except ImportError as err:
# some helper functions
def _get_psi(kernel_size: int, psi_idx: torch.Tensor, psi_vals: torch.Tensor, nlat_in: int, nlon_in: int, nlat_out: int, nlon_out: int, nlat_in_local: Optional[int] = None, nlat_out_local: Optional[int] = None, semi_transposed: Optional[bool] = False):
"""Creates a sparse tensor for spherical harmonic convolution operations.
This function constructs a sparse COO tensor from indices and values, with optional
semi-transposition for computational efficiency in spherical harmonic convolutions.
Args:
kernel_size: Number of kernel elements.
psi_idx: Tensor of shape (3, n_nonzero) containing the indices for the sparse tensor.
The three dimensions represent [kernel_idx, lat_idx, combined_lat_lon_idx].
psi_vals: Tensor of shape (n_nonzero,) containing the values for the sparse tensor.
nlat_in: Number of input latitude points.
nlon_in: Number of input longitude points.
nlat_out: Number of output latitude points.
nlon_out: Number of output longitude points.
nlat_in_local: Local number of input latitude points. If None, defaults to nlat_in.
nlat_out_local: Local number of output latitude points. If None, defaults to nlat_out.
semi_transposed: If True, performs a semi-transposition to facilitate computation
by flipping the longitude axis and reorganizing indices.
Returns:
torch.Tensor: A sparse COO tensor of shape (kernel_size, nlat_out_local, nlat_in_local * nlon)
where nlon is either nlon_in or nlon_out depending on semi_transposed flag.
The tensor is coalesced to remove duplicate indices.
Note:
When semi_transposed=True, the function performs a partial transpose operation
that flips the longitude axis and reorganizes the indices to facilitate
efficient spherical harmonic convolution computations.
"""
"""Creates a sparse tensor for spherical harmonic convolution operations."""
nlat_in_local = nlat_in_local if nlat_in_local is not None else nlat_in
nlat_out_local = nlat_out_local if nlat_out_local is not None else nlat_out
......@@ -90,7 +62,6 @@ def _get_psi(kernel_size: int, psi_idx: torch.Tensor, psi_vals: torch.Tensor, nl
class _DiscoS2ContractionCuda(torch.autograd.Function):
@staticmethod
@custom_fwd(device_type="cuda")
def forward(ctx, x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor,
......@@ -123,7 +94,6 @@ class _DiscoS2ContractionCuda(torch.autograd.Function):
class _DiscoS2TransposeContractionCuda(torch.autograd.Function):
@staticmethod
@custom_fwd(device_type="cuda")
def forward(ctx, x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor,
......@@ -169,7 +139,6 @@ def _disco_s2_transpose_contraction_cuda(x: torch.Tensor, roff_idx: torch.Tensor
def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: int):
assert len(psi.shape) == 3
assert len(x.shape) == 4
psi = psi.to(x.device)
......
......@@ -142,9 +142,6 @@ class AttentionS2(nn.Module):
def extra_repr(self):
r"""
Pretty print module
"""
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_channels={self.in_channels}, out_channels={self.out_channels}, k_channels={self.k_channels}"
def forward(self, query: torch.Tensor, key: Optional[torch.Tensor] = None, value: Optional[torch.Tensor] = None) -> torch.Tensor:
......@@ -317,9 +314,6 @@ class NeighborhoodAttentionS2(nn.Module):
self.proj_bias = None
def extra_repr(self):
r"""
Pretty print module
"""
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_channels={self.in_channels}, out_channels={self.out_channels}, k_channels={self.k_channels}"
def forward(self, query: torch.Tensor, key: Optional[torch.Tensor] = None, value: Optional[torch.Tensor] = None) -> torch.Tensor:
......
......@@ -501,9 +501,6 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
self.psi = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out)
def extra_repr(self):
r"""
Pretty print module
"""
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
@property
......@@ -660,9 +657,6 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
self.psi_st = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, semi_transposed=True)
def extra_repr(self):
r"""
Pretty print module
"""
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
@property
......
......@@ -118,9 +118,6 @@ class RealSHT(nn.Module):
self.register_buffer("weights", weights, persistent=False)
def extra_repr(self):
r"""
Pretty print module
"""
return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}"
def forward(self, x: torch.Tensor):
......@@ -223,9 +220,6 @@ class InverseRealSHT(nn.Module):
self.register_buffer("pct", pct, persistent=False)
def extra_repr(self):
r"""
Pretty print module
"""
return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}"
def forward(self, x: torch.Tensor):
......@@ -332,9 +326,6 @@ class RealVectorSHT(nn.Module):
self.register_buffer("weights", weights, persistent=False)
def extra_repr(self):
r"""
Pretty print module
"""
return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}"
def forward(self, x: torch.Tensor):
......@@ -449,9 +440,6 @@ class InverseRealVectorSHT(nn.Module):
self.register_buffer("dpct", dpct, persistent=False)
def extra_repr(self):
r"""
Pretty print module
"""
return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}"
def forward(self, x: torch.Tensor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment