Commit 61dd6cf1 authored by Boris Bonev's avatar Boris Bonev
Browse files

removing get_psi which got added from merge

parent ba7a4996
...@@ -61,7 +61,7 @@ def _normalize_convolution_tensor_s2( ...@@ -61,7 +61,7 @@ def _normalize_convolution_tensor_s2(
psi_idx, psi_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=False, basis_norm_mode="mean", merge_quadrature=False, eps=1e-9 psi_idx, psi_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=False, basis_norm_mode="mean", merge_quadrature=False, eps=1e-9
): ):
"""Normalizes convolution tensor values based on specified normalization mode. """Normalizes convolution tensor values based on specified normalization mode.
This function applies different normalization strategies to the convolution tensor This function applies different normalization strategies to the convolution tensor
values based on the basis_norm_mode parameter. It can normalize individual basis values based on the basis_norm_mode parameter. It can normalize individual basis
functions, compute mean normalization across all basis functions, or use support functions, compute mean normalization across all basis functions, or use support
...@@ -143,7 +143,6 @@ def _normalize_convolution_tensor_s2( ...@@ -143,7 +143,6 @@ def _normalize_convolution_tensor_s2(
# compute the support # compute the support
support[ik, ilat] = torch.sum(q[iidx]) support[ik, ilat] = torch.sum(q[iidx])
# loop over values and renormalize # loop over values and renormalize
for ik in range(kernel_size): for ik in range(kernel_size):
for ilat in range(nlat_out): for ilat in range(nlat_out):
...@@ -166,7 +165,6 @@ def _normalize_convolution_tensor_s2( ...@@ -166,7 +165,6 @@ def _normalize_convolution_tensor_s2(
if merge_quadrature: if merge_quadrature:
psi_vals[iidx] = psi_vals[iidx] * q[iidx] psi_vals[iidx] = psi_vals[iidx] * q[iidx]
if transpose_normalization and merge_quadrature: if transpose_normalization and merge_quadrature:
psi_vals = psi_vals / correction_factor psi_vals = psi_vals / correction_factor
...@@ -178,13 +176,13 @@ def _precompute_convolution_tensor_s2( ...@@ -178,13 +176,13 @@ def _precompute_convolution_tensor_s2(
in_shape: Tuple[int], in_shape: Tuple[int],
out_shape: Tuple[int], out_shape: Tuple[int],
filter_basis: FilterBasis, filter_basis: FilterBasis,
grid_in: Optional[str]="equiangular", grid_in: Optional[str] = "equiangular",
grid_out: Optional[str]="equiangular", grid_out: Optional[str] = "equiangular",
theta_cutoff: Optional[float]=0.01 * math.pi, theta_cutoff: Optional[float] = 0.01 * math.pi,
theta_eps: Optional[float]=1e-3, theta_eps: Optional[float] = 1e-3,
transpose_normalization: Optional[bool]=False, transpose_normalization: Optional[bool] = False,
basis_norm_mode: Optional[str]="mean", basis_norm_mode: Optional[str] = "mean",
merge_quadrature: Optional[bool]=False, merge_quadrature: Optional[bool] = False,
): ):
""" """
Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$. Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
...@@ -515,18 +513,6 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv): ...@@ -515,18 +513,6 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
""" """
return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous() return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()
def get_psi(self):
"""
Get the convolution tensor
Returns
-------
psi: torch.Tensor
Convolution tensor
"""
psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in)).coalesce()
return psi
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.is_cuda and _cuda_extension_available: if x.is_cuda and _cuda_extension_available:
...@@ -582,7 +568,7 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -582,7 +568,7 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
Whether to use bias Whether to use bias
theta_cutoff: Optional[float] theta_cutoff: Optional[float]
Theta cutoff for the filter basis functions Theta cutoff for the filter basis functions
Returns Returns
-------- --------
out: torch.Tensor out: torch.Tensor
...@@ -663,23 +649,8 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -663,23 +649,8 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
def psi_idx(self): def psi_idx(self):
return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous() return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()
def get_psi(self, semi_transposed: bool = False):
if semi_transposed:
# we do a semi-transposition to faciliate the computation
tout = self.psi_idx[2] // self.nlon_out
pout = self.psi_idx[2] % self.nlon_out
# flip the axis of longitudes
pout = self.nlon_out - 1 - pout
tin = self.psi_idx[1]
idx = torch.stack([self.psi_idx[0], tout, tin * self.nlon_out + pout], dim=0)
psi = torch.sparse_coo_tensor(idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_out)).coalesce()
else:
psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out)).coalesce()
return psi
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
# extract shape # extract shape
B, C, H, W = x.shape B, C, H, W = x.shape
x = x.reshape(B, self.groups, self.groupsize, H, W) x = x.reshape(B, self.groups, self.groupsize, H, W)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment