Commit 61dd6cf1 authored by Boris Bonev's avatar Boris Bonev
Browse files

removing get_psi which got added from merge

parent ba7a4996
......@@ -61,7 +61,7 @@ def _normalize_convolution_tensor_s2(
psi_idx, psi_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=False, basis_norm_mode="mean", merge_quadrature=False, eps=1e-9
):
"""Normalizes convolution tensor values based on specified normalization mode.
This function applies different normalization strategies to the convolution tensor
values based on the basis_norm_mode parameter. It can normalize individual basis
functions, compute mean normalization across all basis functions, or use support
......@@ -143,7 +143,6 @@ def _normalize_convolution_tensor_s2(
# compute the support
support[ik, ilat] = torch.sum(q[iidx])
# loop over values and renormalize
for ik in range(kernel_size):
for ilat in range(nlat_out):
......@@ -166,7 +165,6 @@ def _normalize_convolution_tensor_s2(
if merge_quadrature:
psi_vals[iidx] = psi_vals[iidx] * q[iidx]
if transpose_normalization and merge_quadrature:
psi_vals = psi_vals / correction_factor
......@@ -178,13 +176,13 @@ def _precompute_convolution_tensor_s2(
in_shape: Tuple[int],
out_shape: Tuple[int],
filter_basis: FilterBasis,
grid_in: Optional[str]="equiangular",
grid_out: Optional[str]="equiangular",
theta_cutoff: Optional[float]=0.01 * math.pi,
theta_eps: Optional[float]=1e-3,
transpose_normalization: Optional[bool]=False,
basis_norm_mode: Optional[str]="mean",
merge_quadrature: Optional[bool]=False,
grid_in: Optional[str] = "equiangular",
grid_out: Optional[str] = "equiangular",
theta_cutoff: Optional[float] = 0.01 * math.pi,
theta_eps: Optional[float] = 1e-3,
transpose_normalization: Optional[bool] = False,
basis_norm_mode: Optional[str] = "mean",
merge_quadrature: Optional[bool] = False,
):
"""
Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
......@@ -515,18 +513,6 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
"""
return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()
def get_psi(self):
"""
Get the convolution tensor
Returns
-------
psi: torch.Tensor
Convolution tensor
"""
psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in)).coalesce()
return psi
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.is_cuda and _cuda_extension_available:
......@@ -582,7 +568,7 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
Whether to use bias
theta_cutoff: Optional[float]
Theta cutoff for the filter basis functions
Returns
--------
out: torch.Tensor
......@@ -663,23 +649,8 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
def psi_idx(self):
return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()
def get_psi(self, semi_transposed: bool = False):
if semi_transposed:
# we do a semi-transposition to faciliate the computation
tout = self.psi_idx[2] // self.nlon_out
pout = self.psi_idx[2] % self.nlon_out
# flip the axis of longitudes
pout = self.nlon_out - 1 - pout
tin = self.psi_idx[1]
idx = torch.stack([self.psi_idx[0], tout, tin * self.nlon_out + pout], dim=0)
psi = torch.sparse_coo_tensor(idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_out)).coalesce()
else:
psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out)).coalesce()
return psi
def forward(self, x: torch.Tensor) -> torch.Tensor:
# extract shape
B, C, H, W = x.shape
x = x.reshape(B, self.groups, self.groupsize, H, W)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment