Commit ec53e666 authored by Andrea Paris's avatar Andrea Paris Committed by Boris Bonev
Browse files

further cleanup

parent 1ef713bb
...@@ -42,35 +42,7 @@ except ImportError as err: ...@@ -42,35 +42,7 @@ except ImportError as err:
# some helper functions # some helper functions
def _get_psi(kernel_size: int, psi_idx: torch.Tensor, psi_vals: torch.Tensor, nlat_in: int, nlon_in: int, nlat_out: int, nlon_out: int, nlat_in_local: Optional[int] = None, nlat_out_local: Optional[int] = None, semi_transposed: Optional[bool] = False): def _get_psi(kernel_size: int, psi_idx: torch.Tensor, psi_vals: torch.Tensor, nlat_in: int, nlon_in: int, nlat_out: int, nlon_out: int, nlat_in_local: Optional[int] = None, nlat_out_local: Optional[int] = None, semi_transposed: Optional[bool] = False):
"""Creates a sparse tensor for spherical harmonic convolution operations. """Creates a sparse tensor for spherical harmonic convolution operations."""
This function constructs a sparse COO tensor from indices and values, with optional
semi-transposition for computational efficiency in spherical harmonic convolutions.
Args:
kernel_size: Number of kernel elements.
psi_idx: Tensor of shape (3, n_nonzero) containing the indices for the sparse tensor.
The three dimensions represent [kernel_idx, lat_idx, combined_lat_lon_idx].
psi_vals: Tensor of shape (n_nonzero,) containing the values for the sparse tensor.
nlat_in: Number of input latitude points.
nlon_in: Number of input longitude points.
nlat_out: Number of output latitude points.
nlon_out: Number of output longitude points.
nlat_in_local: Local number of input latitude points. If None, defaults to nlat_in.
nlat_out_local: Local number of output latitude points. If None, defaults to nlat_out.
semi_transposed: If True, performs a semi-transposition to facilitate computation
by flipping the longitude axis and reorganizing indices.
Returns:
torch.Tensor: A sparse COO tensor of shape (kernel_size, nlat_out_local, nlat_in_local * nlon)
where nlon is either nlon_in or nlon_out depending on semi_transposed flag.
The tensor is coalesced to remove duplicate indices.
Note:
When semi_transposed=True, the function performs a partial transpose operation
that flips the longitude axis and reorganizes the indices to facilitate
efficient spherical harmonic convolution computations.
"""
nlat_in_local = nlat_in_local if nlat_in_local is not None else nlat_in nlat_in_local = nlat_in_local if nlat_in_local is not None else nlat_in
nlat_out_local = nlat_out_local if nlat_out_local is not None else nlat_out nlat_out_local = nlat_out_local if nlat_out_local is not None else nlat_out
...@@ -90,7 +62,6 @@ def _get_psi(kernel_size: int, psi_idx: torch.Tensor, psi_vals: torch.Tensor, nl ...@@ -90,7 +62,6 @@ def _get_psi(kernel_size: int, psi_idx: torch.Tensor, psi_vals: torch.Tensor, nl
class _DiscoS2ContractionCuda(torch.autograd.Function): class _DiscoS2ContractionCuda(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(device_type="cuda") @custom_fwd(device_type="cuda")
def forward(ctx, x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, def forward(ctx, x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor,
...@@ -123,7 +94,6 @@ class _DiscoS2ContractionCuda(torch.autograd.Function): ...@@ -123,7 +94,6 @@ class _DiscoS2ContractionCuda(torch.autograd.Function):
class _DiscoS2TransposeContractionCuda(torch.autograd.Function): class _DiscoS2TransposeContractionCuda(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(device_type="cuda") @custom_fwd(device_type="cuda")
def forward(ctx, x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, def forward(ctx, x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor,
...@@ -169,7 +139,6 @@ def _disco_s2_transpose_contraction_cuda(x: torch.Tensor, roff_idx: torch.Tensor ...@@ -169,7 +139,6 @@ def _disco_s2_transpose_contraction_cuda(x: torch.Tensor, roff_idx: torch.Tensor
def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: int): def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: int):
assert len(psi.shape) == 3 assert len(psi.shape) == 3
assert len(x.shape) == 4 assert len(x.shape) == 4
psi = psi.to(x.device) psi = psi.to(x.device)
......
...@@ -142,9 +142,6 @@ class AttentionS2(nn.Module): ...@@ -142,9 +142,6 @@ class AttentionS2(nn.Module):
def extra_repr(self): def extra_repr(self):
r"""
Pretty print module
"""
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_channels={self.in_channels}, out_channels={self.out_channels}, k_channels={self.k_channels}" return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_channels={self.in_channels}, out_channels={self.out_channels}, k_channels={self.k_channels}"
def forward(self, query: torch.Tensor, key: Optional[torch.Tensor] = None, value: Optional[torch.Tensor] = None) -> torch.Tensor: def forward(self, query: torch.Tensor, key: Optional[torch.Tensor] = None, value: Optional[torch.Tensor] = None) -> torch.Tensor:
...@@ -317,9 +314,6 @@ class NeighborhoodAttentionS2(nn.Module): ...@@ -317,9 +314,6 @@ class NeighborhoodAttentionS2(nn.Module):
self.proj_bias = None self.proj_bias = None
def extra_repr(self): def extra_repr(self):
r"""
Pretty print module
"""
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_channels={self.in_channels}, out_channels={self.out_channels}, k_channels={self.k_channels}" return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_channels={self.in_channels}, out_channels={self.out_channels}, k_channels={self.k_channels}"
def forward(self, query: torch.Tensor, key: Optional[torch.Tensor] = None, value: Optional[torch.Tensor] = None) -> torch.Tensor: def forward(self, query: torch.Tensor, key: Optional[torch.Tensor] = None, value: Optional[torch.Tensor] = None) -> torch.Tensor:
......
...@@ -501,9 +501,6 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv): ...@@ -501,9 +501,6 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
self.psi = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out) self.psi = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out)
def extra_repr(self): def extra_repr(self):
r"""
Pretty print module
"""
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}" return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
@property @property
...@@ -660,9 +657,6 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -660,9 +657,6 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
self.psi_st = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, semi_transposed=True) self.psi_st = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, semi_transposed=True)
def extra_repr(self): def extra_repr(self):
r"""
Pretty print module
"""
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}" return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
@property @property
......
...@@ -118,9 +118,6 @@ class RealSHT(nn.Module): ...@@ -118,9 +118,6 @@ class RealSHT(nn.Module):
self.register_buffer("weights", weights, persistent=False) self.register_buffer("weights", weights, persistent=False)
def extra_repr(self): def extra_repr(self):
r"""
Pretty print module
"""
return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}" return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}"
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
...@@ -223,9 +220,6 @@ class InverseRealSHT(nn.Module): ...@@ -223,9 +220,6 @@ class InverseRealSHT(nn.Module):
self.register_buffer("pct", pct, persistent=False) self.register_buffer("pct", pct, persistent=False)
def extra_repr(self): def extra_repr(self):
r"""
Pretty print module
"""
return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}" return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}"
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
...@@ -332,9 +326,6 @@ class RealVectorSHT(nn.Module): ...@@ -332,9 +326,6 @@ class RealVectorSHT(nn.Module):
self.register_buffer("weights", weights, persistent=False) self.register_buffer("weights", weights, persistent=False)
def extra_repr(self): def extra_repr(self):
r"""
Pretty print module
"""
return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}" return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}"
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
...@@ -449,9 +440,6 @@ class InverseRealVectorSHT(nn.Module): ...@@ -449,9 +440,6 @@ class InverseRealVectorSHT(nn.Module):
self.register_buffer("dpct", dpct, persistent=False) self.register_buffer("dpct", dpct, persistent=False)
def extra_repr(self): def extra_repr(self):
r"""
Pretty print module
"""
return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}" return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}"
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment