Commit 60aea808 authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

Cleaning up normalization of DISCO convolutions

parent e1e079b9
......@@ -430,21 +430,21 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
normalization_layer="none",
)
models[f"lsno_sc2_layers4_e32"] = partial(
LSNO,
spectral_transform="sht",
img_size=(nlat, nlon),
grid=grid,
num_layers=4,
scale_factor=2,
embed_dim=32,
operator_type="driscoll-healy",
activation_function="gelu",
big_skip=True,
pos_embed=False,
use_mlp=True,
normalization_layer="none",
)
# models[f"lsno_sc2_layers4_e32"] = partial(
# LSNO,
# spectral_transform="sht",
# img_size=(nlat, nlon),
# grid=grid,
# num_layers=4,
# scale_factor=2,
# embed_dim=32,
# operator_type="driscoll-healy",
# activation_function="gelu",
# big_skip=True,
# pos_embed=False,
# use_mlp=True,
# normalization_layer="none",
# )
# iterate over models and train each model
root_path = os.path.dirname(__file__)
......
This diff is collapsed.
......@@ -40,8 +40,8 @@ def plot_sphere(data,
title=None,
colorbar=False,
coastlines=False,
central_latitude=20,
central_longitude=20,
central_latitude=0,
central_longitude=0,
lon=None,
lat=None,
**kwargs):
......@@ -74,14 +74,15 @@ def plot_sphere(data,
return im
def plot_data(data,
fig=None,
projection=None,
cmap="RdBu",
title=None,
colorbar=False,
lon=None,
lat=None,
**kwargs):
fig=None,
cmap="RdBu",
title=None,
colorbar=False,
coastlines=False,
central_longitude=0,
lon=None,
lat=None,
**kwargs):
if fig == None:
fig = plt.figure()
......@@ -93,16 +94,19 @@ def plot_data(data,
lat = np.linspace(np.pi/2., -np.pi/2., nlat)
Lon, Lat = np.meshgrid(lon, lat)
fig = plt.figure(figsize=(10, 5))
ax = fig.add_subplot(1, 1, 1, projection=projection)
im = ax.pcolormesh(Lon, Lat, data, cmap=cmap, **kwargs)
proj = ccrs.PlateCarree(central_longitude=central_longitude)
# proj = ccrs.Mollweide(central_longitude=central_longitude)
ax = fig.add_subplot(projection=proj)
Lon = Lon*180/np.pi
Lat = Lat*180/np.pi
# contour data over the map.
im = ax.pcolormesh(Lon, Lat, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=False, **kwargs)
if coastlines:
ax.add_feature(cartopy.feature.COASTLINE, edgecolor='white', facecolor='none', linewidth=1.5)
if colorbar:
plt.colorbar(im)
plt.title(title, y=1.05)
ax.grid(True)
ax.set_xticklabels([])
ax.set_yticklabels([])
return im
\ No newline at end of file
......@@ -41,92 +41,110 @@ from torch_harmonics import *
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes
def _compute_vals_isotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, r_cutoff: float):
"""
helper routine to compute the values of the isotropic kernel densely
"""
kernel_size = (nr // 2) + nr % 2
ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
dr = 2 * r_cutoff / (nr + 1)
# compute the support
if nr % 2 == 1:
ir = ikernel * dr
else:
ir = (ikernel + 0.5) * dr
vals = torch.where(
((r - ir).abs() <= dr) & (r <= r_cutoff),
(1 - (r - ir).abs() / dr),
0,
)
return vals
def _compute_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, nphi: int, r_cutoff: float):
"""
helper routine to compute the values of the anisotropic kernel densely
"""
kernel_size = (nr // 2) * nphi + nr % 2
ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
dr = 2 * r_cutoff / (nr + 1)
dphi = 2.0 * math.pi / nphi
# disambiguate even and uneven cases and compute the support
if nr % 2 == 1:
ir = ((ikernel - 1) // nphi + 1) * dr
iphi = ((ikernel - 1) % nphi) * dphi
else:
ir = (ikernel // nphi + 0.5) * dr
iphi = (ikernel % nphi) * dphi
# compute the value of the filter
if nr % 2 == 1:
# find the indices where the rotated position falls into the support of the kernel
cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
cond_phi = ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi)
r_vals = torch.where(cond_r, (1 - (r - ir).abs() / dr), 0.0)
phi_vals = torch.where(cond_phi, (1 - torch.minimum((phi - iphi).abs(), (2 * math.pi - (phi - iphi).abs())) / dphi), 0.0)
vals = torch.where(ikernel > 0, r_vals * phi_vals, r_vals)
else:
# find the indices where the rotated position falls into the support of the kernel
cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
cond_phi = ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi)
r_vals = torch.where(cond_r, (1 - (r - ir).abs() / dr), 0.0)
phi_vals = torch.where(cond_phi, (1 - torch.minimum((phi - iphi).abs(), (2 * math.pi - (phi - iphi).abs())) / dphi), 0.0)
vals = r_vals * phi_vals
# in the even case, the inner casis functions overlap into areas with a negative areas
rn = -r
phin = torch.where(phi + math.pi >= 2 * math.pi, phi - math.pi, phi + math.pi)
cond_rn = ((rn - ir).abs() <= dr) & (rn <= r_cutoff)
cond_phin = ((phin - iphi).abs() <= dphi) | ((2 * math.pi - (phin - iphi).abs()) <= dphi)
rn_vals = torch.where(cond_rn, (1 - (rn - ir).abs() / dr), 0.0)
phin_vals = torch.where(cond_phin, (1 - torch.minimum((phin - iphi).abs(), (2 * math.pi - (phin - iphi).abs())) / dphi), 0.0)
vals += rn_vals * phin_vals
return vals
def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalization=False, merge_quadrature=False, eps=1e-9):
# def _compute_vals_isotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, r_cutoff: float):
# """
# helper routine to compute the values of the isotropic kernel densely
# """
# kernel_size = (nr // 2) + nr % 2
# ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
# dr = 2 * r_cutoff / (nr + 1)
# # compute the support
# if nr % 2 == 1:
# ir = ikernel * dr
# else:
# ir = (ikernel + 0.5) * dr
# vals = torch.where(
# ((r - ir).abs() <= dr) & (r <= r_cutoff),
# (1 - (r - ir).abs() / dr),
# 0,
# )
# return vals
# def _compute_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, nphi: int, r_cutoff: float):
# """
# helper routine to compute the values of the anisotropic kernel densely
# """
# kernel_size = (nr // 2) * nphi + nr % 2
# ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
# dr = 2 * r_cutoff / (nr + 1)
# dphi = 2.0 * math.pi / nphi
# # disambiguate even and uneven cases and compute the support
# if nr % 2 == 1:
# ir = ((ikernel - 1) // nphi + 1) * dr
# iphi = ((ikernel - 1) % nphi) * dphi
# else:
# ir = (ikernel // nphi + 0.5) * dr
# iphi = (ikernel % nphi) * dphi
# # compute the value of the filter
# if nr % 2 == 1:
# # find the indices where the rotated position falls into the support of the kernel
# cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
# cond_phi = ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi)
# r_vals = torch.where(cond_r, (1 - (r - ir).abs() / dr), 0.0)
# phi_vals = torch.where(cond_phi, (1 - torch.minimum((phi - iphi).abs(), (2 * math.pi - (phi - iphi).abs())) / dphi), 0.0)
# vals = torch.where(ikernel > 0, r_vals * phi_vals, r_vals)
# else:
# # find the indices where the rotated position falls into the support of the kernel
# cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
# cond_phi = ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi)
# r_vals = torch.where(cond_r, (1 - (r - ir).abs() / dr), 0.0)
# phi_vals = torch.where(cond_phi, (1 - torch.minimum((phi - iphi).abs(), (2 * math.pi - (phi - iphi).abs())) / dphi), 0.0)
# vals = r_vals * phi_vals
# # in the even case, the inner casis functions overlap into areas with a negative areas
# rn = -r
# phin = torch.where(phi + math.pi >= 2 * math.pi, phi - math.pi, phi + math.pi)
# cond_rn = ((rn - ir).abs() <= dr) & (rn <= r_cutoff)
# cond_phin = ((phin - iphi).abs() <= dphi) | ((2 * math.pi - (phin - iphi).abs()) <= dphi)
# rn_vals = torch.where(cond_rn, (1 - (rn - ir).abs() / dr), 0.0)
# phin_vals = torch.where(cond_phin, (1 - torch.minimum((phin - iphi).abs(), (2 * math.pi - (phin - iphi).abs())) / dphi), 0.0)
# vals += rn_vals * phin_vals
# return vals
def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalization=False, basis_norm_mode="none", merge_quadrature=False, eps=1e-9):
"""
Discretely normalizes the convolution tensor.
"""
kernel_size, nlat_out, nlon_out, nlat_in, nlon_in = psi.shape
scale_factor = float(nlon_in // nlon_out)
correction_factor = nlon_out / nlon_in
if basis_norm_mode == "individual":
if transpose_normalization:
# the normalization is not quite symmetric due to the compressed way psi is stored in the main code
# look at the normalization code in the actual implementation
psi_norm = torch.sqrt(torch.sum(quad_weights.reshape(1, -1, 1, 1, 1) * psi[:, :, :1].abs().pow(2), dim=(1, 4), keepdim=True) / 4 / math.pi)
else:
psi_norm = torch.sqrt(torch.sum(quad_weights.reshape(1, 1, 1, -1, 1) * psi.abs().pow(2), dim=(3, 4), keepdim=True) / 4 / math.pi)
elif basis_norm_mode == "mean":
if transpose_normalization:
# the normalization is not quite symmetric due to the compressed way psi is stored in the main code
# look at the normalization code in the actual implementation
psi_norm = torch.sqrt(torch.sum(quad_weights.reshape(1, -1, 1, 1, 1) * psi[:, :, :1].abs().pow(2), dim=(1, 4), keepdim=True) / 4 / math.pi)
psi_norm = psi_norm.mean(dim=3, keepdim=True)
else:
psi_norm = torch.sqrt(torch.sum(quad_weights.reshape(1, 1, 1, -1, 1) * psi.abs().pow(2), dim=(3, 4), keepdim=True) / 4 / math.pi)
psi_norm = psi_norm.mean(dim=1, keepdim=True)
elif basis_norm_mode == "none":
psi_norm = 1.0
else:
raise ValueError(f"Unknown basis normalization mode {basis_norm_mode}.")
if transpose_normalization:
# the normalization is not quite symmetric due to the compressed way psi is stored in the main code
# look at the normalization code in the actual implementation
psi_norm = torch.sum(quad_weights.reshape(1, -1, 1, 1, 1) * psi[:, :, :1], dim=(1, 4), keepdim=True) / scale_factor
if merge_quadrature:
psi = quad_weights.reshape(1, -1, 1, 1, 1) * psi
psi = quad_weights.reshape(1, -1, 1, 1, 1) * psi / correction_factor
else:
psi_norm = torch.sum(quad_weights.reshape(1, 1, 1, -1, 1) * psi, dim=(3, 4), keepdim=True)
if merge_quadrature:
psi = quad_weights.reshape(1, 1, 1, -1, 1) * psi
......@@ -137,11 +155,12 @@ def _precompute_convolution_tensor_dense(
in_shape,
out_shape,
kernel_shape,
quad_weights,
filter_basis,
grid_in="equiangular",
grid_out="equiangular",
theta_cutoff=0.01 * math.pi,
transpose_normalization=False,
basis_norm_mode="none",
merge_quadrature=False,
):
"""
......@@ -151,29 +170,26 @@ def _precompute_convolution_tensor_dense(
assert len(in_shape) == 2
assert len(out_shape) == 2
quad_weights = quad_weights.reshape(-1, 1)
if len(kernel_shape) == 1:
kernel_handle = partial(_compute_vals_isotropic, nr=kernel_shape[0], r_cutoff=theta_cutoff)
kernel_size = math.ceil(kernel_shape[0] / 2)
elif len(kernel_shape) == 2:
kernel_handle = partial(_compute_vals_anisotropic, nr=kernel_shape[0], nphi=kernel_shape[1], r_cutoff=theta_cutoff)
kernel_size = (kernel_shape[0] // 2) * kernel_shape[1] + kernel_shape[0] % 2
else:
raise ValueError("kernel_shape should be either one- or two-dimensional.")
kernel_size = filter_basis.kernel_size
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
lats_in, _ = quadrature._precompute_latitudes(nlat_in, grid=grid_in)
lats_in, win = quadrature._precompute_latitudes(nlat_in, grid=grid_in)
lats_in = torch.from_numpy(lats_in).float()
lats_out, _ = quadrature._precompute_latitudes(nlat_out, grid=grid_out)
lats_out, wout = quadrature._precompute_latitudes(nlat_out, grid=grid_out)
lats_out = torch.from_numpy(lats_out).float() # array for accumulating non-zero indices
# compute the phi differences. We need to make the linspace exclusive to not double the last point
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1]
lons_out = torch.linspace(0, 2 * math.pi, nlon_out + 1)[:-1]
# compute quadrature weights that will be merged into the Psi tensor
if transpose_normalization:
quad_weights = 2.0 * torch.pi * torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in
else:
quad_weights = 2.0 * torch.pi * torch.from_numpy(win).float().reshape(-1, 1) / nlon_in
out = torch.zeros(kernel_size, nlat_out, nlon_out, nlat_in, nlon_in)
for t in range(nlat_out):
......@@ -187,7 +203,7 @@ def _precompute_convolution_tensor_dense(
# compute cartesian coordinates of the rotated position
x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(gamma) * torch.sin(alpha)
y = torch.sin(beta) * torch.sin(gamma)
y = torch.sin(beta) * torch.sin(gamma) * torch.ones_like(alpha)
# normalize instead of clipping to ensure correct range
norm = torch.sqrt(x * x + y * y + z * z)
......@@ -197,13 +213,17 @@ def _precompute_convolution_tensor_dense(
# compute spherical coordinates
theta = torch.arccos(z)
phi = torch.arctan2(y, x) + torch.pi
phi = torch.arctan2(y, x)
phi = torch.where(phi < 0.0, phi + 2 * torch.pi, phi)
# find the indices where the rotated position falls into the support of the kernel
out[:, t, p, :, :] = kernel_handle(theta, phi)
iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff)
out[iidx[:, 0], t, p, iidx[:, 1], iidx[:, 2]] = vals
# take care of normalization
out = _normalize_convolution_tensor_dense(out, quad_weights=quad_weights, transpose_normalization=transpose_normalization, merge_quadrature=merge_quadrature)
out = _normalize_convolution_tensor_dense(
out, quad_weights=quad_weights, transpose_normalization=transpose_normalization, basis_norm_mode=basis_norm_mode, merge_quadrature=merge_quadrature
)
return out
......@@ -217,30 +237,32 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
else:
self.device = torch.device("cpu")
torch.manual_seed(333)
self.device = torch.device("cpu")
@parameterized.expand(
[
# regular convolution
[8, 4, 2, (16, 32), (16, 32), [3], "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), [5], "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), [3, 3], "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), [4, 3], "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (16, 24), (8, 8), [3], "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (18, 36), (6, 12), [7], "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), [5], "equiangular", "legendre-gauss", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), [5], "legendre-gauss", "equiangular", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), [5], "legendre-gauss", "legendre-gauss", False, 1e-4],
[8, 4, 2, (16, 32), (16, 32), [3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), [3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (24, 48), (12, 24), [3, 3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (24, 48), (12, 24), [4, 3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (24, 48), (12, 24), [2, 2], "morlet", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (16, 24), (8, 8), [3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (18, 36), (6, 12), [7], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), [5], "piecewise linear", "mean", "equiangular", "legendre-gauss", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), [5], "piecewise linear", "mean", "legendre-gauss", "equiangular", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), [5], "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", False, 1e-4],
# transpose convolution
[8, 4, 2, (16, 32), (16, 32), [3], "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), [5], "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), [3, 3], "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), [4, 3], "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (8, 8), (16, 24), [3], "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (6, 12), (18, 36), [7], "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), [5], "equiangular", "legendre-gauss", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), [5], "legendre-gauss", "equiangular", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), [5], "legendre-gauss", "legendre-gauss", True, 1e-4],
[8, 4, 2, (16, 32), (16, 32), [3], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), [5], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (12, 24), (24, 48), [3, 3], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (12, 24), (24, 48), [4, 3], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (12, 24), (24, 48), [2, 2], "morlet", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (8, 8), (16, 24), [3], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (6, 12), (18, 36), [7], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), [5], "piecewise linear", "mean", "equiangular", "legendre-gauss", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), [5], "piecewise linear", "mean", "legendre-gauss", "equiangular", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), [5], "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", True, 1e-4],
]
)
def test_disco_convolution(
......@@ -251,6 +273,8 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
in_shape,
out_shape,
kernel_shape,
basis_type,
basis_norm_mode,
grid_in,
grid_out,
transpose,
......@@ -259,19 +283,38 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
theta_cutoff = (kernel_shape[0] + 1) / 2 * torch.pi / float(nlat_out - 1)
theta_cutoff = (kernel_shape[0] + 1) * torch.pi / float(nlat_in - 1)
Conv = DiscreteContinuousConvTransposeS2 if transpose else DiscreteContinuousConvS2
conv = Conv(in_channels, out_channels, in_shape, out_shape, kernel_shape, groups=1, grid_in=grid_in, grid_out=grid_out, bias=False, theta_cutoff=theta_cutoff).to(
self.device
)
_, wgl = _precompute_latitudes(nlat_in, grid=grid_in)
quad_weights = 2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / nlon_in
conv = Conv(
in_channels,
out_channels,
in_shape,
out_shape,
kernel_shape,
basis_type=basis_type,
basis_norm_mode=basis_norm_mode,
groups=1,
grid_in=grid_in,
grid_out=grid_out,
bias=False,
theta_cutoff=theta_cutoff,
).to(self.device)
filter_basis = conv.filter_basis
if transpose:
psi_dense = _precompute_convolution_tensor_dense(
out_shape, in_shape, kernel_shape, quad_weights, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff, transpose_normalization=True, merge_quadrature=True
out_shape,
in_shape,
kernel_shape,
filter_basis,
grid_in=grid_out,
grid_out=grid_in,
theta_cutoff=theta_cutoff,
transpose_normalization=transpose,
basis_norm_mode=basis_norm_mode,
merge_quadrature=True,
).to(self.device)
psi = torch.sparse_coo_tensor(conv.psi_idx, conv.psi_vals, size=(conv.kernel_size, conv.nlat_in, conv.nlat_out * conv.nlon_out)).to_dense()
......@@ -279,7 +322,16 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
self.assertTrue(torch.allclose(psi, psi_dense[:, :, 0].reshape(-1, nlat_in, nlat_out * nlon_out)))
else:
psi_dense = _precompute_convolution_tensor_dense(
in_shape, out_shape, kernel_shape, quad_weights, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff, transpose_normalization=False, merge_quadrature=True
in_shape,
out_shape,
kernel_shape,
filter_basis,
grid_in=grid_in,
grid_out=grid_out,
theta_cutoff=theta_cutoff,
transpose_normalization=transpose,
basis_norm_mode=basis_norm_mode,
merge_quadrature=True,
).to(self.device)
psi = torch.sparse_coo_tensor(conv.psi_idx, conv.psi_vals, size=(conv.kernel_size, conv.nlat_out, conv.nlat_in * conv.nlon_in)).to_dense()
......
......@@ -183,21 +183,23 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
@parameterized.expand(
[
[128, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", False, 1e-5],
[129, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, [3, 2], 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 64, 128, 32, 8, [3], 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, [3], 2, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 6, [3], 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-5],
[129, 256, 129, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 8, [3, 2], 1, "equiangular", "equiangular", True, 1e-5],
[64, 128, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 8, [3], 2, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 6, [3], 1, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 8, [3], "piecewise linear", "individual", 1, "equiangular", "equiangular", False, 1e-5],
[129, 256, 128, 256, 32, 8, [3], "piecewise linear", "individual", 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, [3, 2], "piecewise linear", "individual", 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 64, 128, 32, 8, [3], "piecewise linear", "individual", 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, [3], "piecewise linear", "individual", 2, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 6, [3], "piecewise linear", "individual", 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, [3], "piecewise linear", "individual", 1, "equiangular", "equiangular", True, 1e-5],
[129, 256, 129, 256, 32, 8, [3], "piecewise linear", "individual", 1, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 8, [3, 2], "piecewise linear", "individual", 1, "equiangular", "equiangular", True, 1e-5],
[64, 128, 128, 256, 32, 8, [3], "piecewise linear", "individual", 1, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 8, [3], "piecewise linear", "individual", 2, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 6, [3], "piecewise linear", "individual", 1, "equiangular", "equiangular", True, 1e-5],
]
)
def test_distributed_disco_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, kernel_shape, groups, grid_in, grid_out, transpose, tol):
def test_distributed_disco_conv(
self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, kernel_shape, basis_type, basis_norm_mode, groups, grid_in, grid_out, transpose, tol
):
B, C, H, W = batch_size, num_chan, nlat_in, nlon_in
......@@ -206,6 +208,8 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
out_channels=C,
in_shape=(nlat_in, nlon_in),
out_shape=(nlat_out, nlon_out),
basis_type=basis_type,
basis_norm_mode=basis_norm_mode,
kernel_shape=kernel_shape,
groups=groups,
grid_in=grid_in,
......
......@@ -57,70 +57,71 @@ except ImportError as err:
def _normalize_convolution_tensor_s2(
psi_idx, psi_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=False, basis_norm_mode="sum", merge_quadrature=False, eps=1e-9
psi_idx, psi_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=False, basis_norm_mode="none", merge_quadrature=False, eps=1e-9
):
"""
Discretely normalizes the convolution tensor. Supports different normalization modes
Discretely normalizes the convolution tensor and pre-applies quadrature weights. Supports the following three normalization modes:
- "none": No normalization is applied.
- "individual": for each output latitude and filter basis function the filter is numerically integrated over the sphere and normalized so that it yields 1.
- "mean": the norm is computed for each output latitude and then averaged over the output latitudes. Each basis function is then normalized by this mean.
"""
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
# reshape the indices implicitly to be ikernel, out_shape[0], in_shape[0], in_shape[1]
idx = torch.stack([psi_idx[0], psi_idx[1], psi_idx[2] // in_shape[1], psi_idx[2] % in_shape[1]], dim=0)
# reshape the indices implicitly to be ikernel, lat_out, lat_in, lon_in
idx = torch.stack([psi_idx[0], psi_idx[1], psi_idx[2] // nlon_in, psi_idx[2] % nlon_in], dim=0)
# getting indices for adressing kernels, input and output latitudes
ikernel = idx[0]
if transpose_normalization:
# pre-compute the quadrature weights
q = quad_weights[idx[1]].reshape(-1)
# loop through dimensions which require normalization
for ik in range(kernel_size):
for ilat in range(nlat_in):
# get relevant entries depending on the normalization mode
if basis_norm_mode == "individual":
iidx = torch.argwhere((idx[0] == ik) & (idx[2] == ilat))
# normalize, while summing also over the input longitude dimension here as this is not available for the output
vnorm = torch.sum(psi_vals[iidx].abs() * q[iidx])
elif basis_norm_mode == "sum":
# this will perform repeated normalization in the kernel dimension but this shouldn't lead to issues
iidx = torch.argwhere(idx[2] == ilat)
# normalize, while summing also over the input longitude dimension here as this is not available for the output
vnorm = torch.sum(psi_vals[iidx].abs() * q[iidx])
else:
raise ValueError(f"Unknown basis normalization mode {basis_norm_mode}.")
if merge_quadrature:
# the correction factor accounts for the difference in longitudinal grid points when the input vector is upscaled
psi_vals[iidx] = psi_vals[iidx] * q[iidx] * nlon_in / nlon_out / (vnorm + eps)
else:
psi_vals[iidx] = psi_vals[iidx] / (vnorm + eps)
ilat_out = idx[2]
ilat_in = idx[1]
# here we are deliberately swapping input and output shapes to handle transpose normalization with the same code
nlat_out = in_shape[0]
correction_factor = out_shape[1] / in_shape[1]
else:
# pre-compute the quadrature weights
q = quad_weights[idx[2]].reshape(-1)
# loop through dimensions which require normalization
for ik in range(kernel_size):
for ilat in range(nlat_out):
# get relevant entries depending on the normalization mode
if basis_norm_mode == "individual":
iidx = torch.argwhere((idx[0] == ik) & (idx[1] == ilat))
# normalize
vnorm = torch.sum(psi_vals[iidx].abs() * q[iidx])
elif basis_norm_mode == "sum":
# this will perform repeated normalization in the kernel dimension but this shouldn't lead to issues
iidx = torch.argwhere(idx[1] == ilat)
# normalize
vnorm = torch.sum(psi_vals[iidx].abs() * q[iidx])
else:
raise ValueError(f"Unknown basis normalization mode {basis_norm_mode}.")
if merge_quadrature:
psi_vals[iidx] = psi_vals[iidx] * q[iidx] / (vnorm + eps)
else:
psi_vals[iidx] = psi_vals[iidx] / (vnorm + eps)
ilat_out = idx[1]
ilat_in = idx[2]
nlat_out = out_shape[0]
# get the quadrature weights
q = quad_weights[ilat_in].reshape(-1)
# buffer to store intermediate values
vnorm = torch.zeros(kernel_size, nlat_out)
# loop through dimensions to compute the norms
for ik in range(kernel_size):
for ilat in range(nlat_out):
# find indices corresponding to the given output latitude and kernel basis function
iidx = torch.argwhere((ikernel == ik) & (ilat_out == ilat))
# compute the 2-norm, accounting for the fact that it is 4-pi normalized
vnorm[ik, ilat] = torch.sqrt(torch.sum(psi_vals[iidx].abs().pow(2) * q[iidx]) / 4 / torch.pi)
# loop over values and renormalize
for ik in range(kernel_size):
for ilat in range(nlat_out):
iidx = torch.argwhere((ikernel == ik) & (ilat_out == ilat))
if basis_norm_mode == "individual":
val = vnorm[ik, ilat]
elif basis_norm_mode == "mean":
val = vnorm[ik, :].mean()
elif basis_norm_mode == "none":
val = 1.0
else:
raise ValueError(f"Unknown basis normalization mode {basis_norm_mode}.")
if merge_quadrature:
psi_vals[iidx] = psi_vals[iidx] * q[iidx] / (val + eps)
else:
psi_vals[iidx] = psi_vals[iidx] / (val + eps)
if transpose_normalization and merge_quadrature:
psi_vals = psi_vals / correction_factor
return psi_vals
......@@ -133,7 +134,7 @@ def _precompute_convolution_tensor_s2(
grid_out="equiangular",
theta_cutoff=0.01 * math.pi,
transpose_normalization=False,
basis_norm_mode="sum",
basis_norm_mode="none",
merge_quadrature=False,
):
"""
......@@ -187,11 +188,11 @@ def _precompute_convolution_tensor_s2(
# compute cartesian coordinates of the rotated position
# This uses the YZY convention of Euler angles, where the last angle (alpha) is a passive rotation,
# and therefore applied with a negative sign
z = -torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(alpha) * torch.cos(gamma)
x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(gamma) * torch.sin(alpha)
y = torch.sin(beta) * torch.sin(gamma)
z = -torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(alpha) * torch.cos(gamma)
# normalization is emportant to avoid NaNs when arccos and atan are applied
# normalization is important to avoid NaNs when arccos and atan are applied
# this can otherwise lead to spurious artifacts in the solution
norm = torch.sqrt(x * x + y * y + z * z)
x = x / norm
......@@ -200,7 +201,8 @@ def _precompute_convolution_tensor_s2(
# compute spherical coordinates, where phi needs to fall into the [0, 2pi) range
theta = torch.arccos(z)
phi = torch.arctan2(y, x) + torch.pi
phi = torch.arctan2(y, x)
phi = torch.where(phi < 0.0, phi + 2 * torch.pi, phi)
# find the indices where the rotated position falls into the support of the kernel
iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff)
......@@ -293,7 +295,7 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
out_shape: Tuple[int],
kernel_shape: Union[int, List[int]],
basis_type: Optional[str] = "piecewise linear",
basis_norm_mode: Optional[str] = "sum",
basis_norm_mode: Optional[str] = "none",
groups: Optional[int] = 1,
grid_in: Optional[str] = "equiangular",
grid_out: Optional[str] = "equiangular",
......@@ -305,6 +307,9 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
self.nlat_in, self.nlon_in = in_shape
self.nlat_out, self.nlon_out = out_shape
# make sure the p-shift works by checking that longitudes are divisible
assert self.nlon_in % self.nlon_out == 0
# heuristic to compute theta cutoff based on the bandlimit of the input field and overlaps of the basis functions
if theta_cutoff is None:
theta_cutoff = torch.pi / float(self.nlat_out - 1)
......@@ -396,7 +401,7 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
out_shape: Tuple[int],
kernel_shape: Union[int, List[int]],
basis_type: Optional[str] = "piecewise linear",
basis_norm_mode: Optional[str] = "sum",
basis_norm_mode: Optional[str] = "none",
groups: Optional[int] = 1,
grid_in: Optional[str] = "equiangular",
grid_out: Optional[str] = "equiangular",
......@@ -408,6 +413,9 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
self.nlat_in, self.nlon_in = in_shape
self.nlat_out, self.nlon_out = out_shape
# make sure the p-shift works by checking that longitudes are divisible
assert self.nlon_out % self.nlon_in == 0
# bandlimit
if theta_cutoff is None:
theta_cutoff = torch.pi / float(self.nlat_in - 1)
......@@ -415,7 +423,7 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
if theta_cutoff <= 0.0:
raise ValueError("Error, theta_cutoff has to be positive.")
# switch in_shape and out_shape since we want transpose conv
# switch in_shape and out_shape since we want the transpose convolution
idx, vals = _precompute_convolution_tensor_s2(
out_shape,
in_shape,
......
......@@ -76,7 +76,7 @@ def _precompute_distributed_convolution_tensor_s2(
grid_out="equiangular",
theta_cutoff=0.01 * math.pi,
transpose_normalization=False,
basis_norm_mode="sum",
basis_norm_mode="none",
merge_quadrature=False,
):
"""
......@@ -142,7 +142,8 @@ def _precompute_distributed_convolution_tensor_s2(
# compute spherical coordinates, where phi needs to fall into the [0, 2pi) range
theta = torch.arccos(z)
phi = torch.arctan2(y, x) + torch.pi
phi = torch.arctan2(y, x)
phi = torch.where(phi < 0.0, phi + 2 * torch.pi, phi)
# find the indices where the rotated position falls into the support of the kernel
iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff)
......@@ -207,7 +208,7 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
out_shape: Tuple[int],
kernel_shape: Union[int, List[int]],
basis_type: Optional[str] = "piecewise linear",
basis_norm_mode: Optional[str] = "sum",
basis_norm_mode: Optional[str] = "none",
groups: Optional[int] = 1,
grid_in: Optional[str] = "equiangular",
grid_out: Optional[str] = "equiangular",
......@@ -347,7 +348,7 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
out_shape: Tuple[int],
kernel_shape: Union[int, List[int]],
basis_type: Optional[str] = "piecewise linear",
basis_norm_mode: Optional[str] = "sum",
basis_norm_mode: Optional[str] = "none",
groups: Optional[int] = 1,
grid_in: Optional[str] = "equiangular",
grid_out: Optional[str] = "equiangular",
......
......@@ -41,12 +41,19 @@ def get_filter_basis(kernel_shape: Union[int, List[int], Tuple[int, int]], basis
if basis_type == "piecewise linear":
return PiecewiseLinearFilterBasis(kernel_shape=kernel_shape)
elif basis_type == "disk morlet":
return DiskMorletFilterBasis(kernel_shape=kernel_shape)
elif basis_type == "morlet":
return MorletFilterBasis(kernel_shape=kernel_shape)
elif basis_type == "zernike":
raise NotImplementedError()
else:
raise ValueError(f"Unknown basis_type {basis_type}")
def _circle_dist(x1: torch.Tensor, x2: torch.Tensor):
"""Helper function to compute the distance on a circle"""
return torch.minimum(torch.abs(x1 - x2), torch.abs(2 * math.pi - torch.abs(x1 - x2)))
class FilterBasis(metaclass=abc.ABCMeta):
"""
Abstract base class for a filter basis
......@@ -64,6 +71,13 @@ class FilterBasis(metaclass=abc.ABCMeta):
def kernel_size(self):
raise NotImplementedError
# @abc.abstractmethod
# def compute_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float):
# """
# Computes the values of the filter basis
# """
# raise NotImplementedError
@abc.abstractmethod
def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float):
"""
......@@ -136,49 +150,49 @@ class PiecewiseLinearFilterBasis(FilterBasis):
# disambiguate even and uneven cases and compute the support
if nr % 2 == 1:
ir = ((ikernel - 1) // nphi + 1) * dr
iphi = ((ikernel - 1) % nphi) * dphi
iphi = ((ikernel - 1) % nphi) * dphi - math.pi
else:
ir = (ikernel // nphi + 0.5) * dr
iphi = (ikernel % nphi) * dphi
iphi = (ikernel % nphi) * dphi - math.pi
# find the indices where the rotated position falls into the support of the kernel
if nr % 2 == 1:
# find the support
cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
cond_phi = (ikernel == 0) | ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi)
cond_phi = (ikernel == 0) | (_circle_dist(phi, iphi).abs() <= dphi)
# find indices where conditions are met
iidx = torch.argwhere(cond_r & cond_phi)
# compute the distance to the collocation points
dist_r = (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs()
dist_phi = (phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs()
dist_phi = _circle_dist(phi[iidx[:, 1], iidx[:, 2]], iphi[iidx[:, 0], 0, 0])
# compute the value of the basis functions
vals = 1 - dist_r / dr
vals *= torch.where(
(iidx[:, 0] > 0),
(1 - torch.minimum(dist_phi, (2 * math.pi - dist_phi)) / dphi),
(1 - dist_phi / dphi),
1.0,
)
else:
# in the even case, the inner casis functions overlap into areas with a negative areas
# in the even case, the inner basis functions overlap into areas with a negative areas
rn = -r
phin = torch.where(phi + math.pi >= 2 * math.pi, phi - math.pi, phi + math.pi)
phin = torch.where(phi + math.pi >= math.pi, phi - math.pi, phi + math.pi)
# find the support
cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
cond_phi = ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi)
cond_phi = _circle_dist(phi, iphi).abs() <= dphi
cond_rn = ((rn - ir).abs() <= dr) & (rn <= r_cutoff)
cond_phin = ((phin - iphi).abs() <= dphi) | ((2 * math.pi - (phin - iphi).abs()) <= dphi)
cond_phin = _circle_dist(phin, iphi) <= dphi
# find indices where conditions are met
iidx = torch.argwhere((cond_r & cond_phi) | (cond_rn & cond_phin))
dist_r = (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs()
dist_phi = (phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs()
dist_phi = _circle_dist(phi[iidx[:, 1], iidx[:, 2]], iphi[iidx[:, 0], 0, 0])
dist_rn = (rn[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs()
dist_phin = (phin[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs()
dist_phin = _circle_dist(phin[iidx[:, 1], iidx[:, 2]], iphi[iidx[:, 0], 0, 0])
# compute the value of the basis functions
vals = cond_r[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_r / dr)
vals *= cond_phi[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - torch.minimum(dist_phi, (2 * math.pi - dist_phi)) / dphi)
vals *= cond_phi[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_phi / dphi)
valsn = cond_rn[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_rn / dr)
valsn *= cond_phin[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - torch.minimum(dist_phin, (2 * math.pi - dist_phin)) / dphi)
valsn *= cond_phin[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_phin / dphi)
vals += valsn
return iidx, vals
......@@ -190,9 +204,10 @@ class PiecewiseLinearFilterBasis(FilterBasis):
else:
return self._compute_support_vals_isotropic(r, phi, r_cutoff=r_cutoff)
class DiskMorletFilterBasis(FilterBasis):
class MorletFilterBasis(FilterBasis):
"""
Morlet-like Filter basis. A Gaussian is multiplied with a Fourier basis in x and y.
Morlet-style filter basis on the disk. A Gaussian is multiplied with a Fourier basis in x and y directions
"""
def __init__(
......@@ -209,12 +224,12 @@ class DiskMorletFilterBasis(FilterBasis):
@property
def kernel_size(self):
return self.kernel_shape[0]*self.kernel_shape[1]
return self.kernel_shape[0] * self.kernel_shape[1]
def _gaussian_envelope(self, r: torch.Tensor, width: float =1.0):
return 1 / (2 * math.pi * width**2 ) * torch.exp(- 0.5 * r**2 / (width**2))
def _gaussian_window(self, r: torch.Tensor, width: float = 1.0):
return 1 / (2 * math.pi * width**2) * torch.exp(-0.5 * r**2 / (width**2))
def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float):
def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float, width: float = 0.25):
"""
Computes the index set that falls into the isotropic kernel's support and returns both indices and values.
"""
......@@ -227,26 +242,18 @@ class DiskMorletFilterBasis(FilterBasis):
# get relevant indices
iidx = torch.argwhere((r <= r_cutoff) & torch.full_like(ikernel, True, dtype=torch.bool))
# # computes the Gaussian envelope. To ensure that the curve is roughly 0 at the boundary, we rescale the Gaussian by 0.25
# width = 0.01
width = 0.25
# width = 1.0
# envelope = self._gaussian_envelope(r, width=0.25 * r_cutoff)
# get x and y
x = r * torch.sin(phi) / r_cutoff
y = r * torch.cos(phi) / r_cutoff
harmonic = torch.where(nkernel % 2 == 1, torch.sin(torch.ceil(nkernel/2) * math.pi * x / width), torch.cos(torch.ceil(nkernel/2) * math.pi * x / width))
harmonic *= torch.where(mkernel % 2 == 1, torch.sin(torch.ceil(mkernel/2) * math.pi * y / width), torch.cos(torch.ceil(mkernel/2) * math.pi * y / width))
harmonic = torch.where(nkernel % 2 == 1, torch.sin(torch.ceil(nkernel / 2) * math.pi * x / width), torch.cos(torch.ceil(nkernel / 2) * math.pi * x / width))
harmonic *= torch.where(mkernel % 2 == 1, torch.sin(torch.ceil(mkernel / 2) * math.pi * y / width), torch.cos(torch.ceil(mkernel / 2) * math.pi * y / width))
# disk area
# disk_area = 2.0 * math.pi * (1.0 - math.cos(r_cutoff))
disk_area = 1
disk_area = 1.0
# computes the Gaussian envelope. To ensure that the curve is roughly 0 at the boundary, we rescale the Gaussian by 0.25
vals = self._gaussian_envelope(r[iidx[:, 1], iidx[:, 2]] / r_cutoff, width=width) * harmonic[iidx[:, 0], iidx[:, 1], iidx[:, 2]] / disk_area
# vals = torch.ones_like(vals)
vals = self._gaussian_window(r[iidx[:, 1], iidx[:, 2]] / r_cutoff, width=width) * harmonic[iidx[:, 0], iidx[:, 1], iidx[:, 2]] / disk_area
return iidx, vals
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
......@@ -31,6 +31,7 @@
import numpy as np
def _precompute_grid(n, grid="equidistant", a=0.0, b=1.0, periodic=False):
if (grid != "equidistant") and periodic:
......@@ -50,27 +51,31 @@ def _precompute_grid(n, grid="equidistant", a=0.0, b=1.0, periodic=False):
return xlg, wlg
def _precompute_latitudes(nlat, grid="equiangular"):
r"""
Convenience routine to precompute latitudes
"""
# compute coordinates
# compute coordinates in the cosine theta domain
xlg, wlg = _precompute_grid(nlat, grid=grid, a=-1.0, b=1.0, periodic=False)
# to perform the quadrature and account for the jacobian of the sphere, the quadrature rule
# is formulated in the cosine theta domain, which is designed to integrate functions of cos theta
lats = np.flip(np.arccos(xlg)).copy()
wlg = np.flip(wlg).copy()
return lats, wlg
def trapezoidal_weights(n, a=-1.0, b=1.0, periodic=False):
r"""
Helper routine which returns equidistant nodes with trapezoidal weights
on the interval [a, b]
"""
xlg = np.linspace(a, b, n)
wlg = (b - a) / (n - 1) * np.ones(n)
xlg = np.linspace(a, b, n, endpoint=periodic)
wlg = (b - a) / (n - periodic * 1) * np.ones(n)
if not periodic:
wlg[0] *= 0.5
......@@ -78,6 +83,7 @@ def trapezoidal_weights(n, a=-1.0, b=1.0, periodic=False):
return xlg, wlg
def legendre_gauss_weights(n, a=-1.0, b=1.0):
r"""
Helper routine which returns the Legendre-Gauss nodes and weights
......@@ -90,6 +96,7 @@ def legendre_gauss_weights(n, a=-1.0, b=1.0):
return xlg, wlg
def lobatto_weights(n, a=-1.0, b=1.0, tol=1e-16, maxiter=100):
r"""
Helper routine which returns the Legendre-Gauss-Lobatto nodes and weights
......@@ -102,33 +109,33 @@ def lobatto_weights(n, a=-1.0, b=1.0, tol=1e-16, maxiter=100):
# Vandermonde Matrix
vdm = np.zeros((n, n))
# initialize Chebyshev nodes as first guess
for i in range(n):
tlg[i] = -np.cos(np.pi*i / (n-1))
for i in range(n):
tlg[i] = -np.cos(np.pi * i / (n - 1))
tmp = 2.0
for i in range(maxiter):
tmp = tlg
vdm[:,0] = 1.0
vdm[:,1] = tlg
vdm[:, 0] = 1.0
vdm[:, 1] = tlg
for k in range(2, n):
vdm[:, k] = ( (2*k-1) * tlg * vdm[:, k-1] - (k-1) * vdm[:, k-2] ) / k
tlg = tmp - ( tlg*vdm[:, n-1] - vdm[:, n-2] ) / ( n * vdm[:, n-1])
if (max(abs(tlg - tmp).flatten()) < tol ):
break
wlg = 2.0 / ( (n*(n-1))*(vdm[:, n-1]**2))
vdm[:, k] = ((2 * k - 1) * tlg * vdm[:, k - 1] - (k - 1) * vdm[:, k - 2]) / k
tlg = tmp - (tlg * vdm[:, n - 1] - vdm[:, n - 2]) / (n * vdm[:, n - 1])
if max(abs(tlg - tmp).flatten()) < tol:
break
wlg = 2.0 / ((n * (n - 1)) * (vdm[:, n - 1] ** 2))
# rescale
tlg = (b - a) * 0.5 * tlg + (b + a) * 0.5
wlg = wlg * (b - a) * 0.5
return tlg, wlg
......@@ -140,12 +147,12 @@ def clenshaw_curtiss_weights(n, a=-1.0, b=1.0):
[1] Joerg Waldvogel, Fast Construction of the Fejer and Clenshaw-Curtis Quadrature Rules; BIT Numerical Mathematics, Vol. 43, No. 1, pp. 001–018.
"""
assert(n > 1)
assert n > 1
tcc = np.cos(np.linspace(np.pi, 0, n))
if n == 2:
wcc = np.array([1., 1.])
wcc = np.array([1.0, 1.0])
else:
n1 = n - 1
......@@ -153,13 +160,13 @@ def clenshaw_curtiss_weights(n, a=-1.0, b=1.0):
l = len(N)
m = n1 - l
v = np.concatenate([2 / N / (N-2), 1 / N[-1:], np.zeros(m)])
v = np.concatenate([2 / N / (N - 2), 1 / N[-1:], np.zeros(m)])
v = 0 - v[:-1] - v[-1:0:-1]
g0 = -np.ones(n1)
g0[l] = g0[l] + n1
g0[m] = g0[m] + n1
g = g0 / (n1**2 - 1 + (n1%2))
g = g0 / (n1**2 - 1 + (n1 % 2))
wcc = np.fft.ifft(v + g).real
wcc = np.concatenate((wcc, wcc[:1]))
......@@ -169,6 +176,7 @@ def clenshaw_curtiss_weights(n, a=-1.0, b=1.0):
return tcc, wcc
def fejer2_weights(n, a=-1.0, b=1.0):
r"""
Computation of the Fejer quadrature nodes and weights.
......@@ -177,7 +185,7 @@ def fejer2_weights(n, a=-1.0, b=1.0):
[1] Joerg Waldvogel, Fast Construction of the Fejer and Clenshaw-Curtis Quadrature Rules; BIT Numerical Mathematics, Vol. 43, No. 1, pp. 001–018.
"""
assert(n > 2)
assert n > 2
tcc = np.cos(np.linspace(np.pi, 0, n))
......@@ -186,7 +194,7 @@ def fejer2_weights(n, a=-1.0, b=1.0):
l = len(N)
m = n1 - l
v = np.concatenate([2 / N / (N-2), 1 / N[-1:], np.zeros(m)])
v = np.concatenate([2 / N / (N - 2), 1 / N[-1:], np.zeros(m)])
v = 0 - v[:-1] - v[-1:0:-1]
wcc = np.fft.ifft(v).real
......@@ -196,4 +204,4 @@ def fejer2_weights(n, a=-1.0, b=1.0):
tcc = (b - a) * 0.5 * tcc + (b + a) * 0.5
wcc = wcc * (b - a) * 0.5
return tcc, wcc
\ No newline at end of file
return tcc, wcc
......@@ -68,19 +68,25 @@ class RealSHT(nn.Module):
# TODO: include assertions regarding the dimensions
# compute quadrature points
# compute quadrature points and lmax based on the exactness of the quadrature
if self.grid == "legendre-gauss":
cost, w = legendre_gauss_weights(nlat, -1, 1)
# maximum polynomial degree for Gauss Legendre is 2 * nlat - 1 >= 2 * lmax
# and therefore lmax = nlat - 1 (inclusive)
self.lmax = lmax or self.nlat
elif self.grid == "lobatto":
cost, w = lobatto_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat-1
# maximum polynomial degree for Gauss Legendre is 2 * nlat - 3 >= 2 * lmax
# and therefore lmax = nlat - 2 (inclusive)
self.lmax = lmax or self.nlat - 1
elif self.grid == "equiangular":
cost, w = clenshaw_curtiss_weights(nlat, -1, 1)
# cost, w = fejer2_weights(nlat, -1, 1)
# in principle, Clenshaw-Curtiss quadrature is only exact up to polynomial degrees of nlat
# however, we observe that the quadrature is remarkably accurate for higher degress. This is why we do not
# choose a lower lmax for now.
self.lmax = lmax or self.nlat
else:
raise(ValueError("Unknown quadrature mode"))
raise (ValueError("Unknown quadrature mode"))
# apply cosine transform and flip them
tq = np.flip(np.arccos(cost))
......@@ -92,24 +98,24 @@ class RealSHT(nn.Module):
weights = torch.from_numpy(w)
pct = _precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
pct = torch.from_numpy(pct)
weights = torch.einsum('mlk,k->mlk', pct, weights)
weights = torch.einsum("mlk,k->mlk", pct, weights)
# remember quadrature weights
self.register_buffer('weights', weights, persistent=False)
self.register_buffer("weights", weights, persistent=False)
def extra_repr(self):
r"""
Pretty print module
"""
return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'
return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}"
def forward(self, x: torch.Tensor):
if x.dim() < 2:
raise ValueError(f"Expected tensor with at least 2 dimensions but got {x.dim()} instead")
assert(x.shape[-2] == self.nlat)
assert(x.shape[-1] == self.nlon)
assert x.shape[-2] == self.nlat
assert x.shape[-1] == self.nlon
# apply real fft in the longitudinal direction
x = 2.0 * torch.pi * torch.fft.rfft(x, dim=-1, norm="forward")
......@@ -124,12 +130,13 @@ class RealSHT(nn.Module):
xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device)
# contraction
xout[..., 0] = torch.einsum('...km,mlk->...lm', x[..., :self.mmax, 0], self.weights.to(x.dtype) )
xout[..., 1] = torch.einsum('...km,mlk->...lm', x[..., :self.mmax, 1], self.weights.to(x.dtype) )
xout[..., 0] = torch.einsum("...km,mlk->...lm", x[..., : self.mmax, 0], self.weights.to(x.dtype))
xout[..., 1] = torch.einsum("...km,mlk->...lm", x[..., : self.mmax, 1], self.weights.to(x.dtype))
x = torch.view_as_complex(xout)
return x
class InverseRealSHT(nn.Module):
r"""
Defines a module for computing the inverse (real-valued) SHT.
......@@ -157,12 +164,12 @@ class InverseRealSHT(nn.Module):
self.lmax = lmax or self.nlat
elif self.grid == "lobatto":
cost, _ = lobatto_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat-1
self.lmax = lmax or self.nlat - 1
elif self.grid == "equiangular":
cost, _ = clenshaw_curtiss_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat
else:
raise(ValueError("Unknown quadrature mode"))
raise (ValueError("Unknown quadrature mode"))
# apply cosine transform and flip them
t = np.flip(np.arccos(cost))
......@@ -174,27 +181,27 @@ class InverseRealSHT(nn.Module):
pct = torch.from_numpy(pct)
# register buffer
self.register_buffer('pct', pct, persistent=False)
self.register_buffer("pct", pct, persistent=False)
def extra_repr(self):
r"""
Pretty print module
"""
return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'
return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}"
def forward(self, x: torch.Tensor):
if len(x.shape) < 2:
raise ValueError(f"Expected tensor with at least 2 dimensions but got {len(x.shape)} instead")
assert(x.shape[-2] == self.lmax)
assert(x.shape[-1] == self.mmax)
assert x.shape[-2] == self.lmax
assert x.shape[-1] == self.mmax
# Evaluate associated Legendre functions on the output nodes
x = torch.view_as_real(x)
rl = torch.einsum('...lm, mlk->...km', x[..., 0], self.pct.to(x.dtype) )
im = torch.einsum('...lm, mlk->...km', x[..., 1], self.pct.to(x.dtype) )
rl = torch.einsum("...lm, mlk->...km", x[..., 0], self.pct.to(x.dtype))
im = torch.einsum("...lm, mlk->...km", x[..., 1], self.pct.to(x.dtype))
xs = torch.stack((rl, im), -1)
# apply the inverse (real) FFT
......@@ -238,13 +245,12 @@ class RealVectorSHT(nn.Module):
self.lmax = lmax or self.nlat
elif self.grid == "lobatto":
cost, w = lobatto_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat-1
self.lmax = lmax or self.nlat - 1
elif self.grid == "equiangular":
cost, w = clenshaw_curtiss_weights(nlat, -1, 1)
# cost, w = fejer2_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat
else:
raise(ValueError("Unknown quadrature mode"))
raise (ValueError("Unknown quadrature mode"))
# apply cosine transform and flip them
tq = np.flip(np.arccos(cost))
......@@ -258,28 +264,28 @@ class RealVectorSHT(nn.Module):
# combine integration weights, normalization factor in to one:
l = torch.arange(0, self.lmax)
norm_factor = 1. / l / (l+1)
norm_factor[0] = 1.
weights = torch.einsum('dmlk,k,l->dmlk', dpct, weights, norm_factor)
norm_factor = 1.0 / l / (l + 1)
norm_factor[0] = 1.0
weights = torch.einsum("dmlk,k,l->dmlk", dpct, weights, norm_factor)
# since the second component is imaginary, we need to take complex conjugation into account
weights[1] = -1 * weights[1]
# remember quadrature weights
self.register_buffer('weights', weights, persistent=False)
self.register_buffer("weights", weights, persistent=False)
def extra_repr(self):
r"""
Pretty print module
"""
return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'
return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}"
def forward(self, x: torch.Tensor):
if x.dim() < 3:
raise ValueError(f"Expected tensor with at least 3 dimensions but got {x.dim()} instead")
assert(x.shape[-2] == self.nlat)
assert(x.shape[-1] == self.nlon)
assert x.shape[-2] == self.nlat
assert x.shape[-1] == self.nlon
# apply real fft in the longitudinal direction
x = 2.0 * torch.pi * torch.fft.rfft(x, dim=-1, norm="forward")
......@@ -295,20 +301,24 @@ class RealVectorSHT(nn.Module):
# contraction - spheroidal component
# real component
xout[..., 0, :, :, 0] = torch.einsum('...km,mlk->...lm', x[..., 0, :, :self.mmax, 0], self.weights[0].to(x.dtype)) \
- torch.einsum('...km,mlk->...lm', x[..., 1, :, :self.mmax, 1], self.weights[1].to(x.dtype))
xout[..., 0, :, :, 0] = torch.einsum("...km,mlk->...lm", x[..., 0, :, : self.mmax, 0], self.weights[0].to(x.dtype)) - torch.einsum(
"...km,mlk->...lm", x[..., 1, :, : self.mmax, 1], self.weights[1].to(x.dtype)
)
# iamg component
xout[..., 0, :, :, 1] = torch.einsum('...km,mlk->...lm', x[..., 0, :, :self.mmax, 1], self.weights[0].to(x.dtype)) \
+ torch.einsum('...km,mlk->...lm', x[..., 1, :, :self.mmax, 0], self.weights[1].to(x.dtype))
xout[..., 0, :, :, 1] = torch.einsum("...km,mlk->...lm", x[..., 0, :, : self.mmax, 1], self.weights[0].to(x.dtype)) + torch.einsum(
"...km,mlk->...lm", x[..., 1, :, : self.mmax, 0], self.weights[1].to(x.dtype)
)
# contraction - toroidal component
# real component
xout[..., 1, :, :, 0] = - torch.einsum('...km,mlk->...lm', x[..., 0, :, :self.mmax, 1], self.weights[1].to(x.dtype)) \
- torch.einsum('...km,mlk->...lm', x[..., 1, :, :self.mmax, 0], self.weights[0].to(x.dtype))
xout[..., 1, :, :, 0] = -torch.einsum("...km,mlk->...lm", x[..., 0, :, : self.mmax, 1], self.weights[1].to(x.dtype)) - torch.einsum(
"...km,mlk->...lm", x[..., 1, :, : self.mmax, 0], self.weights[0].to(x.dtype)
)
# imag component
xout[..., 1, :, :, 1] = torch.einsum('...km,mlk->...lm', x[..., 0, :, :self.mmax, 0], self.weights[1].to(x.dtype)) \
- torch.einsum('...km,mlk->...lm', x[..., 1, :, :self.mmax, 1], self.weights[0].to(x.dtype))
xout[..., 1, :, :, 1] = torch.einsum("...km,mlk->...lm", x[..., 0, :, : self.mmax, 0], self.weights[1].to(x.dtype)) - torch.einsum(
"...km,mlk->...lm", x[..., 1, :, : self.mmax, 1], self.weights[0].to(x.dtype)
)
return torch.view_as_complex(xout)
......@@ -321,6 +331,7 @@ class InverseRealVectorSHT(nn.Module):
[1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
"""
def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
super().__init__()
......@@ -337,12 +348,12 @@ class InverseRealVectorSHT(nn.Module):
self.lmax = lmax or self.nlat
elif self.grid == "lobatto":
cost, _ = lobatto_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat-1
self.lmax = lmax or self.nlat - 1
elif self.grid == "equiangular":
cost, _ = clenshaw_curtiss_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat
else:
raise(ValueError("Unknown quadrature mode"))
raise (ValueError("Unknown quadrature mode"))
# apply cosine transform and flip them
t = np.flip(np.arccos(cost))
......@@ -354,40 +365,36 @@ class InverseRealVectorSHT(nn.Module):
dpct = torch.from_numpy(dpct)
# register weights
self.register_buffer('dpct', dpct, persistent=False)
self.register_buffer("dpct", dpct, persistent=False)
def extra_repr(self):
r"""
Pretty print module
"""
return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'
return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}"
def forward(self, x: torch.Tensor):
if x.dim() < 3:
raise ValueError(f"Expected tensor with at least 3 dimensions but got {x.dim()} instead")
assert(x.shape[-2] == self.lmax)
assert(x.shape[-1] == self.mmax)
assert x.shape[-2] == self.lmax
assert x.shape[-1] == self.mmax
# Evaluate associated Legendre functions on the output nodes
x = torch.view_as_real(x)
# contraction - spheroidal component
# real component
srl = torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 0], self.dpct[0].to(x.dtype)) \
- torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 1], self.dpct[1].to(x.dtype))
srl = torch.einsum("...lm,mlk->...km", x[..., 0, :, :, 0], self.dpct[0].to(x.dtype)) - torch.einsum("...lm,mlk->...km", x[..., 1, :, :, 1], self.dpct[1].to(x.dtype))
# iamg component
sim = torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 1], self.dpct[0].to(x.dtype)) \
+ torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 0], self.dpct[1].to(x.dtype))
sim = torch.einsum("...lm,mlk->...km", x[..., 0, :, :, 1], self.dpct[0].to(x.dtype)) + torch.einsum("...lm,mlk->...km", x[..., 1, :, :, 0], self.dpct[1].to(x.dtype))
# contraction - toroidal component
# real component
trl = - torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 1], self.dpct[1].to(x.dtype)) \
- torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 0], self.dpct[0].to(x.dtype))
trl = -torch.einsum("...lm,mlk->...km", x[..., 0, :, :, 1], self.dpct[1].to(x.dtype)) - torch.einsum("...lm,mlk->...km", x[..., 1, :, :, 0], self.dpct[0].to(x.dtype))
# imag component
tim = torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 0], self.dpct[1].to(x.dtype)) \
- torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 1], self.dpct[0].to(x.dtype))
tim = torch.einsum("...lm,mlk->...km", x[..., 0, :, :, 0], self.dpct[1].to(x.dtype)) - torch.einsum("...lm,mlk->...km", x[..., 1, :, :, 1], self.dpct[0].to(x.dtype))
# reassemble
s = torch.stack((srl, sim), -1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment