Unverified Commit 780fd143 authored by Boris Bonev's avatar Boris Bonev Committed by GitHub
Browse files

Improved computation of Morlet filter basis (#65)

* Improved computation of Morlet filter basis and switched to a Hann window.

* addresses #064 and some cleanup
parent 9eea871c
...@@ -2,6 +2,11 @@ ...@@ -2,6 +2,11 @@
## Versioning ## Versioning
### v0.7.5
* New normalization mode `support` for DISCO convolutions
* More efficient computation of Morlet filter basis
* Changed default for Morlet filter basis to a Hann window function
### v0.7.4 ### v0.7.4
* New filter basis normalization in DISCO convolutions * New filter basis normalization in DISCO convolutions
......
...@@ -422,7 +422,6 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0): ...@@ -422,7 +422,6 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
num_layers=4, num_layers=4,
scale_factor=2, scale_factor=2,
embed_dim=32, embed_dim=32,
operator_type="driscoll-healy",
activation_function="gelu", activation_function="gelu",
big_skip=True, big_skip=True,
pos_embed=False, pos_embed=False,
...@@ -437,14 +436,13 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0): ...@@ -437,14 +436,13 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
num_layers=4, num_layers=4,
scale_factor=2, scale_factor=2,
embed_dim=32, embed_dim=32,
operator_type="driscoll-healy",
activation_function="gelu", activation_function="gelu",
big_skip=False, big_skip=True,
pos_embed=False, pos_embed=False,
use_mlp=True, use_mlp=True,
normalization_layer="none", normalization_layer="none",
kernel_shape=[4, 4], kernel_shape=[2, 2],
encoder_kernel_shape=[4, 4], encoder_kernel_shape=[2, 2],
filter_basis_type="morlet", filter_basis_type="morlet",
upsample_sht = True, upsample_sht = True,
) )
...@@ -456,9 +454,8 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0): ...@@ -456,9 +454,8 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
num_layers=4, num_layers=4,
scale_factor=2, scale_factor=2,
embed_dim=32, embed_dim=32,
operator_type="driscoll-healy",
activation_function="gelu", activation_function="gelu",
big_skip=False, big_skip=True,
pos_embed=False, pos_embed=False,
use_mlp=True, use_mlp=True,
normalization_layer="none", normalization_layer="none",
...@@ -501,7 +498,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0): ...@@ -501,7 +498,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
start_time = time.time() start_time = time.time()
print(f"Training {model_name}, single step") print(f"Training {model_name}, single step")
train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=100, loss_fn="l2", enable_amp=enable_amp, log_grads=log_grads) train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=1, loss_fn="l2", enable_amp=enable_amp, log_grads=log_grads)
if nfuture > 0: if nfuture > 0:
print(f'Training {model_name}, {nfuture} step') print(f'Training {model_name}, {nfuture} step')
......
...@@ -34,79 +34,66 @@ import matplotlib.pyplot as plt ...@@ -34,79 +34,66 @@ import matplotlib.pyplot as plt
import cartopy import cartopy
import cartopy.crs as ccrs import cartopy.crs as ccrs
def plot_sphere(data,
fig=None, def plot_sphere(data, fig=None, cmap="RdBu", title=None, colorbar=False, coastlines=False, gridlines=False, central_latitude=0, central_longitude=0, lon=None, lat=None, **kwargs):
cmap="RdBu",
title=None,
colorbar=False,
coastlines=False,
central_latitude=0,
central_longitude=0,
lon=None,
lat=None,
**kwargs):
if fig == None: if fig == None:
fig = plt.figure() fig = plt.figure()
nlat = data.shape[-2] nlat = data.shape[-2]
nlon = data.shape[-1] nlon = data.shape[-1]
if lon is None: if lon is None:
lon = np.linspace(0, 2*np.pi, nlon) lon = np.linspace(0, 2 * np.pi, nlon)
if lat is None: if lat is None:
lat = np.linspace(np.pi/2., -np.pi/2., nlat) lat = np.linspace(np.pi / 2.0, -np.pi / 2.0, nlat)
Lon, Lat = np.meshgrid(lon, lat) Lon, Lat = np.meshgrid(lon, lat)
proj = ccrs.Orthographic(central_longitude=central_longitude, central_latitude=central_latitude) proj = ccrs.Orthographic(central_longitude=central_longitude, central_latitude=central_latitude)
# proj = ccrs.Mollweide(central_longitude=central_longitude) # proj = ccrs.Mollweide(central_longitude=central_longitude)
ax = fig.add_subplot(projection=proj) ax = fig.add_subplot(projection=proj)
Lon = Lon*180/np.pi Lon = Lon * 180 / np.pi
Lat = Lat*180/np.pi Lat = Lat * 180 / np.pi
# contour data over the map. # contour data over the map.
im = ax.pcolormesh(Lon, Lat, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=False, **kwargs) im = ax.pcolormesh(Lon, Lat, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=False, **kwargs)
if coastlines: if coastlines:
ax.add_feature(cartopy.feature.COASTLINE, edgecolor='white', facecolor='none', linewidth=1.5) ax.add_feature(cartopy.feature.COASTLINE, edgecolor="white", facecolor="none", linewidth=1.5)
if gridlines:
gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=False, linewidth=1.5, color="gray", alpha=0.6, linestyle="--")
if colorbar: if colorbar:
plt.colorbar(im) plt.colorbar(im, extend="both")
plt.title(title, y=1.05) plt.title(title, y=1.05)
return im return im
def plot_data(data,
fig=None, def plot_data(data, fig=None, cmap="RdBu", title=None, colorbar=False, coastlines=False, gridlines=False, central_longitude=0, lon=None, lat=None, **kwargs):
cmap="RdBu",
title=None,
colorbar=False,
coastlines=False,
central_longitude=0,
lon=None,
lat=None,
**kwargs):
if fig == None: if fig == None:
fig = plt.figure() fig = plt.figure()
nlat = data.shape[-2] nlat = data.shape[-2]
nlon = data.shape[-1] nlon = data.shape[-1]
if lon is None: if lon is None:
lon = np.linspace(0, 2*np.pi, nlon) lon = np.linspace(0, 2 * np.pi, nlon)
if lat is None: if lat is None:
lat = np.linspace(np.pi/2., -np.pi/2., nlat) lat = np.linspace(np.pi / 2.0, -np.pi / 2.0, nlat)
Lon, Lat = np.meshgrid(lon, lat) Lon, Lat = np.meshgrid(lon, lat)
proj = ccrs.PlateCarree(central_longitude=central_longitude) proj = ccrs.PlateCarree(central_longitude=central_longitude)
# proj = ccrs.Mollweide(central_longitude=central_longitude) # proj = ccrs.Mollweide(central_longitude=central_longitude)
ax = fig.add_subplot(projection=proj) ax = fig.add_subplot(projection=proj)
Lon = Lon*180/np.pi Lon = Lon * 180 / np.pi
Lat = Lat*180/np.pi Lat = Lat * 180 / np.pi
# contour data over the map. # contour data over the map.
im = ax.pcolormesh(Lon, Lat, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=False, **kwargs) im = ax.pcolormesh(Lon, Lat, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=False, **kwargs)
if coastlines: if coastlines:
ax.add_feature(cartopy.feature.COASTLINE, edgecolor='white', facecolor='none', linewidth=1.5) ax.add_feature(cartopy.feature.COASTLINE, edgecolor="white", facecolor="none", linewidth=1.5)
if gridlines:
gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=False, linewidth=1.5, color="gray", alpha=0.6, linestyle="--")
if colorbar: if colorbar:
plt.colorbar(im) plt.colorbar(im, extend="both")
plt.title(title, y=1.05) plt.title(title, y=1.05)
return im return im
\ No newline at end of file
...@@ -156,7 +156,7 @@ ...@@ -156,7 +156,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"model = SFNO(spectral_transform='sht', operator_type='driscoll-healy', img_size=(nlat, nlon), grid=\"equiangular\",\n", "model = SFNO(img_size=(nlat, nlon), grid=\"equiangular\",\n",
" num_layers=4, scale_factor=3, embed_dim=16, big_skip=True, pos_embed=\"lat\", use_mlp=False, normalization_layer=\"none\").to(device)\n" " num_layers=4, scale_factor=3, embed_dim=16, big_skip=True, pos_embed=\"lat\", use_mlp=False, normalization_layer=\"none\").to(device)\n"
] ]
}, },
......
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# #
__version__ = "0.7.4" __version__ = "0.7.5a"
from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT
from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2 from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
......
...@@ -88,6 +88,7 @@ def _normalize_convolution_tensor_s2( ...@@ -88,6 +88,7 @@ def _normalize_convolution_tensor_s2(
# buffer to store intermediate values # buffer to store intermediate values
vnorm = torch.zeros(kernel_size, nlat_out) vnorm = torch.zeros(kernel_size, nlat_out)
support = torch.zeros(kernel_size, nlat_out)
# loop through dimensions to compute the norms # loop through dimensions to compute the norms
for ik in range(kernel_size): for ik in range(kernel_size):
...@@ -100,6 +101,10 @@ def _normalize_convolution_tensor_s2( ...@@ -100,6 +101,10 @@ def _normalize_convolution_tensor_s2(
# vnorm[ik, ilat] = torch.sqrt(torch.sum(psi_vals[iidx].abs().pow(2) * q[iidx])) # vnorm[ik, ilat] = torch.sqrt(torch.sum(psi_vals[iidx].abs().pow(2) * q[iidx]))
vnorm[ik, ilat] = torch.sum(psi_vals[iidx].abs() * q[iidx]) vnorm[ik, ilat] = torch.sum(psi_vals[iidx].abs() * q[iidx])
# compute the support
support[ik, ilat] = torch.sum(q[iidx])
# loop over values and renormalize # loop over values and renormalize
for ik in range(kernel_size): for ik in range(kernel_size):
for ilat in range(nlat_out): for ilat in range(nlat_out):
...@@ -110,6 +115,8 @@ def _normalize_convolution_tensor_s2( ...@@ -110,6 +115,8 @@ def _normalize_convolution_tensor_s2(
val = vnorm[ik, ilat] val = vnorm[ik, ilat]
elif basis_norm_mode == "mean": elif basis_norm_mode == "mean":
val = vnorm[ik, :].mean() val = vnorm[ik, :].mean()
elif basis_norm_mode == "support":
val = support[ik, ilat]
elif basis_norm_mode == "none": elif basis_norm_mode == "none":
val = 1.0 val = 1.0
else: else:
......
...@@ -148,14 +148,13 @@ class SphericalNeuralOperatorBlock(nn.Module): ...@@ -148,14 +148,13 @@ class SphericalNeuralOperatorBlock(nn.Module):
input_dim, input_dim,
output_dim, output_dim,
conv_type="local", conv_type="local",
operator_type="driscoll-healy",
mlp_ratio=2.0, mlp_ratio=2.0,
drop_rate=0.0, drop_rate=0.0,
drop_path=0.0, drop_path=0.0,
act_layer=nn.ReLU, act_layer=nn.GELU,
norm_layer=nn.Identity, norm_layer=nn.Identity,
inner_skip="None", inner_skip="none",
outer_skip="linear", outer_skip="identity",
use_mlp=True, use_mlp=True,
disco_kernel_shape=[3, 4], disco_kernel_shape=[3, 4],
disco_basis_type="piecewise linear", disco_basis_type="piecewise linear",
...@@ -185,7 +184,7 @@ class SphericalNeuralOperatorBlock(nn.Module): ...@@ -185,7 +184,7 @@ class SphericalNeuralOperatorBlock(nn.Module):
theta_cutoff=4.0 * (disco_kernel_shape[0] + 1) * torch.pi / float(inverse_transform.nlat - 1), theta_cutoff=4.0 * (disco_kernel_shape[0] + 1) * torch.pi / float(inverse_transform.nlat - 1),
) )
elif conv_type == "global": elif conv_type == "global":
self.global_conv = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain_factor, operator_type=operator_type, bias=False) self.global_conv = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain_factor, bias=False)
else: else:
raise ValueError(f"Unknown convolution type {conv_type}") raise ValueError(f"Unknown convolution type {conv_type}")
...@@ -274,8 +273,6 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -274,8 +273,6 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
----------- -----------
img_shape : tuple, optional img_shape : tuple, optional
Shape of the input channels, by default (128, 256) Shape of the input channels, by default (128, 256)
operator_type : str, optional
Type of operator to use ('driscoll-healy', 'diagonal'), by default "driscoll-healy"
kernel_shape: tuple, int kernel_shape: tuple, int
scale_factor : int, optional scale_factor : int, optional
Scale factor to use, by default 3 Scale factor to use, by default 3
...@@ -314,7 +311,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -314,7 +311,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
Example Example
----------- -----------
>>> model = SphericalFourierNeuralOperatorNet( >>> model = LocalSphericalNeuralOperator(
... img_shape=(128, 256), ... img_shape=(128, 256),
... scale_factor=4, ... scale_factor=4,
... in_chans=2, ... in_chans=2,
...@@ -340,15 +337,14 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -340,15 +337,14 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
def __init__( def __init__(
self, self,
img_size=(128, 256), img_size=(128, 256),
operator_type="driscoll-healy",
grid="equiangular", grid="equiangular",
grid_internal="legendre-gauss", grid_internal="legendre-gauss",
scale_factor=4, scale_factor=3,
in_chans=3, in_chans=3,
out_chans=3, out_chans=3,
embed_dim=256, embed_dim=256,
num_layers=4, num_layers=4,
activation_function="relu", activation_function="gelu",
kernel_shape=[3, 4], kernel_shape=[3, 4],
encoder_kernel_shape=[3, 4], encoder_kernel_shape=[3, 4],
filter_basis_type="piecewise linear", filter_basis_type="piecewise linear",
...@@ -365,7 +361,6 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -365,7 +361,6 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
): ):
super().__init__() super().__init__()
self.operator_type = operator_type
self.img_size = img_size self.img_size = img_size
self.grid = grid self.grid = grid
self.grid_internal = grid_internal self.grid_internal = grid_internal
...@@ -438,10 +433,12 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -438,10 +433,12 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
bias=False, bias=False,
) )
# prepare the SHT # compute the modes for the sht
modes_lat = int(self.h * self.hard_thresholding_fraction) modes_lat = self.h
modes_lon = int(self.w // 2 * self.hard_thresholding_fraction) # due to some spectral artifacts with cufft, we substract one mode here
modes_lat = modes_lon = min(modes_lat, modes_lon) modes_lon = (self.w // 2 + 1) - 1
modes_lat = modes_lon = int(min(modes_lat, modes_lon) * self.hard_thresholding_fraction)
self.trans = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=grid_internal).float() self.trans = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=grid_internal).float()
self.itrans = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=grid_internal).float() self.itrans = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=grid_internal).float()
...@@ -451,9 +448,6 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -451,9 +448,6 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
first_layer = i == 0 first_layer = i == 0
last_layer = i == self.num_layers - 1 last_layer = i == self.num_layers - 1
inner_skip = "none"
outer_skip = "identity"
if first_layer: if first_layer:
norm_layer = norm_layer1 norm_layer = norm_layer1
elif last_layer: elif last_layer:
...@@ -467,14 +461,11 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -467,14 +461,11 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
self.embed_dim, self.embed_dim,
self.embed_dim, self.embed_dim,
conv_type="global" if i % 2 == 0 else "local", conv_type="global" if i % 2 == 0 else "local",
operator_type=self.operator_type,
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
drop_rate=drop_rate, drop_rate=drop_rate,
drop_path=dpr[i], drop_path=dpr[i],
act_layer=self.activation_function, act_layer=self.activation_function,
norm_layer=norm_layer, norm_layer=norm_layer,
inner_skip=inner_skip,
outer_skip=outer_skip,
use_mlp=use_mlp, use_mlp=use_mlp,
disco_kernel_shape=kernel_shape, disco_kernel_shape=kernel_shape,
disco_basis_type=filter_basis_type, disco_basis_type=filter_basis_type,
......
...@@ -31,8 +31,7 @@ ...@@ -31,8 +31,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch_harmonics import RealSHT, InverseRealSHT
from torch_harmonics import *
from ._layers import * from ._layers import *
...@@ -50,17 +49,13 @@ class SphericalFourierNeuralOperatorBlock(nn.Module): ...@@ -50,17 +49,13 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
inverse_transform, inverse_transform,
input_dim, input_dim,
output_dim, output_dim,
operator_type="driscoll-healy",
mlp_ratio=2.0, mlp_ratio=2.0,
drop_rate=0.0, drop_rate=0.0,
drop_path=0.0, drop_path=0.0,
act_layer=nn.ReLU, act_layer=nn.GELU,
norm_layer=nn.Identity, norm_layer=nn.Identity,
factorization=None, inner_skip="none",
separable=False, outer_skip="identity",
rank=128,
inner_skip="linear",
outer_skip=None,
use_mlp=True, use_mlp=True,
): ):
super().__init__() super().__init__()
...@@ -73,7 +68,7 @@ class SphericalFourierNeuralOperatorBlock(nn.Module): ...@@ -73,7 +68,7 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
if inner_skip == "linear" or inner_skip == "identity": if inner_skip == "linear" or inner_skip == "identity":
gain_factor /= 2.0 gain_factor /= 2.0
self.global_conv = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain_factor, operator_type=operator_type, bias=False) self.global_conv = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain_factor, bias=False)
if inner_skip == "linear": if inner_skip == "linear":
self.inner_skip = nn.Conv2d(input_dim, output_dim, 1, 1) self.inner_skip = nn.Conv2d(input_dim, output_dim, 1, 1)
...@@ -148,8 +143,6 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -148,8 +143,6 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
---------- ----------
img_shape : tuple, optional img_shape : tuple, optional
Shape of the input channels, by default (128, 256) Shape of the input channels, by default (128, 256)
operator_type : str, optional
Type of operator to use ('driscoll-healy', 'diagonal'), by default "driscoll-healy"
scale_factor : int, optional scale_factor : int, optional
Scale factor to use, by default 3 Scale factor to use, by default 3
in_chans : int, optional in_chans : int, optional
...@@ -204,7 +197,6 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -204,7 +197,6 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
def __init__( def __init__(
self, self,
img_size=(128, 256), img_size=(128, 256),
operator_type="driscoll-healy",
grid="equiangular", grid="equiangular",
grid_internal="legendre-gauss", grid_internal="legendre-gauss",
scale_factor=3, scale_factor=3,
...@@ -212,7 +204,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -212,7 +204,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
out_chans=3, out_chans=3,
embed_dim=256, embed_dim=256,
num_layers=4, num_layers=4,
activation_function="relu", activation_function="gelu",
encoder_layers=1, encoder_layers=1,
use_mlp=True, use_mlp=True,
mlp_ratio=2.0, mlp_ratio=2.0,
...@@ -227,7 +219,6 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -227,7 +219,6 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
super().__init__() super().__init__()
self.operator_type = operator_type
self.img_size = img_size self.img_size = img_size
self.grid = grid self.grid = grid
self.grid_internal = grid_internal self.grid_internal = grid_internal
...@@ -312,7 +303,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -312,7 +303,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
# compute the modes for the sht # compute the modes for the sht
modes_lat = self.h modes_lat = self.h
# due to some spectral artifacts with cufft, we substract one mode here # due to some spectral artifacts with cufft, we substract one mode here
modes_lon = (self.w // 2 + 1) -1 modes_lon = (self.w // 2 + 1) - 1
modes_lat = modes_lon = int(min(modes_lat, modes_lon) * self.hard_thresholding_fraction) modes_lat = modes_lon = int(min(modes_lat, modes_lon) * self.hard_thresholding_fraction)
...@@ -327,12 +318,6 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -327,12 +318,6 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
first_layer = i == 0 first_layer = i == 0
last_layer = i == self.num_layers - 1 last_layer = i == self.num_layers - 1
forward_transform = self.trans_down if first_layer else self.trans
inverse_transform = self.itrans_up if last_layer else self.itrans
inner_skip = "none"
outer_skip = "identity"
if first_layer: if first_layer:
norm_layer = norm_layer1 norm_layer = norm_layer1
elif last_layer: elif last_layer:
...@@ -341,18 +326,15 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -341,18 +326,15 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
norm_layer = norm_layer1 norm_layer = norm_layer1
block = SphericalFourierNeuralOperatorBlock( block = SphericalFourierNeuralOperatorBlock(
forward_transform, self.trans_down if first_layer else self.trans,
inverse_transform, self.itrans_up if last_layer else self.itrans,
self.embed_dim, self.embed_dim,
self.embed_dim, self.embed_dim,
operator_type=self.operator_type,
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
drop_rate=drop_rate, drop_rate=drop_rate,
drop_path=dpr[i], drop_path=dpr[i],
act_layer=self.activation_function, act_layer=self.activation_function,
norm_layer=norm_layer, norm_layer=norm_layer,
inner_skip=inner_skip,
outer_skip=outer_skip,
use_mlp=use_mlp, use_mlp=use_mlp,
) )
......
...@@ -239,7 +239,10 @@ class MorletFilterBasis(FilterBasis): ...@@ -239,7 +239,10 @@ class MorletFilterBasis(FilterBasis):
def gaussian_window(self, r: torch.Tensor, width: float = 1.0): def gaussian_window(self, r: torch.Tensor, width: float = 1.0):
return 1 / (2 * math.pi * width**2) * torch.exp(-0.5 * r**2 / (width**2)) return 1 / (2 * math.pi * width**2) * torch.exp(-0.5 * r**2 / (width**2))
def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float, width: float = 0.25): def hann_window(self, r: torch.Tensor, width: float = 1.0):
return torch.cos(0.5 * torch.pi * r / width) ** 2
def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float, width: float = 1.0):
""" """
Computes the index set that falls into the isotropic kernel's support and returns both indices and values. Computes the index set that falls into the isotropic kernel's support and returns both indices and values.
""" """
...@@ -252,19 +255,20 @@ class MorletFilterBasis(FilterBasis): ...@@ -252,19 +255,20 @@ class MorletFilterBasis(FilterBasis):
# get relevant indices # get relevant indices
iidx = torch.argwhere((r <= r_cutoff) & torch.full_like(ikernel, True, dtype=torch.bool)) iidx = torch.argwhere((r <= r_cutoff) & torch.full_like(ikernel, True, dtype=torch.bool))
# get x and y # get corresponding r, phi, x and y coordinates
x = r * torch.sin(phi) / r_cutoff r = r[iidx[:, 1], iidx[:, 2]] / r_cutoff
y = r * torch.cos(phi) / r_cutoff phi = phi[iidx[:, 1], iidx[:, 2]]
x = r * torch.sin(phi)
harmonic = torch.where(nkernel % 2 == 1, torch.sin(torch.ceil(nkernel / 2) * math.pi * x / width), torch.cos(torch.ceil(nkernel / 2) * math.pi * x / width)) y = r * torch.cos(phi)
harmonic *= torch.where(mkernel % 2 == 1, torch.sin(torch.ceil(mkernel / 2) * math.pi * y / width), torch.cos(torch.ceil(mkernel / 2) * math.pi * y / width)) n = nkernel[iidx[:, 0], 0, 0]
m = mkernel[iidx[:, 0], 0, 0]
# disk area harmonic = torch.where(n % 2 == 1, torch.sin(torch.ceil(n / 2) * math.pi * x / width), torch.cos(torch.ceil(n / 2) * math.pi * x / width))
# disk_area = 2.0 * math.pi * (1.0 - math.cos(r_cutoff)) harmonic *= torch.where(m % 2 == 1, torch.sin(torch.ceil(m / 2) * math.pi * y / width), torch.cos(torch.ceil(m / 2) * math.pi * y / width))
disk_area = 1.0
# computes the Gaussian envelope. To ensure that the curve is roughly 0 at the boundary, we rescale the Gaussian by 0.25 # computes the envelope. To ensure that the curve is roughly 0 at the boundary, we rescale the Gaussian by 0.25
vals = self.gaussian_window(r[iidx[:, 1], iidx[:, 2]] / r_cutoff, width=width) * harmonic[iidx[:, 0], iidx[:, 1], iidx[:, 2]] / disk_area # vals = self.gaussian_window(r, width=width) * harmonic
vals = self.hann_window(r, width=width) * harmonic
return iidx, vals return iidx, vals
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment