Unverified Commit 780fd143 authored by Boris Bonev's avatar Boris Bonev Committed by GitHub
Browse files

Improved computation of Morlet filter basis (#65)

* Improved computation of Morlet filter basis and switched to a Hann window.

* addresses #064 and some cleanup
parent 9eea871c
......@@ -2,6 +2,11 @@
## Versioning
### v0.7.5
* New normalization mode `support` for DISCO convolutions
* More efficient computation of Morlet filter basis
* Changed default for Morlet filter basis to a Hann window function
### v0.7.4
* New filter basis normalization in DISCO convolutions
......
......@@ -422,7 +422,6 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
num_layers=4,
scale_factor=2,
embed_dim=32,
operator_type="driscoll-healy",
activation_function="gelu",
big_skip=True,
pos_embed=False,
......@@ -437,14 +436,13 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
num_layers=4,
scale_factor=2,
embed_dim=32,
operator_type="driscoll-healy",
activation_function="gelu",
big_skip=False,
big_skip=True,
pos_embed=False,
use_mlp=True,
normalization_layer="none",
kernel_shape=[4, 4],
encoder_kernel_shape=[4, 4],
kernel_shape=[2, 2],
encoder_kernel_shape=[2, 2],
filter_basis_type="morlet",
upsample_sht = True,
)
......@@ -456,9 +454,8 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
num_layers=4,
scale_factor=2,
embed_dim=32,
operator_type="driscoll-healy",
activation_function="gelu",
big_skip=False,
big_skip=True,
pos_embed=False,
use_mlp=True,
normalization_layer="none",
......@@ -501,7 +498,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
start_time = time.time()
print(f"Training {model_name}, single step")
train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=100, loss_fn="l2", enable_amp=enable_amp, log_grads=log_grads)
train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=1, loss_fn="l2", enable_amp=enable_amp, log_grads=log_grads)
if nfuture > 0:
print(f'Training {model_name}, {nfuture} step')
......
......@@ -34,79 +34,66 @@ import matplotlib.pyplot as plt
import cartopy
import cartopy.crs as ccrs
def plot_sphere(data,
fig=None,
cmap="RdBu",
title=None,
colorbar=False,
coastlines=False,
central_latitude=0,
central_longitude=0,
lon=None,
lat=None,
**kwargs):
def plot_sphere(data, fig=None, cmap="RdBu", title=None, colorbar=False, coastlines=False, gridlines=False, central_latitude=0, central_longitude=0, lon=None, lat=None, **kwargs):
if fig == None:
fig = plt.figure()
nlat = data.shape[-2]
nlon = data.shape[-1]
if lon is None:
lon = np.linspace(0, 2*np.pi, nlon)
lon = np.linspace(0, 2 * np.pi, nlon)
if lat is None:
lat = np.linspace(np.pi/2., -np.pi/2., nlat)
lat = np.linspace(np.pi / 2.0, -np.pi / 2.0, nlat)
Lon, Lat = np.meshgrid(lon, lat)
proj = ccrs.Orthographic(central_longitude=central_longitude, central_latitude=central_latitude)
# proj = ccrs.Mollweide(central_longitude=central_longitude)
ax = fig.add_subplot(projection=proj)
Lon = Lon*180/np.pi
Lat = Lat*180/np.pi
Lon = Lon * 180 / np.pi
Lat = Lat * 180 / np.pi
# contour data over the map.
im = ax.pcolormesh(Lon, Lat, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=False, **kwargs)
if coastlines:
ax.add_feature(cartopy.feature.COASTLINE, edgecolor='white', facecolor='none', linewidth=1.5)
ax.add_feature(cartopy.feature.COASTLINE, edgecolor="white", facecolor="none", linewidth=1.5)
if gridlines:
gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=False, linewidth=1.5, color="gray", alpha=0.6, linestyle="--")
if colorbar:
plt.colorbar(im)
plt.colorbar(im, extend="both")
plt.title(title, y=1.05)
return im
def plot_data(data,
fig=None,
cmap="RdBu",
title=None,
colorbar=False,
coastlines=False,
central_longitude=0,
lon=None,
lat=None,
**kwargs):
def plot_data(data, fig=None, cmap="RdBu", title=None, colorbar=False, coastlines=False, gridlines=False, central_longitude=0, lon=None, lat=None, **kwargs):
if fig == None:
fig = plt.figure()
nlat = data.shape[-2]
nlon = data.shape[-1]
if lon is None:
lon = np.linspace(0, 2*np.pi, nlon)
lon = np.linspace(0, 2 * np.pi, nlon)
if lat is None:
lat = np.linspace(np.pi/2., -np.pi/2., nlat)
lat = np.linspace(np.pi / 2.0, -np.pi / 2.0, nlat)
Lon, Lat = np.meshgrid(lon, lat)
proj = ccrs.PlateCarree(central_longitude=central_longitude)
# proj = ccrs.Mollweide(central_longitude=central_longitude)
ax = fig.add_subplot(projection=proj)
Lon = Lon*180/np.pi
Lat = Lat*180/np.pi
Lon = Lon * 180 / np.pi
Lat = Lat * 180 / np.pi
# contour data over the map.
im = ax.pcolormesh(Lon, Lat, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=False, **kwargs)
if coastlines:
ax.add_feature(cartopy.feature.COASTLINE, edgecolor='white', facecolor='none', linewidth=1.5)
ax.add_feature(cartopy.feature.COASTLINE, edgecolor="white", facecolor="none", linewidth=1.5)
if gridlines:
gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=False, linewidth=1.5, color="gray", alpha=0.6, linestyle="--")
if colorbar:
plt.colorbar(im)
plt.colorbar(im, extend="both")
plt.title(title, y=1.05)
return im
\ No newline at end of file
return im
......@@ -156,7 +156,7 @@
"metadata": {},
"outputs": [],
"source": [
"model = SFNO(spectral_transform='sht', operator_type='driscoll-healy', img_size=(nlat, nlon), grid=\"equiangular\",\n",
"model = SFNO(img_size=(nlat, nlon), grid=\"equiangular\",\n",
" num_layers=4, scale_factor=3, embed_dim=16, big_skip=True, pos_embed=\"lat\", use_mlp=False, normalization_layer=\"none\").to(device)\n"
]
},
......
......@@ -29,7 +29,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
__version__ = "0.7.4"
__version__ = "0.7.5a"
from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT
from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
......
......@@ -88,6 +88,7 @@ def _normalize_convolution_tensor_s2(
# buffer to store intermediate values
vnorm = torch.zeros(kernel_size, nlat_out)
support = torch.zeros(kernel_size, nlat_out)
# loop through dimensions to compute the norms
for ik in range(kernel_size):
......@@ -100,6 +101,10 @@ def _normalize_convolution_tensor_s2(
# vnorm[ik, ilat] = torch.sqrt(torch.sum(psi_vals[iidx].abs().pow(2) * q[iidx]))
vnorm[ik, ilat] = torch.sum(psi_vals[iidx].abs() * q[iidx])
# compute the support
support[ik, ilat] = torch.sum(q[iidx])
# loop over values and renormalize
for ik in range(kernel_size):
for ilat in range(nlat_out):
......@@ -110,6 +115,8 @@ def _normalize_convolution_tensor_s2(
val = vnorm[ik, ilat]
elif basis_norm_mode == "mean":
val = vnorm[ik, :].mean()
elif basis_norm_mode == "support":
val = support[ik, ilat]
elif basis_norm_mode == "none":
val = 1.0
else:
......
......@@ -148,14 +148,13 @@ class SphericalNeuralOperatorBlock(nn.Module):
input_dim,
output_dim,
conv_type="local",
operator_type="driscoll-healy",
mlp_ratio=2.0,
drop_rate=0.0,
drop_path=0.0,
act_layer=nn.ReLU,
act_layer=nn.GELU,
norm_layer=nn.Identity,
inner_skip="None",
outer_skip="linear",
inner_skip="none",
outer_skip="identity",
use_mlp=True,
disco_kernel_shape=[3, 4],
disco_basis_type="piecewise linear",
......@@ -185,7 +184,7 @@ class SphericalNeuralOperatorBlock(nn.Module):
theta_cutoff=4.0 * (disco_kernel_shape[0] + 1) * torch.pi / float(inverse_transform.nlat - 1),
)
elif conv_type == "global":
self.global_conv = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain_factor, operator_type=operator_type, bias=False)
self.global_conv = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain_factor, bias=False)
else:
raise ValueError(f"Unknown convolution type {conv_type}")
......@@ -274,8 +273,6 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
-----------
img_shape : tuple, optional
Shape of the input channels, by default (128, 256)
operator_type : str, optional
Type of operator to use ('driscoll-healy', 'diagonal'), by default "driscoll-healy"
kernel_shape: tuple, int
scale_factor : int, optional
Scale factor to use, by default 3
......@@ -314,7 +311,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
Example
-----------
>>> model = SphericalFourierNeuralOperatorNet(
>>> model = LocalSphericalNeuralOperator(
... img_shape=(128, 256),
... scale_factor=4,
... in_chans=2,
......@@ -340,15 +337,14 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
def __init__(
self,
img_size=(128, 256),
operator_type="driscoll-healy",
grid="equiangular",
grid_internal="legendre-gauss",
scale_factor=4,
scale_factor=3,
in_chans=3,
out_chans=3,
embed_dim=256,
num_layers=4,
activation_function="relu",
activation_function="gelu",
kernel_shape=[3, 4],
encoder_kernel_shape=[3, 4],
filter_basis_type="piecewise linear",
......@@ -365,7 +361,6 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
):
super().__init__()
self.operator_type = operator_type
self.img_size = img_size
self.grid = grid
self.grid_internal = grid_internal
......@@ -438,10 +433,12 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
bias=False,
)
# prepare the SHT
modes_lat = int(self.h * self.hard_thresholding_fraction)
modes_lon = int(self.w // 2 * self.hard_thresholding_fraction)
modes_lat = modes_lon = min(modes_lat, modes_lon)
# compute the modes for the sht
modes_lat = self.h
# due to some spectral artifacts with cufft, we substract one mode here
modes_lon = (self.w // 2 + 1) - 1
modes_lat = modes_lon = int(min(modes_lat, modes_lon) * self.hard_thresholding_fraction)
self.trans = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=grid_internal).float()
self.itrans = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=grid_internal).float()
......@@ -451,9 +448,6 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
first_layer = i == 0
last_layer = i == self.num_layers - 1
inner_skip = "none"
outer_skip = "identity"
if first_layer:
norm_layer = norm_layer1
elif last_layer:
......@@ -467,14 +461,11 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
self.embed_dim,
self.embed_dim,
conv_type="global" if i % 2 == 0 else "local",
operator_type=self.operator_type,
mlp_ratio=mlp_ratio,
drop_rate=drop_rate,
drop_path=dpr[i],
act_layer=self.activation_function,
norm_layer=norm_layer,
inner_skip=inner_skip,
outer_skip=outer_skip,
use_mlp=use_mlp,
disco_kernel_shape=kernel_shape,
disco_basis_type=filter_basis_type,
......
......@@ -31,8 +31,7 @@
import torch
import torch.nn as nn
from torch_harmonics import *
from torch_harmonics import RealSHT, InverseRealSHT
from ._layers import *
......@@ -50,17 +49,13 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
inverse_transform,
input_dim,
output_dim,
operator_type="driscoll-healy",
mlp_ratio=2.0,
drop_rate=0.0,
drop_path=0.0,
act_layer=nn.ReLU,
act_layer=nn.GELU,
norm_layer=nn.Identity,
factorization=None,
separable=False,
rank=128,
inner_skip="linear",
outer_skip=None,
inner_skip="none",
outer_skip="identity",
use_mlp=True,
):
super().__init__()
......@@ -73,7 +68,7 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
if inner_skip == "linear" or inner_skip == "identity":
gain_factor /= 2.0
self.global_conv = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain_factor, operator_type=operator_type, bias=False)
self.global_conv = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain_factor, bias=False)
if inner_skip == "linear":
self.inner_skip = nn.Conv2d(input_dim, output_dim, 1, 1)
......@@ -148,8 +143,6 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
----------
img_shape : tuple, optional
Shape of the input channels, by default (128, 256)
operator_type : str, optional
Type of operator to use ('driscoll-healy', 'diagonal'), by default "driscoll-healy"
scale_factor : int, optional
Scale factor to use, by default 3
in_chans : int, optional
......@@ -204,7 +197,6 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
def __init__(
self,
img_size=(128, 256),
operator_type="driscoll-healy",
grid="equiangular",
grid_internal="legendre-gauss",
scale_factor=3,
......@@ -212,7 +204,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
out_chans=3,
embed_dim=256,
num_layers=4,
activation_function="relu",
activation_function="gelu",
encoder_layers=1,
use_mlp=True,
mlp_ratio=2.0,
......@@ -227,7 +219,6 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
super().__init__()
self.operator_type = operator_type
self.img_size = img_size
self.grid = grid
self.grid_internal = grid_internal
......@@ -312,7 +303,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
# compute the modes for the sht
modes_lat = self.h
# due to some spectral artifacts with cufft, we substract one mode here
modes_lon = (self.w // 2 + 1) -1
modes_lon = (self.w // 2 + 1) - 1
modes_lat = modes_lon = int(min(modes_lat, modes_lon) * self.hard_thresholding_fraction)
......@@ -327,12 +318,6 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
first_layer = i == 0
last_layer = i == self.num_layers - 1
forward_transform = self.trans_down if first_layer else self.trans
inverse_transform = self.itrans_up if last_layer else self.itrans
inner_skip = "none"
outer_skip = "identity"
if first_layer:
norm_layer = norm_layer1
elif last_layer:
......@@ -341,18 +326,15 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
norm_layer = norm_layer1
block = SphericalFourierNeuralOperatorBlock(
forward_transform,
inverse_transform,
self.trans_down if first_layer else self.trans,
self.itrans_up if last_layer else self.itrans,
self.embed_dim,
self.embed_dim,
operator_type=self.operator_type,
mlp_ratio=mlp_ratio,
drop_rate=drop_rate,
drop_path=dpr[i],
act_layer=self.activation_function,
norm_layer=norm_layer,
inner_skip=inner_skip,
outer_skip=outer_skip,
use_mlp=use_mlp,
)
......
......@@ -239,7 +239,10 @@ class MorletFilterBasis(FilterBasis):
def gaussian_window(self, r: torch.Tensor, width: float = 1.0):
return 1 / (2 * math.pi * width**2) * torch.exp(-0.5 * r**2 / (width**2))
def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float, width: float = 0.25):
def hann_window(self, r: torch.Tensor, width: float = 1.0):
return torch.cos(0.5 * torch.pi * r / width) ** 2
def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float, width: float = 1.0):
"""
Computes the index set that falls into the isotropic kernel's support and returns both indices and values.
"""
......@@ -252,19 +255,20 @@ class MorletFilterBasis(FilterBasis):
# get relevant indices
iidx = torch.argwhere((r <= r_cutoff) & torch.full_like(ikernel, True, dtype=torch.bool))
# get x and y
x = r * torch.sin(phi) / r_cutoff
y = r * torch.cos(phi) / r_cutoff
harmonic = torch.where(nkernel % 2 == 1, torch.sin(torch.ceil(nkernel / 2) * math.pi * x / width), torch.cos(torch.ceil(nkernel / 2) * math.pi * x / width))
harmonic *= torch.where(mkernel % 2 == 1, torch.sin(torch.ceil(mkernel / 2) * math.pi * y / width), torch.cos(torch.ceil(mkernel / 2) * math.pi * y / width))
# get corresponding r, phi, x and y coordinates
r = r[iidx[:, 1], iidx[:, 2]] / r_cutoff
phi = phi[iidx[:, 1], iidx[:, 2]]
x = r * torch.sin(phi)
y = r * torch.cos(phi)
n = nkernel[iidx[:, 0], 0, 0]
m = mkernel[iidx[:, 0], 0, 0]
# disk area
# disk_area = 2.0 * math.pi * (1.0 - math.cos(r_cutoff))
disk_area = 1.0
harmonic = torch.where(n % 2 == 1, torch.sin(torch.ceil(n / 2) * math.pi * x / width), torch.cos(torch.ceil(n / 2) * math.pi * x / width))
harmonic *= torch.where(m % 2 == 1, torch.sin(torch.ceil(m / 2) * math.pi * y / width), torch.cos(torch.ceil(m / 2) * math.pi * y / width))
# computes the Gaussian envelope. To ensure that the curve is roughly 0 at the boundary, we rescale the Gaussian by 0.25
vals = self.gaussian_window(r[iidx[:, 1], iidx[:, 2]] / r_cutoff, width=width) * harmonic[iidx[:, 0], iidx[:, 1], iidx[:, 2]] / disk_area
# computes the envelope. To ensure that the curve is roughly 0 at the boundary, we rescale the Gaussian by 0.25
# vals = self.gaussian_window(r, width=width) * harmonic
vals = self.hann_window(r, width=width) * harmonic
return iidx, vals
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment