Commit 856a0f18 authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

reverting tests to 1-norm

parent 39298ffe
...@@ -41,76 +41,6 @@ from torch_harmonics import * ...@@ -41,76 +41,6 @@ from torch_harmonics import *
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes
# def _compute_vals_isotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, r_cutoff: float):
# """
# helper routine to compute the values of the isotropic kernel densely
# """
# kernel_size = (nr // 2) + nr % 2
# ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
# dr = 2 * r_cutoff / (nr + 1)
# # compute the support
# if nr % 2 == 1:
# ir = ikernel * dr
# else:
# ir = (ikernel + 0.5) * dr
# vals = torch.where(
# ((r - ir).abs() <= dr) & (r <= r_cutoff),
# (1 - (r - ir).abs() / dr),
# 0,
# )
# return vals
# def _compute_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, nphi: int, r_cutoff: float):
# """
# helper routine to compute the values of the anisotropic kernel densely
# """
# kernel_size = (nr // 2) * nphi + nr % 2
# ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
# dr = 2 * r_cutoff / (nr + 1)
# dphi = 2.0 * math.pi / nphi
# # disambiguate even and uneven cases and compute the support
# if nr % 2 == 1:
# ir = ((ikernel - 1) // nphi + 1) * dr
# iphi = ((ikernel - 1) % nphi) * dphi
# else:
# ir = (ikernel // nphi + 0.5) * dr
# iphi = (ikernel % nphi) * dphi
# # compute the value of the filter
# if nr % 2 == 1:
# # find the indices where the rotated position falls into the support of the kernel
# cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
# cond_phi = ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi)
# r_vals = torch.where(cond_r, (1 - (r - ir).abs() / dr), 0.0)
# phi_vals = torch.where(cond_phi, (1 - torch.minimum((phi - iphi).abs(), (2 * math.pi - (phi - iphi).abs())) / dphi), 0.0)
# vals = torch.where(ikernel > 0, r_vals * phi_vals, r_vals)
# else:
# # find the indices where the rotated position falls into the support of the kernel
# cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
# cond_phi = ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi)
# r_vals = torch.where(cond_r, (1 - (r - ir).abs() / dr), 0.0)
# phi_vals = torch.where(cond_phi, (1 - torch.minimum((phi - iphi).abs(), (2 * math.pi - (phi - iphi).abs())) / dphi), 0.0)
# vals = r_vals * phi_vals
# # in the even case, the inner casis functions overlap into areas with a negative areas
# rn = -r
# phin = torch.where(phi + math.pi >= 2 * math.pi, phi - math.pi, phi + math.pi)
# cond_rn = ((rn - ir).abs() <= dr) & (rn <= r_cutoff)
# cond_phin = ((phin - iphi).abs() <= dphi) | ((2 * math.pi - (phin - iphi).abs()) <= dphi)
# rn_vals = torch.where(cond_rn, (1 - (rn - ir).abs() / dr), 0.0)
# phin_vals = torch.where(cond_phin, (1 - torch.minimum((phin - iphi).abs(), (2 * math.pi - (phin - iphi).abs())) / dphi), 0.0)
# vals += rn_vals * phin_vals
# return vals
def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalization=False, basis_norm_mode="none", merge_quadrature=False, eps=1e-9): def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalization=False, basis_norm_mode="none", merge_quadrature=False, eps=1e-9):
""" """
Discretely normalizes the convolution tensor. Discretely normalizes the convolution tensor.
...@@ -123,18 +53,18 @@ def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalizati ...@@ -123,18 +53,18 @@ def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalizati
if transpose_normalization: if transpose_normalization:
# the normalization is not quite symmetric due to the compressed way psi is stored in the main code # the normalization is not quite symmetric due to the compressed way psi is stored in the main code
# look at the normalization code in the actual implementation # look at the normalization code in the actual implementation
psi_norm = torch.sqrt(torch.sum(quad_weights.reshape(1, -1, 1, 1, 1) * psi[:, :, :1].abs().pow(2), dim=(1, 4), keepdim=True) / 4 / math.pi) psi_norm = torch.sum(quad_weights.reshape(1, -1, 1, 1, 1) * psi[:, :, :1].abs(), dim=(1, 4), keepdim=True)
else: else:
psi_norm = torch.sqrt(torch.sum(quad_weights.reshape(1, 1, 1, -1, 1) * psi.abs().pow(2), dim=(3, 4), keepdim=True) / 4 / math.pi) psi_norm = torch.sum(quad_weights.reshape(1, 1, 1, -1, 1) * psi.abs(), dim=(3, 4), keepdim=True)
elif basis_norm_mode == "mean": elif basis_norm_mode == "mean":
if transpose_normalization: if transpose_normalization:
# the normalization is not quite symmetric due to the compressed way psi is stored in the main code # the normalization is not quite symmetric due to the compressed way psi is stored in the main code
# look at the normalization code in the actual implementation # look at the normalization code in the actual implementation
psi_norm = torch.sqrt(torch.sum(quad_weights.reshape(1, -1, 1, 1, 1) * psi[:, :, :1].abs().pow(2), dim=(1, 4), keepdim=True) / 4 / math.pi) psi_norm = torch.sum(quad_weights.reshape(1, -1, 1, 1, 1) * psi[:, :, :1].abs(), dim=(1, 4), keepdim=True)
psi_norm = psi_norm.mean(dim=3, keepdim=True) psi_norm = psi_norm.mean(dim=3, keepdim=True)
else: else:
psi_norm = torch.sqrt(torch.sum(quad_weights.reshape(1, 1, 1, -1, 1) * psi.abs().pow(2), dim=(3, 4), keepdim=True) / 4 / math.pi) psi_norm = torch.sum(quad_weights.reshape(1, 1, 1, -1, 1) * psi.abs(), dim=(3, 4), keepdim=True)
psi_norm = psi_norm.mean(dim=1, keepdim=True) psi_norm = psi_norm.mean(dim=1, keepdim=True)
elif basis_norm_mode == "none": elif basis_norm_mode == "none":
psi_norm = 1.0 psi_norm = 1.0
...@@ -186,9 +116,9 @@ def _precompute_convolution_tensor_dense( ...@@ -186,9 +116,9 @@ def _precompute_convolution_tensor_dense(
# compute quadrature weights that will be merged into the Psi tensor # compute quadrature weights that will be merged into the Psi tensor
if transpose_normalization: if transpose_normalization:
quad_weights = 2.0 * torch.pi * torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in quad_weights = torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in / 2.0
else: else:
quad_weights = 2.0 * torch.pi * torch.from_numpy(win).float().reshape(-1, 1) / nlon_in quad_weights = torch.from_numpy(win).float().reshape(-1, 1) / nlon_in / 2.0
out = torch.zeros(kernel_size, nlat_out, nlon_out, nlat_in, nlon_in) out = torch.zeros(kernel_size, nlat_out, nlon_out, nlat_in, nlon_in)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment