Commit e7ceb9c8 authored by Boris Bonev's avatar Boris Bonev
Browse files

Adding SFNO examples

parent 24490256
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import torch
import torch.nn as nn
# complex activation functions
class ComplexCardioid(nn.Module):
"""
Complex Cardioid activation function
"""
def __init__(self):
super(ComplexCardioid, self).__init__()
def forward(self, z: torch.Tensor) -> torch.Tensor:
out = 0.5 * (1. + torch.cos(z.angle())) * z
return out
class ComplexReLU(nn.Module):
"""
Complex-valued variants of the ReLU activation function
"""
def __init__(self, negative_slope=0., mode="real", bias_shape=None, scale=1.):
super(ComplexReLU, self).__init__()
# store parameters
self.mode = mode
if self.mode in ["modulus", "halfplane"]:
if bias_shape is not None:
self.bias = nn.Parameter(scale * torch.ones(bias_shape, dtype=torch.float32))
else:
self.bias = nn.Parameter(scale * torch.ones((1), dtype=torch.float32))
else:
self.bias = 0
self.negative_slope = negative_slope
self.act = nn.LeakyReLU(negative_slope = negative_slope)
def forward(self, z: torch.Tensor) -> torch.Tensor:
if self.mode == "cartesian":
zr = torch.view_as_real(z)
za = self.act(zr)
out = torch.view_as_complex(za)
elif self.mode == "modulus":
zabs = torch.sqrt(torch.square(z.real) + torch.square(z.imag))
out = torch.where(zabs + self.bias > 0, (zabs + self.bias) * z / zabs, 0.0)
elif self.mode == "cardioid":
out = 0.5 * (1. + torch.cos(z.angle())) * z
# elif self.mode == "halfplane":
# # bias is an angle parameter in this case
# modified_angle = torch.angle(z) - self.bias
# condition = torch.logical_and( (0. <= modified_angle), (modified_angle < torch.pi/2.) )
# out = torch.where(condition, z, self.negative_slope * z)
elif self.mode == "real":
zr = torch.view_as_real(z)
outr = zr.clone()
outr[..., 0] = self.act(zr[..., 0])
out = torch.view_as_complex(outr)
else:
raise NotImplementedError
return out
\ No newline at end of file
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import torch
"""
Contains complex contractions wrapped into jit for harmonic layers
"""
@torch.jit.script
def compl_contract2d_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
tmp = torch.einsum("bixys,kixr->srbkx", a, b)
res = torch.stack([tmp[0,0,...] - tmp[1,1,...], tmp[1,0,...] + tmp[0,1,...]], dim=-1)
return res
@torch.jit.script
def compl_contract2d_fwd_c(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
ac = torch.view_as_complex(a)
bc = torch.view_as_complex(b)
res = torch.einsum("bixy,kix->bkx", ac, bc)
return torch.view_as_real(res)
@torch.jit.script
def compl_contract_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
tmp = torch.einsum("bins,kinr->srbkn", a, b)
res = torch.stack([tmp[0,0,...] - tmp[1,1,...], tmp[1,0,...] + tmp[0,1,...]], dim=-1)
return res
@torch.jit.script
def compl_contract_fwd_c(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
ac = torch.view_as_complex(a)
bc = torch.view_as_complex(b)
res = torch.einsum("bin,kin->bkn", ac, bc)
return torch.view_as_real(res)
# Helper routines for spherical MLPs
@torch.jit.script
def compl_mul1d_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
tmp = torch.einsum("bixs,ior->srbox", a, b)
res = torch.stack([tmp[0,0,...] - tmp[1,1,...], tmp[1,0,...] + tmp[0,1,...]], dim=-1)
return res
@torch.jit.script
def compl_mul1d_fwd_c(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
ac = torch.view_as_complex(a)
bc = torch.view_as_complex(b)
resc = torch.einsum("bix,io->box", ac, bc)
res = torch.view_as_real(resc)
return res
@torch.jit.script
def compl_muladd1d_fwd(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
res = compl_mul1d_fwd(a, b) + c
return res
@torch.jit.script
def compl_muladd1d_fwd_c(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
tmpcc = torch.view_as_complex(compl_mul1d_fwd_c(a, b))
cc = torch.view_as_complex(c)
return torch.view_as_real(tmpcc + cc)
# Helper routines for FFT MLPs
@torch.jit.script
def compl_mul2d_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
tmp = torch.einsum("bixys,ior->srboxy", a, b)
res = torch.stack([tmp[0,0,...] - tmp[1,1,...], tmp[1,0,...] + tmp[0,1,...]], dim=-1)
return res
@torch.jit.script
def compl_mul2d_fwd_c(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
ac = torch.view_as_complex(a)
bc = torch.view_as_complex(b)
resc = torch.einsum("bixy,io->boxy", ac, bc)
res = torch.view_as_real(resc)
return res
@torch.jit.script
def compl_muladd2d_fwd(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
res = compl_mul2d_fwd(a, b) + c
return res
@torch.jit.script
def compl_muladd2d_fwd_c(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
tmpcc = torch.view_as_complex(compl_mul2d_fwd_c(a, b))
cc = torch.view_as_complex(c)
return torch.view_as_real(tmpcc + cc)
@torch.jit.script
def real_mul2d_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
out = torch.einsum("bixy,io->boxy", a, b)
return out
@torch.jit.script
def real_muladd2d_fwd(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
return compl_mul2d_fwd_c(a, b) + c
# for all the experimental layers
# @torch.jit.script
# def compl_exp_mul2d_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
# ac = torch.view_as_complex(a)
# bc = torch.view_as_complex(b)
# resc = torch.einsum("bixy,xio->boxy", ac, bc)
# res = torch.view_as_real(resc)
# return res
# @torch.jit.script
# def compl_exp_muladd2d_fwd(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
# tmpcc = torch.view_as_complex(compl_exp_mul2d_fwd(a, b))
# cc = torch.view_as_complex(c)
# return torch.view_as_real(tmpcc + cc)
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import torch
import tensorly as tl
tl.set_backend('pytorch')
from tltorch.factorized_tensors.core import FactorizedTensor
einsum_symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
def _contract_dense(x, weight, separable=False, operator_type='diagonal'):
order = tl.ndim(x)
# batch-size, in_channels, x, y...
x_syms = list(einsum_symbols[:order])
# in_channels, out_channels, x, y...
weight_syms = list(x_syms[1:]) # no batch-size
# batch-size, out_channels, x, y...
if separable:
out_syms = [x_syms[0]] + list(weight_syms)
else:
weight_syms.insert(1, einsum_symbols[order]) # outputs
out_syms = list(weight_syms)
out_syms[0] = x_syms[0]
if operator_type == 'diagonal':
pass
elif operator_type == 'block-diagonal':
weight_syms.insert(-1, einsum_symbols[order+1])
out_syms[-1] = weight_syms[-2]
elif operator_type == 'vector':
weight_syms.pop()
else:
raise ValueError(f"Unkonw operator type {operator_type}")
eq= ''.join(x_syms) + ',' + ''.join(weight_syms) + '->' + ''.join(out_syms)
if not torch.is_tensor(weight):
weight = weight.to_tensor()
return tl.einsum(eq, x, weight)
def _contract_cp(x, cp_weight, separable=False, operator_type='diagonal'):
order = tl.ndim(x)
x_syms = str(einsum_symbols[:order])
rank_sym = einsum_symbols[order]
out_sym = einsum_symbols[order+1]
out_syms = list(x_syms)
if separable:
factor_syms = [einsum_symbols[1]+rank_sym] #in only
else:
out_syms[1] = out_sym
factor_syms = [einsum_symbols[1]+rank_sym, out_sym+rank_sym] #in, out
factor_syms += [xs+rank_sym for xs in x_syms[2:]] #x, y, ...
if operator_type == 'diagonal':
pass
elif operator_type == 'block-diagonal':
out_syms[-1] = einsum_symbols[order+2]
factor_syms += [out_syms[-1] + rank_sym]
elif operator_type == 'vector':
factor_syms.pop()
else:
raise ValueError(f"Unkonw operator type {operator_type}")
eq = x_syms + ',' + rank_sym + ',' + ','.join(factor_syms) + '->' + ''.join(out_syms)
return tl.einsum(eq, x, cp_weight.weights, *cp_weight.factors)
def _contract_tucker(x, tucker_weight, separable=False, operator_type='diagonal'):
order = tl.ndim(x)
x_syms = str(einsum_symbols[:order])
out_sym = einsum_symbols[order]
out_syms = list(x_syms)
if separable:
core_syms = einsum_symbols[order+1:2*order]
# factor_syms = [einsum_symbols[1]+core_syms[0]] #in only
factor_syms = [xs+rs for (xs, rs) in zip(x_syms[1:], core_syms)] #x, y, ...
else:
core_syms = einsum_symbols[order+1:2*order+1]
out_syms[1] = out_sym
factor_syms = [einsum_symbols[1]+core_syms[0], out_sym+core_syms[1]] #out, in
factor_syms += [xs+rs for (xs, rs) in zip(x_syms[2:], core_syms[2:])] #x, y, ...
if operator_type == 'diagonal':
pass
elif operator_type == 'block-diagonal':
raise NotImplementedError(f"Operator type {operator_type} not implemented for Tucker")
else:
raise ValueError(f"Unkonw operator type {operator_type}")
eq = x_syms + ',' + core_syms + ',' + ','.join(factor_syms) + '->' + ''.join(out_syms)
return tl.einsum(eq, x, tucker_weight.core, *tucker_weight.factors)
def _contract_tt(x, tt_weight, separable=False, operator_type='diagonal'):
order = tl.ndim(x)
x_syms = list(einsum_symbols[:order])
weight_syms = list(x_syms[1:]) # no batch-size
if not separable:
weight_syms.insert(1, einsum_symbols[order]) # outputs
out_syms = list(weight_syms)
out_syms[0] = x_syms[0]
else:
out_syms = list(x_syms)
if operator_type == 'diagonal':
pass
elif operator_type == 'block-diagonal':
weight_syms.insert(-1, einsum_symbols[order+1])
out_syms[-1] = weight_syms[-2]
elif operator_type == 'vector':
weight_syms.pop()
else:
raise ValueError(f"Unkonw operator type {operator_type}")
rank_syms = list(einsum_symbols[order+2:])
tt_syms = []
for i, s in enumerate(weight_syms):
tt_syms.append([rank_syms[i], s, rank_syms[i+1]])
eq = ''.join(x_syms) + ',' + ','.join(''.join(f) for f in tt_syms) + '->' + ''.join(out_syms)
return tl.einsum(eq, x, *tt_weight.factors)
def get_contract_fun(weight, implementation='reconstructed', separable=False):
"""Generic ND implementation of Fourier Spectral Conv contraction
Parameters
----------
weight : tensorly-torch's FactorizedTensor
implementation : {'reconstructed', 'factorized'}, default is 'reconstructed'
whether to reconstruct the weight and do a forward pass (reconstructed)
or contract directly the factors of the factorized weight with the input (factorized)
Returns
-------
function : (x, weight) -> x * weight in Fourier space
"""
if implementation == 'reconstructed':
return _contract_dense
elif implementation == 'factorized':
if torch.is_tensor(weight):
return _contract_dense
elif isinstance(weight, FactorizedTensor):
if weight.name.lower() == 'complexdense':
return _contract_dense
elif weight.name.lower() == 'complextucker':
return _contract_tucker
elif weight.name.lower() == 'complextt':
return _contract_tt
elif weight.name.lower() == 'complexcp':
return _contract_cp
else:
raise ValueError(f'Got unexpected factorized weight type {weight.name}')
else:
raise ValueError(f'Got unexpected weight type of class {weight.__class__.__name__}')
else:
raise ValueError(f'Got {implementation=}, expected "reconstructed" or "factorized"')
This diff is collapsed.
This diff is collapsed.
This source diff could not be displayed because it is too large. You can view the blob instead.
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import os
import time
from tqdm import tqdm
from functools import partial
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.cuda import amp
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# wandb logging
import wandb
wandb.login()
import sys
sys.path.append(os.path.join(os.path.dirname( __file__), "../"))
from pde_sphere import SphereSolver
def l2loss_sphere(solver, prd, tar, relative=False, squared=False):
loss = solver.integrate_grid((prd - tar)**2, dimensionless=True).sum(dim=-1)
if relative:
loss = loss / solver.integrate_grid(tar**2, dimensionless=True).sum(dim=-1)
if not squared:
loss = torch.sqrt(loss)
loss = loss.mean()
return loss
def spectral_l2loss_sphere(solver, prd, tar, relative=False, squared=False):
# compute coefficients
coeffs = torch.view_as_real(solver.sht(prd - tar))
coeffs = coeffs[..., 0]**2 + coeffs[..., 1]**2
norm2 = coeffs[..., :, 0] + 2 * torch.sum(coeffs[..., :, 1:], dim=-1)
loss = torch.sum(norm2, dim=(-1,-2))
if relative:
tar_coeffs = torch.view_as_real(solver.sht(tar))
tar_coeffs = tar_coeffs[..., 0]**2 + tar_coeffs[..., 1]**2
tar_norm2 = tar_coeffs[..., :, 0] + 2 * torch.sum(tar_coeffs[..., :, 1:], dim=-1)
tar_norm2 = torch.sum(tar_norm2, dim=(-1,-2))
loss = loss / tar_norm2
if not squared:
loss = torch.sqrt(loss)
loss = loss.mean()
return loss
def spectral_loss_sphere(solver, prd, tar, relative=False, squared=False):
# gradient weighting factors
lmax = solver.sht.lmax
ls = torch.arange(lmax).float()
spectral_weights = (ls*(ls + 1)).reshape(1, 1, -1, 1).to(prd.device)
# compute coefficients
coeffs = torch.view_as_real(solver.sht(prd - tar))
coeffs = coeffs[..., 0]**2 + coeffs[..., 1]**2
coeffs = spectral_weights * coeffs
norm2 = coeffs[..., :, 0] + 2 * torch.sum(coeffs[..., :, 1:], dim=-1)
loss = torch.sum(norm2, dim=(-1,-2))
if relative:
tar_coeffs = torch.view_as_real(solver.sht(tar))
tar_coeffs = tar_coeffs[..., 0]**2 + tar_coeffs[..., 1]**2
tar_coeffs = spectral_weights * tar_coeffs
tar_norm2 = tar_coeffs[..., :, 0] + 2 * torch.sum(tar_coeffs[..., :, 1:], dim=-1)
tar_norm2 = torch.sum(tar_norm2, dim=(-1,-2))
loss = loss / tar_norm2
if not squared:
loss = torch.sqrt(loss)
loss = loss.mean()
return loss
def h1loss_sphere(solver, prd, tar, relative=False, squared=False):
# gradient weighting factors
lmax = solver.sht.lmax
ls = torch.arange(lmax).float()
spectral_weights = (ls*(ls + 1)).reshape(1, 1, -1, 1).to(prd.device)
# compute coefficients
coeffs = torch.view_as_real(solver.sht(prd - tar))
coeffs = coeffs[..., 0]**2 + coeffs[..., 1]**2
h1_coeffs = spectral_weights * coeffs
h1_norm2 = h1_coeffs[..., :, 0] + 2 * torch.sum(h1_coeffs[..., :, 1:], dim=-1)
l2_norm2 = coeffs[..., :, 0] + 2 * torch.sum(coeffs[..., :, 1:], dim=-1)
h1_loss = torch.sum(h1_norm2, dim=(-1,-2))
l2_loss = torch.sum(l2_norm2, dim=(-1,-2))
# strictly speaking this is not exactly h1 loss
if not squared:
loss = torch.sqrt(h1_loss) + torch.sqrt(l2_loss)
else:
loss = h1_loss + l2_loss
if relative:
raise NotImplementedError("Relative H1 loss not implemented")
loss = loss.mean()
return loss
def fluct_l2loss_sphere(solver, prd, tar, inp, relative=False, polar_opt=0):
# compute the weighting factor first
fluct = solver.integrate_grid((tar - inp)**2, dimensionless=True, polar_opt=polar_opt)
weight = fluct / torch.sum(fluct, dim=-1, keepdim=True)
# weight = weight.reshape(*weight.shape, 1, 1)
loss = weight * solver.integrate_grid((prd - tar)**2, dimensionless=True, polar_opt=polar_opt)
if relative:
loss = loss / (weight * solver.integrate_grid(tar**2, dimensionless=True, polar_opt=polar_opt))
loss = torch.mean(loss)
return loss
def main(train=True, load_checkpoint=False, enable_amp=False):
# set seed
torch.manual_seed(333)
torch.cuda.manual_seed(333)
# set device
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
torch.cuda.set_device(device.index)
# dataset
from utils.pde_dataset import PdeDataset
# 1 hour prediction steps
dt = 1*3600
dt_solver = 150
nsteps = dt//dt_solver
dataset = PdeDataset(dt=dt, nsteps=nsteps, dims=(256, 512), device=device, normalize=True)
# There is still an issue with parallel dataloading. Do NOT use it at the moment
# dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, persistent_workers=True)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0, persistent_workers=False)
solver = dataset.solver.to(device)
nlat = dataset.nlat
nlon = dataset.nlon
# training function
def train_model(model, dataloader, optimizer, gscaler, scheduler=None, nepochs=20, nfuture=0, num_examples=256, num_valid=8, loss_fn='l2'):
train_start = time.time()
for epoch in range(nepochs):
# time each epoch
epoch_start = time.time()
dataloader.dataset.set_initial_condition('random')
dataloader.dataset.set_num_examples(num_examples)
# do the training
acc_loss = 0
model.train()
for inp, tar in dataloader:
with amp.autocast(enabled=enable_amp):
prd = model(inp)
for _ in range(nfuture):
prd = model(prd)
if loss_fn == 'l2':
loss = l2loss_sphere(solver, prd, tar, relative=False)
elif loss_fn == 'h1':
loss = h1loss_sphere(solver, prd, tar, relative=False)
elif loss_fn == 'spectral':
loss = spectral_loss_sphere(solver, prd, tar, relative=False)
elif loss_fn == 'fluct':
loss = fluct_l2loss_sphere(solver, prd, tar, inp, relative=True)
else:
raise NotImplementedError(f'Unknown loss function {loss_fn}')
acc_loss += loss.item() * inp.size(0)
optimizer.zero_grad(set_to_none=True)
# gscaler.scale(loss).backward()
gscaler.scale(loss).backward()
gscaler.step(optimizer)
gscaler.update()
acc_loss = acc_loss / len(dataloader.dataset)
dataloader.dataset.set_initial_condition('random')
dataloader.dataset.set_num_examples(num_valid)
# perform validation
valid_loss = 0
model.eval()
with torch.no_grad():
for inp, tar in dataloader:
prd = model(inp)
for _ in range(nfuture):
prd = model(prd)
loss = l2loss_sphere(solver, prd, tar, relative=True)
valid_loss += loss.item() * inp.size(0)
valid_loss = valid_loss / len(dataloader.dataset)
if scheduler is not None:
scheduler.step(valid_loss)
epoch_time = time.time() - epoch_start
print(f'--------------------------------------------------------------------------------')
print(f'Epoch {epoch} summary:')
print(f'time taken: {epoch_time}')
print(f'accumulated training loss: {acc_loss}')
print(f'relative validation loss: {valid_loss}')
if wandb.run is not None:
current_lr = optimizer.param_groups[0]['lr']
wandb.log({"loss": acc_loss, "validation loss": valid_loss, "learning rate": current_lr})
train_time = time.time() - train_start
print(f'--------------------------------------------------------------------------------')
print(f'done. Training took {train_time}.')
return valid_loss
# rolls out the FNO and compares to the classical solver
def autoregressive_inference(model, dataset, path_root, nsteps, autoreg_steps=10, nskip=1, plot_channel=0, nics=20):
model.eval()
losses = np.zeros(nics)
fno_times = np.zeros(nics)
nwp_times = np.zeros(nics)
for iic in range(nics):
ic = dataset.solver.random_initial_condition(mach=0.2)
inp_mean = dataset.inp_mean
inp_var = dataset.inp_var
prd = (dataset.solver.spec2grid(ic) - inp_mean) / torch.sqrt(inp_var)
prd = prd.unsqueeze(0)
uspec = ic.clone()
# ML model
start_time = time.time()
for i in range(1, autoreg_steps+1):
# evaluate the ML model
prd = model(prd)
if iic == nics-1 and nskip > 0 and i % nskip == 0:
# do plotting
fig = plt.figure(figsize=(7.5, 6))
dataset.solver.plot_griddata(prd[0, plot_channel], fig, vmax=4, vmin=-4)
plt.savefig(path_root+'_pred_'+str(i//nskip)+'.png')
plt.clf()
fno_times[iic] = time.time() - start_time
# classical model
start_time = time.time()
for i in range(1, autoreg_steps+1):
# advance classical model
uspec = dataset.solver.timestep(uspec, nsteps)
if iic == nics-1 and i % nskip == 0 and nskip > 0:
ref = (dataset.solver.spec2grid(uspec) - inp_mean) / torch.sqrt(inp_var)
fig = plt.figure(figsize=(7.5, 6))
dataset.solver.plot_griddata(ref[plot_channel], fig, vmax=4, vmin=-4)
plt.savefig(path_root+'_truth_'+str(i//nskip)+'.png')
plt.clf()
nwp_times[iic] = time.time() - start_time
# ref = (dataset.solver.spec2grid(uspec) - inp_mean) / torch.sqrt(inp_var)
ref = dataset.solver.spec2grid(uspec)
prd = prd * torch.sqrt(inp_var) + inp_mean
losses[iic] = l2loss_sphere(solver, prd, ref, relative=True).item()
return losses, fno_times, nwp_times
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# prepare dicts containing models and corresponding metrics
models = {}
metrics = {}
# # U-Net if installed
# from models.unet import UNet
# models['unet_baseline'] = partial(UNet)
# SFNO and FNO models
from models.sfno import SphericalFourierNeuralOperatorNet as SFNO
# SFNO models
models['sfno_sc3_layer4_edim256_linear'] = partial(SFNO, spectral_transform='sht', filter_type='linear', img_size=(nlat, nlon),
num_layers=4, scale_factor=3, embed_dim=256, operator_type='vector')
models['sfno_sc3_layer4_edim256_real'] = partial(SFNO, spectral_transform='sht', filter_type='non-linear', img_size=(nlat, nlon),
num_layers=4, scale_factor=3, embed_dim=256, complex_activation = 'real', operator_type='diagonal')
# FNO models
models['fno_sc3_layer4_edim256_linear'] = partial(SFNO, spectral_transform='fft', filter_type='linear', img_size=(nlat, nlon),
num_layers=4, scale_factor=3, embed_dim=256, operator_type='diagonal')
models['fno_sc3_layer4_edim256_real'] = partial(SFNO, spectral_transform='fft', filter_type='non-linear', img_size=(nlat, nlon),
num_layers=4, scale_factor=3, embed_dim=256, complex_activation='real')
# iterate over models and train each model
root_path = os.path.dirname(__file__)
for model_name, model_handle in models.items():
model = model_handle().to(device)
metrics[model_name] = {}
num_params = count_parameters(model)
print(f'number of trainable params: {num_params}')
metrics[model_name]['num_params'] = num_params
if load_checkpoint:
model.load_state_dict(torch.load(os.path.join(root_path, 'checkpoints/'+model_name)))
# run the training
if train:
run = wandb.init(project="sfno spherical swe", group=model_name, name=model_name + '_' + str(time.time()), config=model_handle.keywords)
# optimizer:
optimizer = torch.optim.Adam(model.parameters(), lr=1E-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
gscaler = amp.GradScaler(enabled=enable_amp)
start_time = time.time()
print(f'Training {model_name}, single step')
train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=200, loss_fn='l2')
# multistep training
print(f'Training {model_name}, two step')
optimizer = torch.optim.Adam(model.parameters(), lr=5E-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
gscaler = amp.GradScaler(enabled=enable_amp)
dataloader.dataset.nsteps = 2 * dt//dt_solver
train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=20, nfuture=1)
dataloader.dataset.nsteps = 1 * dt//dt_solver
training_time = time.time() - start_time
run.finish()
torch.save(model.state_dict(), os.path.join(root_path, 'checkpoints/'+model_name))
# set seed
torch.manual_seed(333)
torch.cuda.manual_seed(333)
with torch.inference_mode():
losses, fno_times, nwp_times = autoregressive_inference(model, dataset, os.path.join(root_path,'paper_figures/'+model_name), nsteps=nsteps, autoreg_steps=10)
metrics[model_name]['loss_mean'] = np.mean(losses)
metrics[model_name]['loss_std'] = np.std(losses)
metrics[model_name]['fno_time_mean'] = np.mean(fno_times)
metrics[model_name]['fno_time_std'] = np.std(fno_times)
metrics[model_name]['nwp_time_mean'] = np.mean(nwp_times)
metrics[model_name]['nwp_time_std'] = np.std(nwp_times)
if train:
metrics[model_name]['training_time'] = training_time
df = pd.DataFrame(metrics)
df.to_pickle(os.path.join(root_path, 'output_data/metrics.pkl'))
if __name__ == "__main__":
import torch.multiprocessing as mp
mp.set_start_method('forkserver', force=True)
main(train=True, load_checkpoint=False, enable_amp=False)
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import torch
import os
from math import ceil
import sys
sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), "torch_harmonics"))
sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), "examples"))
from shallow_water_equations import ShallowWaterSolver
class PdeDataset(torch.utils.data.Dataset):
"""Custom Dataset class for PDE training data"""
def __init__(self, dt, nsteps, dims=(384, 768), pde='shallow water equations', initial_condition='random',
num_examples=32, device=torch.device('cpu'), normalize=True, stream=None):
self.num_examples = num_examples
self.device = device
self.stream = stream
self.nlat = dims[0]
self.nlon = dims[1]
# number of solver steps used to compute the target
self.nsteps = nsteps
self.normalize = normalize
if pde == 'shallow water equations':
lmax = ceil(self.nlat/3)
mmax = lmax
dt_solver = dt / float(self.nsteps)
self.solver = ShallowWaterSolver(self.nlat, self.nlon, dt_solver, lmax=lmax, mmax=mmax, grid='equiangular').to(self.device).float()
else:
raise NotImplementedError
self.set_initial_condition(ictype=initial_condition)
if self.normalize:
inp0, _ = self._get_sample()
self.inp_mean = torch.mean(inp0, dim=(-1, -2)).reshape(-1, 1, 1)
self.inp_var = torch.var(inp0, dim=(-1, -2)).reshape(-1, 1, 1)
def __len__(self):
length = self.num_examples if self.ictype == 'random' else 1
return length
def set_initial_condition(self, ictype='random'):
self.ictype = ictype
def set_num_examples(self, num_examples=32):
self.num_examples = num_examples
def _get_sample(self):
if self.ictype == 'random':
inp = self.solver.random_initial_condition(mach=0.2)
elif self.ictype == 'galewsky':
inp = self.solver.galewsky_initial_condition()
# solve pde for n steps to return the target
tar = self.solver.timestep(inp, self.nsteps)
inp = self.solver.spec2grid(inp)
tar = self.solver.spec2grid(tar)
return inp, tar
def __getitem__(self, index):
# if self.stream is None:
# self.stream = torch.cuda.Stream()
# with torch.cuda.stream(self.stream):
# with torch.inference_mode():
# with torch.no_grad():
# inp, tar = self._get_sample()
# if self.normalize:
# inp = (inp - self.inp_mean) / torch.sqrt(self.inp_var)
# tar = (tar - self.inp_mean) / torch.sqrt(self.inp_var)
# self.stream.synchronize()
with torch.inference_mode():
with torch.no_grad():
inp, tar = self._get_sample()
if self.normalize:
inp = (inp - self.inp_mean) / torch.sqrt(self.inp_var)
tar = (tar - self.inp_mean) / torch.sqrt(self.inp_var)
return inp.clone(), tar.clone()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment