Commit 15d0750c authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

switched psi tensor computation to double precision and implemented a fudge...

switched psi tensor computation to double precision and implemented a fudge factor for theta_cutoff to avoid aliasing issues with the grid width
parent 55bbcb25
...@@ -84,11 +84,11 @@ def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalizati ...@@ -84,11 +84,11 @@ def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalizati
def _precompute_convolution_tensor_dense( def _precompute_convolution_tensor_dense(
in_shape, in_shape,
out_shape, out_shape,
kernel_shape,
filter_basis, filter_basis,
grid_in="equiangular", grid_in="equiangular",
grid_out="equiangular", grid_out="equiangular",
theta_cutoff=0.01 * math.pi, theta_cutoff=0.01 * math.pi,
theta_eps=1e-3,
transpose_normalization=False, transpose_normalization=False,
basis_norm_mode="none", basis_norm_mode="none",
merge_quadrature=False, merge_quadrature=False,
...@@ -106,21 +106,25 @@ def _precompute_convolution_tensor_dense( ...@@ -106,21 +106,25 @@ def _precompute_convolution_tensor_dense(
nlat_out, nlon_out = out_shape nlat_out, nlon_out = out_shape
lats_in, win = quadrature._precompute_latitudes(nlat_in, grid=grid_in) lats_in, win = quadrature._precompute_latitudes(nlat_in, grid=grid_in)
lats_in = torch.from_numpy(lats_in).float() lats_in = torch.from_numpy(lats_in)
lats_out, wout = quadrature._precompute_latitudes(nlat_out, grid=grid_out) lats_out, wout = quadrature._precompute_latitudes(nlat_out, grid=grid_out)
lats_out = torch.from_numpy(lats_out).float() # array for accumulating non-zero indices lats_out = torch.from_numpy(lats_out)
# compute the phi differences. We need to make the linspace exclusive to not double the last point # compute the phi differences. We need to make the linspace exclusive to not double the last point
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1] lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1, dtype=torch.float64)[:-1]
lons_out = torch.linspace(0, 2 * math.pi, nlon_out + 1)[:-1] lons_out = torch.linspace(0, 2 * math.pi, nlon_out + 1, dtype=torch.float64)[:-1]
# effective theta cutoff if multiplied with a fudge factor to avoid aliasing with grid width (especially near poles)
theta_cutoff_eff = (1.0 + theta_eps) * theta_cutoff
# compute quadrature weights that will be merged into the Psi tensor # compute quadrature weights that will be merged into the Psi tensor
if transpose_normalization: if transpose_normalization:
quad_weights = torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in / 2.0 quad_weights = torch.from_numpy(wout).reshape(-1, 1) / nlon_in / 2.0
else: else:
quad_weights = torch.from_numpy(win).float().reshape(-1, 1) / nlon_in / 2.0 quad_weights = torch.from_numpy(win).reshape(-1, 1) / nlon_in / 2.0
out = torch.zeros(kernel_size, nlat_out, nlon_out, nlat_in, nlon_in) # array for accumulating non-zero indices
out = torch.zeros(kernel_size, nlat_out, nlon_out, nlat_in, nlon_in, dtype=torch.float64)
for t in range(nlat_out): for t in range(nlat_out):
for p in range(nlon_out): for p in range(nlon_out):
...@@ -147,13 +151,14 @@ def _precompute_convolution_tensor_dense( ...@@ -147,13 +151,14 @@ def _precompute_convolution_tensor_dense(
phi = torch.where(phi < 0.0, phi + 2 * torch.pi, phi) phi = torch.where(phi < 0.0, phi + 2 * torch.pi, phi)
# find the indices where the rotated position falls into the support of the kernel # find the indices where the rotated position falls into the support of the kernel
iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff) iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff_eff)
out[iidx[:, 0], t, p, iidx[:, 1], iidx[:, 2]] = vals out[iidx[:, 0], t, p, iidx[:, 1], iidx[:, 2]] = vals
# take care of normalization # take care of normalization and cast to float
out = _normalize_convolution_tensor_dense( out = _normalize_convolution_tensor_dense(
out, quad_weights=quad_weights, transpose_normalization=transpose_normalization, basis_norm_mode=basis_norm_mode, merge_quadrature=merge_quadrature out, quad_weights=quad_weights, transpose_normalization=transpose_normalization, basis_norm_mode=basis_norm_mode, merge_quadrature=merge_quadrature
) )
out = out.to(dtype=torch.float32)
return out return out
...@@ -239,7 +244,6 @@ class TestDiscreteContinuousConvolution(unittest.TestCase): ...@@ -239,7 +244,6 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
psi_dense = _precompute_convolution_tensor_dense( psi_dense = _precompute_convolution_tensor_dense(
out_shape, out_shape,
in_shape, in_shape,
kernel_shape,
filter_basis, filter_basis,
grid_in=grid_out, grid_in=grid_out,
grid_out=grid_in, grid_out=grid_in,
...@@ -256,7 +260,6 @@ class TestDiscreteContinuousConvolution(unittest.TestCase): ...@@ -256,7 +260,6 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
psi_dense = _precompute_convolution_tensor_dense( psi_dense = _precompute_convolution_tensor_dense(
in_shape, in_shape,
out_shape, out_shape,
kernel_shape,
filter_basis, filter_basis,
grid_in=grid_in, grid_in=grid_in,
grid_out=grid_out, grid_out=grid_out,
......
...@@ -134,6 +134,7 @@ def _precompute_convolution_tensor_s2( ...@@ -134,6 +134,7 @@ def _precompute_convolution_tensor_s2(
grid_in="equiangular", grid_in="equiangular",
grid_out="equiangular", grid_out="equiangular",
theta_cutoff=0.01 * math.pi, theta_cutoff=0.01 * math.pi,
theta_eps = 1e-3,
transpose_normalization=False, transpose_normalization=False,
basis_norm_mode="mean", basis_norm_mode="mean",
merge_quadrature=False, merge_quadrature=False,
...@@ -164,20 +165,23 @@ def _precompute_convolution_tensor_s2( ...@@ -164,20 +165,23 @@ def _precompute_convolution_tensor_s2(
# precompute input and output grids # precompute input and output grids
lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in) lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in)
lats_in = torch.from_numpy(lats_in).float() lats_in = torch.from_numpy(lats_in)
lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out) lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out)
lats_out = torch.from_numpy(lats_out).float() lats_out = torch.from_numpy(lats_out)
# compute the phi differences # compute the phi differences
# It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0 # It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1] lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1, dtype=torch.float64)[:-1]
# compute quadrature weights and merge them into the convolution tensor. # compute quadrature weights and merge them into the convolution tensor.
# These quadrature integrate to 1 over the sphere. # These quadrature integrate to 1 over the sphere.
if transpose_normalization: if transpose_normalization:
quad_weights = torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in / 2.0 quad_weights = torch.from_numpy(wout).reshape(-1, 1) / nlon_in / 2.0
else: else:
quad_weights = torch.from_numpy(win).float().reshape(-1, 1) / nlon_in / 2.0 quad_weights = torch.from_numpy(win).reshape(-1, 1) / nlon_in / 2.0
# effective theta cutoff if multiplied with a fudge factor to avoid aliasing with grid width (especially near poles)
theta_cutoff_eff = (1.0 + theta_eps) * theta_cutoff
out_idx = [] out_idx = []
out_vals = [] out_vals = []
...@@ -207,7 +211,7 @@ def _precompute_convolution_tensor_s2( ...@@ -207,7 +211,7 @@ def _precompute_convolution_tensor_s2(
phi = torch.where(phi < 0.0, phi + 2 * torch.pi, phi) phi = torch.where(phi < 0.0, phi + 2 * torch.pi, phi)
# find the indices where the rotated position falls into the support of the kernel # find the indices where the rotated position falls into the support of the kernel
iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff) iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff_eff)
# add the output latitude and reshape such that psi has dimensions kernel_shape x nlat_out x (nlat_in*nlon_in) # add the output latitude and reshape such that psi has dimensions kernel_shape x nlat_out x (nlat_in*nlon_in)
idx = torch.stack([iidx[:, 0], t * torch.ones_like(iidx[:, 0]), iidx[:, 1] * nlon_in + iidx[:, 2]], dim=0) idx = torch.stack([iidx[:, 0], t * torch.ones_like(iidx[:, 0]), iidx[:, 1] * nlon_in + iidx[:, 2]], dim=0)
...@@ -217,8 +221,8 @@ def _precompute_convolution_tensor_s2( ...@@ -217,8 +221,8 @@ def _precompute_convolution_tensor_s2(
out_vals.append(vals) out_vals.append(vals)
# concatenate the indices and values # concatenate the indices and values
out_idx = torch.cat(out_idx, dim=-1).to(torch.long).contiguous() out_idx = torch.cat(out_idx, dim=-1)
out_vals = torch.cat(out_vals, dim=-1).to(torch.float32).contiguous() out_vals = torch.cat(out_vals, dim=-1)
out_vals = _normalize_convolution_tensor_s2( out_vals = _normalize_convolution_tensor_s2(
out_idx, out_idx,
...@@ -232,6 +236,9 @@ def _precompute_convolution_tensor_s2( ...@@ -232,6 +236,9 @@ def _precompute_convolution_tensor_s2(
merge_quadrature=merge_quadrature, merge_quadrature=merge_quadrature,
) )
out_idx = out_idx.contiguous()
out_vals = out_vals.to(dtype=torch.float32).contiguous()
return out_idx, out_vals return out_idx, out_vals
......
...@@ -75,6 +75,7 @@ def _precompute_distributed_convolution_tensor_s2( ...@@ -75,6 +75,7 @@ def _precompute_distributed_convolution_tensor_s2(
grid_in="equiangular", grid_in="equiangular",
grid_out="equiangular", grid_out="equiangular",
theta_cutoff=0.01 * math.pi, theta_cutoff=0.01 * math.pi,
theta_eps = 1e-3,
transpose_normalization=False, transpose_normalization=False,
basis_norm_mode="mean", basis_norm_mode="mean",
merge_quadrature=False, merge_quadrature=False,
...@@ -103,21 +104,25 @@ def _precompute_distributed_convolution_tensor_s2( ...@@ -103,21 +104,25 @@ def _precompute_distributed_convolution_tensor_s2(
nlat_in, nlon_in = in_shape nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape nlat_out, nlon_out = out_shape
# precompute input and output grids
lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in) lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in)
lats_in = torch.from_numpy(lats_in).float() lats_in = torch.from_numpy(lats_in)
lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out) lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out)
lats_out = torch.from_numpy(lats_out).float() lats_out = torch.from_numpy(lats_out)
# compute the phi differences # compute the phi differences
# It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0 # It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1] lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1, dtype=torch.float64)[:-1]
# compute quadrature weights and merge them into the convolution tensor. # compute quadrature weights and merge them into the convolution tensor.
# These quadrature integrate to 1 over the sphere. # These quadrature integrate to 1 over the sphere.
if transpose_normalization: if transpose_normalization:
quad_weights = torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in / 2.0 quad_weights = torch.from_numpy(wout).reshape(-1, 1) / nlon_in / 2.0
else: else:
quad_weights = torch.from_numpy(win).float().reshape(-1, 1) / nlon_in / 2.0 quad_weights = torch.from_numpy(win).reshape(-1, 1) / nlon_in / 2.0
# effective theta cutoff if multiplied with a fudge factor to avoid aliasing with grid width (especially near poles)
theta_cutoff_eff = (1.0 + theta_eps) * theta_cutoff
out_idx = [] out_idx = []
out_vals = [] out_vals = []
...@@ -147,7 +152,7 @@ def _precompute_distributed_convolution_tensor_s2( ...@@ -147,7 +152,7 @@ def _precompute_distributed_convolution_tensor_s2(
phi = torch.where(phi < 0.0, phi + 2 * torch.pi, phi) phi = torch.where(phi < 0.0, phi + 2 * torch.pi, phi)
# find the indices where the rotated position falls into the support of the kernel # find the indices where the rotated position falls into the support of the kernel
iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff) iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff_eff)
# add the output latitude and reshape such that psi has dimensions kernel_shape x nlat_out x (nlat_in*nlon_in) # add the output latitude and reshape such that psi has dimensions kernel_shape x nlat_out x (nlat_in*nlon_in)
idx = torch.stack([iidx[:, 0], t * torch.ones_like(iidx[:, 0]), iidx[:, 1] * nlon_in + iidx[:, 2]], dim=0) idx = torch.stack([iidx[:, 0], t * torch.ones_like(iidx[:, 0]), iidx[:, 1] * nlon_in + iidx[:, 2]], dim=0)
...@@ -157,8 +162,8 @@ def _precompute_distributed_convolution_tensor_s2( ...@@ -157,8 +162,8 @@ def _precompute_distributed_convolution_tensor_s2(
out_vals.append(vals) out_vals.append(vals)
# concatenate the indices and values # concatenate the indices and values
out_idx = torch.cat(out_idx, dim=-1).to(torch.long).contiguous() out_idx = torch.cat(out_idx, dim=-1)
out_vals = torch.cat(out_vals, dim=-1).to(torch.float32).contiguous() out_vals = torch.cat(out_vals, dim=-1)
out_vals = _normalize_convolution_tensor_s2( out_vals = _normalize_convolution_tensor_s2(
out_idx, out_idx,
...@@ -189,6 +194,9 @@ def _precompute_distributed_convolution_tensor_s2( ...@@ -189,6 +194,9 @@ def _precompute_distributed_convolution_tensor_s2(
# for the indices we need to recompute them to refer to local indices of the input tenor # for the indices we need to recompute them to refer to local indices of the input tenor
out_idx = torch.stack([out_idx[0, ilats], out_idx[1, ilats], (lats[ilats] - start_idx) * nlon_in + lons[ilats]], dim=0) out_idx = torch.stack([out_idx[0, ilats], out_idx[1, ilats], (lats[ilats] - start_idx) * nlon_in + lons[ilats]], dim=0)
out_idx = out_idx.contiguous()
out_vals = out_vals.to(dtype=torch.float32).contiguous()
return out_idx, out_vals return out_idx, out_vals
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment