Unverified Commit 4dadf551 authored by Boris Bonev's avatar Boris Bonev Committed by GitHub
Browse files

architectural improvements to sfno (#18)

Major Cleanup of SFNO. Retiring non-linear architecture and fixing initialization. Adding scripts for training and validation.
parent 08108157
This diff is collapsed.
This diff is collapsed.
......@@ -108,16 +108,16 @@
"output_type": "stream",
"text": [
"/home/bbonev/.zshenv:export:2: not valid in this context: :/usr/local/cuda-11.7/lib64\n",
"--2023-10-24 18:08:10-- https://astropedia.astrogeology.usgs.gov/download/Mars/GlobalSurveyor/MOLA/thumbs/Mars_MGS_MOLA_DEM_mosaic_global_1024.jpg\n",
"--2023-10-30 18:00:14-- https://astropedia.astrogeology.usgs.gov/download/Mars/GlobalSurveyor/MOLA/thumbs/Mars_MGS_MOLA_DEM_mosaic_global_1024.jpg\n",
"Resolving astropedia.astrogeology.usgs.gov (astropedia.astrogeology.usgs.gov)... 137.227.239.81, 2001:49c8:c000:122d::81\n",
"Connecting to astropedia.astrogeology.usgs.gov (astropedia.astrogeology.usgs.gov)|137.227.239.81|:443... connected.\n",
"HTTP request sent, awaiting response... 200 \n",
"Length: 55192 (54K) [image/jpeg]\n",
"Saving to: ‘./data/mola_topo.jpg’\n",
"\n",
"./data/mola_topo.jp 100%[===================>] 53.90K 161KB/s in 0.3s \n",
"./data/mola_topo.jp 100%[===================>] 53.90K 154KB/s in 0.3s \n",
"\n",
"2023-10-24 18:08:12 (161 KB/s) - ‘./data/mola_topo.jpg’ saved [55192/55192]\n",
"2023-10-30 18:00:15 (154 KB/s) - ‘./data/mola_topo.jpg’ saved [55192/55192]\n",
"\n"
]
}
......@@ -142,7 +142,7 @@
{
"data": {
"text/plain": [
"<cartopy.mpl.geocollection.GeoQuadMesh at 0x7f991436a230>"
"<cartopy.mpl.geocollection.GeoQuadMesh at 0x7f49e4952380>"
]
},
"execution_count": 4,
......@@ -178,46 +178,46 @@
"name": "stdout",
"output_type": "stream",
"text": [
"iter: 0, loss: 504.56821962467404\n",
"iter: 1, loss: 0.00802396426749307\n",
"iter: 2, loss: 0.008023963812431065\n",
"iter: 3, loss: 0.008023963784318747\n",
"iter: 4, loss: 0.008023962882019332\n",
"iter: 5, loss: 0.008023963275982648\n",
"iter: 6, loss: 0.008023962667711045\n",
"iter: 7, loss: 0.008023963782547126\n",
"iter: 8, loss: 0.008023963340130377\n",
"iter: 9, loss: 0.008023963717686556\n",
"iter: 10, loss: 0.008023963189075497\n",
"iter: 11, loss: 0.008023963662749444\n",
"iter: 12, loss: 0.008023964217954163\n",
"iter: 13, loss: 0.008023963645109735\n",
"iter: 14, loss: 0.008023963884895183\n",
"iter: 15, loss: 0.008023963417559243\n",
"iter: 16, loss: 0.008023963709925376\n",
"iter: 17, loss: 0.008023963864442468\n",
"iter: 18, loss: 0.008023963186281617\n",
"iter: 19, loss: 0.008023962844331859\n",
"iter: 20, loss: 0.008023963578808139\n",
"iter: 21, loss: 0.00802396382884392\n",
"iter: 22, loss: 0.008023963250166802\n",
"iter: 23, loss: 0.008023963424637747\n",
"iter: 24, loss: 0.008023964456974\n",
"iter: 25, loss: 0.00802396354425496\n",
"iter: 26, loss: 0.008023964264189777\n",
"iter: 27, loss: 0.008023963659278077\n",
"iter: 28, loss: 0.008023963463597659\n",
"iter: 29, loss: 0.008023963289571119\n",
"iter: 30, loss: 0.008023964016864156\n",
"iter: 31, loss: 0.008023963531573766\n",
"iter: 32, loss: 0.008023963437000084\n",
"iter: 33, loss: 0.008023964116843215\n",
"iter: 34, loss: 0.008023962721410783\n",
"iter: 35, loss: 0.008023963977951472\n",
"iter: 36, loss: 0.008023963204566793\n",
"iter: 37, loss: 0.00802396369010344\n",
"iter: 38, loss: 0.008023963907011133\n",
"iter: 39, loss: 0.008023963523688133\n"
"iter: 0, loss: 453.0968931302793\n",
"iter: 1, loss: 0.008023964326606358\n",
"iter: 2, loss: 0.008023963388341868\n",
"iter: 3, loss: 0.008023963340660247\n",
"iter: 4, loss: 0.008023963596959654\n",
"iter: 5, loss: 0.008023963735337598\n",
"iter: 6, loss: 0.008023964260612844\n",
"iter: 7, loss: 0.008023964042363394\n",
"iter: 8, loss: 0.00802396368406042\n",
"iter: 9, loss: 0.008023962714947052\n",
"iter: 10, loss: 0.008023963489819921\n",
"iter: 11, loss: 0.008023963701078593\n",
"iter: 12, loss: 0.008023962923266034\n",
"iter: 13, loss: 0.008023964198518512\n",
"iter: 14, loss: 0.008023962813486126\n",
"iter: 15, loss: 0.008023964110803488\n",
"iter: 16, loss: 0.00802396403813473\n",
"iter: 17, loss: 0.008023963786036484\n",
"iter: 18, loss: 0.008023964195574898\n",
"iter: 19, loss: 0.008023963516124565\n",
"iter: 20, loss: 0.008023964508201684\n",
"iter: 21, loss: 0.008023963767474551\n",
"iter: 22, loss: 0.008023963648388185\n",
"iter: 23, loss: 0.008023963972575866\n",
"iter: 24, loss: 0.008023964038780116\n",
"iter: 25, loss: 0.008023963707541834\n",
"iter: 26, loss: 0.008023963269911932\n",
"iter: 27, loss: 0.008023963391352053\n",
"iter: 28, loss: 0.008023963414851426\n",
"iter: 29, loss: 0.008023964147064296\n",
"iter: 30, loss: 0.008023963760174639\n",
"iter: 31, loss: 0.008023963924162339\n",
"iter: 32, loss: 0.00802396360354566\n",
"iter: 33, loss: 0.00802396407422616\n",
"iter: 34, loss: 0.008023962918493041\n",
"iter: 35, loss: 0.008023963622013491\n",
"iter: 36, loss: 0.0080239635670241\n",
"iter: 37, loss: 0.008023963871070301\n",
"iter: 38, loss: 0.008023963587685968\n",
"iter: 39, loss: 0.008023963496770136\n"
]
}
],
......@@ -271,7 +271,7 @@
{
"data": {
"text/plain": [
"<cartopy.mpl.geocollection.GeoQuadMesh at 0x7f99039db190>"
"<cartopy.mpl.geocollection.GeoQuadMesh at 0x7f49d214b9a0>"
]
},
"execution_count": 6,
......
......@@ -31,6 +31,7 @@
import numpy as np
import matplotlib.pyplot as plt
import cartopy
import cartopy.crs as ccrs
def plot_sphere(data,
......@@ -38,10 +39,12 @@ def plot_sphere(data,
cmap="RdBu",
title=None,
colorbar=False,
coastlines=False,
central_latitude=20,
central_longitude=20,
lon=None,
lat=None):
lat=None,
**kwargs):
if fig == None:
fig = plt.figure()
......@@ -61,8 +64,9 @@ def plot_sphere(data,
Lat = Lat*180/np.pi
# contour data over the map.
im = ax.pcolormesh(Lon, Lat, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=False)
# ax.add_feature(cartopy.feature.COASTLINE, edgecolor='white', facecolor='none', linewidth=1.5)
im = ax.pcolormesh(Lon, Lat, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=False, **kwargs)
if coastlines:
ax.add_feature(cartopy.feature.COASTLINE, edgecolor='white', facecolor='none', linewidth=1.5)
if colorbar:
plt.colorbar(im)
plt.title(title, y=1.05)
......@@ -76,7 +80,8 @@ def plot_data(data,
title=None,
colorbar=False,
lon=None,
lat=None):
lat=None,
**kwargs):
if fig == None:
fig = plt.figure()
......@@ -90,7 +95,8 @@ def plot_data(data,
fig = plt.figure(figsize=(10, 5))
ax = fig.add_subplot(1, 1, 1, projection=projection)
im = ax.pcolormesh(Lon, Lat, data, cmap=cmap)
im = ax.pcolormesh(Lon, Lat, data, cmap=cmap, **kwargs)
if colorbar:
plt.colorbar(im)
plt.title(title, y=1.05)
......
This diff is collapsed.
......@@ -43,8 +43,8 @@ from .activations import *
# # import FactorizedTensor from tensorly for tensorized operations
# import tensorly as tl
# from tensorly.plugins import use_opt_einsum
# tl.set_backend('pytorch')
# use_opt_einsum('optimal')
# tl.set_backend("pytorch")
# use_opt_einsum("optimal")
from tltorch.factorized_tensors.core import FactorizedTensor
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
......@@ -137,21 +137,37 @@ class MLP(nn.Module):
in_features,
hidden_features = None,
out_features = None,
act_layer = nn.GELU,
output_bias = True,
act_layer = nn.ReLU,
output_bias = False,
drop_rate = 0.,
checkpointing = False):
checkpointing = False,
gain = 1.0):
super(MLP, self).__init__()
self.checkpointing = checkpointing
out_features = out_features or in_features
hidden_features = hidden_features or in_features
# Fist dense layer
fc1 = nn.Conv2d(in_features, hidden_features, 1, bias=True)
# ln1 = norm_layer(num_features=hidden_features)
# initialize the weights correctly
scale = math.sqrt(2.0 / in_features)
nn.init.normal_(fc1.weight, mean=0., std=scale)
if fc1.bias is not None:
nn.init.constant_(fc1.bias, 0.0)
# activation
act = act_layer()
fc2 = nn.Conv2d(hidden_features, out_features, 1, bias = output_bias)
# output layer
fc2 = nn.Conv2d(hidden_features, out_features, 1, bias=output_bias)
# gain factor for the output determines the scaling of the output init
scale = math.sqrt(gain / hidden_features)
nn.init.normal_(fc2.weight, mean=0., std=scale)
if fc2.bias is not None:
nn.init.constant_(fc2.bias, 0.0)
if drop_rate > 0.:
drop = nn.Dropout(drop_rate)
drop = nn.Dropout2d(drop_rate)
self.fwd = nn.Sequential(fc1, act, drop, fc2, drop)
else:
self.fwd = nn.Sequential(fc1, act, fc2)
......@@ -218,15 +234,12 @@ class SpectralConvS2(nn.Module):
inverse_transform,
in_channels,
out_channels,
scale = 'auto',
operator_type = 'driscoll-healy',
gain = 2.,
operator_type = "driscoll-healy",
lr_scale_exponent = 0,
bias = False):
super(SpectralConvS2, self).__init__()
if scale == 'auto':
scale = (2 / in_channels)**0.5
self.forward_transform = forward_transform
self.inverse_transform = inverse_transform
......@@ -242,33 +255,31 @@ class SpectralConvS2(nn.Module):
assert self.inverse_transform.lmax == self.modes_lat
assert self.inverse_transform.mmax == self.modes_lon
weight_shape = [in_channels, out_channels]
weight_shape = [out_channels, in_channels]
if self.operator_type == 'diagonal':
if self.operator_type == "diagonal":
weight_shape += [self.modes_lat, self.modes_lon]
from .contractions import contract_diagonal as _contract
elif self.operator_type == 'block-diagonal':
elif self.operator_type == "block-diagonal":
weight_shape += [self.modes_lat, self.modes_lon, self.modes_lon]
from .contractions import contract_blockdiag as _contract
elif self.operator_type == 'driscoll-healy':
elif self.operator_type == "driscoll-healy":
weight_shape += [self.modes_lat]
from .contractions import contract_dhconv as _contract
else:
raise NotImplementedError(f"Unkonw operator type f{self.operator_type}")
# form weight tensors
self.weight = nn.Parameter(scale * torch.randn(*weight_shape, 2))
# rescale the learning rate for better training of spectral parameters
lr_scale = (torch.arange(self.modes_lat)+1).reshape(-1, 1)**(lr_scale_exponent)
self.register_buffer("lr_scale", lr_scale)
# self.weight.register_hook(lambda grad: self.lr_scale*grad)
scale = math.sqrt(gain / in_channels) * torch.ones(self.modes_lat, 2)
scale[0] *= math.sqrt(2)
self.weight = nn.Parameter(scale * torch.view_as_real(torch.randn(*weight_shape, dtype=torch.complex64)))
# self.weight = nn.Parameter(scale * torch.randn(*weight_shape, 2))
# get the right contraction function
self._contract = _contract
if bias:
self.bias = nn.Parameter(scale * torch.randn(1, out_channels, 1, 1))
self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
def forward(self, x):
......@@ -290,7 +301,7 @@ class SpectralConvS2(nn.Module):
with amp.autocast(enabled=False):
x = self.inverse_transform(x)
if hasattr(self, 'bias'):
if hasattr(self, "bias"):
x = x + self.bias
x = x.type(dtype)
......@@ -306,19 +317,16 @@ class FactorizedSpectralConvS2(nn.Module):
inverse_transform,
in_channels,
out_channels,
scale = 'auto',
operator_type = 'driscoll-healy',
gain = 2.,
operator_type = "driscoll-healy",
rank = 0.2,
factorization = None,
separable = False,
implementation = 'factorized',
implementation = "factorized",
decomposition_kwargs=dict(),
bias = False):
super(SpectralConvS2, self).__init__()
if scale == 'auto':
scale = (2 / in_channels)**0.5
self.forward_transform = forward_transform
self.inverse_transform = inverse_transform
......@@ -330,9 +338,9 @@ class FactorizedSpectralConvS2(nn.Module):
# Make sure we are using a Complex Factorized Tensor
if factorization is None:
factorization = 'Dense' # No factorization
if not factorization.lower().startswith('complex'):
factorization = f'Complex{factorization}'
factorization = "Dense" # No factorization
if not factorization.lower().startswith("complex"):
factorization = f"Complex{factorization}"
# remember factorization details
self.operator_type = operator_type
......@@ -343,16 +351,16 @@ class FactorizedSpectralConvS2(nn.Module):
assert self.inverse_transform.lmax == self.modes_lat
assert self.inverse_transform.mmax == self.modes_lon
weight_shape = [in_channels]
weight_shape = [out_channels]
if not self.separable:
weight_shape += [out_channels]
weight_shape += [in_channels]
if self.operator_type == 'diagonal':
if self.operator_type == "diagonal":
weight_shape += [self.modes_lat, self.modes_lon]
elif self.operator_type == 'block-diagonal':
elif self.operator_type == "block-diagonal":
weight_shape += [self.modes_lat, self.modes_lon, self.modes_lon]
elif self.operator_type == 'driscoll-healy':
elif self.operator_type == "driscoll-healy":
weight_shape += [self.modes_lat]
else:
raise NotImplementedError(f"Unkonw operator type f{self.operator_type}")
......@@ -362,6 +370,7 @@ class FactorizedSpectralConvS2(nn.Module):
fixed_rank_modes=False, **decomposition_kwargs)
# initialization of weights
scale = math.sqrt(gain / in_channels)
self.weight.normal_(0, scale)
# get the right contraction function
......@@ -369,7 +378,7 @@ class FactorizedSpectralConvS2(nn.Module):
self._contract = get_contract_fun(self.weight, implementation=implementation, separable=separable)
if bias:
self.bias = nn.Parameter(scale * torch.randn(1, out_channels, 1, 1))
self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
def forward(self, x):
......@@ -388,242 +397,8 @@ class FactorizedSpectralConvS2(nn.Module):
with amp.autocast(enabled=False):
x = self.inverse_transform(x)
if hasattr(self, 'bias'):
if hasattr(self, "bias"):
x = x + self.bias
x = x.type(dtype)
return x, residual
class SpectralAttention2d(nn.Module):
"""
geometrical Spectral Attention layer
"""
def __init__(self,
forward_transform,
inverse_transform,
embed_dim,
sparsity_threshold = 0.0,
hidden_size_factor = 2,
use_complex_kernels = False,
complex_activation = 'real',
bias = False,
spectral_layers = 1,
drop_rate = 0.):
super(SpectralAttention2d, self).__init__()
self.embed_dim = embed_dim
self.sparsity_threshold = sparsity_threshold
self.hidden_size = int(hidden_size_factor * self.embed_dim)
self.scale = 1 / embed_dim**2
self.mul_add_handle = compl_muladd2d_fwd_c if use_complex_kernels else compl_muladd2d_fwd
self.mul_handle = compl_mul2d_fwd_c if use_complex_kernels else compl_mul2d_fwd
self.spectral_layers = spectral_layers
self.modes_lat = forward_transform.lmax
self.modes_lon = forward_transform.mmax
# only storing the forward handle to be able to call it
self.forward_transform = forward_transform
self.inverse_transform = inverse_transform
self.scale_residual = (self.forward_transform.nlat != self.inverse_transform.nlat) \
or (self.forward_transform.nlon != self.inverse_transform.nlon)
assert inverse_transform.lmax == self.modes_lat
assert inverse_transform.mmax == self.modes_lon
# weights
w = [self.scale * torch.randn(self.embed_dim, self.hidden_size, 2)]
for l in range(1, self.spectral_layers):
w.append(self.scale * torch.randn(self.hidden_size, self.hidden_size, 2))
self.w = nn.ParameterList(w)
if bias:
self.b = nn.ParameterList([self.scale * torch.randn(self.hidden_size, 1, 2) for _ in range(self.spectral_layers)])
self.wout = nn.Parameter(self.scale * torch.randn(self.hidden_size, self.embed_dim, 2))
self.drop = nn.Dropout(drop_rate) if drop_rate > 0. else nn.Identity()
self.activations = nn.ModuleList([])
for l in range(0, self.spectral_layers):
self.activations.append(ComplexReLU(mode=complex_activation, bias_shape=(self.hidden_size, 1, 1), scale=self.scale))
def forward_mlp(self, x):
x = torch.view_as_real(x)
xr = x
for l in range(self.spectral_layers):
if hasattr(self, 'b'):
xr = self.mul_add_handle(xr, self.w[l], self.b[l])
else:
xr = self.mul_handle(xr, self.w[l])
xr = torch.view_as_complex(xr)
xr = self.activations[l](xr)
xr = self.drop(xr)
xr = torch.view_as_real(xr)
x = self.mul_handle(xr, self.wout)
x = torch.view_as_complex(x)
return x
def forward(self, x):
dtype = x.dtype
x = x.float()
residual = x
with amp.autocast(enabled=False):
x = self.forward_transform(x)
if self.scale_residual:
residual = self.inverse_transform(x)
x = self.forward_mlp(x)
with amp.autocast(enabled=False):
x = self.inverse_transform(x)
x = x.type(dtype)
return x, residual
class SpectralAttentionS2(nn.Module):
"""
Spherical non-linear FNO layer
"""
def __init__(self,
forward_transform,
inverse_transform,
embed_dim,
operator_type = 'diagonal',
sparsity_threshold = 0.0,
hidden_size_factor = 2,
complex_activation = 'real',
scale = 'auto',
bias = False,
spectral_layers = 1,
drop_rate = 0.):
super(SpectralAttentionS2, self).__init__()
self.embed_dim = embed_dim
self.sparsity_threshold = sparsity_threshold
self.operator_type = operator_type
self.spectral_layers = spectral_layers
if scale == 'auto':
self.scale = (1 / (embed_dim * embed_dim))
self.modes_lat = forward_transform.lmax
self.modes_lon = forward_transform.mmax
# only storing the forward handle to be able to call it
self.forward_transform = forward_transform
self.inverse_transform = inverse_transform
self.scale_residual = (self.forward_transform.nlat != self.inverse_transform.nlat) \
or (self.forward_transform.nlon != self.inverse_transform.nlon)
assert inverse_transform.lmax == self.modes_lat
assert inverse_transform.mmax == self.modes_lon
hidden_size = int(hidden_size_factor * self.embed_dim)
if operator_type == 'diagonal':
self.mul_add_handle = compl_muladd2d_fwd
self.mul_handle = compl_mul2d_fwd
# weights
w = [self.scale * torch.randn(self.embed_dim, hidden_size, 2)]
for l in range(1, self.spectral_layers):
w.append(self.scale * torch.randn(hidden_size, hidden_size, 2))
self.w = nn.ParameterList(w)
self.wout = nn.Parameter(self.scale * torch.randn(hidden_size, self.embed_dim, 2))
if bias:
self.b = nn.ParameterList([self.scale * torch.randn(hidden_size, 1, 1, 2) for _ in range(self.spectral_layers)])
self.activations = nn.ModuleList([])
for l in range(0, self.spectral_layers):
self.activations.append(ComplexReLU(mode=complex_activation, bias_shape=(hidden_size, 1, 1), scale=self.scale))
elif operator_type == 'driscoll-healy':
self.mul_add_handle = compl_exp_muladd2d_fwd
self.mul_handle = compl_exp_mul2d_fwd
# weights
w = [self.scale * torch.randn(self.modes_lat, self.embed_dim, hidden_size, 2)]
for l in range(1, self.spectral_layers):
w.append(self.scale * torch.randn(self.modes_lat, hidden_size, hidden_size, 2))
self.w = nn.ParameterList(w)
if bias:
self.b = nn.ParameterList([self.scale * torch.randn(hidden_size, 1, 1, 2) for _ in range(self.spectral_layers)])
self.wout = nn.Parameter(self.scale * torch.randn(self.modes_lat, hidden_size, self.embed_dim, 2))
self.activations = nn.ModuleList([])
for l in range(0, self.spectral_layers):
self.activations.append(ComplexReLU(mode=complex_activation, bias_shape=(hidden_size, 1, 1), scale=self.scale))
else:
raise ValueError('Unknown operator type')
self.drop = nn.Dropout(drop_rate) if drop_rate > 0. else nn.Identity()
def forward_mlp(self, x):
B, C, H, W = x.shape
xr = torch.view_as_real(x)
for l in range(self.spectral_layers):
if hasattr(self, 'b'):
xr = self.mul_add_handle(xr, self.w[l], self.b[l])
else:
xr = self.mul_handle(xr, self.w[l])
xr = torch.view_as_complex(xr)
xr = self.activations[l](xr)
xr = self.drop(xr)
xr = torch.view_as_real(xr)
# final MLP
x = self.mul_handle(xr, self.wout)
x = torch.view_as_complex(x)
return x
def forward(self, x):
dtype = x.dtype
x = x.to(torch.float32)
residual = x
# FWD transform
with amp.autocast(enabled=False):
x = self.forward_transform(x)
if self.scale_residual:
residual = self.inverse_transform(x)
# MLP
x = self.forward_mlp(x)
# BWD transform
with amp.autocast(enabled=False):
x = self.inverse_transform(x)
# cast back to initial precision
x = x.to(dtype)
return x, residual
\ No newline at end of file
......@@ -239,7 +239,7 @@ class ShallowWaterSolver(nn.Module):
ctype = torch.complex128 if self.lap.dtype == torch.float64 else torch.complex64
# mach number relative to wave speed
llimit = mlimit = 20
llimit = mlimit = 80
# hgrid = self.havg + hamp * torch.randn(self.nlat, self.nlon, device=device, dtype=dtype)
# ugrid = uamp * torch.randn(self.nlat, self.nlon, device=device, dtype=dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment