Commit 15d0750c authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

switched psi tensor computation to double precision and implemented a fudge...

switched psi tensor computation to double precision and implemented a fudge factor for theta_cutoff to avoid aliasing issues with the grid width
parent 55bbcb25
......@@ -84,11 +84,11 @@ def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalizati
def _precompute_convolution_tensor_dense(
in_shape,
out_shape,
kernel_shape,
filter_basis,
grid_in="equiangular",
grid_out="equiangular",
theta_cutoff=0.01 * math.pi,
theta_eps=1e-3,
transpose_normalization=False,
basis_norm_mode="none",
merge_quadrature=False,
......@@ -106,21 +106,25 @@ def _precompute_convolution_tensor_dense(
nlat_out, nlon_out = out_shape
lats_in, win = quadrature._precompute_latitudes(nlat_in, grid=grid_in)
lats_in = torch.from_numpy(lats_in).float()
lats_in = torch.from_numpy(lats_in)
lats_out, wout = quadrature._precompute_latitudes(nlat_out, grid=grid_out)
lats_out = torch.from_numpy(lats_out).float() # array for accumulating non-zero indices
lats_out = torch.from_numpy(lats_out)
# compute the phi differences. We need to make the linspace exclusive to not double the last point
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1]
lons_out = torch.linspace(0, 2 * math.pi, nlon_out + 1)[:-1]
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1, dtype=torch.float64)[:-1]
lons_out = torch.linspace(0, 2 * math.pi, nlon_out + 1, dtype=torch.float64)[:-1]
# effective theta cutoff if multiplied with a fudge factor to avoid aliasing with grid width (especially near poles)
theta_cutoff_eff = (1.0 + theta_eps) * theta_cutoff
# compute quadrature weights that will be merged into the Psi tensor
if transpose_normalization:
quad_weights = torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in / 2.0
quad_weights = torch.from_numpy(wout).reshape(-1, 1) / nlon_in / 2.0
else:
quad_weights = torch.from_numpy(win).float().reshape(-1, 1) / nlon_in / 2.0
quad_weights = torch.from_numpy(win).reshape(-1, 1) / nlon_in / 2.0
out = torch.zeros(kernel_size, nlat_out, nlon_out, nlat_in, nlon_in)
# array for accumulating non-zero indices
out = torch.zeros(kernel_size, nlat_out, nlon_out, nlat_in, nlon_in, dtype=torch.float64)
for t in range(nlat_out):
for p in range(nlon_out):
......@@ -147,13 +151,14 @@ def _precompute_convolution_tensor_dense(
phi = torch.where(phi < 0.0, phi + 2 * torch.pi, phi)
# find the indices where the rotated position falls into the support of the kernel
iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff)
iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff_eff)
out[iidx[:, 0], t, p, iidx[:, 1], iidx[:, 2]] = vals
# take care of normalization
# take care of normalization and cast to float
out = _normalize_convolution_tensor_dense(
out, quad_weights=quad_weights, transpose_normalization=transpose_normalization, basis_norm_mode=basis_norm_mode, merge_quadrature=merge_quadrature
)
out = out.to(dtype=torch.float32)
return out
......@@ -239,7 +244,6 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
psi_dense = _precompute_convolution_tensor_dense(
out_shape,
in_shape,
kernel_shape,
filter_basis,
grid_in=grid_out,
grid_out=grid_in,
......@@ -256,7 +260,6 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
psi_dense = _precompute_convolution_tensor_dense(
in_shape,
out_shape,
kernel_shape,
filter_basis,
grid_in=grid_in,
grid_out=grid_out,
......
......@@ -134,6 +134,7 @@ def _precompute_convolution_tensor_s2(
grid_in="equiangular",
grid_out="equiangular",
theta_cutoff=0.01 * math.pi,
theta_eps = 1e-3,
transpose_normalization=False,
basis_norm_mode="mean",
merge_quadrature=False,
......@@ -164,20 +165,23 @@ def _precompute_convolution_tensor_s2(
# precompute input and output grids
lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in)
lats_in = torch.from_numpy(lats_in).float()
lats_in = torch.from_numpy(lats_in)
lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out)
lats_out = torch.from_numpy(lats_out).float()
lats_out = torch.from_numpy(lats_out)
# compute the phi differences
# It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1]
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1, dtype=torch.float64)[:-1]
# compute quadrature weights and merge them into the convolution tensor.
# These quadrature integrate to 1 over the sphere.
if transpose_normalization:
quad_weights = torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in / 2.0
quad_weights = torch.from_numpy(wout).reshape(-1, 1) / nlon_in / 2.0
else:
quad_weights = torch.from_numpy(win).float().reshape(-1, 1) / nlon_in / 2.0
quad_weights = torch.from_numpy(win).reshape(-1, 1) / nlon_in / 2.0
# effective theta cutoff if multiplied with a fudge factor to avoid aliasing with grid width (especially near poles)
theta_cutoff_eff = (1.0 + theta_eps) * theta_cutoff
out_idx = []
out_vals = []
......@@ -207,7 +211,7 @@ def _precompute_convolution_tensor_s2(
phi = torch.where(phi < 0.0, phi + 2 * torch.pi, phi)
# find the indices where the rotated position falls into the support of the kernel
iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff)
iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff_eff)
# add the output latitude and reshape such that psi has dimensions kernel_shape x nlat_out x (nlat_in*nlon_in)
idx = torch.stack([iidx[:, 0], t * torch.ones_like(iidx[:, 0]), iidx[:, 1] * nlon_in + iidx[:, 2]], dim=0)
......@@ -217,8 +221,8 @@ def _precompute_convolution_tensor_s2(
out_vals.append(vals)
# concatenate the indices and values
out_idx = torch.cat(out_idx, dim=-1).to(torch.long).contiguous()
out_vals = torch.cat(out_vals, dim=-1).to(torch.float32).contiguous()
out_idx = torch.cat(out_idx, dim=-1)
out_vals = torch.cat(out_vals, dim=-1)
out_vals = _normalize_convolution_tensor_s2(
out_idx,
......@@ -232,6 +236,9 @@ def _precompute_convolution_tensor_s2(
merge_quadrature=merge_quadrature,
)
out_idx = out_idx.contiguous()
out_vals = out_vals.to(dtype=torch.float32).contiguous()
return out_idx, out_vals
......
......@@ -75,6 +75,7 @@ def _precompute_distributed_convolution_tensor_s2(
grid_in="equiangular",
grid_out="equiangular",
theta_cutoff=0.01 * math.pi,
theta_eps = 1e-3,
transpose_normalization=False,
basis_norm_mode="mean",
merge_quadrature=False,
......@@ -103,21 +104,25 @@ def _precompute_distributed_convolution_tensor_s2(
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
# precompute input and output grids
lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in)
lats_in = torch.from_numpy(lats_in).float()
lats_in = torch.from_numpy(lats_in)
lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out)
lats_out = torch.from_numpy(lats_out).float()
lats_out = torch.from_numpy(lats_out)
# compute the phi differences
# It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1]
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1, dtype=torch.float64)[:-1]
# compute quadrature weights and merge them into the convolution tensor.
# These quadrature integrate to 1 over the sphere.
if transpose_normalization:
quad_weights = torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in / 2.0
quad_weights = torch.from_numpy(wout).reshape(-1, 1) / nlon_in / 2.0
else:
quad_weights = torch.from_numpy(win).float().reshape(-1, 1) / nlon_in / 2.0
quad_weights = torch.from_numpy(win).reshape(-1, 1) / nlon_in / 2.0
# effective theta cutoff if multiplied with a fudge factor to avoid aliasing with grid width (especially near poles)
theta_cutoff_eff = (1.0 + theta_eps) * theta_cutoff
out_idx = []
out_vals = []
......@@ -147,7 +152,7 @@ def _precompute_distributed_convolution_tensor_s2(
phi = torch.where(phi < 0.0, phi + 2 * torch.pi, phi)
# find the indices where the rotated position falls into the support of the kernel
iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff)
iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff_eff)
# add the output latitude and reshape such that psi has dimensions kernel_shape x nlat_out x (nlat_in*nlon_in)
idx = torch.stack([iidx[:, 0], t * torch.ones_like(iidx[:, 0]), iidx[:, 1] * nlon_in + iidx[:, 2]], dim=0)
......@@ -157,8 +162,8 @@ def _precompute_distributed_convolution_tensor_s2(
out_vals.append(vals)
# concatenate the indices and values
out_idx = torch.cat(out_idx, dim=-1).to(torch.long).contiguous()
out_vals = torch.cat(out_vals, dim=-1).to(torch.float32).contiguous()
out_idx = torch.cat(out_idx, dim=-1)
out_vals = torch.cat(out_vals, dim=-1)
out_vals = _normalize_convolution_tensor_s2(
out_idx,
......@@ -189,6 +194,9 @@ def _precompute_distributed_convolution_tensor_s2(
# for the indices we need to recompute them to refer to local indices of the input tenor
out_idx = torch.stack([out_idx[0, ilats], out_idx[1, ilats], (lats[ilats] - start_idx) * nlon_in + lons[ilats]], dim=0)
out_idx = out_idx.contiguous()
out_vals = out_vals.to(dtype=torch.float32).contiguous()
return out_idx, out_vals
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment