Commit e7ceb9c8 authored by Boris Bonev's avatar Boris Bonev
Browse files

Adding SFNO examples

parent 24490256
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import torch
import torch.nn as nn
# complex activation functions
class ComplexCardioid(nn.Module):
"""
Complex Cardioid activation function
"""
def __init__(self):
super(ComplexCardioid, self).__init__()
def forward(self, z: torch.Tensor) -> torch.Tensor:
out = 0.5 * (1. + torch.cos(z.angle())) * z
return out
class ComplexReLU(nn.Module):
"""
Complex-valued variants of the ReLU activation function
"""
def __init__(self, negative_slope=0., mode="real", bias_shape=None, scale=1.):
super(ComplexReLU, self).__init__()
# store parameters
self.mode = mode
if self.mode in ["modulus", "halfplane"]:
if bias_shape is not None:
self.bias = nn.Parameter(scale * torch.ones(bias_shape, dtype=torch.float32))
else:
self.bias = nn.Parameter(scale * torch.ones((1), dtype=torch.float32))
else:
self.bias = 0
self.negative_slope = negative_slope
self.act = nn.LeakyReLU(negative_slope = negative_slope)
def forward(self, z: torch.Tensor) -> torch.Tensor:
if self.mode == "cartesian":
zr = torch.view_as_real(z)
za = self.act(zr)
out = torch.view_as_complex(za)
elif self.mode == "modulus":
zabs = torch.sqrt(torch.square(z.real) + torch.square(z.imag))
out = torch.where(zabs + self.bias > 0, (zabs + self.bias) * z / zabs, 0.0)
elif self.mode == "cardioid":
out = 0.5 * (1. + torch.cos(z.angle())) * z
# elif self.mode == "halfplane":
# # bias is an angle parameter in this case
# modified_angle = torch.angle(z) - self.bias
# condition = torch.logical_and( (0. <= modified_angle), (modified_angle < torch.pi/2.) )
# out = torch.where(condition, z, self.negative_slope * z)
elif self.mode == "real":
zr = torch.view_as_real(z)
outr = zr.clone()
outr[..., 0] = self.act(zr[..., 0])
out = torch.view_as_complex(outr)
else:
raise NotImplementedError
return out
\ No newline at end of file
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import torch
"""
Contains complex contractions wrapped into jit for harmonic layers
"""
@torch.jit.script
def compl_contract2d_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
tmp = torch.einsum("bixys,kixr->srbkx", a, b)
res = torch.stack([tmp[0,0,...] - tmp[1,1,...], tmp[1,0,...] + tmp[0,1,...]], dim=-1)
return res
@torch.jit.script
def compl_contract2d_fwd_c(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
ac = torch.view_as_complex(a)
bc = torch.view_as_complex(b)
res = torch.einsum("bixy,kix->bkx", ac, bc)
return torch.view_as_real(res)
@torch.jit.script
def compl_contract_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
tmp = torch.einsum("bins,kinr->srbkn", a, b)
res = torch.stack([tmp[0,0,...] - tmp[1,1,...], tmp[1,0,...] + tmp[0,1,...]], dim=-1)
return res
@torch.jit.script
def compl_contract_fwd_c(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
ac = torch.view_as_complex(a)
bc = torch.view_as_complex(b)
res = torch.einsum("bin,kin->bkn", ac, bc)
return torch.view_as_real(res)
# Helper routines for spherical MLPs
@torch.jit.script
def compl_mul1d_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
tmp = torch.einsum("bixs,ior->srbox", a, b)
res = torch.stack([tmp[0,0,...] - tmp[1,1,...], tmp[1,0,...] + tmp[0,1,...]], dim=-1)
return res
@torch.jit.script
def compl_mul1d_fwd_c(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
ac = torch.view_as_complex(a)
bc = torch.view_as_complex(b)
resc = torch.einsum("bix,io->box", ac, bc)
res = torch.view_as_real(resc)
return res
@torch.jit.script
def compl_muladd1d_fwd(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
res = compl_mul1d_fwd(a, b) + c
return res
@torch.jit.script
def compl_muladd1d_fwd_c(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
tmpcc = torch.view_as_complex(compl_mul1d_fwd_c(a, b))
cc = torch.view_as_complex(c)
return torch.view_as_real(tmpcc + cc)
# Helper routines for FFT MLPs
@torch.jit.script
def compl_mul2d_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
tmp = torch.einsum("bixys,ior->srboxy", a, b)
res = torch.stack([tmp[0,0,...] - tmp[1,1,...], tmp[1,0,...] + tmp[0,1,...]], dim=-1)
return res
@torch.jit.script
def compl_mul2d_fwd_c(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
ac = torch.view_as_complex(a)
bc = torch.view_as_complex(b)
resc = torch.einsum("bixy,io->boxy", ac, bc)
res = torch.view_as_real(resc)
return res
@torch.jit.script
def compl_muladd2d_fwd(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
res = compl_mul2d_fwd(a, b) + c
return res
@torch.jit.script
def compl_muladd2d_fwd_c(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
tmpcc = torch.view_as_complex(compl_mul2d_fwd_c(a, b))
cc = torch.view_as_complex(c)
return torch.view_as_real(tmpcc + cc)
@torch.jit.script
def real_mul2d_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
out = torch.einsum("bixy,io->boxy", a, b)
return out
@torch.jit.script
def real_muladd2d_fwd(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
return compl_mul2d_fwd_c(a, b) + c
# for all the experimental layers
# @torch.jit.script
# def compl_exp_mul2d_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
# ac = torch.view_as_complex(a)
# bc = torch.view_as_complex(b)
# resc = torch.einsum("bixy,xio->boxy", ac, bc)
# res = torch.view_as_real(resc)
# return res
# @torch.jit.script
# def compl_exp_muladd2d_fwd(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
# tmpcc = torch.view_as_complex(compl_exp_mul2d_fwd(a, b))
# cc = torch.view_as_complex(c)
# return torch.view_as_real(tmpcc + cc)
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import torch
import tensorly as tl
tl.set_backend('pytorch')
from tltorch.factorized_tensors.core import FactorizedTensor
einsum_symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
def _contract_dense(x, weight, separable=False, operator_type='diagonal'):
order = tl.ndim(x)
# batch-size, in_channels, x, y...
x_syms = list(einsum_symbols[:order])
# in_channels, out_channels, x, y...
weight_syms = list(x_syms[1:]) # no batch-size
# batch-size, out_channels, x, y...
if separable:
out_syms = [x_syms[0]] + list(weight_syms)
else:
weight_syms.insert(1, einsum_symbols[order]) # outputs
out_syms = list(weight_syms)
out_syms[0] = x_syms[0]
if operator_type == 'diagonal':
pass
elif operator_type == 'block-diagonal':
weight_syms.insert(-1, einsum_symbols[order+1])
out_syms[-1] = weight_syms[-2]
elif operator_type == 'vector':
weight_syms.pop()
else:
raise ValueError(f"Unkonw operator type {operator_type}")
eq= ''.join(x_syms) + ',' + ''.join(weight_syms) + '->' + ''.join(out_syms)
if not torch.is_tensor(weight):
weight = weight.to_tensor()
return tl.einsum(eq, x, weight)
def _contract_cp(x, cp_weight, separable=False, operator_type='diagonal'):
order = tl.ndim(x)
x_syms = str(einsum_symbols[:order])
rank_sym = einsum_symbols[order]
out_sym = einsum_symbols[order+1]
out_syms = list(x_syms)
if separable:
factor_syms = [einsum_symbols[1]+rank_sym] #in only
else:
out_syms[1] = out_sym
factor_syms = [einsum_symbols[1]+rank_sym, out_sym+rank_sym] #in, out
factor_syms += [xs+rank_sym for xs in x_syms[2:]] #x, y, ...
if operator_type == 'diagonal':
pass
elif operator_type == 'block-diagonal':
out_syms[-1] = einsum_symbols[order+2]
factor_syms += [out_syms[-1] + rank_sym]
elif operator_type == 'vector':
factor_syms.pop()
else:
raise ValueError(f"Unkonw operator type {operator_type}")
eq = x_syms + ',' + rank_sym + ',' + ','.join(factor_syms) + '->' + ''.join(out_syms)
return tl.einsum(eq, x, cp_weight.weights, *cp_weight.factors)
def _contract_tucker(x, tucker_weight, separable=False, operator_type='diagonal'):
order = tl.ndim(x)
x_syms = str(einsum_symbols[:order])
out_sym = einsum_symbols[order]
out_syms = list(x_syms)
if separable:
core_syms = einsum_symbols[order+1:2*order]
# factor_syms = [einsum_symbols[1]+core_syms[0]] #in only
factor_syms = [xs+rs for (xs, rs) in zip(x_syms[1:], core_syms)] #x, y, ...
else:
core_syms = einsum_symbols[order+1:2*order+1]
out_syms[1] = out_sym
factor_syms = [einsum_symbols[1]+core_syms[0], out_sym+core_syms[1]] #out, in
factor_syms += [xs+rs for (xs, rs) in zip(x_syms[2:], core_syms[2:])] #x, y, ...
if operator_type == 'diagonal':
pass
elif operator_type == 'block-diagonal':
raise NotImplementedError(f"Operator type {operator_type} not implemented for Tucker")
else:
raise ValueError(f"Unkonw operator type {operator_type}")
eq = x_syms + ',' + core_syms + ',' + ','.join(factor_syms) + '->' + ''.join(out_syms)
return tl.einsum(eq, x, tucker_weight.core, *tucker_weight.factors)
def _contract_tt(x, tt_weight, separable=False, operator_type='diagonal'):
order = tl.ndim(x)
x_syms = list(einsum_symbols[:order])
weight_syms = list(x_syms[1:]) # no batch-size
if not separable:
weight_syms.insert(1, einsum_symbols[order]) # outputs
out_syms = list(weight_syms)
out_syms[0] = x_syms[0]
else:
out_syms = list(x_syms)
if operator_type == 'diagonal':
pass
elif operator_type == 'block-diagonal':
weight_syms.insert(-1, einsum_symbols[order+1])
out_syms[-1] = weight_syms[-2]
elif operator_type == 'vector':
weight_syms.pop()
else:
raise ValueError(f"Unkonw operator type {operator_type}")
rank_syms = list(einsum_symbols[order+2:])
tt_syms = []
for i, s in enumerate(weight_syms):
tt_syms.append([rank_syms[i], s, rank_syms[i+1]])
eq = ''.join(x_syms) + ',' + ','.join(''.join(f) for f in tt_syms) + '->' + ''.join(out_syms)
return tl.einsum(eq, x, *tt_weight.factors)
def get_contract_fun(weight, implementation='reconstructed', separable=False):
"""Generic ND implementation of Fourier Spectral Conv contraction
Parameters
----------
weight : tensorly-torch's FactorizedTensor
implementation : {'reconstructed', 'factorized'}, default is 'reconstructed'
whether to reconstruct the weight and do a forward pass (reconstructed)
or contract directly the factors of the factorized weight with the input (factorized)
Returns
-------
function : (x, weight) -> x * weight in Fourier space
"""
if implementation == 'reconstructed':
return _contract_dense
elif implementation == 'factorized':
if torch.is_tensor(weight):
return _contract_dense
elif isinstance(weight, FactorizedTensor):
if weight.name.lower() == 'complexdense':
return _contract_dense
elif weight.name.lower() == 'complextucker':
return _contract_tucker
elif weight.name.lower() == 'complextt':
return _contract_tt
elif weight.name.lower() == 'complexcp':
return _contract_cp
else:
raise ValueError(f'Got unexpected factorized weight type {weight.name}')
else:
raise ValueError(f'Got unexpected weight type of class {weight.__class__.__name__}')
else:
raise ValueError(f'Got {implementation=}, expected "reconstructed" or "factorized"')
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
from functools import partial
from collections import OrderedDict
from copy import Error, deepcopy
from re import S
from numpy.lib.arraypad import pad
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.fft
from torch.nn.modules.container import Sequential
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
from torch.cuda import amp
from typing import Optional
import math
from torch_harmonics import *
from models.contractions import *
from models.activations import *
from models.factorizations import get_contract_fun
# # import FactorizedTensor from tensorly for tensorized operations
# import tensorly as tl
# from tensorly.plugins import use_opt_einsum
# tl.set_backend('pytorch')
# use_opt_einsum('optimal')
from tltorch.factorized_tensors.core import FactorizedTensor
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
@torch.jit.script
def drop_path(x: torch.Tensor, drop_prob: float = 0., training: bool = False) -> torch.Tensor:
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1. - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2d ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class MLP(nn.Module):
def __init__(self,
in_features,
hidden_features = None,
out_features = None,
act_layer = nn.GELU,
output_bias = True,
drop_rate = 0.,
checkpointing = False):
super(MLP, self).__init__()
self.checkpointing = checkpointing
out_features = out_features or in_features
hidden_features = hidden_features or in_features
fc1 = nn.Conv2d(in_features, hidden_features, 1, bias=True)
# ln1 = norm_layer(num_features=hidden_features)
act = act_layer()
fc2 = nn.Conv2d(hidden_features, out_features, 1, bias = output_bias)
if drop_rate > 0.:
drop = nn.Dropout(drop_rate)
self.fwd = nn.Sequential(fc1, act, drop, fc2, drop)
else:
self.fwd = nn.Sequential(fc1, act, fc2)
@torch.jit.ignore
def checkpoint_forward(self, x):
return checkpoint(self.fwd, x)
def forward(self, x):
if self.checkpointing:
return self.checkpoint_forward(x)
else:
return self.fwd(x)
class RealFFT2(nn.Module):
"""
Helper routine to wrap FFT similarly to the SHT
"""
def __init__(self,
nlat,
nlon,
lmax = None,
mmax = None):
super(RealFFT2, self).__init__()
self.nlat = nlat
self.nlon = nlon
self.lmax = lmax or self.nlat
self.mmax = mmax or self.nlon // 2 + 1
def forward(self, x):
y = torch.fft.rfft2(x, dim=(-2, -1), norm="ortho")
y = torch.cat((y[..., :math.ceil(self.lmax/2), :self.mmax], y[..., -math.floor(self.lmax/2):, :self.mmax]), dim=-2)
return y
class InverseRealFFT2(nn.Module):
"""
Helper routine to wrap FFT similarly to the SHT
"""
def __init__(self,
nlat,
nlon,
lmax = None,
mmax = None):
super(InverseRealFFT2, self).__init__()
self.nlat = nlat
self.nlon = nlon
self.lmax = lmax or self.nlat
self.mmax = mmax or self.nlon // 2 + 1
def forward(self, x):
return torch.fft.irfft2(x, dim=(-2, -1), s=(self.nlat, self.nlon), norm="ortho")
class SpectralConvS2(nn.Module):
"""
Spectral Convolution according to Driscoll & Healy. Designed for convolutions on the two-sphere S2
using the Spherical Harmonic Transforms in torch-harmonics, but supports convolutions on the periodic
domain via the RealFFT2 and InverseRealFFT2 wrappers.
"""
def __init__(self,
forward_transform,
inverse_transform,
in_channels,
out_channels,
scale = 'auto',
operator_type = 'diagonal',
rank = 0.2,
factorization = None,
separable = False,
implementation = 'factorized',
decomposition_kwargs=dict(),
bias = False):
super(SpectralConvS2, self).__init__()
if scale == 'auto':
scale = (1 / (in_channels * out_channels))
self.forward_transform = forward_transform
self.inverse_transform = inverse_transform
self.modes_lat = self.inverse_transform.lmax
self.modes_lon = self.inverse_transform.mmax
self.scale_residual = (self.forward_transform.nlat != self.inverse_transform.nlat) \
or (self.forward_transform.nlon != self.inverse_transform.nlon)
# Make sure we are using a Complex Factorized Tensor
if factorization is None:
factorization = 'Dense' # No factorization
if not factorization.lower().startswith('complex'):
factorization = f'Complex{factorization}'
# remember factorization details
self.operator_type = operator_type
self.rank = rank
self.factorization = factorization
self.separable = separable
assert self.inverse_transform.lmax == self.modes_lat
assert self.inverse_transform.mmax == self.modes_lon
weight_shape = [in_channels]
if not self.separable:
weight_shape += [out_channels]
if self.operator_type == 'diagonal':
weight_shape += [self.modes_lat, self.modes_lon]
elif self.operator_type == 'block-diagonal':
weight_shape += [self.modes_lat, self.modes_lon, self.modes_lon]
elif self.operator_type == 'vector':
weight_shape += [self.modes_lat]
else:
raise NotImplementedError(f"Unkonw operator type f{self.operator_type}")
# form weight tensors
self.weight = FactorizedTensor.new(weight_shape, rank=self.rank, factorization=factorization,
fixed_rank_modes=False, **decomposition_kwargs)
# initialization of weights
self.weight.normal_(0, scale)
self._contract = get_contract_fun(self.weight, implementation=implementation, separable=separable)
if bias:
self.bias = nn.Parameter(scale * torch.randn(1, out_channels, 1, 1))
def forward(self, x):
dtype = x.dtype
x = x.float()
residual = x
B, C, H, W = x.shape
with amp.autocast(enabled=False):
x = self.forward_transform(x)
if self.scale_residual:
residual = self.inverse_transform(x)
x = self._contract(x, self.weight, separable=self.separable, operator_type=self.operator_type)
with amp.autocast(enabled=False):
x = self.inverse_transform(x)
if hasattr(self, 'bias'):
x = x + self.bias
x = x.type(dtype)
return x, residual
class SpectralAttention2d(nn.Module):
"""
geometrical Spectral Attention layer
"""
def __init__(self,
forward_transform,
inverse_transform,
embed_dim,
sparsity_threshold = 0.0,
hidden_size_factor = 2,
use_complex_kernels = False,
complex_activation = 'real',
bias = False,
spectral_layers = 1,
drop_rate = 0.):
super(SpectralAttention2d, self).__init__()
self.embed_dim = embed_dim
self.sparsity_threshold = sparsity_threshold
self.hidden_size = int(hidden_size_factor * self.embed_dim)
self.scale = 1 / embed_dim**2
self.mul_add_handle = compl_muladd2d_fwd_c if use_complex_kernels else compl_muladd2d_fwd
self.mul_handle = compl_mul2d_fwd_c if use_complex_kernels else compl_mul2d_fwd
self.spectral_layers = spectral_layers
self.modes_lat = forward_transform.lmax
self.modes_lon = forward_transform.mmax
# only storing the forward handle to be able to call it
self.forward_transform = forward_transform
self.inverse_transform = inverse_transform
self.scale_residual = (self.forward_transform.nlat != self.inverse_transform.nlat) \
or (self.forward_transform.nlon != self.inverse_transform.nlon)
assert inverse_transform.lmax == self.modes_lat
assert inverse_transform.mmax == self.modes_lon
# weights
w = [self.scale * torch.randn(self.embed_dim, self.hidden_size, 2)]
for l in range(1, self.spectral_layers):
w.append(self.scale * torch.randn(self.hidden_size, self.hidden_size, 2))
self.w = nn.ParameterList(w)
if bias:
self.b = nn.ParameterList([self.scale * torch.randn(self.hidden_size, 1, 2) for _ in range(self.spectral_layers)])
self.wout = nn.Parameter(self.scale * torch.randn(self.hidden_size, self.embed_dim, 2))
self.drop = nn.Dropout(drop_rate) if drop_rate > 0. else nn.Identity()
self.activations = nn.ModuleList([])
for l in range(0, self.spectral_layers):
self.activations.append(ComplexReLU(mode=complex_activation, bias_shape=(self.hidden_size, 1, 1), scale=self.scale))
def forward_mlp(self, x):
x = torch.view_as_real(x)
xr = x
for l in range(self.spectral_layers):
if hasattr(self, 'b'):
xr = self.mul_add_handle(xr, self.w[l], self.b[l])
else:
xr = self.mul_handle(xr, self.w[l])
xr = torch.view_as_complex(xr)
xr = self.activations[l](xr)
xr = self.drop(xr)
xr = torch.view_as_real(xr)
x = self.mul_handle(xr, self.wout)
x = torch.view_as_complex(x)
return x
def forward(self, x):
dtype = x.dtype
x = x.float()
residual = x
with amp.autocast(enabled=False):
x = self.forward_transform(x)
if self.scale_residual:
residual = self.inverse_transform(x)
x = self.forward_mlp(x)
with amp.autocast(enabled=False):
x = self.inverse_transform(x)
x = x.type(dtype)
return x, residual
class SpectralAttentionS2(nn.Module):
"""
Spherical non-linear FNO layer
"""
def __init__(self,
forward_transform,
inverse_transform,
embed_dim,
operator_type = 'diagonal',
sparsity_threshold = 0.0,
hidden_size_factor = 2,
complex_activation = 'real',
scale = 'auto',
bias = False,
spectral_layers = 1,
drop_rate = 0.):
super(SpectralAttentionS2, self).__init__()
self.embed_dim = embed_dim
self.sparsity_threshold = sparsity_threshold
self.operator_type = operator_type
self.spectral_layers = spectral_layers
if scale == 'auto':
self.scale = (1 / (embed_dim * embed_dim))
self.modes_lat = forward_transform.lmax
self.modes_lon = forward_transform.mmax
# only storing the forward handle to be able to call it
self.forward_transform = forward_transform
self.inverse_transform = inverse_transform
self.scale_residual = (self.forward_transform.nlat != self.inverse_transform.nlat) \
or (self.forward_transform.nlon != self.inverse_transform.nlon)
assert inverse_transform.lmax == self.modes_lat
assert inverse_transform.mmax == self.modes_lon
hidden_size = int(hidden_size_factor * self.embed_dim)
if operator_type == 'diagonal':
self.mul_add_handle = compl_muladd2d_fwd
self.mul_handle = compl_mul2d_fwd
# weights
w = [self.scale * torch.randn(self.embed_dim, hidden_size, 2)]
for l in range(1, self.spectral_layers):
w.append(self.scale * torch.randn(hidden_size, hidden_size, 2))
self.w = nn.ParameterList(w)
self.wout = nn.Parameter(self.scale * torch.randn(hidden_size, self.embed_dim, 2))
if bias:
self.b = nn.ParameterList([self.scale * torch.randn(hidden_size, 1, 1, 2) for _ in range(self.spectral_layers)])
self.activations = nn.ModuleList([])
for l in range(0, self.spectral_layers):
self.activations.append(ComplexReLU(mode=complex_activation, bias_shape=(hidden_size, 1, 1), scale=self.scale))
elif operator_type == 'vector':
self.mul_add_handle = compl_exp_muladd2d_fwd
self.mul_handle = compl_exp_mul2d_fwd
# weights
w = [self.scale * torch.randn(self.modes_lat, self.embed_dim, hidden_size, 2)]
for l in range(1, self.spectral_layers):
w.append(self.scale * torch.randn(self.modes_lat, hidden_size, hidden_size, 2))
self.w = nn.ParameterList(w)
if bias:
self.b = nn.ParameterList([self.scale * torch.randn(hidden_size, 1, 1, 2) for _ in range(self.spectral_layers)])
self.wout = nn.Parameter(self.scale * torch.randn(self.modes_lat, hidden_size, self.embed_dim, 2))
self.activations = nn.ModuleList([])
for l in range(0, self.spectral_layers):
self.activations.append(ComplexReLU(mode=complex_activation, bias_shape=(hidden_size, 1, 1), scale=self.scale))
else:
raise ValueError('Unknown operator type')
self.drop = nn.Dropout(drop_rate) if drop_rate > 0. else nn.Identity()
def forward_mlp(self, x):
B, C, H, W = x.shape
xr = torch.view_as_real(x)
for l in range(self.spectral_layers):
if hasattr(self, 'b'):
xr = self.mul_add_handle(xr, self.w[l], self.b[l])
else:
xr = self.mul_handle(xr, self.w[l])
xr = torch.view_as_complex(xr)
xr = self.activations[l](xr)
xr = self.drop(xr)
xr = torch.view_as_real(xr)
# final MLP
x = self.mul_handle(xr, self.wout)
x = torch.view_as_complex(x)
return x
def forward(self, x):
dtype = x.dtype
x = x.to(torch.float32)
residual = x
# FWD transform
with amp.autocast(enabled=False):
x = self.forward_transform(x)
if self.scale_residual:
residual = self.inverse_transform(x)
# MLP
x = self.forward_mlp(x)
# BWD transform
with amp.autocast(enabled=False):
x = self.inverse_transform(x)
# cast back to initial precision
x = x.to(dtype)
return x, residual
\ No newline at end of file
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import torch
import torch.nn as nn
from apex.normalization import FusedLayerNorm
from torch_harmonics import *
from models.layers import *
class SpectralFilterLayer(nn.Module):
"""
Fourier layer. Contains the convolution part of the FNO/SFNO
"""
def __init__(
self,
forward_transform,
inverse_transform,
embed_dim,
filter_type = 'non-linear',
operator_type = 'diagonal',
sparsity_threshold = 0.0,
use_complex_kernels = True,
hidden_size_factor = 2,
factorization = None,
separable = False,
rank = 1e-2,
complex_activation = 'real',
spectral_layers = 1,
drop_rate = 0):
super(SpectralFilterLayer, self).__init__()
if filter_type == 'non-linear' and isinstance(forward_transform, RealSHT):
self.filter = SpectralAttentionS2(forward_transform,
inverse_transform,
embed_dim,
operator_type = operator_type,
sparsity_threshold = sparsity_threshold,
hidden_size_factor = hidden_size_factor,
complex_activation = complex_activation,
spectral_layers = spectral_layers,
drop_rate = drop_rate,
bias = False)
elif filter_type == 'non-linear' and isinstance(forward_transform, RealFFT2):
self.filter = SpectralAttention2d(forward_transform,
inverse_transform,
embed_dim,
sparsity_threshold = sparsity_threshold,
use_complex_kernels = use_complex_kernels,
hidden_size_factor = hidden_size_factor,
complex_activation = complex_activation,
spectral_layers = spectral_layers,
drop_rate = drop_rate,
bias = False)
elif filter_type == 'linear':
self.filter = SpectralConvS2(forward_transform,
inverse_transform,
embed_dim,
embed_dim,
operator_type = operator_type,
rank = rank,
factorization = factorization,
separable = separable,
bias = True)
else:
raise(NotImplementedError)
def forward(self, x):
return self.filter(x)
class SphericalFourierNeuralOperatorBlock(nn.Module):
"""
Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
"""
def __init__(
self,
forward_transform,
inverse_transform,
embed_dim,
filter_type = 'non-linear',
operator_type = 'diagonal',
mlp_ratio = 2.,
drop_rate = 0.,
drop_path = 0.,
act_layer = nn.GELU,
norm_layer = (nn.LayerNorm, nn.LayerNorm),
sparsity_threshold = 0.0,
use_complex_kernels = True,
factorization = None,
separable = False,
rank = 128,
inner_skip = 'linear',
outer_skip = None, # None, nn.linear or nn.Identity
concat_skip = False,
use_mlp = True,
complex_activation = 'real',
spectral_layers = 3):
super(SphericalFourierNeuralOperatorBlock, self).__init__()
# norm layer
self.norm0 = norm_layer[0]() #((h,w))
# convolution layer
self.filter = SpectralFilterLayer(forward_transform,
inverse_transform,
embed_dim,
filter_type,
operator_type = operator_type,
sparsity_threshold = sparsity_threshold,
use_complex_kernels = use_complex_kernels,
hidden_size_factor = mlp_ratio,
factorization = factorization,
separable = separable,
rank = rank,
complex_activation = complex_activation,
spectral_layers = spectral_layers,
drop_rate = drop_rate)
if inner_skip == 'linear':
self.inner_skip = nn.Conv2d(embed_dim, embed_dim, 1, 1)
elif inner_skip == 'identity':
self.inner_skip = nn.Identity()
self.concat_skip = concat_skip
if concat_skip and inner_skip is not None:
self.inner_skip_conv = nn.Conv2d(2*embed_dim, embed_dim, 1, bias=False)
if filter_type == 'linear' or filter_type == 'local':
self.act_layer = act_layer()
# dropout
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
# norm layer
self.norm1 = norm_layer[1]() #((h,w))
if use_mlp == True:
mlp_hidden_dim = int(embed_dim * mlp_ratio)
self.mlp = MLP(in_features = embed_dim,
hidden_features = mlp_hidden_dim,
act_layer = act_layer,
drop_rate = drop_rate,
checkpointing = False)
if outer_skip == 'linear':
self.outer_skip = nn.Conv2d(embed_dim, embed_dim, 1, 1)
elif outer_skip == 'identity':
self.outer_skip = nn.Identity()
if concat_skip and outer_skip is not None:
self.outer_skip_conv = nn.Conv2d(2*embed_dim, embed_dim, 1, bias=False)
def forward(self, x):
x = self.norm0(x)
x, residual = self.filter(x)
if hasattr(self, 'inner_skip'):
if self.concat_skip:
x = torch.cat((x, self.inner_skip(residual)), dim=1)
x = self.inner_skip_conv(x)
else:
x = x + self.inner_skip(residual)
if hasattr(self, 'act_layer'):
x = self.act_layer(x)
x = self.norm1(x)
if hasattr(self, 'mlp'):
x = self.mlp(x)
x = self.drop_path(x)
if hasattr(self, 'outer_skip'):
if self.concat_skip:
x = torch.cat((x, self.outer_skip(residual)), dim=1)
x = self.outer_skip_conv(x)
else:
x = x + self.outer_skip(residual)
return x
class SphericalFourierNeuralOperatorNet(nn.Module):
"""
SphericalFourierNeuralOperator module. Can use both FFTs and SHTs to represent either FNO or SFNO,
both linear and non-linear variants.
Parameters
----------
filter_type : str, optional
Type of filter to use ('linear', 'non-linear'), by default "linear"
spectral_transform : str, optional
Type of spectral transformation to use, by default "sht"
operator_type : str, optional
Type of operator to use ('vector', 'diagonal'), by default "vector"
img_shape : tuple, optional
Shape of the input channels, by default (128, 256)
scale_factor : int, optional
Scale factor to use, by default 3
in_chans : int, optional
Number of input channels, by default 3
out_chans : int, optional
Number of output channels, by default 3
embed_dim : int, optional
Dimension of the embeddings, by default 256
num_layers : int, optional
Number of layers in the network, by default 4
activation_function : str, optional
Activation function to use, by default "gelu"
encoder_layers : int, optional
Number of layers in the encoder, by default 1
use_mlp : int, optional
Whether to use MLP, by default True
mlp_ratio : int, optional
Ratio of MLP to use, by default 2.0
drop_rate : float, optional
Dropout rate, by default 0.0
drop_path_rate : float, optional
Dropout path rate, by default 0.0
sparsity_threshold : float, optional
Threshold for sparsity, by default 0.0
normalization_layer : str, optional
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
hard_thresholding_fraction : float, optional
Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0
use_complex_kernels : bool, optional
Whether to use complex kernels, by default True
big_skip : bool, optional
Whether to add a single large skip connection, by default True
rank : float, optional
Rank of the approximation, by default 1.0
factorization : Any, optional
Type of factorization to use, by default None
separable : bool, optional
Whether to use separable convolutions, by default False
rank : (int, Tuple[int]), optional
If a factorization is used, which rank to use. Argument is passed to tensorly
complex_activation : str, optional
Type of complex activation function to use, by default "real"
spectral_layers : int, optional
Number of spectral layers, by default 3
pos_embed : bool, optional
Whether to use positional embedding, by default True
Example:
--------
>>> model = SphericalFourierNeuralOperatorNet(
... img_shape=(128, 256),
... scale_factor=4,
... in_chans=2,
... out_chans=2,
... embed_dim=16,
... num_layers=2,
... encoder_layers=1,
... num_blocks=4,
... spectral_layers=2,
... use_mlp=True,)
>>> model(torch.randn(1, 2, 128, 256)).shape
torch.Size([1, 2, 128, 256])
"""
def __init__(
self,
filter_type = 'linear',
spectral_transform = 'sht',
operator_type = 'vector',
img_size = (128, 256),
scale_factor = 3,
in_chans = 3,
out_chans = 3,
embed_dim = 256,
num_layers = 4,
activation_function = 'gelu',
encoder_layers = 1,
use_mlp = True,
mlp_ratio = 2.,
drop_rate = 0.,
drop_path_rate = 0.,
sparsity_threshold = 0.0,
normalization_layer = 'instance_norm',
hard_thresholding_fraction = 1.0,
use_complex_kernels = True,
big_skip = True,
factorization = None,
separable = False,
rank = 128,
complex_activation = 'real',
spectral_layers = 2,
pos_embed = True):
super(SphericalFourierNeuralOperatorNet, self).__init__()
self.filter_type = filter_type
self.spectral_transform = spectral_transform
self.operator_type = operator_type
self.img_size = img_size
self.scale_factor = scale_factor
self.in_chans = in_chans
self.out_chans = out_chans
self.embed_dim = self.num_features = embed_dim
self.pos_embed_dim = self.embed_dim
self.num_layers = num_layers
self.hard_thresholding_fraction = hard_thresholding_fraction
self.normalization_layer = normalization_layer
self.use_mlp = use_mlp
self.encoder_layers = encoder_layers
self.big_skip = big_skip
self.factorization = factorization
self.separable = separable,
self.rank = rank
self.complex_activation = complex_activation
self.spectral_layers = spectral_layers
# activation function
if activation_function == 'relu':
self.activation_function = nn.ReLU
elif activation_function == 'gelu':
self.activation_function = nn.GELU
else:
raise ValueError(f"Unknown activation function {activation_function}")
# compute downsampled image size
self.h = self.img_size[0] // scale_factor
self.w = self.img_size[1] // scale_factor
# dropout
self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0. else nn.Identity()
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_layers)]
# pick norm layer
if self.normalization_layer == "layer_norm":
norm_layer0 = partial(nn.LayerNorm, normalized_shape=(self.img_size[0], self.img_size[1]), eps=1e-6)
norm_layer1 = partial(nn.LayerNorm, normalized_shape=(self.h, self.w), eps=1e-6)
elif self.normalization_layer == "instance_norm":
norm_layer0 = partial(nn.InstanceNorm2d, num_features=self.embed_dim, eps=1e-6, affine=True, track_running_stats=False)
norm_layer1 = norm_layer0
elif self.normalization_layer == "none":
norm_layer0 = nn.Identity
norm_layer1 = norm_layer0
else:
raise NotImplementedError(f"Error, normalization {self.normalization_layer} not implemented.")
if pos_embed:
self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, self.img_size[0], self.img_size[1]))
else:
self.pos_embed = None
# encoder
encoder_hidden_dim = self.embed_dim
current_dim = self.in_chans
encoder_modules = []
for i in range(self.encoder_layers):
encoder_modules.append(nn.Conv2d(current_dim, encoder_hidden_dim, 1, bias=True))
encoder_modules.append(self.activation_function())
current_dim = encoder_hidden_dim
encoder_modules.append(nn.Conv2d(current_dim, self.embed_dim, 1, bias=False))
self.encoder = nn.Sequential(*encoder_modules)
# prepare the spectral transform
if self.spectral_transform == 'sht':
modes_lat = int(self.h * self.hard_thresholding_fraction)
modes_lon = int((self.w // 2 + 1) * self.hard_thresholding_fraction)
self.trans_down = RealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid='equiangular').float()
self.itrans_up = InverseRealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid='equiangular').float()
self.trans = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid='legendre-gauss').float()
self.itrans = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid='legendre-gauss').float()
elif self.spectral_transform == 'fft':
modes_lat = int(self.h * self.hard_thresholding_fraction)
modes_lon = int((self.w // 2 + 1) * self.hard_thresholding_fraction)
self.trans_down = RealFFT2(*self.img_size, lmax=modes_lat, mmax=modes_lon).float()
self.itrans_up = InverseRealFFT2(*self.img_size, lmax=modes_lat, mmax=modes_lon).float()
self.trans = RealFFT2(self.h, self.w, lmax=modes_lat, mmax=modes_lon).float()
self.itrans = InverseRealFFT2(self.h, self.w, lmax=modes_lat, mmax=modes_lon).float()
else:
raise(ValueError('Unknown spectral transform'))
self.blocks = nn.ModuleList([])
for i in range(self.num_layers):
first_layer = i == 0
last_layer = i == self.num_layers-1
forward_transform = self.trans_down if first_layer else self.trans
inverse_transform = self.itrans_up if last_layer else self.itrans
inner_skip = 'linear'
outer_skip = 'identity'
if first_layer:
norm_layer = (norm_layer0, norm_layer1)
elif last_layer:
norm_layer = (norm_layer1, norm_layer0)
else:
norm_layer = (norm_layer1, norm_layer1)
block = SphericalFourierNeuralOperatorBlock(forward_transform,
inverse_transform,
self.embed_dim,
filter_type = filter_type,
operator_type = self.operator_type,
mlp_ratio = mlp_ratio,
drop_rate = drop_rate,
drop_path = dpr[i],
act_layer = self.activation_function,
norm_layer = norm_layer,
sparsity_threshold = sparsity_threshold,
use_complex_kernels = use_complex_kernels,
inner_skip = inner_skip,
outer_skip = outer_skip,
use_mlp = use_mlp,
factorization = self.factorization,
separable = self.separable,
rank = self.rank,
complex_activation = self.complex_activation,
spectral_layers = self.spectral_layers)
self.blocks.append(block)
# decoder
decoder_hidden_dim = self.embed_dim
current_dim = self.embed_dim + self.big_skip*self.in_chans
decoder_modules = []
for i in range(self.encoder_layers):
decoder_modules.append(nn.Conv2d(current_dim, decoder_hidden_dim, 1, bias=True))
decoder_modules.append(self.activation_function())
current_dim = decoder_hidden_dim
decoder_modules.append(nn.Conv2d(current_dim, self.out_chans, 1, bias=False))
self.decoder = nn.Sequential(*decoder_modules)
# trunc_normal_(self.pos_embed, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
trunc_normal_(m.weight, std=.02)
#nn.init.normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm) or isinstance(m, FusedLayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def forward_features(self, x):
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
return x
def forward(self, x):
if self.big_skip:
residual = x
x = self.encoder(x)
if self.pos_embed is not None:
x = x + self.pos_embed
x = self.forward_features(x)
if self.big_skip:
x = torch.cat((x, residual), dim=1)
x = self.decoder(x)
return x
This source diff could not be displayed because it is too large. You can view the blob instead.
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import os
import time
from tqdm import tqdm
from functools import partial
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.cuda import amp
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# wandb logging
import wandb
wandb.login()
import sys
sys.path.append(os.path.join(os.path.dirname( __file__), "../"))
from pde_sphere import SphereSolver
def l2loss_sphere(solver, prd, tar, relative=False, squared=False):
loss = solver.integrate_grid((prd - tar)**2, dimensionless=True).sum(dim=-1)
if relative:
loss = loss / solver.integrate_grid(tar**2, dimensionless=True).sum(dim=-1)
if not squared:
loss = torch.sqrt(loss)
loss = loss.mean()
return loss
def spectral_l2loss_sphere(solver, prd, tar, relative=False, squared=False):
# compute coefficients
coeffs = torch.view_as_real(solver.sht(prd - tar))
coeffs = coeffs[..., 0]**2 + coeffs[..., 1]**2
norm2 = coeffs[..., :, 0] + 2 * torch.sum(coeffs[..., :, 1:], dim=-1)
loss = torch.sum(norm2, dim=(-1,-2))
if relative:
tar_coeffs = torch.view_as_real(solver.sht(tar))
tar_coeffs = tar_coeffs[..., 0]**2 + tar_coeffs[..., 1]**2
tar_norm2 = tar_coeffs[..., :, 0] + 2 * torch.sum(tar_coeffs[..., :, 1:], dim=-1)
tar_norm2 = torch.sum(tar_norm2, dim=(-1,-2))
loss = loss / tar_norm2
if not squared:
loss = torch.sqrt(loss)
loss = loss.mean()
return loss
def spectral_loss_sphere(solver, prd, tar, relative=False, squared=False):
# gradient weighting factors
lmax = solver.sht.lmax
ls = torch.arange(lmax).float()
spectral_weights = (ls*(ls + 1)).reshape(1, 1, -1, 1).to(prd.device)
# compute coefficients
coeffs = torch.view_as_real(solver.sht(prd - tar))
coeffs = coeffs[..., 0]**2 + coeffs[..., 1]**2
coeffs = spectral_weights * coeffs
norm2 = coeffs[..., :, 0] + 2 * torch.sum(coeffs[..., :, 1:], dim=-1)
loss = torch.sum(norm2, dim=(-1,-2))
if relative:
tar_coeffs = torch.view_as_real(solver.sht(tar))
tar_coeffs = tar_coeffs[..., 0]**2 + tar_coeffs[..., 1]**2
tar_coeffs = spectral_weights * tar_coeffs
tar_norm2 = tar_coeffs[..., :, 0] + 2 * torch.sum(tar_coeffs[..., :, 1:], dim=-1)
tar_norm2 = torch.sum(tar_norm2, dim=(-1,-2))
loss = loss / tar_norm2
if not squared:
loss = torch.sqrt(loss)
loss = loss.mean()
return loss
def h1loss_sphere(solver, prd, tar, relative=False, squared=False):
# gradient weighting factors
lmax = solver.sht.lmax
ls = torch.arange(lmax).float()
spectral_weights = (ls*(ls + 1)).reshape(1, 1, -1, 1).to(prd.device)
# compute coefficients
coeffs = torch.view_as_real(solver.sht(prd - tar))
coeffs = coeffs[..., 0]**2 + coeffs[..., 1]**2
h1_coeffs = spectral_weights * coeffs
h1_norm2 = h1_coeffs[..., :, 0] + 2 * torch.sum(h1_coeffs[..., :, 1:], dim=-1)
l2_norm2 = coeffs[..., :, 0] + 2 * torch.sum(coeffs[..., :, 1:], dim=-1)
h1_loss = torch.sum(h1_norm2, dim=(-1,-2))
l2_loss = torch.sum(l2_norm2, dim=(-1,-2))
# strictly speaking this is not exactly h1 loss
if not squared:
loss = torch.sqrt(h1_loss) + torch.sqrt(l2_loss)
else:
loss = h1_loss + l2_loss
if relative:
raise NotImplementedError("Relative H1 loss not implemented")
loss = loss.mean()
return loss
def fluct_l2loss_sphere(solver, prd, tar, inp, relative=False, polar_opt=0):
# compute the weighting factor first
fluct = solver.integrate_grid((tar - inp)**2, dimensionless=True, polar_opt=polar_opt)
weight = fluct / torch.sum(fluct, dim=-1, keepdim=True)
# weight = weight.reshape(*weight.shape, 1, 1)
loss = weight * solver.integrate_grid((prd - tar)**2, dimensionless=True, polar_opt=polar_opt)
if relative:
loss = loss / (weight * solver.integrate_grid(tar**2, dimensionless=True, polar_opt=polar_opt))
loss = torch.mean(loss)
return loss
def main(train=True, load_checkpoint=False, enable_amp=False):
# set seed
torch.manual_seed(333)
torch.cuda.manual_seed(333)
# set device
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
torch.cuda.set_device(device.index)
# dataset
from utils.pde_dataset import PdeDataset
# 1 hour prediction steps
dt = 1*3600
dt_solver = 150
nsteps = dt//dt_solver
dataset = PdeDataset(dt=dt, nsteps=nsteps, dims=(256, 512), device=device, normalize=True)
# There is still an issue with parallel dataloading. Do NOT use it at the moment
# dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, persistent_workers=True)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0, persistent_workers=False)
solver = dataset.solver.to(device)
nlat = dataset.nlat
nlon = dataset.nlon
# training function
def train_model(model, dataloader, optimizer, gscaler, scheduler=None, nepochs=20, nfuture=0, num_examples=256, num_valid=8, loss_fn='l2'):
train_start = time.time()
for epoch in range(nepochs):
# time each epoch
epoch_start = time.time()
dataloader.dataset.set_initial_condition('random')
dataloader.dataset.set_num_examples(num_examples)
# do the training
acc_loss = 0
model.train()
for inp, tar in dataloader:
with amp.autocast(enabled=enable_amp):
prd = model(inp)
for _ in range(nfuture):
prd = model(prd)
if loss_fn == 'l2':
loss = l2loss_sphere(solver, prd, tar, relative=False)
elif loss_fn == 'h1':
loss = h1loss_sphere(solver, prd, tar, relative=False)
elif loss_fn == 'spectral':
loss = spectral_loss_sphere(solver, prd, tar, relative=False)
elif loss_fn == 'fluct':
loss = fluct_l2loss_sphere(solver, prd, tar, inp, relative=True)
else:
raise NotImplementedError(f'Unknown loss function {loss_fn}')
acc_loss += loss.item() * inp.size(0)
optimizer.zero_grad(set_to_none=True)
# gscaler.scale(loss).backward()
gscaler.scale(loss).backward()
gscaler.step(optimizer)
gscaler.update()
acc_loss = acc_loss / len(dataloader.dataset)
dataloader.dataset.set_initial_condition('random')
dataloader.dataset.set_num_examples(num_valid)
# perform validation
valid_loss = 0
model.eval()
with torch.no_grad():
for inp, tar in dataloader:
prd = model(inp)
for _ in range(nfuture):
prd = model(prd)
loss = l2loss_sphere(solver, prd, tar, relative=True)
valid_loss += loss.item() * inp.size(0)
valid_loss = valid_loss / len(dataloader.dataset)
if scheduler is not None:
scheduler.step(valid_loss)
epoch_time = time.time() - epoch_start
print(f'--------------------------------------------------------------------------------')
print(f'Epoch {epoch} summary:')
print(f'time taken: {epoch_time}')
print(f'accumulated training loss: {acc_loss}')
print(f'relative validation loss: {valid_loss}')
if wandb.run is not None:
current_lr = optimizer.param_groups[0]['lr']
wandb.log({"loss": acc_loss, "validation loss": valid_loss, "learning rate": current_lr})
train_time = time.time() - train_start
print(f'--------------------------------------------------------------------------------')
print(f'done. Training took {train_time}.')
return valid_loss
# rolls out the FNO and compares to the classical solver
def autoregressive_inference(model, dataset, path_root, nsteps, autoreg_steps=10, nskip=1, plot_channel=0, nics=20):
model.eval()
losses = np.zeros(nics)
fno_times = np.zeros(nics)
nwp_times = np.zeros(nics)
for iic in range(nics):
ic = dataset.solver.random_initial_condition(mach=0.2)
inp_mean = dataset.inp_mean
inp_var = dataset.inp_var
prd = (dataset.solver.spec2grid(ic) - inp_mean) / torch.sqrt(inp_var)
prd = prd.unsqueeze(0)
uspec = ic.clone()
# ML model
start_time = time.time()
for i in range(1, autoreg_steps+1):
# evaluate the ML model
prd = model(prd)
if iic == nics-1 and nskip > 0 and i % nskip == 0:
# do plotting
fig = plt.figure(figsize=(7.5, 6))
dataset.solver.plot_griddata(prd[0, plot_channel], fig, vmax=4, vmin=-4)
plt.savefig(path_root+'_pred_'+str(i//nskip)+'.png')
plt.clf()
fno_times[iic] = time.time() - start_time
# classical model
start_time = time.time()
for i in range(1, autoreg_steps+1):
# advance classical model
uspec = dataset.solver.timestep(uspec, nsteps)
if iic == nics-1 and i % nskip == 0 and nskip > 0:
ref = (dataset.solver.spec2grid(uspec) - inp_mean) / torch.sqrt(inp_var)
fig = plt.figure(figsize=(7.5, 6))
dataset.solver.plot_griddata(ref[plot_channel], fig, vmax=4, vmin=-4)
plt.savefig(path_root+'_truth_'+str(i//nskip)+'.png')
plt.clf()
nwp_times[iic] = time.time() - start_time
# ref = (dataset.solver.spec2grid(uspec) - inp_mean) / torch.sqrt(inp_var)
ref = dataset.solver.spec2grid(uspec)
prd = prd * torch.sqrt(inp_var) + inp_mean
losses[iic] = l2loss_sphere(solver, prd, ref, relative=True).item()
return losses, fno_times, nwp_times
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# prepare dicts containing models and corresponding metrics
models = {}
metrics = {}
# # U-Net if installed
# from models.unet import UNet
# models['unet_baseline'] = partial(UNet)
# SFNO and FNO models
from models.sfno import SphericalFourierNeuralOperatorNet as SFNO
# SFNO models
models['sfno_sc3_layer4_edim256_linear'] = partial(SFNO, spectral_transform='sht', filter_type='linear', img_size=(nlat, nlon),
num_layers=4, scale_factor=3, embed_dim=256, operator_type='vector')
models['sfno_sc3_layer4_edim256_real'] = partial(SFNO, spectral_transform='sht', filter_type='non-linear', img_size=(nlat, nlon),
num_layers=4, scale_factor=3, embed_dim=256, complex_activation = 'real', operator_type='diagonal')
# FNO models
models['fno_sc3_layer4_edim256_linear'] = partial(SFNO, spectral_transform='fft', filter_type='linear', img_size=(nlat, nlon),
num_layers=4, scale_factor=3, embed_dim=256, operator_type='diagonal')
models['fno_sc3_layer4_edim256_real'] = partial(SFNO, spectral_transform='fft', filter_type='non-linear', img_size=(nlat, nlon),
num_layers=4, scale_factor=3, embed_dim=256, complex_activation='real')
# iterate over models and train each model
root_path = os.path.dirname(__file__)
for model_name, model_handle in models.items():
model = model_handle().to(device)
metrics[model_name] = {}
num_params = count_parameters(model)
print(f'number of trainable params: {num_params}')
metrics[model_name]['num_params'] = num_params
if load_checkpoint:
model.load_state_dict(torch.load(os.path.join(root_path, 'checkpoints/'+model_name)))
# run the training
if train:
run = wandb.init(project="sfno spherical swe", group=model_name, name=model_name + '_' + str(time.time()), config=model_handle.keywords)
# optimizer:
optimizer = torch.optim.Adam(model.parameters(), lr=1E-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
gscaler = amp.GradScaler(enabled=enable_amp)
start_time = time.time()
print(f'Training {model_name}, single step')
train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=200, loss_fn='l2')
# multistep training
print(f'Training {model_name}, two step')
optimizer = torch.optim.Adam(model.parameters(), lr=5E-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
gscaler = amp.GradScaler(enabled=enable_amp)
dataloader.dataset.nsteps = 2 * dt//dt_solver
train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=20, nfuture=1)
dataloader.dataset.nsteps = 1 * dt//dt_solver
training_time = time.time() - start_time
run.finish()
torch.save(model.state_dict(), os.path.join(root_path, 'checkpoints/'+model_name))
# set seed
torch.manual_seed(333)
torch.cuda.manual_seed(333)
with torch.inference_mode():
losses, fno_times, nwp_times = autoregressive_inference(model, dataset, os.path.join(root_path,'paper_figures/'+model_name), nsteps=nsteps, autoreg_steps=10)
metrics[model_name]['loss_mean'] = np.mean(losses)
metrics[model_name]['loss_std'] = np.std(losses)
metrics[model_name]['fno_time_mean'] = np.mean(fno_times)
metrics[model_name]['fno_time_std'] = np.std(fno_times)
metrics[model_name]['nwp_time_mean'] = np.mean(nwp_times)
metrics[model_name]['nwp_time_std'] = np.std(nwp_times)
if train:
metrics[model_name]['training_time'] = training_time
df = pd.DataFrame(metrics)
df.to_pickle(os.path.join(root_path, 'output_data/metrics.pkl'))
if __name__ == "__main__":
import torch.multiprocessing as mp
mp.set_start_method('forkserver', force=True)
main(train=True, load_checkpoint=False, enable_amp=False)
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import torch
import os
from math import ceil
import sys
sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), "torch_harmonics"))
sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), "examples"))
from shallow_water_equations import ShallowWaterSolver
class PdeDataset(torch.utils.data.Dataset):
"""Custom Dataset class for PDE training data"""
def __init__(self, dt, nsteps, dims=(384, 768), pde='shallow water equations', initial_condition='random',
num_examples=32, device=torch.device('cpu'), normalize=True, stream=None):
self.num_examples = num_examples
self.device = device
self.stream = stream
self.nlat = dims[0]
self.nlon = dims[1]
# number of solver steps used to compute the target
self.nsteps = nsteps
self.normalize = normalize
if pde == 'shallow water equations':
lmax = ceil(self.nlat/3)
mmax = lmax
dt_solver = dt / float(self.nsteps)
self.solver = ShallowWaterSolver(self.nlat, self.nlon, dt_solver, lmax=lmax, mmax=mmax, grid='equiangular').to(self.device).float()
else:
raise NotImplementedError
self.set_initial_condition(ictype=initial_condition)
if self.normalize:
inp0, _ = self._get_sample()
self.inp_mean = torch.mean(inp0, dim=(-1, -2)).reshape(-1, 1, 1)
self.inp_var = torch.var(inp0, dim=(-1, -2)).reshape(-1, 1, 1)
def __len__(self):
length = self.num_examples if self.ictype == 'random' else 1
return length
def set_initial_condition(self, ictype='random'):
self.ictype = ictype
def set_num_examples(self, num_examples=32):
self.num_examples = num_examples
def _get_sample(self):
if self.ictype == 'random':
inp = self.solver.random_initial_condition(mach=0.2)
elif self.ictype == 'galewsky':
inp = self.solver.galewsky_initial_condition()
# solve pde for n steps to return the target
tar = self.solver.timestep(inp, self.nsteps)
inp = self.solver.spec2grid(inp)
tar = self.solver.spec2grid(tar)
return inp, tar
def __getitem__(self, index):
# if self.stream is None:
# self.stream = torch.cuda.Stream()
# with torch.cuda.stream(self.stream):
# with torch.inference_mode():
# with torch.no_grad():
# inp, tar = self._get_sample()
# if self.normalize:
# inp = (inp - self.inp_mean) / torch.sqrt(self.inp_var)
# tar = (tar - self.inp_mean) / torch.sqrt(self.inp_var)
# self.stream.synchronize()
with torch.inference_mode():
with torch.no_grad():
inp, tar = self._get_sample()
if self.normalize:
inp = (inp - self.inp_mean) / torch.sqrt(self.inp_var)
tar = (tar - self.inp_mean) / torch.sqrt(self.inp_var)
return inp.clone(), tar.clone()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment