Unverified Commit 4dadf551 authored by Boris Bonev's avatar Boris Bonev Committed by GitHub
Browse files

architectural improvements to sfno (#18)

Major Cleanup of SFNO. Retiring non-linear architecture and fixing initialization. Adding scripts for training and validation.
parent 08108157
...@@ -46,13 +46,12 @@ import pandas as pd ...@@ -46,13 +46,12 @@ import pandas as pd
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from torch_harmonics.examples.sfno import PdeDataset from torch_harmonics.examples.sfno import PdeDataset
from torch_harmonics.examples.sfno import SphericalFourierNeuralOperatorNet as SFNO
# wandb logging # wandb logging
import wandb import wandb
wandb.login() wandb.login()
def l2loss_sphere(solver, prd, tar, relative=False, squared=False): def l2loss_sphere(solver, prd, tar, relative=False, squared=True):
loss = solver.integrate_grid((prd - tar)**2, dimensionless=True).sum(dim=-1) loss = solver.integrate_grid((prd - tar)**2, dimensionless=True).sum(dim=-1)
if relative: if relative:
loss = loss / solver.integrate_grid(tar**2, dimensionless=True).sum(dim=-1) loss = loss / solver.integrate_grid(tar**2, dimensionless=True).sum(dim=-1)
...@@ -63,7 +62,7 @@ def l2loss_sphere(solver, prd, tar, relative=False, squared=False): ...@@ -63,7 +62,7 @@ def l2loss_sphere(solver, prd, tar, relative=False, squared=False):
return loss return loss
def spectral_l2loss_sphere(solver, prd, tar, relative=False, squared=False): def spectral_l2loss_sphere(solver, prd, tar, relative=False, squared=True):
# compute coefficients # compute coefficients
coeffs = torch.view_as_real(solver.sht(prd - tar)) coeffs = torch.view_as_real(solver.sht(prd - tar))
coeffs = coeffs[..., 0]**2 + coeffs[..., 1]**2 coeffs = coeffs[..., 0]**2 + coeffs[..., 1]**2
...@@ -83,7 +82,7 @@ def spectral_l2loss_sphere(solver, prd, tar, relative=False, squared=False): ...@@ -83,7 +82,7 @@ def spectral_l2loss_sphere(solver, prd, tar, relative=False, squared=False):
return loss return loss
def spectral_loss_sphere(solver, prd, tar, relative=False, squared=False): def spectral_loss_sphere(solver, prd, tar, relative=False, squared=True):
# gradient weighting factors # gradient weighting factors
lmax = solver.sht.lmax lmax = solver.sht.lmax
ls = torch.arange(lmax).float() ls = torch.arange(lmax).float()
...@@ -110,7 +109,7 @@ def spectral_loss_sphere(solver, prd, tar, relative=False, squared=False): ...@@ -110,7 +109,7 @@ def spectral_loss_sphere(solver, prd, tar, relative=False, squared=False):
return loss return loss
def h1loss_sphere(solver, prd, tar, relative=False, squared=False): def h1loss_sphere(solver, prd, tar, relative=False, squared=True):
# gradient weighting factors # gradient weighting factors
lmax = solver.sht.lmax lmax = solver.sht.lmax
ls = torch.arange(lmax).float() ls = torch.arange(lmax).float()
...@@ -139,7 +138,6 @@ def h1loss_sphere(solver, prd, tar, relative=False, squared=False): ...@@ -139,7 +138,6 @@ def h1loss_sphere(solver, prd, tar, relative=False, squared=False):
return loss return loss
def fluct_l2loss_sphere(solver, prd, tar, inp, relative=False, polar_opt=0): def fluct_l2loss_sphere(solver, prd, tar, inp, relative=False, polar_opt=0):
# compute the weighting factor first # compute the weighting factor first
fluct = solver.integrate_grid((tar - inp)**2, dimensionless=True, polar_opt=polar_opt) fluct = solver.integrate_grid((tar - inp)**2, dimensionless=True, polar_opt=polar_opt)
...@@ -152,36 +150,107 @@ def fluct_l2loss_sphere(solver, prd, tar, inp, relative=False, polar_opt=0): ...@@ -152,36 +150,107 @@ def fluct_l2loss_sphere(solver, prd, tar, inp, relative=False, polar_opt=0):
loss = torch.mean(loss) loss = torch.mean(loss)
return loss return loss
# rolls out the FNO and compares to the classical solver
def autoregressive_inference(model,
dataset,
path_root,
nsteps,
autoreg_steps=10,
nskip=1,
plot_channel=0,
nics=20):
def main(train=True, load_checkpoint=False, enable_amp=False): model.eval()
# set seed losses = np.zeros(nics)
torch.manual_seed(333) fno_times = np.zeros(nics)
torch.cuda.manual_seed(333) nwp_times = np.zeros(nics)
# set device for iic in range(nics):
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu') ic = dataset.solver.random_initial_condition(mach=0.2)
if torch.cuda.is_available(): inp_mean = dataset.inp_mean
torch.cuda.set_device(device.index) inp_var = dataset.inp_var
# 1 hour prediction steps prd = (dataset.solver.spec2grid(ic) - inp_mean) / torch.sqrt(inp_var)
dt = 1*3600 prd = prd.unsqueeze(0)
dt_solver = 150 uspec = ic.clone()
nsteps = dt//dt_solver
dataset = PdeDataset(dt=dt, nsteps=nsteps, dims=(256, 512), device=device, normalize=True)
# There is still an issue with parallel dataloading. Do NOT use it at the moment
# dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, persistent_workers=True)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0, persistent_workers=False)
solver = dataset.solver.to(device)
nlat = dataset.nlat # ML model
nlon = dataset.nlon start_time = time.time()
for i in range(1, autoreg_steps+1):
# evaluate the ML model
prd = model(prd)
if iic == nics-1 and nskip > 0 and i % nskip == 0:
# do plotting
fig = plt.figure(figsize=(7.5, 6))
dataset.solver.plot_griddata(prd[0, plot_channel], fig, vmax=4, vmin=-4)
plt.savefig(path_root+'_pred_'+str(i//nskip)+'.png')
plt.clf()
# training function fno_times[iic] = time.time() - start_time
def train_model(model, dataloader, optimizer, gscaler, scheduler=None, nepochs=20, nfuture=0, num_examples=256, num_valid=8, loss_fn='l2'):
# classical model
start_time = time.time()
for i in range(1, autoreg_steps+1):
# advance classical model
uspec = dataset.solver.timestep(uspec, nsteps)
if iic == nics-1 and i % nskip == 0 and nskip > 0:
ref = (dataset.solver.spec2grid(uspec) - inp_mean) / torch.sqrt(inp_var)
fig = plt.figure(figsize=(7.5, 6))
dataset.solver.plot_griddata(ref[plot_channel], fig, vmax=4, vmin=-4)
plt.savefig(path_root+'_truth_'+str(i//nskip)+'.png')
plt.clf()
nwp_times[iic] = time.time() - start_time
# ref = (dataset.solver.spec2grid(uspec) - inp_mean) / torch.sqrt(inp_var)
ref = dataset.solver.spec2grid(uspec)
prd = prd * torch.sqrt(inp_var) + inp_mean
losses[iic] = l2loss_sphere(dataset.solver, prd, ref, relative=True).item()
return losses, fno_times, nwp_times
# convenience function for logging weights and gradients
def log_weights_and_grads(model, iters=1):
"""
Helper routine intended for debugging purposes
"""
root_path = os.path.join(os.path.dirname(__file__), "weights_and_grads")
weights_and_grads_fname = os.path.join(root_path, f"weights_and_grads_step{iters:03d}.tar")
print(weights_and_grads_fname)
weights_dict = {k:v for k,v in model.named_parameters()}
grad_dict = {k:v.grad for k,v in model.named_parameters()}
store_dict = {'iteration': iters, 'grads': grad_dict, 'weights': weights_dict}
torch.save(store_dict, weights_and_grads_fname)
# training function
def train_model(model,
dataloader,
optimizer,
gscaler,
scheduler=None,
nepochs=20,
nfuture=0,
num_examples=256,
num_valid=8,
loss_fn='l2',
enable_amp=False,
log_grads=0):
train_start = time.time() train_start = time.time()
# count iterations
iters = 0
for epoch in range(nepochs): for epoch in range(nepochs):
# time each epoch # time each epoch
...@@ -190,6 +259,9 @@ def main(train=True, load_checkpoint=False, enable_amp=False): ...@@ -190,6 +259,9 @@ def main(train=True, load_checkpoint=False, enable_amp=False):
dataloader.dataset.set_initial_condition('random') dataloader.dataset.set_initial_condition('random')
dataloader.dataset.set_num_examples(num_examples) dataloader.dataset.set_num_examples(num_examples)
# get the solver for its convenience functions
solver = dataloader.dataset.solver
# do the training # do the training
acc_loss = 0 acc_loss = 0
model.train() model.train()
...@@ -204,6 +276,8 @@ def main(train=True, load_checkpoint=False, enable_amp=False): ...@@ -204,6 +276,8 @@ def main(train=True, load_checkpoint=False, enable_amp=False):
if loss_fn == 'l2': if loss_fn == 'l2':
loss = l2loss_sphere(solver, prd, tar, relative=False) loss = l2loss_sphere(solver, prd, tar, relative=False)
elif loss_fn == 'spectral l2':
loss = spectral_l2loss_sphere(solver, prd, tar, relative=False)
elif loss_fn == 'h1': elif loss_fn == 'h1':
loss = h1loss_sphere(solver, prd, tar, relative=False) loss = h1loss_sphere(solver, prd, tar, relative=False)
elif loss_fn == 'spectral': elif loss_fn == 'spectral':
...@@ -216,11 +290,16 @@ def main(train=True, load_checkpoint=False, enable_amp=False): ...@@ -216,11 +290,16 @@ def main(train=True, load_checkpoint=False, enable_amp=False):
acc_loss += loss.item() * inp.size(0) acc_loss += loss.item() * inp.size(0)
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad(set_to_none=True)
# gscaler.scale(loss).backward()
gscaler.scale(loss).backward() gscaler.scale(loss).backward()
if log_grads and iters % log_grads == 0:
log_weights_and_grads(model, iters=iters)
gscaler.step(optimizer) gscaler.step(optimizer)
gscaler.update() gscaler.update()
iters += 1
acc_loss = acc_loss / len(dataloader.dataset) acc_loss = acc_loss / len(dataloader.dataset)
dataloader.dataset.set_initial_condition('random') dataloader.dataset.set_initial_condition('random')
...@@ -262,64 +341,28 @@ def main(train=True, load_checkpoint=False, enable_amp=False): ...@@ -262,64 +341,28 @@ def main(train=True, load_checkpoint=False, enable_amp=False):
print(f'done. Training took {train_time}.') print(f'done. Training took {train_time}.')
return valid_loss return valid_loss
# rolls out the FNO and compares to the classical solver def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
def autoregressive_inference(model, dataset, path_root, nsteps, autoreg_steps=10, nskip=1, plot_channel=0, nics=20):
model.eval()
losses = np.zeros(nics)
fno_times = np.zeros(nics)
nwp_times = np.zeros(nics)
for iic in range(nics):
ic = dataset.solver.random_initial_condition(mach=0.2)
inp_mean = dataset.inp_mean
inp_var = dataset.inp_var
prd = (dataset.solver.spec2grid(ic) - inp_mean) / torch.sqrt(inp_var)
prd = prd.unsqueeze(0)
uspec = ic.clone()
# ML model
start_time = time.time()
for i in range(1, autoreg_steps+1):
# evaluate the ML model
prd = model(prd)
if iic == nics-1 and nskip > 0 and i % nskip == 0:
# do plotting
fig = plt.figure(figsize=(7.5, 6))
dataset.solver.plot_griddata(prd[0, plot_channel], fig, vmax=4, vmin=-4)
plt.savefig(path_root+'_pred_'+str(i//nskip)+'.png')
plt.clf()
fno_times[iic] = time.time() - start_time
# classical model
start_time = time.time()
for i in range(1, autoreg_steps+1):
# advance classical model
uspec = dataset.solver.timestep(uspec, nsteps)
if iic == nics-1 and i % nskip == 0 and nskip > 0: # set seed
ref = (dataset.solver.spec2grid(uspec) - inp_mean) / torch.sqrt(inp_var) torch.manual_seed(333)
torch.cuda.manual_seed(333)
fig = plt.figure(figsize=(7.5, 6))
dataset.solver.plot_griddata(ref[plot_channel], fig, vmax=4, vmin=-4)
plt.savefig(path_root+'_truth_'+str(i//nskip)+'.png')
plt.clf()
nwp_times[iic] = time.time() - start_time
# ref = (dataset.solver.spec2grid(uspec) - inp_mean) / torch.sqrt(inp_var) # set device
ref = dataset.solver.spec2grid(uspec) device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
prd = prd * torch.sqrt(inp_var) + inp_mean if torch.cuda.is_available():
losses[iic] = l2loss_sphere(solver, prd, ref, relative=True).item() torch.cuda.set_device(device.index)
# 1 hour prediction steps
dt = 1*3600
dt_solver = 150
nsteps = dt//dt_solver
dataset = PdeDataset(dt=dt, nsteps=nsteps, dims=(256, 512), device=device, normalize=True)
# There is still an issue with parallel dataloading. Do NOT use it at the moment
# dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, persistent_workers=True)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0, persistent_workers=False)
return losses, fno_times, nwp_times nlat = dataset.nlat
nlon = dataset.nlon
def count_parameters(model): def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad) return sum(p.numel() for p in model.parameters() if p.requires_grad)
...@@ -328,20 +371,28 @@ def main(train=True, load_checkpoint=False, enable_amp=False): ...@@ -328,20 +371,28 @@ def main(train=True, load_checkpoint=False, enable_amp=False):
models = {} models = {}
metrics = {} metrics = {}
from torch_harmonics.examples.sfno import SphericalFourierNeuralOperatorNet as SFNO
models["sfno_sc3_layer4_e16_linskip_nomlp"] = partial(SFNO, spectral_transform='sht', img_size=(nlat, nlon), grid="equiangular",
num_layers=4, scale_factor=3, embed_dim=16, operator_type='driscoll-healy',
big_skip=False, pos_embed=False, use_mlp=False, normalization_layer="none")
# models["sfno_sc3_layer4_e256_noskip_mlp"] = partial(SFNO, spectral_transform='sht', img_size=(nlat, nlon), grid="equiangular",
# num_layers=4, scale_factor=3, embed_dim=256, operator_type='driscoll-healy',
# big_skip=False, pos_embed=False, use_mlp=True, normalization_layer="none")
# from torch_harmonics.examples.sfno.models.unet import UNet
# models['unet_baseline'] = partial(UNet)
# # U-Net if installed # # U-Net if installed
# from models.unet import UNet # from models.unet import UNet
# models['unet_baseline'] = partial(UNet) # models['unet_baseline'] = partial(UNet)
# SFNO models # SFNO models
models['sfno_sc3_layer4_edim256_linear'] = partial(SFNO, spectral_transform='sht', filter_type='linear', img_size=(nlat, nlon), # models['sfno_sc3_layer4_edim256_linear'] = partial(SFNO, spectral_transform='sht', img_size=(nlat, nlon), grid="equiangular",
num_layers=4, scale_factor=3, embed_dim=256, operator_type='driscoll-healy') # num_layers=4, scale_factor=3, embed_dim=256, operator_type='driscoll-healy')
models['sfno_sc3_layer4_edim256_real'] = partial(SFNO, spectral_transform='sht', filter_type='non-linear', img_size=(nlat, nlon), # # FNO models
num_layers=4, scale_factor=3, embed_dim=256, complex_activation = 'real', operator_type='diagonal') # models['fno_sc3_layer4_edim256_linear'] = partial(SFNO, spectral_transform='fft', img_size=(nlat, nlon), grid="equiangular",
# FNO models # num_layers=4, scale_factor=3, embed_dim=256, operator_type='diagonal')
models['fno_sc3_layer4_edim256_linear'] = partial(SFNO, spectral_transform='fft', filter_type='linear', img_size=(nlat, nlon),
num_layers=4, scale_factor=3, embed_dim=256, operator_type='diagonal')
models['fno_sc3_layer4_edim256_real'] = partial(SFNO, spectral_transform='fft', filter_type='non-linear', img_size=(nlat, nlon),
num_layers=4, scale_factor=3, embed_dim=256, complex_activation='real')
# iterate over models and train each model # iterate over models and train each model
root_path = os.path.dirname(__file__) root_path = os.path.dirname(__file__)
...@@ -349,6 +400,8 @@ def main(train=True, load_checkpoint=False, enable_amp=False): ...@@ -349,6 +400,8 @@ def main(train=True, load_checkpoint=False, enable_amp=False):
model = model_handle().to(device) model = model_handle().to(device)
print(model)
metrics[model_name] = {} metrics[model_name] = {}
num_params = count_parameters(model) num_params = count_parameters(model)
...@@ -360,26 +413,26 @@ def main(train=True, load_checkpoint=False, enable_amp=False): ...@@ -360,26 +413,26 @@ def main(train=True, load_checkpoint=False, enable_amp=False):
# run the training # run the training
if train: if train:
run = wandb.init(project="sfno spherical swe", group=model_name, name=model_name + '_' + str(time.time()), config=model_handle.keywords) run = wandb.init(project="sfno ablations spherical swe", group=model_name, name=model_name + '_' + str(time.time()), config=model_handle.keywords)
# optimizer: # optimizer:
optimizer = torch.optim.Adam(model.parameters(), lr=1E-3) optimizer = torch.optim.Adam(model.parameters(), lr=3E-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
gscaler = amp.GradScaler(enabled=enable_amp) gscaler = amp.GradScaler(enabled=enable_amp)
start_time = time.time() start_time = time.time()
print(f'Training {model_name}, single step') print(f'Training {model_name}, single step')
train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=200, loss_fn='l2') train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=10, loss_fn='l2', enable_amp=enable_amp, log_grads=log_grads)
# multistep training # # multistep training
print(f'Training {model_name}, two step') # print(f'Training {model_name}, two step')
optimizer = torch.optim.Adam(model.parameters(), lr=5E-5) # optimizer = torch.optim.Adam(model.parameters(), lr=5E-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
gscaler = amp.GradScaler(enabled=enable_amp) # gscaler = amp.GradScaler(enabled=enable_amp)
dataloader.dataset.nsteps = 2 * dt//dt_solver # dataloader.dataset.nsteps = 2 * dt//dt_solver
train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=20, nfuture=1) # train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=20, nfuture=1, enable_amp=enable_amp)
dataloader.dataset.nsteps = 1 * dt//dt_solver # dataloader.dataset.nsteps = 1 * dt//dt_solver
training_time = time.time() - start_time training_time = time.time() - start_time
...@@ -392,7 +445,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False): ...@@ -392,7 +445,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False):
torch.cuda.manual_seed(333) torch.cuda.manual_seed(333)
with torch.inference_mode(): with torch.inference_mode():
losses, fno_times, nwp_times = autoregressive_inference(model, dataset, os.path.join(root_path,'paper_figures/'+model_name), nsteps=nsteps, autoreg_steps=10) losses, fno_times, nwp_times = autoregressive_inference(model, dataset, os.path.join(root_path,'figures/'+model_name), nsteps=nsteps, autoreg_steps=10)
metrics[model_name]['loss_mean'] = np.mean(losses) metrics[model_name]['loss_mean'] = np.mean(losses)
metrics[model_name]['loss_std'] = np.std(losses) metrics[model_name]['loss_std'] = np.std(losses)
metrics[model_name]['fno_time_mean'] = np.mean(fno_times) metrics[model_name]['fno_time_mean'] = np.mean(fno_times)
...@@ -409,4 +462,4 @@ if __name__ == "__main__": ...@@ -409,4 +462,4 @@ if __name__ == "__main__":
import torch.multiprocessing as mp import torch.multiprocessing as mp
mp.set_start_method('forkserver', force=True) mp.set_start_method('forkserver', force=True)
main(train=True, load_checkpoint=False, enable_amp=False) main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0)
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -108,16 +108,16 @@ ...@@ -108,16 +108,16 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"/home/bbonev/.zshenv:export:2: not valid in this context: :/usr/local/cuda-11.7/lib64\n", "/home/bbonev/.zshenv:export:2: not valid in this context: :/usr/local/cuda-11.7/lib64\n",
"--2023-10-24 18:08:10-- https://astropedia.astrogeology.usgs.gov/download/Mars/GlobalSurveyor/MOLA/thumbs/Mars_MGS_MOLA_DEM_mosaic_global_1024.jpg\n", "--2023-10-30 18:00:14-- https://astropedia.astrogeology.usgs.gov/download/Mars/GlobalSurveyor/MOLA/thumbs/Mars_MGS_MOLA_DEM_mosaic_global_1024.jpg\n",
"Resolving astropedia.astrogeology.usgs.gov (astropedia.astrogeology.usgs.gov)... 137.227.239.81, 2001:49c8:c000:122d::81\n", "Resolving astropedia.astrogeology.usgs.gov (astropedia.astrogeology.usgs.gov)... 137.227.239.81, 2001:49c8:c000:122d::81\n",
"Connecting to astropedia.astrogeology.usgs.gov (astropedia.astrogeology.usgs.gov)|137.227.239.81|:443... connected.\n", "Connecting to astropedia.astrogeology.usgs.gov (astropedia.astrogeology.usgs.gov)|137.227.239.81|:443... connected.\n",
"HTTP request sent, awaiting response... 200 \n", "HTTP request sent, awaiting response... 200 \n",
"Length: 55192 (54K) [image/jpeg]\n", "Length: 55192 (54K) [image/jpeg]\n",
"Saving to: ‘./data/mola_topo.jpg’\n", "Saving to: ‘./data/mola_topo.jpg’\n",
"\n", "\n",
"./data/mola_topo.jp 100%[===================>] 53.90K 161KB/s in 0.3s \n", "./data/mola_topo.jp 100%[===================>] 53.90K 154KB/s in 0.3s \n",
"\n", "\n",
"2023-10-24 18:08:12 (161 KB/s) - ‘./data/mola_topo.jpg’ saved [55192/55192]\n", "2023-10-30 18:00:15 (154 KB/s) - ‘./data/mola_topo.jpg’ saved [55192/55192]\n",
"\n" "\n"
] ]
} }
...@@ -142,7 +142,7 @@ ...@@ -142,7 +142,7 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"<cartopy.mpl.geocollection.GeoQuadMesh at 0x7f991436a230>" "<cartopy.mpl.geocollection.GeoQuadMesh at 0x7f49e4952380>"
] ]
}, },
"execution_count": 4, "execution_count": 4,
...@@ -178,46 +178,46 @@ ...@@ -178,46 +178,46 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"iter: 0, loss: 504.56821962467404\n", "iter: 0, loss: 453.0968931302793\n",
"iter: 1, loss: 0.00802396426749307\n", "iter: 1, loss: 0.008023964326606358\n",
"iter: 2, loss: 0.008023963812431065\n", "iter: 2, loss: 0.008023963388341868\n",
"iter: 3, loss: 0.008023963784318747\n", "iter: 3, loss: 0.008023963340660247\n",
"iter: 4, loss: 0.008023962882019332\n", "iter: 4, loss: 0.008023963596959654\n",
"iter: 5, loss: 0.008023963275982648\n", "iter: 5, loss: 0.008023963735337598\n",
"iter: 6, loss: 0.008023962667711045\n", "iter: 6, loss: 0.008023964260612844\n",
"iter: 7, loss: 0.008023963782547126\n", "iter: 7, loss: 0.008023964042363394\n",
"iter: 8, loss: 0.008023963340130377\n", "iter: 8, loss: 0.00802396368406042\n",
"iter: 9, loss: 0.008023963717686556\n", "iter: 9, loss: 0.008023962714947052\n",
"iter: 10, loss: 0.008023963189075497\n", "iter: 10, loss: 0.008023963489819921\n",
"iter: 11, loss: 0.008023963662749444\n", "iter: 11, loss: 0.008023963701078593\n",
"iter: 12, loss: 0.008023964217954163\n", "iter: 12, loss: 0.008023962923266034\n",
"iter: 13, loss: 0.008023963645109735\n", "iter: 13, loss: 0.008023964198518512\n",
"iter: 14, loss: 0.008023963884895183\n", "iter: 14, loss: 0.008023962813486126\n",
"iter: 15, loss: 0.008023963417559243\n", "iter: 15, loss: 0.008023964110803488\n",
"iter: 16, loss: 0.008023963709925376\n", "iter: 16, loss: 0.00802396403813473\n",
"iter: 17, loss: 0.008023963864442468\n", "iter: 17, loss: 0.008023963786036484\n",
"iter: 18, loss: 0.008023963186281617\n", "iter: 18, loss: 0.008023964195574898\n",
"iter: 19, loss: 0.008023962844331859\n", "iter: 19, loss: 0.008023963516124565\n",
"iter: 20, loss: 0.008023963578808139\n", "iter: 20, loss: 0.008023964508201684\n",
"iter: 21, loss: 0.00802396382884392\n", "iter: 21, loss: 0.008023963767474551\n",
"iter: 22, loss: 0.008023963250166802\n", "iter: 22, loss: 0.008023963648388185\n",
"iter: 23, loss: 0.008023963424637747\n", "iter: 23, loss: 0.008023963972575866\n",
"iter: 24, loss: 0.008023964456974\n", "iter: 24, loss: 0.008023964038780116\n",
"iter: 25, loss: 0.00802396354425496\n", "iter: 25, loss: 0.008023963707541834\n",
"iter: 26, loss: 0.008023964264189777\n", "iter: 26, loss: 0.008023963269911932\n",
"iter: 27, loss: 0.008023963659278077\n", "iter: 27, loss: 0.008023963391352053\n",
"iter: 28, loss: 0.008023963463597659\n", "iter: 28, loss: 0.008023963414851426\n",
"iter: 29, loss: 0.008023963289571119\n", "iter: 29, loss: 0.008023964147064296\n",
"iter: 30, loss: 0.008023964016864156\n", "iter: 30, loss: 0.008023963760174639\n",
"iter: 31, loss: 0.008023963531573766\n", "iter: 31, loss: 0.008023963924162339\n",
"iter: 32, loss: 0.008023963437000084\n", "iter: 32, loss: 0.00802396360354566\n",
"iter: 33, loss: 0.008023964116843215\n", "iter: 33, loss: 0.00802396407422616\n",
"iter: 34, loss: 0.008023962721410783\n", "iter: 34, loss: 0.008023962918493041\n",
"iter: 35, loss: 0.008023963977951472\n", "iter: 35, loss: 0.008023963622013491\n",
"iter: 36, loss: 0.008023963204566793\n", "iter: 36, loss: 0.0080239635670241\n",
"iter: 37, loss: 0.00802396369010344\n", "iter: 37, loss: 0.008023963871070301\n",
"iter: 38, loss: 0.008023963907011133\n", "iter: 38, loss: 0.008023963587685968\n",
"iter: 39, loss: 0.008023963523688133\n" "iter: 39, loss: 0.008023963496770136\n"
] ]
} }
], ],
...@@ -271,7 +271,7 @@ ...@@ -271,7 +271,7 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"<cartopy.mpl.geocollection.GeoQuadMesh at 0x7f99039db190>" "<cartopy.mpl.geocollection.GeoQuadMesh at 0x7f49d214b9a0>"
] ]
}, },
"execution_count": 6, "execution_count": 6,
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import cartopy
import cartopy.crs as ccrs import cartopy.crs as ccrs
def plot_sphere(data, def plot_sphere(data,
...@@ -38,10 +39,12 @@ def plot_sphere(data, ...@@ -38,10 +39,12 @@ def plot_sphere(data,
cmap="RdBu", cmap="RdBu",
title=None, title=None,
colorbar=False, colorbar=False,
coastlines=False,
central_latitude=20, central_latitude=20,
central_longitude=20, central_longitude=20,
lon=None, lon=None,
lat=None): lat=None,
**kwargs):
if fig == None: if fig == None:
fig = plt.figure() fig = plt.figure()
...@@ -61,8 +64,9 @@ def plot_sphere(data, ...@@ -61,8 +64,9 @@ def plot_sphere(data,
Lat = Lat*180/np.pi Lat = Lat*180/np.pi
# contour data over the map. # contour data over the map.
im = ax.pcolormesh(Lon, Lat, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=False) im = ax.pcolormesh(Lon, Lat, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=False, **kwargs)
# ax.add_feature(cartopy.feature.COASTLINE, edgecolor='white', facecolor='none', linewidth=1.5) if coastlines:
ax.add_feature(cartopy.feature.COASTLINE, edgecolor='white', facecolor='none', linewidth=1.5)
if colorbar: if colorbar:
plt.colorbar(im) plt.colorbar(im)
plt.title(title, y=1.05) plt.title(title, y=1.05)
...@@ -76,7 +80,8 @@ def plot_data(data, ...@@ -76,7 +80,8 @@ def plot_data(data,
title=None, title=None,
colorbar=False, colorbar=False,
lon=None, lon=None,
lat=None): lat=None,
**kwargs):
if fig == None: if fig == None:
fig = plt.figure() fig = plt.figure()
...@@ -90,7 +95,8 @@ def plot_data(data, ...@@ -90,7 +95,8 @@ def plot_data(data,
fig = plt.figure(figsize=(10, 5)) fig = plt.figure(figsize=(10, 5))
ax = fig.add_subplot(1, 1, 1, projection=projection) ax = fig.add_subplot(1, 1, 1, projection=projection)
im = ax.pcolormesh(Lon, Lat, data, cmap=cmap) im = ax.pcolormesh(Lon, Lat, data, cmap=cmap, **kwargs)
if colorbar: if colorbar:
plt.colorbar(im) plt.colorbar(im)
plt.title(title, y=1.05) plt.title(title, y=1.05)
......
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -43,8 +43,8 @@ from .activations import * ...@@ -43,8 +43,8 @@ from .activations import *
# # import FactorizedTensor from tensorly for tensorized operations # # import FactorizedTensor from tensorly for tensorized operations
# import tensorly as tl # import tensorly as tl
# from tensorly.plugins import use_opt_einsum # from tensorly.plugins import use_opt_einsum
# tl.set_backend('pytorch') # tl.set_backend("pytorch")
# use_opt_einsum('optimal') # use_opt_einsum("optimal")
from tltorch.factorized_tensors.core import FactorizedTensor from tltorch.factorized_tensors.core import FactorizedTensor
def _no_grad_trunc_normal_(tensor, mean, std, a, b): def _no_grad_trunc_normal_(tensor, mean, std, a, b):
...@@ -137,21 +137,37 @@ class MLP(nn.Module): ...@@ -137,21 +137,37 @@ class MLP(nn.Module):
in_features, in_features,
hidden_features = None, hidden_features = None,
out_features = None, out_features = None,
act_layer = nn.GELU, act_layer = nn.ReLU,
output_bias = True, output_bias = False,
drop_rate = 0., drop_rate = 0.,
checkpointing = False): checkpointing = False,
gain = 1.0):
super(MLP, self).__init__() super(MLP, self).__init__()
self.checkpointing = checkpointing self.checkpointing = checkpointing
out_features = out_features or in_features out_features = out_features or in_features
hidden_features = hidden_features or in_features hidden_features = hidden_features or in_features
# Fist dense layer
fc1 = nn.Conv2d(in_features, hidden_features, 1, bias=True) fc1 = nn.Conv2d(in_features, hidden_features, 1, bias=True)
# ln1 = norm_layer(num_features=hidden_features) # initialize the weights correctly
scale = math.sqrt(2.0 / in_features)
nn.init.normal_(fc1.weight, mean=0., std=scale)
if fc1.bias is not None:
nn.init.constant_(fc1.bias, 0.0)
# activation
act = act_layer() act = act_layer()
fc2 = nn.Conv2d(hidden_features, out_features, 1, bias = output_bias)
# output layer
fc2 = nn.Conv2d(hidden_features, out_features, 1, bias=output_bias)
# gain factor for the output determines the scaling of the output init
scale = math.sqrt(gain / hidden_features)
nn.init.normal_(fc2.weight, mean=0., std=scale)
if fc2.bias is not None:
nn.init.constant_(fc2.bias, 0.0)
if drop_rate > 0.: if drop_rate > 0.:
drop = nn.Dropout(drop_rate) drop = nn.Dropout2d(drop_rate)
self.fwd = nn.Sequential(fc1, act, drop, fc2, drop) self.fwd = nn.Sequential(fc1, act, drop, fc2, drop)
else: else:
self.fwd = nn.Sequential(fc1, act, fc2) self.fwd = nn.Sequential(fc1, act, fc2)
...@@ -218,15 +234,12 @@ class SpectralConvS2(nn.Module): ...@@ -218,15 +234,12 @@ class SpectralConvS2(nn.Module):
inverse_transform, inverse_transform,
in_channels, in_channels,
out_channels, out_channels,
scale = 'auto', gain = 2.,
operator_type = 'driscoll-healy', operator_type = "driscoll-healy",
lr_scale_exponent = 0, lr_scale_exponent = 0,
bias = False): bias = False):
super(SpectralConvS2, self).__init__() super(SpectralConvS2, self).__init__()
if scale == 'auto':
scale = (2 / in_channels)**0.5
self.forward_transform = forward_transform self.forward_transform = forward_transform
self.inverse_transform = inverse_transform self.inverse_transform = inverse_transform
...@@ -242,33 +255,31 @@ class SpectralConvS2(nn.Module): ...@@ -242,33 +255,31 @@ class SpectralConvS2(nn.Module):
assert self.inverse_transform.lmax == self.modes_lat assert self.inverse_transform.lmax == self.modes_lat
assert self.inverse_transform.mmax == self.modes_lon assert self.inverse_transform.mmax == self.modes_lon
weight_shape = [in_channels, out_channels] weight_shape = [out_channels, in_channels]
if self.operator_type == 'diagonal': if self.operator_type == "diagonal":
weight_shape += [self.modes_lat, self.modes_lon] weight_shape += [self.modes_lat, self.modes_lon]
from .contractions import contract_diagonal as _contract from .contractions import contract_diagonal as _contract
elif self.operator_type == 'block-diagonal': elif self.operator_type == "block-diagonal":
weight_shape += [self.modes_lat, self.modes_lon, self.modes_lon] weight_shape += [self.modes_lat, self.modes_lon, self.modes_lon]
from .contractions import contract_blockdiag as _contract from .contractions import contract_blockdiag as _contract
elif self.operator_type == 'driscoll-healy': elif self.operator_type == "driscoll-healy":
weight_shape += [self.modes_lat] weight_shape += [self.modes_lat]
from .contractions import contract_dhconv as _contract from .contractions import contract_dhconv as _contract
else: else:
raise NotImplementedError(f"Unkonw operator type f{self.operator_type}") raise NotImplementedError(f"Unkonw operator type f{self.operator_type}")
# form weight tensors # form weight tensors
self.weight = nn.Parameter(scale * torch.randn(*weight_shape, 2)) scale = math.sqrt(gain / in_channels) * torch.ones(self.modes_lat, 2)
scale[0] *= math.sqrt(2)
# rescale the learning rate for better training of spectral parameters self.weight = nn.Parameter(scale * torch.view_as_real(torch.randn(*weight_shape, dtype=torch.complex64)))
lr_scale = (torch.arange(self.modes_lat)+1).reshape(-1, 1)**(lr_scale_exponent) # self.weight = nn.Parameter(scale * torch.randn(*weight_shape, 2))
self.register_buffer("lr_scale", lr_scale)
# self.weight.register_hook(lambda grad: self.lr_scale*grad)
# get the right contraction function # get the right contraction function
self._contract = _contract self._contract = _contract
if bias: if bias:
self.bias = nn.Parameter(scale * torch.randn(1, out_channels, 1, 1)) self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
def forward(self, x): def forward(self, x):
...@@ -290,7 +301,7 @@ class SpectralConvS2(nn.Module): ...@@ -290,7 +301,7 @@ class SpectralConvS2(nn.Module):
with amp.autocast(enabled=False): with amp.autocast(enabled=False):
x = self.inverse_transform(x) x = self.inverse_transform(x)
if hasattr(self, 'bias'): if hasattr(self, "bias"):
x = x + self.bias x = x + self.bias
x = x.type(dtype) x = x.type(dtype)
...@@ -306,19 +317,16 @@ class FactorizedSpectralConvS2(nn.Module): ...@@ -306,19 +317,16 @@ class FactorizedSpectralConvS2(nn.Module):
inverse_transform, inverse_transform,
in_channels, in_channels,
out_channels, out_channels,
scale = 'auto', gain = 2.,
operator_type = 'driscoll-healy', operator_type = "driscoll-healy",
rank = 0.2, rank = 0.2,
factorization = None, factorization = None,
separable = False, separable = False,
implementation = 'factorized', implementation = "factorized",
decomposition_kwargs=dict(), decomposition_kwargs=dict(),
bias = False): bias = False):
super(SpectralConvS2, self).__init__() super(SpectralConvS2, self).__init__()
if scale == 'auto':
scale = (2 / in_channels)**0.5
self.forward_transform = forward_transform self.forward_transform = forward_transform
self.inverse_transform = inverse_transform self.inverse_transform = inverse_transform
...@@ -330,9 +338,9 @@ class FactorizedSpectralConvS2(nn.Module): ...@@ -330,9 +338,9 @@ class FactorizedSpectralConvS2(nn.Module):
# Make sure we are using a Complex Factorized Tensor # Make sure we are using a Complex Factorized Tensor
if factorization is None: if factorization is None:
factorization = 'Dense' # No factorization factorization = "Dense" # No factorization
if not factorization.lower().startswith('complex'): if not factorization.lower().startswith("complex"):
factorization = f'Complex{factorization}' factorization = f"Complex{factorization}"
# remember factorization details # remember factorization details
self.operator_type = operator_type self.operator_type = operator_type
...@@ -343,16 +351,16 @@ class FactorizedSpectralConvS2(nn.Module): ...@@ -343,16 +351,16 @@ class FactorizedSpectralConvS2(nn.Module):
assert self.inverse_transform.lmax == self.modes_lat assert self.inverse_transform.lmax == self.modes_lat
assert self.inverse_transform.mmax == self.modes_lon assert self.inverse_transform.mmax == self.modes_lon
weight_shape = [in_channels] weight_shape = [out_channels]
if not self.separable: if not self.separable:
weight_shape += [out_channels] weight_shape += [in_channels]
if self.operator_type == 'diagonal': if self.operator_type == "diagonal":
weight_shape += [self.modes_lat, self.modes_lon] weight_shape += [self.modes_lat, self.modes_lon]
elif self.operator_type == 'block-diagonal': elif self.operator_type == "block-diagonal":
weight_shape += [self.modes_lat, self.modes_lon, self.modes_lon] weight_shape += [self.modes_lat, self.modes_lon, self.modes_lon]
elif self.operator_type == 'driscoll-healy': elif self.operator_type == "driscoll-healy":
weight_shape += [self.modes_lat] weight_shape += [self.modes_lat]
else: else:
raise NotImplementedError(f"Unkonw operator type f{self.operator_type}") raise NotImplementedError(f"Unkonw operator type f{self.operator_type}")
...@@ -362,6 +370,7 @@ class FactorizedSpectralConvS2(nn.Module): ...@@ -362,6 +370,7 @@ class FactorizedSpectralConvS2(nn.Module):
fixed_rank_modes=False, **decomposition_kwargs) fixed_rank_modes=False, **decomposition_kwargs)
# initialization of weights # initialization of weights
scale = math.sqrt(gain / in_channels)
self.weight.normal_(0, scale) self.weight.normal_(0, scale)
# get the right contraction function # get the right contraction function
...@@ -369,7 +378,7 @@ class FactorizedSpectralConvS2(nn.Module): ...@@ -369,7 +378,7 @@ class FactorizedSpectralConvS2(nn.Module):
self._contract = get_contract_fun(self.weight, implementation=implementation, separable=separable) self._contract = get_contract_fun(self.weight, implementation=implementation, separable=separable)
if bias: if bias:
self.bias = nn.Parameter(scale * torch.randn(1, out_channels, 1, 1)) self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
def forward(self, x): def forward(self, x):
...@@ -388,242 +397,8 @@ class FactorizedSpectralConvS2(nn.Module): ...@@ -388,242 +397,8 @@ class FactorizedSpectralConvS2(nn.Module):
with amp.autocast(enabled=False): with amp.autocast(enabled=False):
x = self.inverse_transform(x) x = self.inverse_transform(x)
if hasattr(self, 'bias'): if hasattr(self, "bias"):
x = x + self.bias x = x + self.bias
x = x.type(dtype) x = x.type(dtype)
return x, residual return x, residual
class SpectralAttention2d(nn.Module):
"""
geometrical Spectral Attention layer
"""
def __init__(self,
forward_transform,
inverse_transform,
embed_dim,
sparsity_threshold = 0.0,
hidden_size_factor = 2,
use_complex_kernels = False,
complex_activation = 'real',
bias = False,
spectral_layers = 1,
drop_rate = 0.):
super(SpectralAttention2d, self).__init__()
self.embed_dim = embed_dim
self.sparsity_threshold = sparsity_threshold
self.hidden_size = int(hidden_size_factor * self.embed_dim)
self.scale = 1 / embed_dim**2
self.mul_add_handle = compl_muladd2d_fwd_c if use_complex_kernels else compl_muladd2d_fwd
self.mul_handle = compl_mul2d_fwd_c if use_complex_kernels else compl_mul2d_fwd
self.spectral_layers = spectral_layers
self.modes_lat = forward_transform.lmax
self.modes_lon = forward_transform.mmax
# only storing the forward handle to be able to call it
self.forward_transform = forward_transform
self.inverse_transform = inverse_transform
self.scale_residual = (self.forward_transform.nlat != self.inverse_transform.nlat) \
or (self.forward_transform.nlon != self.inverse_transform.nlon)
assert inverse_transform.lmax == self.modes_lat
assert inverse_transform.mmax == self.modes_lon
# weights
w = [self.scale * torch.randn(self.embed_dim, self.hidden_size, 2)]
for l in range(1, self.spectral_layers):
w.append(self.scale * torch.randn(self.hidden_size, self.hidden_size, 2))
self.w = nn.ParameterList(w)
if bias:
self.b = nn.ParameterList([self.scale * torch.randn(self.hidden_size, 1, 2) for _ in range(self.spectral_layers)])
self.wout = nn.Parameter(self.scale * torch.randn(self.hidden_size, self.embed_dim, 2))
self.drop = nn.Dropout(drop_rate) if drop_rate > 0. else nn.Identity()
self.activations = nn.ModuleList([])
for l in range(0, self.spectral_layers):
self.activations.append(ComplexReLU(mode=complex_activation, bias_shape=(self.hidden_size, 1, 1), scale=self.scale))
def forward_mlp(self, x):
x = torch.view_as_real(x)
xr = x
for l in range(self.spectral_layers):
if hasattr(self, 'b'):
xr = self.mul_add_handle(xr, self.w[l], self.b[l])
else:
xr = self.mul_handle(xr, self.w[l])
xr = torch.view_as_complex(xr)
xr = self.activations[l](xr)
xr = self.drop(xr)
xr = torch.view_as_real(xr)
x = self.mul_handle(xr, self.wout)
x = torch.view_as_complex(x)
return x
def forward(self, x):
dtype = x.dtype
x = x.float()
residual = x
with amp.autocast(enabled=False):
x = self.forward_transform(x)
if self.scale_residual:
residual = self.inverse_transform(x)
x = self.forward_mlp(x)
with amp.autocast(enabled=False):
x = self.inverse_transform(x)
x = x.type(dtype)
return x, residual
class SpectralAttentionS2(nn.Module):
"""
Spherical non-linear FNO layer
"""
def __init__(self,
forward_transform,
inverse_transform,
embed_dim,
operator_type = 'diagonal',
sparsity_threshold = 0.0,
hidden_size_factor = 2,
complex_activation = 'real',
scale = 'auto',
bias = False,
spectral_layers = 1,
drop_rate = 0.):
super(SpectralAttentionS2, self).__init__()
self.embed_dim = embed_dim
self.sparsity_threshold = sparsity_threshold
self.operator_type = operator_type
self.spectral_layers = spectral_layers
if scale == 'auto':
self.scale = (1 / (embed_dim * embed_dim))
self.modes_lat = forward_transform.lmax
self.modes_lon = forward_transform.mmax
# only storing the forward handle to be able to call it
self.forward_transform = forward_transform
self.inverse_transform = inverse_transform
self.scale_residual = (self.forward_transform.nlat != self.inverse_transform.nlat) \
or (self.forward_transform.nlon != self.inverse_transform.nlon)
assert inverse_transform.lmax == self.modes_lat
assert inverse_transform.mmax == self.modes_lon
hidden_size = int(hidden_size_factor * self.embed_dim)
if operator_type == 'diagonal':
self.mul_add_handle = compl_muladd2d_fwd
self.mul_handle = compl_mul2d_fwd
# weights
w = [self.scale * torch.randn(self.embed_dim, hidden_size, 2)]
for l in range(1, self.spectral_layers):
w.append(self.scale * torch.randn(hidden_size, hidden_size, 2))
self.w = nn.ParameterList(w)
self.wout = nn.Parameter(self.scale * torch.randn(hidden_size, self.embed_dim, 2))
if bias:
self.b = nn.ParameterList([self.scale * torch.randn(hidden_size, 1, 1, 2) for _ in range(self.spectral_layers)])
self.activations = nn.ModuleList([])
for l in range(0, self.spectral_layers):
self.activations.append(ComplexReLU(mode=complex_activation, bias_shape=(hidden_size, 1, 1), scale=self.scale))
elif operator_type == 'driscoll-healy':
self.mul_add_handle = compl_exp_muladd2d_fwd
self.mul_handle = compl_exp_mul2d_fwd
# weights
w = [self.scale * torch.randn(self.modes_lat, self.embed_dim, hidden_size, 2)]
for l in range(1, self.spectral_layers):
w.append(self.scale * torch.randn(self.modes_lat, hidden_size, hidden_size, 2))
self.w = nn.ParameterList(w)
if bias:
self.b = nn.ParameterList([self.scale * torch.randn(hidden_size, 1, 1, 2) for _ in range(self.spectral_layers)])
self.wout = nn.Parameter(self.scale * torch.randn(self.modes_lat, hidden_size, self.embed_dim, 2))
self.activations = nn.ModuleList([])
for l in range(0, self.spectral_layers):
self.activations.append(ComplexReLU(mode=complex_activation, bias_shape=(hidden_size, 1, 1), scale=self.scale))
else:
raise ValueError('Unknown operator type')
self.drop = nn.Dropout(drop_rate) if drop_rate > 0. else nn.Identity()
def forward_mlp(self, x):
B, C, H, W = x.shape
xr = torch.view_as_real(x)
for l in range(self.spectral_layers):
if hasattr(self, 'b'):
xr = self.mul_add_handle(xr, self.w[l], self.b[l])
else:
xr = self.mul_handle(xr, self.w[l])
xr = torch.view_as_complex(xr)
xr = self.activations[l](xr)
xr = self.drop(xr)
xr = torch.view_as_real(xr)
# final MLP
x = self.mul_handle(xr, self.wout)
x = torch.view_as_complex(x)
return x
def forward(self, x):
dtype = x.dtype
x = x.to(torch.float32)
residual = x
# FWD transform
with amp.autocast(enabled=False):
x = self.forward_transform(x)
if self.scale_residual:
residual = self.inverse_transform(x)
# MLP
x = self.forward_mlp(x)
# BWD transform
with amp.autocast(enabled=False):
x = self.inverse_transform(x)
# cast back to initial precision
x = x.to(dtype)
return x, residual
\ No newline at end of file
...@@ -38,6 +38,7 @@ from .layers import * ...@@ -38,6 +38,7 @@ from .layers import *
from functools import partial from functools import partial
class SpectralFilterLayer(nn.Module): class SpectralFilterLayer(nn.Module):
""" """
Fourier layer. Contains the convolution part of the FNO/SFNO Fourier layer. Contains the convolution part of the FNO/SFNO
...@@ -47,64 +48,37 @@ class SpectralFilterLayer(nn.Module): ...@@ -47,64 +48,37 @@ class SpectralFilterLayer(nn.Module):
self, self,
forward_transform, forward_transform,
inverse_transform, inverse_transform,
embed_dim, input_dim,
filter_type = "non-linear", output_dim,
gain = 2.,
operator_type = "diagonal", operator_type = "diagonal",
sparsity_threshold = 0.0,
use_complex_kernels = True,
hidden_size_factor = 2, hidden_size_factor = 2,
lr_scale_exponent = 0,
factorization = None, factorization = None,
separable = False, separable = False,
rank = 1e-2, rank = 1e-2,
complex_activation = "real", bias = True):
spectral_layers = 1,
drop_rate = 0):
super(SpectralFilterLayer, self).__init__() super(SpectralFilterLayer, self).__init__()
if filter_type == "non-linear" and isinstance(forward_transform, RealSHT): if factorization is None:
self.filter = SpectralAttentionS2(forward_transform,
inverse_transform,
embed_dim,
operator_type = operator_type,
sparsity_threshold = sparsity_threshold,
hidden_size_factor = hidden_size_factor,
complex_activation = complex_activation,
spectral_layers = spectral_layers,
drop_rate = drop_rate,
bias = False)
elif filter_type == "non-linear" and isinstance(forward_transform, RealFFT2):
self.filter = SpectralAttention2d(forward_transform,
inverse_transform,
embed_dim,
sparsity_threshold = sparsity_threshold,
use_complex_kernels = use_complex_kernels,
hidden_size_factor = hidden_size_factor,
complex_activation = complex_activation,
spectral_layers = spectral_layers,
drop_rate = drop_rate,
bias = False)
elif filter_type == "linear" and factorization is None:
self.filter = SpectralConvS2(forward_transform, self.filter = SpectralConvS2(forward_transform,
inverse_transform, inverse_transform,
embed_dim, input_dim,
embed_dim, output_dim,
gain = gain,
operator_type = operator_type, operator_type = operator_type,
lr_scale_exponent = lr_scale_exponent, bias = bias)
bias = True)
elif filter_type == "linear" and factorization is not None: elif factorization is not None:
self.filter = FactorizedSpectralConvS2(forward_transform, self.filter = FactorizedSpectralConvS2(forward_transform,
inverse_transform, inverse_transform,
embed_dim, input_dim,
embed_dim, output_dim,
gain = gain,
operator_type = operator_type, operator_type = operator_type,
rank = rank, rank = rank,
factorization = factorization, factorization = factorization,
separable = separable, separable = separable,
bias = True) bias = bias)
else: else:
raise(NotImplementedError) raise(NotImplementedError)
...@@ -120,60 +94,54 @@ class SphericalFourierNeuralOperatorBlock(nn.Module): ...@@ -120,60 +94,54 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
self, self,
forward_transform, forward_transform,
inverse_transform, inverse_transform,
embed_dim, input_dim,
filter_type = "non-linear", output_dim,
operator_type = "driscoll-healy", operator_type = "driscoll-healy",
mlp_ratio = 2., mlp_ratio = 2.,
drop_rate = 0., drop_rate = 0.,
drop_path = 0., drop_path = 0.,
act_layer = nn.GELU, act_layer = nn.ReLU,
norm_layer = nn.Identity, norm_layer = nn.Identity,
sparsity_threshold = 0.0,
use_complex_kernels = True,
lr_scale_exponent = 0,
factorization = None, factorization = None,
separable = False, separable = False,
rank = 128, rank = 128,
inner_skip = "linear", inner_skip = "linear",
outer_skip = None, outer_skip = None,
concat_skip = False, use_mlp = True):
use_mlp = True,
complex_activation = "real",
spectral_layers = 3):
super(SphericalFourierNeuralOperatorBlock, self).__init__() super(SphericalFourierNeuralOperatorBlock, self).__init__()
if act_layer == nn.Identity:
gain_factor = 1.0
else:
gain_factor = 2.0
if inner_skip == "linear" or inner_skip == "identity":
gain_factor /= 2.0
# convolution layer # convolution layer
self.filter = SpectralFilterLayer(forward_transform, self.filter = SpectralFilterLayer(forward_transform,
inverse_transform, inverse_transform,
embed_dim, input_dim,
filter_type, output_dim,
gain = gain_factor,
operator_type = operator_type, operator_type = operator_type,
sparsity_threshold = sparsity_threshold,
use_complex_kernels = use_complex_kernels,
hidden_size_factor = mlp_ratio, hidden_size_factor = mlp_ratio,
lr_scale_exponent = lr_scale_exponent,
factorization = factorization, factorization = factorization,
separable = separable, separable = separable,
rank = rank, rank = rank,
complex_activation = complex_activation, bias = True)
spectral_layers = spectral_layers,
drop_rate = drop_rate)
if inner_skip == "linear": if inner_skip == "linear":
self.inner_skip = nn.Conv2d(embed_dim, embed_dim, 1, 1) self.inner_skip = nn.Conv2d(input_dim, output_dim, 1, 1)
nn.init.normal_(self.inner_skip.weight, std=math.sqrt(gain_factor/input_dim))
elif inner_skip == "identity": elif inner_skip == "identity":
assert input_dim == output_dim
self.inner_skip = nn.Identity() self.inner_skip = nn.Identity()
elif inner_skip == "none": elif inner_skip == "none":
pass pass
else: else:
raise ValueError(f"Unknown skip connection type {inner_skip}") raise ValueError(f"Unknown skip connection type {inner_skip}")
self.concat_skip = concat_skip
if concat_skip and inner_skip is not None:
self.inner_skip_conv = nn.Conv2d(2*embed_dim, embed_dim, 1, bias=False)
if filter_type == "linear":
self.act_layer = act_layer() self.act_layer = act_layer()
# first normalisation layer # first normalisation layer
...@@ -182,59 +150,67 @@ class SphericalFourierNeuralOperatorBlock(nn.Module): ...@@ -182,59 +150,67 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
# dropout # dropout
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
gain_factor = 1.0
if outer_skip == "linear" or inner_skip == "identity":
gain_factor /= 2.
if use_mlp == True: if use_mlp == True:
mlp_hidden_dim = int(embed_dim * mlp_ratio) mlp_hidden_dim = int(output_dim * mlp_ratio)
self.mlp = MLP(in_features = embed_dim, self.mlp = MLP(in_features = output_dim,
out_features = input_dim,
hidden_features = mlp_hidden_dim, hidden_features = mlp_hidden_dim,
act_layer = act_layer, act_layer = act_layer,
drop_rate = drop_rate, drop_rate = drop_rate,
checkpointing = False) checkpointing = False,
gain = gain_factor)
if outer_skip == "linear": if outer_skip == "linear":
self.outer_skip = nn.Conv2d(embed_dim, embed_dim, 1, 1) self.outer_skip = nn.Conv2d(input_dim, input_dim, 1, 1)
torch.nn.init.normal_(self.outer_skip.weight, std=math.sqrt(gain_factor/input_dim))
elif outer_skip == "identity": elif outer_skip == "identity":
assert input_dim == output_dim
self.outer_skip = nn.Identity() self.outer_skip = nn.Identity()
elif outer_skip == "none": elif outer_skip == "none":
pass pass
else: else:
raise ValueError(f"Unknown skip connection type {outer_skip}") raise ValueError(f"Unknown skip connection type {outer_skip}")
if concat_skip and outer_skip is not None:
self.outer_skip_conv = nn.Conv2d(2*embed_dim, embed_dim, 1, bias=False)
# second normalisation layer # second normalisation layer
self.norm1 = norm_layer() self.norm1 = norm_layer()
# def init_weights(self, scale):
# if hasattr(self, "inner_skip") and isinstance(self.inner_skip, nn.Conv2d):
# gain_factor = 1.
# scale = (gain_factor / embed_dim)**0.5
# nn.init.normal_(self.inner_skip.weight, mean=0., std=scale)
# self.filter.filter.init_weights(scale)
# else:
# gain_factor = 2.
# scale = (gain_factor / embed_dim)**0.5
# self.filter.filter.init_weights(scale)
def forward(self, x): def forward(self, x):
x, residual = self.filter(x) x, residual = self.filter(x)
x = self.norm0(x)
if hasattr(self, "inner_skip"): if hasattr(self, "inner_skip"):
if self.concat_skip:
x = torch.cat((x, self.inner_skip(residual)), dim=1)
x = self.inner_skip_conv(x)
else:
x = x + self.inner_skip(residual) x = x + self.inner_skip(residual)
if hasattr(self, "act_layer"): if hasattr(self, "act_layer"):
x = self.act_layer(x) x = self.act_layer(x)
x = self.norm0(x)
if hasattr(self, "mlp"): if hasattr(self, "mlp"):
x = self.mlp(x) x = self.mlp(x)
x = self.norm1(x)
x = self.drop_path(x) x = self.drop_path(x)
if hasattr(self, "outer_skip"): if hasattr(self, "outer_skip"):
if self.concat_skip:
x = torch.cat((x, self.outer_skip(residual)), dim=1)
x = self.outer_skip_conv(x)
else:
x = x + self.outer_skip(residual) x = x + self.outer_skip(residual)
x = self.norm1(x)
return x return x
class SphericalFourierNeuralOperatorNet(nn.Module): class SphericalFourierNeuralOperatorNet(nn.Module):
...@@ -244,8 +220,6 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -244,8 +220,6 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
Parameters Parameters
---------- ----------
filter_type : str, optional
Type of filter to use ('linear', 'non-linear'), by default "linear"
spectral_transform : str, optional spectral_transform : str, optional
Type of spectral transformation to use, by default "sht" Type of spectral transformation to use, by default "sht"
operator_type : str, optional operator_type : str, optional
...@@ -274,30 +248,20 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -274,30 +248,20 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
Dropout rate, by default 0.0 Dropout rate, by default 0.0
drop_path_rate : float, optional drop_path_rate : float, optional
Dropout path rate, by default 0.0 Dropout path rate, by default 0.0
sparsity_threshold : float, optional
Threshold for sparsity, by default 0.0
normalization_layer : str, optional normalization_layer : str, optional
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm" Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
hard_thresholding_fraction : float, optional hard_thresholding_fraction : float, optional
Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0 Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0
use_complex_kernels : bool, optional
Whether to use complex kernels, by default True
big_skip : bool, optional big_skip : bool, optional
Whether to add a single large skip connection, by default True Whether to add a single large skip connection, by default True
rank : float, optional rank : float, optional
Rank of the approximation, by default 1.0 Rank of the approximation, by default 1.0
lr_scale_exponent : float, optional
exponential rescaling of spectral coefficients, by default 0.0 (no rescaling)
factorization : Any, optional factorization : Any, optional
Type of factorization to use, by default None Type of factorization to use, by default None
separable : bool, optional separable : bool, optional
Whether to use separable convolutions, by default False Whether to use separable convolutions, by default False
rank : (int, Tuple[int]), optional rank : (int, Tuple[int]), optional
If a factorization is used, which rank to use. Argument is passed to tensorly If a factorization is used, which rank to use. Argument is passed to tensorly
complex_activation : str, optional
Type of complex activation function to use, by default "real"
spectral_layers : int, optional
Number of spectral layers, by default 3
pos_embed : bool, optional pos_embed : bool, optional
Whether to use positional embedding, by default True Whether to use positional embedding, by default True
...@@ -317,63 +281,58 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -317,63 +281,58 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
def __init__( def __init__(
self, self,
filter_type = "linear",
spectral_transform = "sht", spectral_transform = "sht",
operator_type = "driscoll-healy", operator_type = "driscoll-healy",
img_size = (128, 256), img_size = (128, 256),
grid = "equiangular",
scale_factor = 3, scale_factor = 3,
in_chans = 3, in_chans = 3,
out_chans = 3, out_chans = 3,
embed_dim = 256, embed_dim = 256,
num_layers = 4, num_layers = 4,
activation_function = "gelu", activation_function = "relu",
encoder_layers = 1, encoder_layers = 1,
use_mlp = True, use_mlp = True,
mlp_ratio = 2., mlp_ratio = 2.,
drop_rate = 0., drop_rate = 0.,
drop_path_rate = 0., drop_path_rate = 0.,
sparsity_threshold = 0.0,
normalization_layer = "none", normalization_layer = "none",
hard_thresholding_fraction = 1.0, hard_thresholding_fraction = 1.0,
use_complex_kernels = True, use_complex_kernels = True,
big_skip = True, big_skip = False,
lr_scale_exponent = 0,
factorization = None, factorization = None,
separable = False, separable = False,
rank = 128, rank = 128,
complex_activation = "real", pos_embed = False):
spectral_layers = 2,
pos_embed = True):
super(SphericalFourierNeuralOperatorNet, self).__init__() super(SphericalFourierNeuralOperatorNet, self).__init__()
self.filter_type = filter_type
self.spectral_transform = spectral_transform self.spectral_transform = spectral_transform
self.operator_type = operator_type self.operator_type = operator_type
self.img_size = img_size self.img_size = img_size
self.grid = grid
self.scale_factor = scale_factor self.scale_factor = scale_factor
self.in_chans = in_chans self.in_chans = in_chans
self.out_chans = out_chans self.out_chans = out_chans
self.embed_dim = self.num_features = embed_dim self.embed_dim = embed_dim
self.pos_embed_dim = self.embed_dim
self.num_layers = num_layers self.num_layers = num_layers
self.hard_thresholding_fraction = hard_thresholding_fraction self.hard_thresholding_fraction = hard_thresholding_fraction
self.normalization_layer = normalization_layer self.normalization_layer = normalization_layer
self.use_mlp = use_mlp self.use_mlp = use_mlp
self.encoder_layers = encoder_layers self.encoder_layers = encoder_layers
self.big_skip = big_skip self.big_skip = big_skip
self.lr_scale_exponent = lr_scale_exponent
self.factorization = factorization self.factorization = factorization
self.separable = separable, self.separable = separable,
self.rank = rank self.rank = rank
self.complex_activation = complex_activation
self.spectral_layers = spectral_layers
# activation function # activation function
if activation_function == "relu": if activation_function == "relu":
self.activation_function = nn.ReLU self.activation_function = nn.ReLU
elif activation_function == "gelu": elif activation_function == "gelu":
self.activation_function = nn.GELU self.activation_function = nn.GELU
# for debugging purposes
elif activation_function == "identity":
self.activation_function = nn.Identity
else: else:
raise ValueError(f"Unknown activation function {activation_function}") raise ValueError(f"Unknown activation function {activation_function}")
...@@ -391,37 +350,68 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -391,37 +350,68 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
norm_layer1 = partial(nn.LayerNorm, normalized_shape=(self.h, self.w), eps=1e-6) norm_layer1 = partial(nn.LayerNorm, normalized_shape=(self.h, self.w), eps=1e-6)
elif self.normalization_layer == "instance_norm": elif self.normalization_layer == "instance_norm":
norm_layer0 = partial(nn.InstanceNorm2d, num_features=self.embed_dim, eps=1e-6, affine=True, track_running_stats=False) norm_layer0 = partial(nn.InstanceNorm2d, num_features=self.embed_dim, eps=1e-6, affine=True, track_running_stats=False)
norm_layer1 = norm_layer0 norm_layer1 = partial(nn.InstanceNorm2d, num_features=self.embed_dim, eps=1e-6, affine=True, track_running_stats=False)
elif self.normalization_layer == "none": elif self.normalization_layer == "none":
norm_layer0 = nn.Identity norm_layer0 = nn.Identity
norm_layer1 = norm_layer0 norm_layer1 = norm_layer0
else: else:
raise NotImplementedError(f"Error, normalization {self.normalization_layer} not implemented.") raise NotImplementedError(f"Error, normalization {self.normalization_layer} not implemented.")
if pos_embed: if pos_embed == "latlon" or pos_embed==True:
self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, self.img_size[0], self.img_size[1])) self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, self.img_size[0], self.img_size[1]))
nn.init.constant_(self.pos_embed, 0.0)
elif pos_embed == "lat":
self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, self.img_size[0], 1))
nn.init.constant_(self.pos_embed, 0.0)
elif pos_embed == "const":
self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, 1, 1))
nn.init.constant_(self.pos_embed, 0.0)
else: else:
self.pos_embed = None self.pos_embed = None
# encoder # # encoder
# encoder_hidden_dim = int(self.embed_dim * mlp_ratio)
# encoder = MLP(in_features = self.in_chans,
# out_features = self.embed_dim,
# hidden_features = encoder_hidden_dim,
# act_layer = self.activation_function,
# drop_rate = drop_rate,
# checkpointing = False)
# self.encoder = encoder
# construct an encoder with num_encoder_layers
num_encoder_layers = 1
encoder_hidden_dim = int(self.embed_dim * mlp_ratio) encoder_hidden_dim = int(self.embed_dim * mlp_ratio)
encoder = MLP(in_features = self.in_chans, current_dim = self.in_chans
out_features = self.embed_dim, encoder_layers = []
hidden_features = encoder_hidden_dim, for l in range(num_encoder_layers-1):
act_layer = self.activation_function, fc = nn.Conv2d(current_dim, encoder_hidden_dim, 1, bias=True)
drop_rate = drop_rate, # initialize the weights correctly
checkpointing = False) scale = math.sqrt(2. / current_dim)
self.encoder = encoder nn.init.normal_(fc.weight, mean=0., std=scale)
# self.encoder = nn.Sequential(encoder, norm_layer0()) if fc.bias is not None:
nn.init.constant_(fc.bias, 0.0)
encoder_layers.append(fc)
encoder_layers.append(self.activation_function())
current_dim = encoder_hidden_dim
fc = nn.Conv2d(current_dim, self.embed_dim, 1, bias=False)
scale = math.sqrt(1. / current_dim)
nn.init.normal_(fc.weight, mean=0., std=scale)
if fc.bias is not None:
nn.init.constant_(fc.bias, 0.0)
encoder_layers.append(fc)
self.encoder = nn.Sequential(*encoder_layers)
# prepare the spectral transform # prepare the spectral transform
if self.spectral_transform == "sht": if self.spectral_transform == "sht":
modes_lat = int(self.h * self.hard_thresholding_fraction) modes_lat = int(self.h * self.hard_thresholding_fraction)
modes_lon = int((self.w // 2 + 1) * self.hard_thresholding_fraction) modes_lon = int(self.w//2 * self.hard_thresholding_fraction)
modes_lat = modes_lon = min(modes_lat, modes_lon)
self.trans_down = RealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid="equiangular").float() self.trans_down = RealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid=self.grid).float()
self.itrans_up = InverseRealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid="equiangular").float() self.itrans_up = InverseRealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid=self.grid).float()
self.trans = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid="legendre-gauss").float() self.trans = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid="legendre-gauss").float()
self.itrans = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid="legendre-gauss").float() self.itrans = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid="legendre-gauss").float()
...@@ -447,8 +437,8 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -447,8 +437,8 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
forward_transform = self.trans_down if first_layer else self.trans forward_transform = self.trans_down if first_layer else self.trans
inverse_transform = self.itrans_up if last_layer else self.itrans inverse_transform = self.itrans_up if last_layer else self.itrans
inner_skip = 'linear' inner_skip = "none"
outer_skip = 'identity' outer_skip = "identity"
if first_layer: if first_layer:
norm_layer = norm_layer1 norm_layer = norm_layer1
...@@ -460,45 +450,53 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -460,45 +450,53 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
block = SphericalFourierNeuralOperatorBlock(forward_transform, block = SphericalFourierNeuralOperatorBlock(forward_transform,
inverse_transform, inverse_transform,
self.embed_dim, self.embed_dim,
filter_type = filter_type, self.embed_dim,
operator_type = self.operator_type, operator_type = self.operator_type,
mlp_ratio = mlp_ratio, mlp_ratio = mlp_ratio,
drop_rate = drop_rate, drop_rate = drop_rate,
drop_path = dpr[i], drop_path = dpr[i],
act_layer = self.activation_function, act_layer = self.activation_function,
norm_layer = norm_layer, norm_layer = norm_layer,
sparsity_threshold = sparsity_threshold,
use_complex_kernels = use_complex_kernels,
inner_skip = inner_skip, inner_skip = inner_skip,
outer_skip = outer_skip, outer_skip = outer_skip,
use_mlp = use_mlp, use_mlp = use_mlp,
lr_scale_exponent = self.lr_scale_exponent,
factorization = self.factorization, factorization = self.factorization,
separable = self.separable, separable = self.separable,
rank = self.rank, rank = self.rank)
complex_activation = self.complex_activation,
spectral_layers = self.spectral_layers)
self.blocks.append(block) self.blocks.append(block)
# decoder # # decoder
encoder_hidden_dim = int(self.embed_dim * mlp_ratio) # decoder_hidden_dim = int(self.embed_dim * mlp_ratio)
self.decoder = MLP(in_features = self.embed_dim + self.big_skip*self.in_chans, # self.decoder = MLP(in_features = self.embed_dim + self.big_skip*self.in_chans,
out_features = self.out_chans, # out_features = self.out_chans,
hidden_features = encoder_hidden_dim, # hidden_features = decoder_hidden_dim,
act_layer = self.activation_function, # act_layer = self.activation_function,
drop_rate = drop_rate, # drop_rate = drop_rate,
checkpointing = False) # checkpointing = False)
# trunc_normal_(self.pos_embed, std=.02) # construct an decoder with num_decoder_layers
self.apply(self._init_weights) num_decoder_layers = 1
decoder_hidden_dim = int(self.embed_dim * mlp_ratio)
def _init_weights(self, m): current_dim = self.embed_dim + self.big_skip*self.in_chans
if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): decoder_layers = []
trunc_normal_(m.weight, std=.02) for l in range(num_decoder_layers-1):
#nn.init.normal_(m.weight, std=0.02) fc = nn.Conv2d(current_dim, decoder_hidden_dim, 1, bias=True)
if m.bias is not None: # initialize the weights correctly
nn.init.constant_(m.bias, 0) scale = math.sqrt(2. / current_dim)
nn.init.normal_(fc.weight, mean=0., std=scale)
if fc.bias is not None:
nn.init.constant_(fc.bias, 0.0)
decoder_layers.append(fc)
decoder_layers.append(self.activation_function())
current_dim = decoder_hidden_dim
fc = nn.Conv2d(current_dim, self.out_chans, 1, bias=False)
scale = math.sqrt(1. / current_dim)
nn.init.normal_(fc.weight, mean=0., std=scale)
if fc.bias is not None:
nn.init.constant_(fc.bias, 0.0)
decoder_layers.append(fc)
self.decoder = nn.Sequential(*decoder_layers)
@torch.jit.ignore @torch.jit.ignore
def no_weight_decay(self): def no_weight_decay(self):
......
...@@ -239,7 +239,7 @@ class ShallowWaterSolver(nn.Module): ...@@ -239,7 +239,7 @@ class ShallowWaterSolver(nn.Module):
ctype = torch.complex128 if self.lap.dtype == torch.float64 else torch.complex64 ctype = torch.complex128 if self.lap.dtype == torch.float64 else torch.complex64
# mach number relative to wave speed # mach number relative to wave speed
llimit = mlimit = 20 llimit = mlimit = 80
# hgrid = self.havg + hamp * torch.randn(self.nlat, self.nlon, device=device, dtype=dtype) # hgrid = self.havg + hamp * torch.randn(self.nlat, self.nlon, device=device, dtype=dtype)
# ugrid = uamp * torch.randn(self.nlat, self.nlon, device=device, dtype=dtype) # ugrid = uamp * torch.randn(self.nlat, self.nlon, device=device, dtype=dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment