Commit 39298ffe authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

adapting quadrature to sum to 1 in DISCO conv and revertign to 1-norm

parent 60aea808
...@@ -96,8 +96,9 @@ def _normalize_convolution_tensor_s2( ...@@ -96,8 +96,9 @@ def _normalize_convolution_tensor_s2(
# find indices corresponding to the given output latitude and kernel basis function # find indices corresponding to the given output latitude and kernel basis function
iidx = torch.argwhere((ikernel == ik) & (ilat_out == ilat)) iidx = torch.argwhere((ikernel == ik) & (ilat_out == ilat))
# compute the 2-norm, accounting for the fact that it is 4-pi normalized # compute the 1-norm
vnorm[ik, ilat] = torch.sqrt(torch.sum(psi_vals[iidx].abs().pow(2) * q[iidx]) / 4 / torch.pi) # vnorm[ik, ilat] = torch.sqrt(torch.sum(psi_vals[iidx].abs().pow(2) * q[iidx]))
vnorm[ik, ilat] = torch.sum(psi_vals[iidx].abs() * q[iidx])
# loop over values and renormalize # loop over values and renormalize
for ik in range(kernel_size): for ik in range(kernel_size):
...@@ -114,10 +115,10 @@ def _normalize_convolution_tensor_s2( ...@@ -114,10 +115,10 @@ def _normalize_convolution_tensor_s2(
else: else:
raise ValueError(f"Unknown basis normalization mode {basis_norm_mode}.") raise ValueError(f"Unknown basis normalization mode {basis_norm_mode}.")
psi_vals[iidx] = psi_vals[iidx] / (val + eps)
if merge_quadrature: if merge_quadrature:
psi_vals[iidx] = psi_vals[iidx] * q[iidx] / (val + eps) psi_vals[iidx] = psi_vals[iidx] * q[iidx]
else:
psi_vals[iidx] = psi_vals[iidx] / (val + eps)
if transpose_normalization and merge_quadrature: if transpose_normalization and merge_quadrature:
...@@ -171,11 +172,12 @@ def _precompute_convolution_tensor_s2( ...@@ -171,11 +172,12 @@ def _precompute_convolution_tensor_s2(
# It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0 # It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1] lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1]
# compute quadrature weights that will be merged into the Psi tensor # compute quadrature weights and merge them into the convolution tensor.
# These quadrature integrate to 1 over the sphere.
if transpose_normalization: if transpose_normalization:
quad_weights = 2.0 * torch.pi * torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in quad_weights = torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in / 2.0
else: else:
quad_weights = 2.0 * torch.pi * torch.from_numpy(win).float().reshape(-1, 1) / nlon_in quad_weights = torch.from_numpy(win).float().reshape(-1, 1) / nlon_in / 2.0
out_idx = [] out_idx = []
out_vals = [] out_vals = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment