Commit 60aea808 authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

Cleaning up normalization of DISCO convolutions

parent e1e079b9
...@@ -430,21 +430,21 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0): ...@@ -430,21 +430,21 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
normalization_layer="none", normalization_layer="none",
) )
models[f"lsno_sc2_layers4_e32"] = partial( # models[f"lsno_sc2_layers4_e32"] = partial(
LSNO, # LSNO,
spectral_transform="sht", # spectral_transform="sht",
img_size=(nlat, nlon), # img_size=(nlat, nlon),
grid=grid, # grid=grid,
num_layers=4, # num_layers=4,
scale_factor=2, # scale_factor=2,
embed_dim=32, # embed_dim=32,
operator_type="driscoll-healy", # operator_type="driscoll-healy",
activation_function="gelu", # activation_function="gelu",
big_skip=True, # big_skip=True,
pos_embed=False, # pos_embed=False,
use_mlp=True, # use_mlp=True,
normalization_layer="none", # normalization_layer="none",
) # )
# iterate over models and train each model # iterate over models and train each model
root_path = os.path.dirname(__file__) root_path = os.path.dirname(__file__)
......
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -40,8 +40,8 @@ def plot_sphere(data, ...@@ -40,8 +40,8 @@ def plot_sphere(data,
title=None, title=None,
colorbar=False, colorbar=False,
coastlines=False, coastlines=False,
central_latitude=20, central_latitude=0,
central_longitude=20, central_longitude=0,
lon=None, lon=None,
lat=None, lat=None,
**kwargs): **kwargs):
...@@ -74,14 +74,15 @@ def plot_sphere(data, ...@@ -74,14 +74,15 @@ def plot_sphere(data,
return im return im
def plot_data(data, def plot_data(data,
fig=None, fig=None,
projection=None, cmap="RdBu",
cmap="RdBu", title=None,
title=None, colorbar=False,
colorbar=False, coastlines=False,
lon=None, central_longitude=0,
lat=None, lon=None,
**kwargs): lat=None,
**kwargs):
if fig == None: if fig == None:
fig = plt.figure() fig = plt.figure()
...@@ -93,16 +94,19 @@ def plot_data(data, ...@@ -93,16 +94,19 @@ def plot_data(data,
lat = np.linspace(np.pi/2., -np.pi/2., nlat) lat = np.linspace(np.pi/2., -np.pi/2., nlat)
Lon, Lat = np.meshgrid(lon, lat) Lon, Lat = np.meshgrid(lon, lat)
fig = plt.figure(figsize=(10, 5)) proj = ccrs.PlateCarree(central_longitude=central_longitude)
ax = fig.add_subplot(1, 1, 1, projection=projection) # proj = ccrs.Mollweide(central_longitude=central_longitude)
im = ax.pcolormesh(Lon, Lat, data, cmap=cmap, **kwargs)
ax = fig.add_subplot(projection=proj)
Lon = Lon*180/np.pi
Lat = Lat*180/np.pi
# contour data over the map.
im = ax.pcolormesh(Lon, Lat, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=False, **kwargs)
if coastlines:
ax.add_feature(cartopy.feature.COASTLINE, edgecolor='white', facecolor='none', linewidth=1.5)
if colorbar: if colorbar:
plt.colorbar(im) plt.colorbar(im)
plt.title(title, y=1.05) plt.title(title, y=1.05)
ax.grid(True)
ax.set_xticklabels([])
ax.set_yticklabels([])
return im return im
\ No newline at end of file
...@@ -41,92 +41,110 @@ from torch_harmonics import * ...@@ -41,92 +41,110 @@ from torch_harmonics import *
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes
def _compute_vals_isotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, r_cutoff: float): # def _compute_vals_isotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, r_cutoff: float):
""" # """
helper routine to compute the values of the isotropic kernel densely # helper routine to compute the values of the isotropic kernel densely
""" # """
kernel_size = (nr // 2) + nr % 2 # kernel_size = (nr // 2) + nr % 2
ikernel = torch.arange(kernel_size).reshape(-1, 1, 1) # ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
dr = 2 * r_cutoff / (nr + 1) # dr = 2 * r_cutoff / (nr + 1)
# compute the support # # compute the support
if nr % 2 == 1: # if nr % 2 == 1:
ir = ikernel * dr # ir = ikernel * dr
else: # else:
ir = (ikernel + 0.5) * dr # ir = (ikernel + 0.5) * dr
vals = torch.where( # vals = torch.where(
((r - ir).abs() <= dr) & (r <= r_cutoff), # ((r - ir).abs() <= dr) & (r <= r_cutoff),
(1 - (r - ir).abs() / dr), # (1 - (r - ir).abs() / dr),
0, # 0,
) # )
return vals # return vals
def _compute_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, nphi: int, r_cutoff: float): # def _compute_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, nphi: int, r_cutoff: float):
""" # """
helper routine to compute the values of the anisotropic kernel densely # helper routine to compute the values of the anisotropic kernel densely
""" # """
kernel_size = (nr // 2) * nphi + nr % 2 # kernel_size = (nr // 2) * nphi + nr % 2
ikernel = torch.arange(kernel_size).reshape(-1, 1, 1) # ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
dr = 2 * r_cutoff / (nr + 1) # dr = 2 * r_cutoff / (nr + 1)
dphi = 2.0 * math.pi / nphi # dphi = 2.0 * math.pi / nphi
# disambiguate even and uneven cases and compute the support # # disambiguate even and uneven cases and compute the support
if nr % 2 == 1: # if nr % 2 == 1:
ir = ((ikernel - 1) // nphi + 1) * dr # ir = ((ikernel - 1) // nphi + 1) * dr
iphi = ((ikernel - 1) % nphi) * dphi # iphi = ((ikernel - 1) % nphi) * dphi
else: # else:
ir = (ikernel // nphi + 0.5) * dr # ir = (ikernel // nphi + 0.5) * dr
iphi = (ikernel % nphi) * dphi # iphi = (ikernel % nphi) * dphi
# compute the value of the filter # # compute the value of the filter
if nr % 2 == 1: # if nr % 2 == 1:
# find the indices where the rotated position falls into the support of the kernel # # find the indices where the rotated position falls into the support of the kernel
cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff) # cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
cond_phi = ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi) # cond_phi = ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi)
r_vals = torch.where(cond_r, (1 - (r - ir).abs() / dr), 0.0) # r_vals = torch.where(cond_r, (1 - (r - ir).abs() / dr), 0.0)
phi_vals = torch.where(cond_phi, (1 - torch.minimum((phi - iphi).abs(), (2 * math.pi - (phi - iphi).abs())) / dphi), 0.0) # phi_vals = torch.where(cond_phi, (1 - torch.minimum((phi - iphi).abs(), (2 * math.pi - (phi - iphi).abs())) / dphi), 0.0)
vals = torch.where(ikernel > 0, r_vals * phi_vals, r_vals) # vals = torch.where(ikernel > 0, r_vals * phi_vals, r_vals)
else: # else:
# find the indices where the rotated position falls into the support of the kernel # # find the indices where the rotated position falls into the support of the kernel
cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff) # cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
cond_phi = ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi) # cond_phi = ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi)
r_vals = torch.where(cond_r, (1 - (r - ir).abs() / dr), 0.0) # r_vals = torch.where(cond_r, (1 - (r - ir).abs() / dr), 0.0)
phi_vals = torch.where(cond_phi, (1 - torch.minimum((phi - iphi).abs(), (2 * math.pi - (phi - iphi).abs())) / dphi), 0.0) # phi_vals = torch.where(cond_phi, (1 - torch.minimum((phi - iphi).abs(), (2 * math.pi - (phi - iphi).abs())) / dphi), 0.0)
vals = r_vals * phi_vals # vals = r_vals * phi_vals
# in the even case, the inner casis functions overlap into areas with a negative areas # # in the even case, the inner casis functions overlap into areas with a negative areas
rn = -r # rn = -r
phin = torch.where(phi + math.pi >= 2 * math.pi, phi - math.pi, phi + math.pi) # phin = torch.where(phi + math.pi >= 2 * math.pi, phi - math.pi, phi + math.pi)
cond_rn = ((rn - ir).abs() <= dr) & (rn <= r_cutoff) # cond_rn = ((rn - ir).abs() <= dr) & (rn <= r_cutoff)
cond_phin = ((phin - iphi).abs() <= dphi) | ((2 * math.pi - (phin - iphi).abs()) <= dphi) # cond_phin = ((phin - iphi).abs() <= dphi) | ((2 * math.pi - (phin - iphi).abs()) <= dphi)
rn_vals = torch.where(cond_rn, (1 - (rn - ir).abs() / dr), 0.0) # rn_vals = torch.where(cond_rn, (1 - (rn - ir).abs() / dr), 0.0)
phin_vals = torch.where(cond_phin, (1 - torch.minimum((phin - iphi).abs(), (2 * math.pi - (phin - iphi).abs())) / dphi), 0.0) # phin_vals = torch.where(cond_phin, (1 - torch.minimum((phin - iphi).abs(), (2 * math.pi - (phin - iphi).abs())) / dphi), 0.0)
vals += rn_vals * phin_vals # vals += rn_vals * phin_vals
return vals # return vals
def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalization=False, merge_quadrature=False, eps=1e-9): def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalization=False, basis_norm_mode="none", merge_quadrature=False, eps=1e-9):
""" """
Discretely normalizes the convolution tensor. Discretely normalizes the convolution tensor.
""" """
kernel_size, nlat_out, nlon_out, nlat_in, nlon_in = psi.shape kernel_size, nlat_out, nlon_out, nlat_in, nlon_in = psi.shape
scale_factor = float(nlon_in // nlon_out) correction_factor = nlon_out / nlon_in
if basis_norm_mode == "individual":
if transpose_normalization:
# the normalization is not quite symmetric due to the compressed way psi is stored in the main code
# look at the normalization code in the actual implementation
psi_norm = torch.sqrt(torch.sum(quad_weights.reshape(1, -1, 1, 1, 1) * psi[:, :, :1].abs().pow(2), dim=(1, 4), keepdim=True) / 4 / math.pi)
else:
psi_norm = torch.sqrt(torch.sum(quad_weights.reshape(1, 1, 1, -1, 1) * psi.abs().pow(2), dim=(3, 4), keepdim=True) / 4 / math.pi)
elif basis_norm_mode == "mean":
if transpose_normalization:
# the normalization is not quite symmetric due to the compressed way psi is stored in the main code
# look at the normalization code in the actual implementation
psi_norm = torch.sqrt(torch.sum(quad_weights.reshape(1, -1, 1, 1, 1) * psi[:, :, :1].abs().pow(2), dim=(1, 4), keepdim=True) / 4 / math.pi)
psi_norm = psi_norm.mean(dim=3, keepdim=True)
else:
psi_norm = torch.sqrt(torch.sum(quad_weights.reshape(1, 1, 1, -1, 1) * psi.abs().pow(2), dim=(3, 4), keepdim=True) / 4 / math.pi)
psi_norm = psi_norm.mean(dim=1, keepdim=True)
elif basis_norm_mode == "none":
psi_norm = 1.0
else:
raise ValueError(f"Unknown basis normalization mode {basis_norm_mode}.")
if transpose_normalization: if transpose_normalization:
# the normalization is not quite symmetric due to the compressed way psi is stored in the main code
# look at the normalization code in the actual implementation
psi_norm = torch.sum(quad_weights.reshape(1, -1, 1, 1, 1) * psi[:, :, :1], dim=(1, 4), keepdim=True) / scale_factor
if merge_quadrature: if merge_quadrature:
psi = quad_weights.reshape(1, -1, 1, 1, 1) * psi psi = quad_weights.reshape(1, -1, 1, 1, 1) * psi / correction_factor
else: else:
psi_norm = torch.sum(quad_weights.reshape(1, 1, 1, -1, 1) * psi, dim=(3, 4), keepdim=True)
if merge_quadrature: if merge_quadrature:
psi = quad_weights.reshape(1, 1, 1, -1, 1) * psi psi = quad_weights.reshape(1, 1, 1, -1, 1) * psi
...@@ -137,11 +155,12 @@ def _precompute_convolution_tensor_dense( ...@@ -137,11 +155,12 @@ def _precompute_convolution_tensor_dense(
in_shape, in_shape,
out_shape, out_shape,
kernel_shape, kernel_shape,
quad_weights, filter_basis,
grid_in="equiangular", grid_in="equiangular",
grid_out="equiangular", grid_out="equiangular",
theta_cutoff=0.01 * math.pi, theta_cutoff=0.01 * math.pi,
transpose_normalization=False, transpose_normalization=False,
basis_norm_mode="none",
merge_quadrature=False, merge_quadrature=False,
): ):
""" """
...@@ -151,29 +170,26 @@ def _precompute_convolution_tensor_dense( ...@@ -151,29 +170,26 @@ def _precompute_convolution_tensor_dense(
assert len(in_shape) == 2 assert len(in_shape) == 2
assert len(out_shape) == 2 assert len(out_shape) == 2
quad_weights = quad_weights.reshape(-1, 1) kernel_size = filter_basis.kernel_size
if len(kernel_shape) == 1:
kernel_handle = partial(_compute_vals_isotropic, nr=kernel_shape[0], r_cutoff=theta_cutoff)
kernel_size = math.ceil(kernel_shape[0] / 2)
elif len(kernel_shape) == 2:
kernel_handle = partial(_compute_vals_anisotropic, nr=kernel_shape[0], nphi=kernel_shape[1], r_cutoff=theta_cutoff)
kernel_size = (kernel_shape[0] // 2) * kernel_shape[1] + kernel_shape[0] % 2
else:
raise ValueError("kernel_shape should be either one- or two-dimensional.")
nlat_in, nlon_in = in_shape nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape nlat_out, nlon_out = out_shape
lats_in, _ = quadrature._precompute_latitudes(nlat_in, grid=grid_in) lats_in, win = quadrature._precompute_latitudes(nlat_in, grid=grid_in)
lats_in = torch.from_numpy(lats_in).float() lats_in = torch.from_numpy(lats_in).float()
lats_out, _ = quadrature._precompute_latitudes(nlat_out, grid=grid_out) lats_out, wout = quadrature._precompute_latitudes(nlat_out, grid=grid_out)
lats_out = torch.from_numpy(lats_out).float() # array for accumulating non-zero indices lats_out = torch.from_numpy(lats_out).float() # array for accumulating non-zero indices
# compute the phi differences. We need to make the linspace exclusive to not double the last point # compute the phi differences. We need to make the linspace exclusive to not double the last point
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1] lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1]
lons_out = torch.linspace(0, 2 * math.pi, nlon_out + 1)[:-1] lons_out = torch.linspace(0, 2 * math.pi, nlon_out + 1)[:-1]
# compute quadrature weights that will be merged into the Psi tensor
if transpose_normalization:
quad_weights = 2.0 * torch.pi * torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in
else:
quad_weights = 2.0 * torch.pi * torch.from_numpy(win).float().reshape(-1, 1) / nlon_in
out = torch.zeros(kernel_size, nlat_out, nlon_out, nlat_in, nlon_in) out = torch.zeros(kernel_size, nlat_out, nlon_out, nlat_in, nlon_in)
for t in range(nlat_out): for t in range(nlat_out):
...@@ -187,7 +203,7 @@ def _precompute_convolution_tensor_dense( ...@@ -187,7 +203,7 @@ def _precompute_convolution_tensor_dense(
# compute cartesian coordinates of the rotated position # compute cartesian coordinates of the rotated position
x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(gamma) * torch.sin(alpha) x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(gamma) * torch.sin(alpha)
y = torch.sin(beta) * torch.sin(gamma) y = torch.sin(beta) * torch.sin(gamma) * torch.ones_like(alpha)
# normalize instead of clipping to ensure correct range # normalize instead of clipping to ensure correct range
norm = torch.sqrt(x * x + y * y + z * z) norm = torch.sqrt(x * x + y * y + z * z)
...@@ -197,13 +213,17 @@ def _precompute_convolution_tensor_dense( ...@@ -197,13 +213,17 @@ def _precompute_convolution_tensor_dense(
# compute spherical coordinates # compute spherical coordinates
theta = torch.arccos(z) theta = torch.arccos(z)
phi = torch.arctan2(y, x) + torch.pi phi = torch.arctan2(y, x)
phi = torch.where(phi < 0.0, phi + 2 * torch.pi, phi)
# find the indices where the rotated position falls into the support of the kernel # find the indices where the rotated position falls into the support of the kernel
out[:, t, p, :, :] = kernel_handle(theta, phi) iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff)
out[iidx[:, 0], t, p, iidx[:, 1], iidx[:, 2]] = vals
# take care of normalization # take care of normalization
out = _normalize_convolution_tensor_dense(out, quad_weights=quad_weights, transpose_normalization=transpose_normalization, merge_quadrature=merge_quadrature) out = _normalize_convolution_tensor_dense(
out, quad_weights=quad_weights, transpose_normalization=transpose_normalization, basis_norm_mode=basis_norm_mode, merge_quadrature=merge_quadrature
)
return out return out
...@@ -217,30 +237,32 @@ class TestDiscreteContinuousConvolution(unittest.TestCase): ...@@ -217,30 +237,32 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
else: else:
self.device = torch.device("cpu") self.device = torch.device("cpu")
torch.manual_seed(333) self.device = torch.device("cpu")
@parameterized.expand( @parameterized.expand(
[ [
# regular convolution # regular convolution
[8, 4, 2, (16, 32), (16, 32), [3], "equiangular", "equiangular", False, 1e-4], [8, 4, 2, (16, 32), (16, 32), [3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), [5], "equiangular", "equiangular", False, 1e-4], [8, 4, 2, (16, 32), (8, 16), [3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), [3, 3], "equiangular", "equiangular", False, 1e-4], [8, 4, 2, (24, 48), (12, 24), [3, 3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), [4, 3], "equiangular", "equiangular", False, 1e-4], [8, 4, 2, (24, 48), (12, 24), [4, 3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (16, 24), (8, 8), [3], "equiangular", "equiangular", False, 1e-4], [8, 4, 2, (24, 48), (12, 24), [2, 2], "morlet", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (18, 36), (6, 12), [7], "equiangular", "equiangular", False, 1e-4], [8, 4, 2, (16, 24), (8, 8), [3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), [5], "equiangular", "legendre-gauss", False, 1e-4], [8, 4, 2, (18, 36), (6, 12), [7], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), [5], "legendre-gauss", "equiangular", False, 1e-4], [8, 4, 2, (16, 32), (8, 16), [5], "piecewise linear", "mean", "equiangular", "legendre-gauss", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), [5], "legendre-gauss", "legendre-gauss", False, 1e-4], [8, 4, 2, (16, 32), (8, 16), [5], "piecewise linear", "mean", "legendre-gauss", "equiangular", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), [5], "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", False, 1e-4],
# transpose convolution # transpose convolution
[8, 4, 2, (16, 32), (16, 32), [3], "equiangular", "equiangular", True, 1e-4], [8, 4, 2, (16, 32), (16, 32), [3], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), [5], "equiangular", "equiangular", True, 1e-4], [8, 4, 2, (8, 16), (16, 32), [5], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), [3, 3], "equiangular", "equiangular", True, 1e-4], [8, 4, 2, (12, 24), (24, 48), [3, 3], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), [4, 3], "equiangular", "equiangular", True, 1e-4], [8, 4, 2, (12, 24), (24, 48), [4, 3], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (8, 8), (16, 24), [3], "equiangular", "equiangular", True, 1e-4], [8, 4, 2, (12, 24), (24, 48), [2, 2], "morlet", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (6, 12), (18, 36), [7], "equiangular", "equiangular", True, 1e-4], [8, 4, 2, (8, 8), (16, 24), [3], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), [5], "equiangular", "legendre-gauss", True, 1e-4], [8, 4, 2, (6, 12), (18, 36), [7], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), [5], "legendre-gauss", "equiangular", True, 1e-4], [8, 4, 2, (8, 16), (16, 32), [5], "piecewise linear", "mean", "equiangular", "legendre-gauss", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), [5], "legendre-gauss", "legendre-gauss", True, 1e-4], [8, 4, 2, (8, 16), (16, 32), [5], "piecewise linear", "mean", "legendre-gauss", "equiangular", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), [5], "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", True, 1e-4],
] ]
) )
def test_disco_convolution( def test_disco_convolution(
...@@ -251,6 +273,8 @@ class TestDiscreteContinuousConvolution(unittest.TestCase): ...@@ -251,6 +273,8 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
in_shape, in_shape,
out_shape, out_shape,
kernel_shape, kernel_shape,
basis_type,
basis_norm_mode,
grid_in, grid_in,
grid_out, grid_out,
transpose, transpose,
...@@ -259,19 +283,38 @@ class TestDiscreteContinuousConvolution(unittest.TestCase): ...@@ -259,19 +283,38 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
nlat_in, nlon_in = in_shape nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape nlat_out, nlon_out = out_shape
theta_cutoff = (kernel_shape[0] + 1) / 2 * torch.pi / float(nlat_out - 1) theta_cutoff = (kernel_shape[0] + 1) * torch.pi / float(nlat_in - 1)
Conv = DiscreteContinuousConvTransposeS2 if transpose else DiscreteContinuousConvS2 Conv = DiscreteContinuousConvTransposeS2 if transpose else DiscreteContinuousConvS2
conv = Conv(in_channels, out_channels, in_shape, out_shape, kernel_shape, groups=1, grid_in=grid_in, grid_out=grid_out, bias=False, theta_cutoff=theta_cutoff).to( conv = Conv(
self.device in_channels,
) out_channels,
in_shape,
_, wgl = _precompute_latitudes(nlat_in, grid=grid_in) out_shape,
quad_weights = 2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / nlon_in kernel_shape,
basis_type=basis_type,
basis_norm_mode=basis_norm_mode,
groups=1,
grid_in=grid_in,
grid_out=grid_out,
bias=False,
theta_cutoff=theta_cutoff,
).to(self.device)
filter_basis = conv.filter_basis
if transpose: if transpose:
psi_dense = _precompute_convolution_tensor_dense( psi_dense = _precompute_convolution_tensor_dense(
out_shape, in_shape, kernel_shape, quad_weights, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff, transpose_normalization=True, merge_quadrature=True out_shape,
in_shape,
kernel_shape,
filter_basis,
grid_in=grid_out,
grid_out=grid_in,
theta_cutoff=theta_cutoff,
transpose_normalization=transpose,
basis_norm_mode=basis_norm_mode,
merge_quadrature=True,
).to(self.device) ).to(self.device)
psi = torch.sparse_coo_tensor(conv.psi_idx, conv.psi_vals, size=(conv.kernel_size, conv.nlat_in, conv.nlat_out * conv.nlon_out)).to_dense() psi = torch.sparse_coo_tensor(conv.psi_idx, conv.psi_vals, size=(conv.kernel_size, conv.nlat_in, conv.nlat_out * conv.nlon_out)).to_dense()
...@@ -279,7 +322,16 @@ class TestDiscreteContinuousConvolution(unittest.TestCase): ...@@ -279,7 +322,16 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
self.assertTrue(torch.allclose(psi, psi_dense[:, :, 0].reshape(-1, nlat_in, nlat_out * nlon_out))) self.assertTrue(torch.allclose(psi, psi_dense[:, :, 0].reshape(-1, nlat_in, nlat_out * nlon_out)))
else: else:
psi_dense = _precompute_convolution_tensor_dense( psi_dense = _precompute_convolution_tensor_dense(
in_shape, out_shape, kernel_shape, quad_weights, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff, transpose_normalization=False, merge_quadrature=True in_shape,
out_shape,
kernel_shape,
filter_basis,
grid_in=grid_in,
grid_out=grid_out,
theta_cutoff=theta_cutoff,
transpose_normalization=transpose,
basis_norm_mode=basis_norm_mode,
merge_quadrature=True,
).to(self.device) ).to(self.device)
psi = torch.sparse_coo_tensor(conv.psi_idx, conv.psi_vals, size=(conv.kernel_size, conv.nlat_out, conv.nlat_in * conv.nlon_in)).to_dense() psi = torch.sparse_coo_tensor(conv.psi_idx, conv.psi_vals, size=(conv.kernel_size, conv.nlat_out, conv.nlat_in * conv.nlon_in)).to_dense()
......
...@@ -183,21 +183,23 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): ...@@ -183,21 +183,23 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
@parameterized.expand( @parameterized.expand(
[ [
[128, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", False, 1e-5], [128, 256, 128, 256, 32, 8, [3], "piecewise linear", "individual", 1, "equiangular", "equiangular", False, 1e-5],
[129, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", False, 1e-5], [129, 256, 128, 256, 32, 8, [3], "piecewise linear", "individual", 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, [3, 2], 1, "equiangular", "equiangular", False, 1e-5], [128, 256, 128, 256, 32, 8, [3, 2], "piecewise linear", "individual", 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 64, 128, 32, 8, [3], 1, "equiangular", "equiangular", False, 1e-5], [128, 256, 64, 128, 32, 8, [3], "piecewise linear", "individual", 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, [3], 2, "equiangular", "equiangular", False, 1e-5], [128, 256, 128, 256, 32, 8, [3], "piecewise linear", "individual", 2, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 6, [3], 1, "equiangular", "equiangular", False, 1e-5], [128, 256, 128, 256, 32, 6, [3], "piecewise linear", "individual", 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-5], [128, 256, 128, 256, 32, 8, [3], "piecewise linear", "individual", 1, "equiangular", "equiangular", True, 1e-5],
[129, 256, 129, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-5], [129, 256, 129, 256, 32, 8, [3], "piecewise linear", "individual", 1, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 8, [3, 2], 1, "equiangular", "equiangular", True, 1e-5], [128, 256, 128, 256, 32, 8, [3, 2], "piecewise linear", "individual", 1, "equiangular", "equiangular", True, 1e-5],
[64, 128, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-5], [64, 128, 128, 256, 32, 8, [3], "piecewise linear", "individual", 1, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 8, [3], 2, "equiangular", "equiangular", True, 1e-5], [128, 256, 128, 256, 32, 8, [3], "piecewise linear", "individual", 2, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 6, [3], 1, "equiangular", "equiangular", True, 1e-5], [128, 256, 128, 256, 32, 6, [3], "piecewise linear", "individual", 1, "equiangular", "equiangular", True, 1e-5],
] ]
) )
def test_distributed_disco_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, kernel_shape, groups, grid_in, grid_out, transpose, tol): def test_distributed_disco_conv(
self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, kernel_shape, basis_type, basis_norm_mode, groups, grid_in, grid_out, transpose, tol
):
B, C, H, W = batch_size, num_chan, nlat_in, nlon_in B, C, H, W = batch_size, num_chan, nlat_in, nlon_in
...@@ -206,6 +208,8 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): ...@@ -206,6 +208,8 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
out_channels=C, out_channels=C,
in_shape=(nlat_in, nlon_in), in_shape=(nlat_in, nlon_in),
out_shape=(nlat_out, nlon_out), out_shape=(nlat_out, nlon_out),
basis_type=basis_type,
basis_norm_mode=basis_norm_mode,
kernel_shape=kernel_shape, kernel_shape=kernel_shape,
groups=groups, groups=groups,
grid_in=grid_in, grid_in=grid_in,
......
...@@ -57,70 +57,71 @@ except ImportError as err: ...@@ -57,70 +57,71 @@ except ImportError as err:
def _normalize_convolution_tensor_s2( def _normalize_convolution_tensor_s2(
psi_idx, psi_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=False, basis_norm_mode="sum", merge_quadrature=False, eps=1e-9 psi_idx, psi_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=False, basis_norm_mode="none", merge_quadrature=False, eps=1e-9
): ):
""" """
Discretely normalizes the convolution tensor. Supports different normalization modes Discretely normalizes the convolution tensor and pre-applies quadrature weights. Supports the following three normalization modes:
- "none": No normalization is applied.
- "individual": for each output latitude and filter basis function the filter is numerically integrated over the sphere and normalized so that it yields 1.
- "mean": the norm is computed for each output latitude and then averaged over the output latitudes. Each basis function is then normalized by this mean.
""" """
nlat_in, nlon_in = in_shape # reshape the indices implicitly to be ikernel, out_shape[0], in_shape[0], in_shape[1]
nlat_out, nlon_out = out_shape idx = torch.stack([psi_idx[0], psi_idx[1], psi_idx[2] // in_shape[1], psi_idx[2] % in_shape[1]], dim=0)
# reshape the indices implicitly to be ikernel, lat_out, lat_in, lon_in # getting indices for adressing kernels, input and output latitudes
idx = torch.stack([psi_idx[0], psi_idx[1], psi_idx[2] // nlon_in, psi_idx[2] % nlon_in], dim=0) ikernel = idx[0]
if transpose_normalization: if transpose_normalization:
# pre-compute the quadrature weights ilat_out = idx[2]
q = quad_weights[idx[1]].reshape(-1) ilat_in = idx[1]
# here we are deliberately swapping input and output shapes to handle transpose normalization with the same code
# loop through dimensions which require normalization nlat_out = in_shape[0]
for ik in range(kernel_size): correction_factor = out_shape[1] / in_shape[1]
for ilat in range(nlat_in):
# get relevant entries depending on the normalization mode
if basis_norm_mode == "individual":
iidx = torch.argwhere((idx[0] == ik) & (idx[2] == ilat))
# normalize, while summing also over the input longitude dimension here as this is not available for the output
vnorm = torch.sum(psi_vals[iidx].abs() * q[iidx])
elif basis_norm_mode == "sum":
# this will perform repeated normalization in the kernel dimension but this shouldn't lead to issues
iidx = torch.argwhere(idx[2] == ilat)
# normalize, while summing also over the input longitude dimension here as this is not available for the output
vnorm = torch.sum(psi_vals[iidx].abs() * q[iidx])
else:
raise ValueError(f"Unknown basis normalization mode {basis_norm_mode}.")
if merge_quadrature:
# the correction factor accounts for the difference in longitudinal grid points when the input vector is upscaled
psi_vals[iidx] = psi_vals[iidx] * q[iidx] * nlon_in / nlon_out / (vnorm + eps)
else:
psi_vals[iidx] = psi_vals[iidx] / (vnorm + eps)
else: else:
# pre-compute the quadrature weights ilat_out = idx[1]
q = quad_weights[idx[2]].reshape(-1) ilat_in = idx[2]
nlat_out = out_shape[0]
# loop through dimensions which require normalization
for ik in range(kernel_size): # get the quadrature weights
for ilat in range(nlat_out): q = quad_weights[ilat_in].reshape(-1)
# get relevant entries depending on the normalization mode # buffer to store intermediate values
if basis_norm_mode == "individual": vnorm = torch.zeros(kernel_size, nlat_out)
iidx = torch.argwhere((idx[0] == ik) & (idx[1] == ilat))
# normalize # loop through dimensions to compute the norms
vnorm = torch.sum(psi_vals[iidx].abs() * q[iidx]) for ik in range(kernel_size):
elif basis_norm_mode == "sum": for ilat in range(nlat_out):
# this will perform repeated normalization in the kernel dimension but this shouldn't lead to issues
iidx = torch.argwhere(idx[1] == ilat) # find indices corresponding to the given output latitude and kernel basis function
# normalize iidx = torch.argwhere((ikernel == ik) & (ilat_out == ilat))
vnorm = torch.sum(psi_vals[iidx].abs() * q[iidx])
else: # compute the 2-norm, accounting for the fact that it is 4-pi normalized
raise ValueError(f"Unknown basis normalization mode {basis_norm_mode}.") vnorm[ik, ilat] = torch.sqrt(torch.sum(psi_vals[iidx].abs().pow(2) * q[iidx]) / 4 / torch.pi)
if merge_quadrature: # loop over values and renormalize
psi_vals[iidx] = psi_vals[iidx] * q[iidx] / (vnorm + eps) for ik in range(kernel_size):
else: for ilat in range(nlat_out):
psi_vals[iidx] = psi_vals[iidx] / (vnorm + eps)
iidx = torch.argwhere((ikernel == ik) & (ilat_out == ilat))
if basis_norm_mode == "individual":
val = vnorm[ik, ilat]
elif basis_norm_mode == "mean":
val = vnorm[ik, :].mean()
elif basis_norm_mode == "none":
val = 1.0
else:
raise ValueError(f"Unknown basis normalization mode {basis_norm_mode}.")
if merge_quadrature:
psi_vals[iidx] = psi_vals[iidx] * q[iidx] / (val + eps)
else:
psi_vals[iidx] = psi_vals[iidx] / (val + eps)
if transpose_normalization and merge_quadrature:
psi_vals = psi_vals / correction_factor
return psi_vals return psi_vals
...@@ -133,7 +134,7 @@ def _precompute_convolution_tensor_s2( ...@@ -133,7 +134,7 @@ def _precompute_convolution_tensor_s2(
grid_out="equiangular", grid_out="equiangular",
theta_cutoff=0.01 * math.pi, theta_cutoff=0.01 * math.pi,
transpose_normalization=False, transpose_normalization=False,
basis_norm_mode="sum", basis_norm_mode="none",
merge_quadrature=False, merge_quadrature=False,
): ):
""" """
...@@ -187,11 +188,11 @@ def _precompute_convolution_tensor_s2( ...@@ -187,11 +188,11 @@ def _precompute_convolution_tensor_s2(
# compute cartesian coordinates of the rotated position # compute cartesian coordinates of the rotated position
# This uses the YZY convention of Euler angles, where the last angle (alpha) is a passive rotation, # This uses the YZY convention of Euler angles, where the last angle (alpha) is a passive rotation,
# and therefore applied with a negative sign # and therefore applied with a negative sign
z = -torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(alpha) * torch.cos(gamma)
x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(gamma) * torch.sin(alpha) x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(gamma) * torch.sin(alpha)
y = torch.sin(beta) * torch.sin(gamma) y = torch.sin(beta) * torch.sin(gamma)
z = -torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(alpha) * torch.cos(gamma)
# normalization is emportant to avoid NaNs when arccos and atan are applied # normalization is important to avoid NaNs when arccos and atan are applied
# this can otherwise lead to spurious artifacts in the solution # this can otherwise lead to spurious artifacts in the solution
norm = torch.sqrt(x * x + y * y + z * z) norm = torch.sqrt(x * x + y * y + z * z)
x = x / norm x = x / norm
...@@ -200,7 +201,8 @@ def _precompute_convolution_tensor_s2( ...@@ -200,7 +201,8 @@ def _precompute_convolution_tensor_s2(
# compute spherical coordinates, where phi needs to fall into the [0, 2pi) range # compute spherical coordinates, where phi needs to fall into the [0, 2pi) range
theta = torch.arccos(z) theta = torch.arccos(z)
phi = torch.arctan2(y, x) + torch.pi phi = torch.arctan2(y, x)
phi = torch.where(phi < 0.0, phi + 2 * torch.pi, phi)
# find the indices where the rotated position falls into the support of the kernel # find the indices where the rotated position falls into the support of the kernel
iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff) iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff)
...@@ -293,7 +295,7 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv): ...@@ -293,7 +295,7 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
out_shape: Tuple[int], out_shape: Tuple[int],
kernel_shape: Union[int, List[int]], kernel_shape: Union[int, List[int]],
basis_type: Optional[str] = "piecewise linear", basis_type: Optional[str] = "piecewise linear",
basis_norm_mode: Optional[str] = "sum", basis_norm_mode: Optional[str] = "none",
groups: Optional[int] = 1, groups: Optional[int] = 1,
grid_in: Optional[str] = "equiangular", grid_in: Optional[str] = "equiangular",
grid_out: Optional[str] = "equiangular", grid_out: Optional[str] = "equiangular",
...@@ -305,6 +307,9 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv): ...@@ -305,6 +307,9 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
self.nlat_in, self.nlon_in = in_shape self.nlat_in, self.nlon_in = in_shape
self.nlat_out, self.nlon_out = out_shape self.nlat_out, self.nlon_out = out_shape
# make sure the p-shift works by checking that longitudes are divisible
assert self.nlon_in % self.nlon_out == 0
# heuristic to compute theta cutoff based on the bandlimit of the input field and overlaps of the basis functions # heuristic to compute theta cutoff based on the bandlimit of the input field and overlaps of the basis functions
if theta_cutoff is None: if theta_cutoff is None:
theta_cutoff = torch.pi / float(self.nlat_out - 1) theta_cutoff = torch.pi / float(self.nlat_out - 1)
...@@ -396,7 +401,7 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -396,7 +401,7 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
out_shape: Tuple[int], out_shape: Tuple[int],
kernel_shape: Union[int, List[int]], kernel_shape: Union[int, List[int]],
basis_type: Optional[str] = "piecewise linear", basis_type: Optional[str] = "piecewise linear",
basis_norm_mode: Optional[str] = "sum", basis_norm_mode: Optional[str] = "none",
groups: Optional[int] = 1, groups: Optional[int] = 1,
grid_in: Optional[str] = "equiangular", grid_in: Optional[str] = "equiangular",
grid_out: Optional[str] = "equiangular", grid_out: Optional[str] = "equiangular",
...@@ -408,6 +413,9 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -408,6 +413,9 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
self.nlat_in, self.nlon_in = in_shape self.nlat_in, self.nlon_in = in_shape
self.nlat_out, self.nlon_out = out_shape self.nlat_out, self.nlon_out = out_shape
# make sure the p-shift works by checking that longitudes are divisible
assert self.nlon_out % self.nlon_in == 0
# bandlimit # bandlimit
if theta_cutoff is None: if theta_cutoff is None:
theta_cutoff = torch.pi / float(self.nlat_in - 1) theta_cutoff = torch.pi / float(self.nlat_in - 1)
...@@ -415,7 +423,7 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -415,7 +423,7 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
if theta_cutoff <= 0.0: if theta_cutoff <= 0.0:
raise ValueError("Error, theta_cutoff has to be positive.") raise ValueError("Error, theta_cutoff has to be positive.")
# switch in_shape and out_shape since we want transpose conv # switch in_shape and out_shape since we want the transpose convolution
idx, vals = _precompute_convolution_tensor_s2( idx, vals = _precompute_convolution_tensor_s2(
out_shape, out_shape,
in_shape, in_shape,
......
...@@ -76,7 +76,7 @@ def _precompute_distributed_convolution_tensor_s2( ...@@ -76,7 +76,7 @@ def _precompute_distributed_convolution_tensor_s2(
grid_out="equiangular", grid_out="equiangular",
theta_cutoff=0.01 * math.pi, theta_cutoff=0.01 * math.pi,
transpose_normalization=False, transpose_normalization=False,
basis_norm_mode="sum", basis_norm_mode="none",
merge_quadrature=False, merge_quadrature=False,
): ):
""" """
...@@ -142,7 +142,8 @@ def _precompute_distributed_convolution_tensor_s2( ...@@ -142,7 +142,8 @@ def _precompute_distributed_convolution_tensor_s2(
# compute spherical coordinates, where phi needs to fall into the [0, 2pi) range # compute spherical coordinates, where phi needs to fall into the [0, 2pi) range
theta = torch.arccos(z) theta = torch.arccos(z)
phi = torch.arctan2(y, x) + torch.pi phi = torch.arctan2(y, x)
phi = torch.where(phi < 0.0, phi + 2 * torch.pi, phi)
# find the indices where the rotated position falls into the support of the kernel # find the indices where the rotated position falls into the support of the kernel
iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff) iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff)
...@@ -207,7 +208,7 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv): ...@@ -207,7 +208,7 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
out_shape: Tuple[int], out_shape: Tuple[int],
kernel_shape: Union[int, List[int]], kernel_shape: Union[int, List[int]],
basis_type: Optional[str] = "piecewise linear", basis_type: Optional[str] = "piecewise linear",
basis_norm_mode: Optional[str] = "sum", basis_norm_mode: Optional[str] = "none",
groups: Optional[int] = 1, groups: Optional[int] = 1,
grid_in: Optional[str] = "equiangular", grid_in: Optional[str] = "equiangular",
grid_out: Optional[str] = "equiangular", grid_out: Optional[str] = "equiangular",
...@@ -347,7 +348,7 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -347,7 +348,7 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
out_shape: Tuple[int], out_shape: Tuple[int],
kernel_shape: Union[int, List[int]], kernel_shape: Union[int, List[int]],
basis_type: Optional[str] = "piecewise linear", basis_type: Optional[str] = "piecewise linear",
basis_norm_mode: Optional[str] = "sum", basis_norm_mode: Optional[str] = "none",
groups: Optional[int] = 1, groups: Optional[int] = 1,
grid_in: Optional[str] = "equiangular", grid_in: Optional[str] = "equiangular",
grid_out: Optional[str] = "equiangular", grid_out: Optional[str] = "equiangular",
......
...@@ -41,12 +41,19 @@ def get_filter_basis(kernel_shape: Union[int, List[int], Tuple[int, int]], basis ...@@ -41,12 +41,19 @@ def get_filter_basis(kernel_shape: Union[int, List[int], Tuple[int, int]], basis
if basis_type == "piecewise linear": if basis_type == "piecewise linear":
return PiecewiseLinearFilterBasis(kernel_shape=kernel_shape) return PiecewiseLinearFilterBasis(kernel_shape=kernel_shape)
elif basis_type == "disk morlet": elif basis_type == "morlet":
return DiskMorletFilterBasis(kernel_shape=kernel_shape) return MorletFilterBasis(kernel_shape=kernel_shape)
elif basis_type == "zernike":
raise NotImplementedError()
else: else:
raise ValueError(f"Unknown basis_type {basis_type}") raise ValueError(f"Unknown basis_type {basis_type}")
def _circle_dist(x1: torch.Tensor, x2: torch.Tensor):
"""Helper function to compute the distance on a circle"""
return torch.minimum(torch.abs(x1 - x2), torch.abs(2 * math.pi - torch.abs(x1 - x2)))
class FilterBasis(metaclass=abc.ABCMeta): class FilterBasis(metaclass=abc.ABCMeta):
""" """
Abstract base class for a filter basis Abstract base class for a filter basis
...@@ -64,6 +71,13 @@ class FilterBasis(metaclass=abc.ABCMeta): ...@@ -64,6 +71,13 @@ class FilterBasis(metaclass=abc.ABCMeta):
def kernel_size(self): def kernel_size(self):
raise NotImplementedError raise NotImplementedError
# @abc.abstractmethod
# def compute_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float):
# """
# Computes the values of the filter basis
# """
# raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float): def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float):
""" """
...@@ -136,49 +150,49 @@ class PiecewiseLinearFilterBasis(FilterBasis): ...@@ -136,49 +150,49 @@ class PiecewiseLinearFilterBasis(FilterBasis):
# disambiguate even and uneven cases and compute the support # disambiguate even and uneven cases and compute the support
if nr % 2 == 1: if nr % 2 == 1:
ir = ((ikernel - 1) // nphi + 1) * dr ir = ((ikernel - 1) // nphi + 1) * dr
iphi = ((ikernel - 1) % nphi) * dphi iphi = ((ikernel - 1) % nphi) * dphi - math.pi
else: else:
ir = (ikernel // nphi + 0.5) * dr ir = (ikernel // nphi + 0.5) * dr
iphi = (ikernel % nphi) * dphi iphi = (ikernel % nphi) * dphi - math.pi
# find the indices where the rotated position falls into the support of the kernel # find the indices where the rotated position falls into the support of the kernel
if nr % 2 == 1: if nr % 2 == 1:
# find the support # find the support
cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff) cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
cond_phi = (ikernel == 0) | ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi) cond_phi = (ikernel == 0) | (_circle_dist(phi, iphi).abs() <= dphi)
# find indices where conditions are met # find indices where conditions are met
iidx = torch.argwhere(cond_r & cond_phi) iidx = torch.argwhere(cond_r & cond_phi)
# compute the distance to the collocation points # compute the distance to the collocation points
dist_r = (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() dist_r = (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs()
dist_phi = (phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs() dist_phi = _circle_dist(phi[iidx[:, 1], iidx[:, 2]], iphi[iidx[:, 0], 0, 0])
# compute the value of the basis functions # compute the value of the basis functions
vals = 1 - dist_r / dr vals = 1 - dist_r / dr
vals *= torch.where( vals *= torch.where(
(iidx[:, 0] > 0), (iidx[:, 0] > 0),
(1 - torch.minimum(dist_phi, (2 * math.pi - dist_phi)) / dphi), (1 - dist_phi / dphi),
1.0, 1.0,
) )
else: else:
# in the even case, the inner casis functions overlap into areas with a negative areas # in the even case, the inner basis functions overlap into areas with a negative areas
rn = -r rn = -r
phin = torch.where(phi + math.pi >= 2 * math.pi, phi - math.pi, phi + math.pi) phin = torch.where(phi + math.pi >= math.pi, phi - math.pi, phi + math.pi)
# find the support # find the support
cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff) cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
cond_phi = ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi) cond_phi = _circle_dist(phi, iphi).abs() <= dphi
cond_rn = ((rn - ir).abs() <= dr) & (rn <= r_cutoff) cond_rn = ((rn - ir).abs() <= dr) & (rn <= r_cutoff)
cond_phin = ((phin - iphi).abs() <= dphi) | ((2 * math.pi - (phin - iphi).abs()) <= dphi) cond_phin = _circle_dist(phin, iphi) <= dphi
# find indices where conditions are met # find indices where conditions are met
iidx = torch.argwhere((cond_r & cond_phi) | (cond_rn & cond_phin)) iidx = torch.argwhere((cond_r & cond_phi) | (cond_rn & cond_phin))
dist_r = (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() dist_r = (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs()
dist_phi = (phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs() dist_phi = _circle_dist(phi[iidx[:, 1], iidx[:, 2]], iphi[iidx[:, 0], 0, 0])
dist_rn = (rn[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() dist_rn = (rn[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs()
dist_phin = (phin[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs() dist_phin = _circle_dist(phin[iidx[:, 1], iidx[:, 2]], iphi[iidx[:, 0], 0, 0])
# compute the value of the basis functions # compute the value of the basis functions
vals = cond_r[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_r / dr) vals = cond_r[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_r / dr)
vals *= cond_phi[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - torch.minimum(dist_phi, (2 * math.pi - dist_phi)) / dphi) vals *= cond_phi[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_phi / dphi)
valsn = cond_rn[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_rn / dr) valsn = cond_rn[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_rn / dr)
valsn *= cond_phin[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - torch.minimum(dist_phin, (2 * math.pi - dist_phin)) / dphi) valsn *= cond_phin[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_phin / dphi)
vals += valsn vals += valsn
return iidx, vals return iidx, vals
...@@ -190,9 +204,10 @@ class PiecewiseLinearFilterBasis(FilterBasis): ...@@ -190,9 +204,10 @@ class PiecewiseLinearFilterBasis(FilterBasis):
else: else:
return self._compute_support_vals_isotropic(r, phi, r_cutoff=r_cutoff) return self._compute_support_vals_isotropic(r, phi, r_cutoff=r_cutoff)
class DiskMorletFilterBasis(FilterBasis):
class MorletFilterBasis(FilterBasis):
""" """
Morlet-like Filter basis. A Gaussian is multiplied with a Fourier basis in x and y. Morlet-style filter basis on the disk. A Gaussian is multiplied with a Fourier basis in x and y directions
""" """
def __init__( def __init__(
...@@ -209,12 +224,12 @@ class DiskMorletFilterBasis(FilterBasis): ...@@ -209,12 +224,12 @@ class DiskMorletFilterBasis(FilterBasis):
@property @property
def kernel_size(self): def kernel_size(self):
return self.kernel_shape[0]*self.kernel_shape[1] return self.kernel_shape[0] * self.kernel_shape[1]
def _gaussian_envelope(self, r: torch.Tensor, width: float =1.0): def _gaussian_window(self, r: torch.Tensor, width: float = 1.0):
return 1 / (2 * math.pi * width**2 ) * torch.exp(- 0.5 * r**2 / (width**2)) return 1 / (2 * math.pi * width**2) * torch.exp(-0.5 * r**2 / (width**2))
def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float): def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float, width: float = 0.25):
""" """
Computes the index set that falls into the isotropic kernel's support and returns both indices and values. Computes the index set that falls into the isotropic kernel's support and returns both indices and values.
""" """
...@@ -227,26 +242,18 @@ class DiskMorletFilterBasis(FilterBasis): ...@@ -227,26 +242,18 @@ class DiskMorletFilterBasis(FilterBasis):
# get relevant indices # get relevant indices
iidx = torch.argwhere((r <= r_cutoff) & torch.full_like(ikernel, True, dtype=torch.bool)) iidx = torch.argwhere((r <= r_cutoff) & torch.full_like(ikernel, True, dtype=torch.bool))
# # computes the Gaussian envelope. To ensure that the curve is roughly 0 at the boundary, we rescale the Gaussian by 0.25
# width = 0.01
width = 0.25
# width = 1.0
# envelope = self._gaussian_envelope(r, width=0.25 * r_cutoff)
# get x and y # get x and y
x = r * torch.sin(phi) / r_cutoff x = r * torch.sin(phi) / r_cutoff
y = r * torch.cos(phi) / r_cutoff y = r * torch.cos(phi) / r_cutoff
harmonic = torch.where(nkernel % 2 == 1, torch.sin(torch.ceil(nkernel/2) * math.pi * x / width), torch.cos(torch.ceil(nkernel/2) * math.pi * x / width)) harmonic = torch.where(nkernel % 2 == 1, torch.sin(torch.ceil(nkernel / 2) * math.pi * x / width), torch.cos(torch.ceil(nkernel / 2) * math.pi * x / width))
harmonic *= torch.where(mkernel % 2 == 1, torch.sin(torch.ceil(mkernel/2) * math.pi * y / width), torch.cos(torch.ceil(mkernel/2) * math.pi * y / width)) harmonic *= torch.where(mkernel % 2 == 1, torch.sin(torch.ceil(mkernel / 2) * math.pi * y / width), torch.cos(torch.ceil(mkernel / 2) * math.pi * y / width))
# disk area # disk area
# disk_area = 2.0 * math.pi * (1.0 - math.cos(r_cutoff)) # disk_area = 2.0 * math.pi * (1.0 - math.cos(r_cutoff))
disk_area = 1 disk_area = 1.0
# computes the Gaussian envelope. To ensure that the curve is roughly 0 at the boundary, we rescale the Gaussian by 0.25 # computes the Gaussian envelope. To ensure that the curve is roughly 0 at the boundary, we rescale the Gaussian by 0.25
vals = self._gaussian_envelope(r[iidx[:, 1], iidx[:, 2]] / r_cutoff, width=width) * harmonic[iidx[:, 0], iidx[:, 1], iidx[:, 2]] / disk_area vals = self._gaussian_window(r[iidx[:, 1], iidx[:, 2]] / r_cutoff, width=width) * harmonic[iidx[:, 0], iidx[:, 1], iidx[:, 2]] / disk_area
# vals = torch.ones_like(vals)
return iidx, vals return iidx, vals
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
# #
# Redistribution and use in source and binary forms, with or without # Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met: # modification, are permitted provided that the following conditions are met:
# #
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
import numpy as np import numpy as np
def _precompute_grid(n, grid="equidistant", a=0.0, b=1.0, periodic=False): def _precompute_grid(n, grid="equidistant", a=0.0, b=1.0, periodic=False):
if (grid != "equidistant") and periodic: if (grid != "equidistant") and periodic:
...@@ -50,27 +51,31 @@ def _precompute_grid(n, grid="equidistant", a=0.0, b=1.0, periodic=False): ...@@ -50,27 +51,31 @@ def _precompute_grid(n, grid="equidistant", a=0.0, b=1.0, periodic=False):
return xlg, wlg return xlg, wlg
def _precompute_latitudes(nlat, grid="equiangular"): def _precompute_latitudes(nlat, grid="equiangular"):
r""" r"""
Convenience routine to precompute latitudes Convenience routine to precompute latitudes
""" """
# compute coordinates # compute coordinates in the cosine theta domain
xlg, wlg = _precompute_grid(nlat, grid=grid, a=-1.0, b=1.0, periodic=False) xlg, wlg = _precompute_grid(nlat, grid=grid, a=-1.0, b=1.0, periodic=False)
# to perform the quadrature and account for the jacobian of the sphere, the quadrature rule
# is formulated in the cosine theta domain, which is designed to integrate functions of cos theta
lats = np.flip(np.arccos(xlg)).copy() lats = np.flip(np.arccos(xlg)).copy()
wlg = np.flip(wlg).copy() wlg = np.flip(wlg).copy()
return lats, wlg return lats, wlg
def trapezoidal_weights(n, a=-1.0, b=1.0, periodic=False): def trapezoidal_weights(n, a=-1.0, b=1.0, periodic=False):
r""" r"""
Helper routine which returns equidistant nodes with trapezoidal weights Helper routine which returns equidistant nodes with trapezoidal weights
on the interval [a, b] on the interval [a, b]
""" """
xlg = np.linspace(a, b, n) xlg = np.linspace(a, b, n, endpoint=periodic)
wlg = (b - a) / (n - 1) * np.ones(n) wlg = (b - a) / (n - periodic * 1) * np.ones(n)
if not periodic: if not periodic:
wlg[0] *= 0.5 wlg[0] *= 0.5
...@@ -78,6 +83,7 @@ def trapezoidal_weights(n, a=-1.0, b=1.0, periodic=False): ...@@ -78,6 +83,7 @@ def trapezoidal_weights(n, a=-1.0, b=1.0, periodic=False):
return xlg, wlg return xlg, wlg
def legendre_gauss_weights(n, a=-1.0, b=1.0): def legendre_gauss_weights(n, a=-1.0, b=1.0):
r""" r"""
Helper routine which returns the Legendre-Gauss nodes and weights Helper routine which returns the Legendre-Gauss nodes and weights
...@@ -90,6 +96,7 @@ def legendre_gauss_weights(n, a=-1.0, b=1.0): ...@@ -90,6 +96,7 @@ def legendre_gauss_weights(n, a=-1.0, b=1.0):
return xlg, wlg return xlg, wlg
def lobatto_weights(n, a=-1.0, b=1.0, tol=1e-16, maxiter=100): def lobatto_weights(n, a=-1.0, b=1.0, tol=1e-16, maxiter=100):
r""" r"""
Helper routine which returns the Legendre-Gauss-Lobatto nodes and weights Helper routine which returns the Legendre-Gauss-Lobatto nodes and weights
...@@ -102,33 +109,33 @@ def lobatto_weights(n, a=-1.0, b=1.0, tol=1e-16, maxiter=100): ...@@ -102,33 +109,33 @@ def lobatto_weights(n, a=-1.0, b=1.0, tol=1e-16, maxiter=100):
# Vandermonde Matrix # Vandermonde Matrix
vdm = np.zeros((n, n)) vdm = np.zeros((n, n))
# initialize Chebyshev nodes as first guess # initialize Chebyshev nodes as first guess
for i in range(n): for i in range(n):
tlg[i] = -np.cos(np.pi*i / (n-1)) tlg[i] = -np.cos(np.pi * i / (n - 1))
tmp = 2.0 tmp = 2.0
for i in range(maxiter): for i in range(maxiter):
tmp = tlg tmp = tlg
vdm[:,0] = 1.0 vdm[:, 0] = 1.0
vdm[:,1] = tlg vdm[:, 1] = tlg
for k in range(2, n): for k in range(2, n):
vdm[:, k] = ( (2*k-1) * tlg * vdm[:, k-1] - (k-1) * vdm[:, k-2] ) / k vdm[:, k] = ((2 * k - 1) * tlg * vdm[:, k - 1] - (k - 1) * vdm[:, k - 2]) / k
tlg = tmp - ( tlg*vdm[:, n-1] - vdm[:, n-2] ) / ( n * vdm[:, n-1]) tlg = tmp - (tlg * vdm[:, n - 1] - vdm[:, n - 2]) / (n * vdm[:, n - 1])
if (max(abs(tlg - tmp).flatten()) < tol ): if max(abs(tlg - tmp).flatten()) < tol:
break break
wlg = 2.0 / ( (n*(n-1))*(vdm[:, n-1]**2)) wlg = 2.0 / ((n * (n - 1)) * (vdm[:, n - 1] ** 2))
# rescale # rescale
tlg = (b - a) * 0.5 * tlg + (b + a) * 0.5 tlg = (b - a) * 0.5 * tlg + (b + a) * 0.5
wlg = wlg * (b - a) * 0.5 wlg = wlg * (b - a) * 0.5
return tlg, wlg return tlg, wlg
...@@ -140,12 +147,12 @@ def clenshaw_curtiss_weights(n, a=-1.0, b=1.0): ...@@ -140,12 +147,12 @@ def clenshaw_curtiss_weights(n, a=-1.0, b=1.0):
[1] Joerg Waldvogel, Fast Construction of the Fejer and Clenshaw-Curtis Quadrature Rules; BIT Numerical Mathematics, Vol. 43, No. 1, pp. 001–018. [1] Joerg Waldvogel, Fast Construction of the Fejer and Clenshaw-Curtis Quadrature Rules; BIT Numerical Mathematics, Vol. 43, No. 1, pp. 001–018.
""" """
assert(n > 1) assert n > 1
tcc = np.cos(np.linspace(np.pi, 0, n)) tcc = np.cos(np.linspace(np.pi, 0, n))
if n == 2: if n == 2:
wcc = np.array([1., 1.]) wcc = np.array([1.0, 1.0])
else: else:
n1 = n - 1 n1 = n - 1
...@@ -153,13 +160,13 @@ def clenshaw_curtiss_weights(n, a=-1.0, b=1.0): ...@@ -153,13 +160,13 @@ def clenshaw_curtiss_weights(n, a=-1.0, b=1.0):
l = len(N) l = len(N)
m = n1 - l m = n1 - l
v = np.concatenate([2 / N / (N-2), 1 / N[-1:], np.zeros(m)]) v = np.concatenate([2 / N / (N - 2), 1 / N[-1:], np.zeros(m)])
v = 0 - v[:-1] - v[-1:0:-1] v = 0 - v[:-1] - v[-1:0:-1]
g0 = -np.ones(n1) g0 = -np.ones(n1)
g0[l] = g0[l] + n1 g0[l] = g0[l] + n1
g0[m] = g0[m] + n1 g0[m] = g0[m] + n1
g = g0 / (n1**2 - 1 + (n1%2)) g = g0 / (n1**2 - 1 + (n1 % 2))
wcc = np.fft.ifft(v + g).real wcc = np.fft.ifft(v + g).real
wcc = np.concatenate((wcc, wcc[:1])) wcc = np.concatenate((wcc, wcc[:1]))
...@@ -169,6 +176,7 @@ def clenshaw_curtiss_weights(n, a=-1.0, b=1.0): ...@@ -169,6 +176,7 @@ def clenshaw_curtiss_weights(n, a=-1.0, b=1.0):
return tcc, wcc return tcc, wcc
def fejer2_weights(n, a=-1.0, b=1.0): def fejer2_weights(n, a=-1.0, b=1.0):
r""" r"""
Computation of the Fejer quadrature nodes and weights. Computation of the Fejer quadrature nodes and weights.
...@@ -177,7 +185,7 @@ def fejer2_weights(n, a=-1.0, b=1.0): ...@@ -177,7 +185,7 @@ def fejer2_weights(n, a=-1.0, b=1.0):
[1] Joerg Waldvogel, Fast Construction of the Fejer and Clenshaw-Curtis Quadrature Rules; BIT Numerical Mathematics, Vol. 43, No. 1, pp. 001–018. [1] Joerg Waldvogel, Fast Construction of the Fejer and Clenshaw-Curtis Quadrature Rules; BIT Numerical Mathematics, Vol. 43, No. 1, pp. 001–018.
""" """
assert(n > 2) assert n > 2
tcc = np.cos(np.linspace(np.pi, 0, n)) tcc = np.cos(np.linspace(np.pi, 0, n))
...@@ -186,7 +194,7 @@ def fejer2_weights(n, a=-1.0, b=1.0): ...@@ -186,7 +194,7 @@ def fejer2_weights(n, a=-1.0, b=1.0):
l = len(N) l = len(N)
m = n1 - l m = n1 - l
v = np.concatenate([2 / N / (N-2), 1 / N[-1:], np.zeros(m)]) v = np.concatenate([2 / N / (N - 2), 1 / N[-1:], np.zeros(m)])
v = 0 - v[:-1] - v[-1:0:-1] v = 0 - v[:-1] - v[-1:0:-1]
wcc = np.fft.ifft(v).real wcc = np.fft.ifft(v).real
...@@ -196,4 +204,4 @@ def fejer2_weights(n, a=-1.0, b=1.0): ...@@ -196,4 +204,4 @@ def fejer2_weights(n, a=-1.0, b=1.0):
tcc = (b - a) * 0.5 * tcc + (b + a) * 0.5 tcc = (b - a) * 0.5 * tcc + (b + a) * 0.5
wcc = wcc * (b - a) * 0.5 wcc = wcc * (b - a) * 0.5
return tcc, wcc return tcc, wcc
\ No newline at end of file
...@@ -68,19 +68,25 @@ class RealSHT(nn.Module): ...@@ -68,19 +68,25 @@ class RealSHT(nn.Module):
# TODO: include assertions regarding the dimensions # TODO: include assertions regarding the dimensions
# compute quadrature points # compute quadrature points and lmax based on the exactness of the quadrature
if self.grid == "legendre-gauss": if self.grid == "legendre-gauss":
cost, w = legendre_gauss_weights(nlat, -1, 1) cost, w = legendre_gauss_weights(nlat, -1, 1)
# maximum polynomial degree for Gauss Legendre is 2 * nlat - 1 >= 2 * lmax
# and therefore lmax = nlat - 1 (inclusive)
self.lmax = lmax or self.nlat self.lmax = lmax or self.nlat
elif self.grid == "lobatto": elif self.grid == "lobatto":
cost, w = lobatto_weights(nlat, -1, 1) cost, w = lobatto_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat-1 # maximum polynomial degree for Gauss Legendre is 2 * nlat - 3 >= 2 * lmax
# and therefore lmax = nlat - 2 (inclusive)
self.lmax = lmax or self.nlat - 1
elif self.grid == "equiangular": elif self.grid == "equiangular":
cost, w = clenshaw_curtiss_weights(nlat, -1, 1) cost, w = clenshaw_curtiss_weights(nlat, -1, 1)
# cost, w = fejer2_weights(nlat, -1, 1) # in principle, Clenshaw-Curtiss quadrature is only exact up to polynomial degrees of nlat
# however, we observe that the quadrature is remarkably accurate for higher degress. This is why we do not
# choose a lower lmax for now.
self.lmax = lmax or self.nlat self.lmax = lmax or self.nlat
else: else:
raise(ValueError("Unknown quadrature mode")) raise (ValueError("Unknown quadrature mode"))
# apply cosine transform and flip them # apply cosine transform and flip them
tq = np.flip(np.arccos(cost)) tq = np.flip(np.arccos(cost))
...@@ -92,24 +98,24 @@ class RealSHT(nn.Module): ...@@ -92,24 +98,24 @@ class RealSHT(nn.Module):
weights = torch.from_numpy(w) weights = torch.from_numpy(w)
pct = _precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase) pct = _precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
pct = torch.from_numpy(pct) pct = torch.from_numpy(pct)
weights = torch.einsum('mlk,k->mlk', pct, weights) weights = torch.einsum("mlk,k->mlk", pct, weights)
# remember quadrature weights # remember quadrature weights
self.register_buffer('weights', weights, persistent=False) self.register_buffer("weights", weights, persistent=False)
def extra_repr(self): def extra_repr(self):
r""" r"""
Pretty print module Pretty print module
""" """
return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}' return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}"
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
if x.dim() < 2: if x.dim() < 2:
raise ValueError(f"Expected tensor with at least 2 dimensions but got {x.dim()} instead") raise ValueError(f"Expected tensor with at least 2 dimensions but got {x.dim()} instead")
assert(x.shape[-2] == self.nlat) assert x.shape[-2] == self.nlat
assert(x.shape[-1] == self.nlon) assert x.shape[-1] == self.nlon
# apply real fft in the longitudinal direction # apply real fft in the longitudinal direction
x = 2.0 * torch.pi * torch.fft.rfft(x, dim=-1, norm="forward") x = 2.0 * torch.pi * torch.fft.rfft(x, dim=-1, norm="forward")
...@@ -124,12 +130,13 @@ class RealSHT(nn.Module): ...@@ -124,12 +130,13 @@ class RealSHT(nn.Module):
xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device) xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device)
# contraction # contraction
xout[..., 0] = torch.einsum('...km,mlk->...lm', x[..., :self.mmax, 0], self.weights.to(x.dtype) ) xout[..., 0] = torch.einsum("...km,mlk->...lm", x[..., : self.mmax, 0], self.weights.to(x.dtype))
xout[..., 1] = torch.einsum('...km,mlk->...lm', x[..., :self.mmax, 1], self.weights.to(x.dtype) ) xout[..., 1] = torch.einsum("...km,mlk->...lm", x[..., : self.mmax, 1], self.weights.to(x.dtype))
x = torch.view_as_complex(xout) x = torch.view_as_complex(xout)
return x return x
class InverseRealSHT(nn.Module): class InverseRealSHT(nn.Module):
r""" r"""
Defines a module for computing the inverse (real-valued) SHT. Defines a module for computing the inverse (real-valued) SHT.
...@@ -157,12 +164,12 @@ class InverseRealSHT(nn.Module): ...@@ -157,12 +164,12 @@ class InverseRealSHT(nn.Module):
self.lmax = lmax or self.nlat self.lmax = lmax or self.nlat
elif self.grid == "lobatto": elif self.grid == "lobatto":
cost, _ = lobatto_weights(nlat, -1, 1) cost, _ = lobatto_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat-1 self.lmax = lmax or self.nlat - 1
elif self.grid == "equiangular": elif self.grid == "equiangular":
cost, _ = clenshaw_curtiss_weights(nlat, -1, 1) cost, _ = clenshaw_curtiss_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat self.lmax = lmax or self.nlat
else: else:
raise(ValueError("Unknown quadrature mode")) raise (ValueError("Unknown quadrature mode"))
# apply cosine transform and flip them # apply cosine transform and flip them
t = np.flip(np.arccos(cost)) t = np.flip(np.arccos(cost))
...@@ -174,27 +181,27 @@ class InverseRealSHT(nn.Module): ...@@ -174,27 +181,27 @@ class InverseRealSHT(nn.Module):
pct = torch.from_numpy(pct) pct = torch.from_numpy(pct)
# register buffer # register buffer
self.register_buffer('pct', pct, persistent=False) self.register_buffer("pct", pct, persistent=False)
def extra_repr(self): def extra_repr(self):
r""" r"""
Pretty print module Pretty print module
""" """
return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}' return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}"
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
if len(x.shape) < 2: if len(x.shape) < 2:
raise ValueError(f"Expected tensor with at least 2 dimensions but got {len(x.shape)} instead") raise ValueError(f"Expected tensor with at least 2 dimensions but got {len(x.shape)} instead")
assert(x.shape[-2] == self.lmax) assert x.shape[-2] == self.lmax
assert(x.shape[-1] == self.mmax) assert x.shape[-1] == self.mmax
# Evaluate associated Legendre functions on the output nodes # Evaluate associated Legendre functions on the output nodes
x = torch.view_as_real(x) x = torch.view_as_real(x)
rl = torch.einsum('...lm, mlk->...km', x[..., 0], self.pct.to(x.dtype) ) rl = torch.einsum("...lm, mlk->...km", x[..., 0], self.pct.to(x.dtype))
im = torch.einsum('...lm, mlk->...km', x[..., 1], self.pct.to(x.dtype) ) im = torch.einsum("...lm, mlk->...km", x[..., 1], self.pct.to(x.dtype))
xs = torch.stack((rl, im), -1) xs = torch.stack((rl, im), -1)
# apply the inverse (real) FFT # apply the inverse (real) FFT
...@@ -238,13 +245,12 @@ class RealVectorSHT(nn.Module): ...@@ -238,13 +245,12 @@ class RealVectorSHT(nn.Module):
self.lmax = lmax or self.nlat self.lmax = lmax or self.nlat
elif self.grid == "lobatto": elif self.grid == "lobatto":
cost, w = lobatto_weights(nlat, -1, 1) cost, w = lobatto_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat-1 self.lmax = lmax or self.nlat - 1
elif self.grid == "equiangular": elif self.grid == "equiangular":
cost, w = clenshaw_curtiss_weights(nlat, -1, 1) cost, w = clenshaw_curtiss_weights(nlat, -1, 1)
# cost, w = fejer2_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat self.lmax = lmax or self.nlat
else: else:
raise(ValueError("Unknown quadrature mode")) raise (ValueError("Unknown quadrature mode"))
# apply cosine transform and flip them # apply cosine transform and flip them
tq = np.flip(np.arccos(cost)) tq = np.flip(np.arccos(cost))
...@@ -258,28 +264,28 @@ class RealVectorSHT(nn.Module): ...@@ -258,28 +264,28 @@ class RealVectorSHT(nn.Module):
# combine integration weights, normalization factor in to one: # combine integration weights, normalization factor in to one:
l = torch.arange(0, self.lmax) l = torch.arange(0, self.lmax)
norm_factor = 1. / l / (l+1) norm_factor = 1.0 / l / (l + 1)
norm_factor[0] = 1. norm_factor[0] = 1.0
weights = torch.einsum('dmlk,k,l->dmlk', dpct, weights, norm_factor) weights = torch.einsum("dmlk,k,l->dmlk", dpct, weights, norm_factor)
# since the second component is imaginary, we need to take complex conjugation into account # since the second component is imaginary, we need to take complex conjugation into account
weights[1] = -1 * weights[1] weights[1] = -1 * weights[1]
# remember quadrature weights # remember quadrature weights
self.register_buffer('weights', weights, persistent=False) self.register_buffer("weights", weights, persistent=False)
def extra_repr(self): def extra_repr(self):
r""" r"""
Pretty print module Pretty print module
""" """
return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}' return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}"
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
if x.dim() < 3: if x.dim() < 3:
raise ValueError(f"Expected tensor with at least 3 dimensions but got {x.dim()} instead") raise ValueError(f"Expected tensor with at least 3 dimensions but got {x.dim()} instead")
assert(x.shape[-2] == self.nlat) assert x.shape[-2] == self.nlat
assert(x.shape[-1] == self.nlon) assert x.shape[-1] == self.nlon
# apply real fft in the longitudinal direction # apply real fft in the longitudinal direction
x = 2.0 * torch.pi * torch.fft.rfft(x, dim=-1, norm="forward") x = 2.0 * torch.pi * torch.fft.rfft(x, dim=-1, norm="forward")
...@@ -295,20 +301,24 @@ class RealVectorSHT(nn.Module): ...@@ -295,20 +301,24 @@ class RealVectorSHT(nn.Module):
# contraction - spheroidal component # contraction - spheroidal component
# real component # real component
xout[..., 0, :, :, 0] = torch.einsum('...km,mlk->...lm', x[..., 0, :, :self.mmax, 0], self.weights[0].to(x.dtype)) \ xout[..., 0, :, :, 0] = torch.einsum("...km,mlk->...lm", x[..., 0, :, : self.mmax, 0], self.weights[0].to(x.dtype)) - torch.einsum(
- torch.einsum('...km,mlk->...lm', x[..., 1, :, :self.mmax, 1], self.weights[1].to(x.dtype)) "...km,mlk->...lm", x[..., 1, :, : self.mmax, 1], self.weights[1].to(x.dtype)
)
# iamg component # iamg component
xout[..., 0, :, :, 1] = torch.einsum('...km,mlk->...lm', x[..., 0, :, :self.mmax, 1], self.weights[0].to(x.dtype)) \ xout[..., 0, :, :, 1] = torch.einsum("...km,mlk->...lm", x[..., 0, :, : self.mmax, 1], self.weights[0].to(x.dtype)) + torch.einsum(
+ torch.einsum('...km,mlk->...lm', x[..., 1, :, :self.mmax, 0], self.weights[1].to(x.dtype)) "...km,mlk->...lm", x[..., 1, :, : self.mmax, 0], self.weights[1].to(x.dtype)
)
# contraction - toroidal component # contraction - toroidal component
# real component # real component
xout[..., 1, :, :, 0] = - torch.einsum('...km,mlk->...lm', x[..., 0, :, :self.mmax, 1], self.weights[1].to(x.dtype)) \ xout[..., 1, :, :, 0] = -torch.einsum("...km,mlk->...lm", x[..., 0, :, : self.mmax, 1], self.weights[1].to(x.dtype)) - torch.einsum(
- torch.einsum('...km,mlk->...lm', x[..., 1, :, :self.mmax, 0], self.weights[0].to(x.dtype)) "...km,mlk->...lm", x[..., 1, :, : self.mmax, 0], self.weights[0].to(x.dtype)
)
# imag component # imag component
xout[..., 1, :, :, 1] = torch.einsum('...km,mlk->...lm', x[..., 0, :, :self.mmax, 0], self.weights[1].to(x.dtype)) \ xout[..., 1, :, :, 1] = torch.einsum("...km,mlk->...lm", x[..., 0, :, : self.mmax, 0], self.weights[1].to(x.dtype)) - torch.einsum(
- torch.einsum('...km,mlk->...lm', x[..., 1, :, :self.mmax, 1], self.weights[0].to(x.dtype)) "...km,mlk->...lm", x[..., 1, :, : self.mmax, 1], self.weights[0].to(x.dtype)
)
return torch.view_as_complex(xout) return torch.view_as_complex(xout)
...@@ -321,6 +331,7 @@ class InverseRealVectorSHT(nn.Module): ...@@ -321,6 +331,7 @@ class InverseRealVectorSHT(nn.Module):
[1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems. [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math. [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
""" """
def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True): def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
super().__init__() super().__init__()
...@@ -337,12 +348,12 @@ class InverseRealVectorSHT(nn.Module): ...@@ -337,12 +348,12 @@ class InverseRealVectorSHT(nn.Module):
self.lmax = lmax or self.nlat self.lmax = lmax or self.nlat
elif self.grid == "lobatto": elif self.grid == "lobatto":
cost, _ = lobatto_weights(nlat, -1, 1) cost, _ = lobatto_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat-1 self.lmax = lmax or self.nlat - 1
elif self.grid == "equiangular": elif self.grid == "equiangular":
cost, _ = clenshaw_curtiss_weights(nlat, -1, 1) cost, _ = clenshaw_curtiss_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat self.lmax = lmax or self.nlat
else: else:
raise(ValueError("Unknown quadrature mode")) raise (ValueError("Unknown quadrature mode"))
# apply cosine transform and flip them # apply cosine transform and flip them
t = np.flip(np.arccos(cost)) t = np.flip(np.arccos(cost))
...@@ -354,40 +365,36 @@ class InverseRealVectorSHT(nn.Module): ...@@ -354,40 +365,36 @@ class InverseRealVectorSHT(nn.Module):
dpct = torch.from_numpy(dpct) dpct = torch.from_numpy(dpct)
# register weights # register weights
self.register_buffer('dpct', dpct, persistent=False) self.register_buffer("dpct", dpct, persistent=False)
def extra_repr(self): def extra_repr(self):
r""" r"""
Pretty print module Pretty print module
""" """
return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}' return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}"
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
if x.dim() < 3: if x.dim() < 3:
raise ValueError(f"Expected tensor with at least 3 dimensions but got {x.dim()} instead") raise ValueError(f"Expected tensor with at least 3 dimensions but got {x.dim()} instead")
assert(x.shape[-2] == self.lmax) assert x.shape[-2] == self.lmax
assert(x.shape[-1] == self.mmax) assert x.shape[-1] == self.mmax
# Evaluate associated Legendre functions on the output nodes # Evaluate associated Legendre functions on the output nodes
x = torch.view_as_real(x) x = torch.view_as_real(x)
# contraction - spheroidal component # contraction - spheroidal component
# real component # real component
srl = torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 0], self.dpct[0].to(x.dtype)) \ srl = torch.einsum("...lm,mlk->...km", x[..., 0, :, :, 0], self.dpct[0].to(x.dtype)) - torch.einsum("...lm,mlk->...km", x[..., 1, :, :, 1], self.dpct[1].to(x.dtype))
- torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 1], self.dpct[1].to(x.dtype))
# iamg component # iamg component
sim = torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 1], self.dpct[0].to(x.dtype)) \ sim = torch.einsum("...lm,mlk->...km", x[..., 0, :, :, 1], self.dpct[0].to(x.dtype)) + torch.einsum("...lm,mlk->...km", x[..., 1, :, :, 0], self.dpct[1].to(x.dtype))
+ torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 0], self.dpct[1].to(x.dtype))
# contraction - toroidal component # contraction - toroidal component
# real component # real component
trl = - torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 1], self.dpct[1].to(x.dtype)) \ trl = -torch.einsum("...lm,mlk->...km", x[..., 0, :, :, 1], self.dpct[1].to(x.dtype)) - torch.einsum("...lm,mlk->...km", x[..., 1, :, :, 0], self.dpct[0].to(x.dtype))
- torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 0], self.dpct[0].to(x.dtype))
# imag component # imag component
tim = torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 0], self.dpct[1].to(x.dtype)) \ tim = torch.einsum("...lm,mlk->...km", x[..., 0, :, :, 0], self.dpct[1].to(x.dtype)) - torch.einsum("...lm,mlk->...km", x[..., 1, :, :, 1], self.dpct[0].to(x.dtype))
- torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 1], self.dpct[0].to(x.dtype))
# reassemble # reassemble
s = torch.stack((srl, sim), -1) s = torch.stack((srl, sim), -1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment