Commit b6b2bce3 authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

implemented Zernike filter basis

parent 7126fb9a
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
* New filter basis normalization in DISCO convolutions * New filter basis normalization in DISCO convolutions
* Reworked DISCO filter basis datastructure * Reworked DISCO filter basis datastructure
* Support for new filter basis types * Support for new filter basis types
* Adding Zernike polynomial basis on a disk
* Adding Morlet wavelet basis functions on a spherical disk * Adding Morlet wavelet basis functions on a spherical disk
* Cleaning up the SFNO example and adding new Local Spherical Neural Operator model * Cleaning up the SFNO example and adding new Local Spherical Neural Operator model
* Updated resampling module to extend input signal to the poles if needed * Updated resampling module to extend input signal to the poles if needed
......
...@@ -430,7 +430,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0): ...@@ -430,7 +430,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
normalization_layer="none", normalization_layer="none",
) )
models[f"lsno_sc2_layers4_e32"] = partial( models[f"lsno_sc2_layers4_e32_morlet"] = partial(
LSNO, LSNO,
img_size=(nlat, nlon), img_size=(nlat, nlon),
grid=grid, grid=grid,
...@@ -443,6 +443,27 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0): ...@@ -443,6 +443,27 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
pos_embed=False, pos_embed=False,
use_mlp=True, use_mlp=True,
normalization_layer="none", normalization_layer="none",
kernel_shape=[4, 4],
encoder_kernel_shape=[4, 4],
filter_basis_type="morlet"
)
models[f"lsno_sc2_layers4_e32_zernike"] = partial(
LSNO,
img_size=(nlat, nlon),
grid=grid,
num_layers=4,
scale_factor=2,
embed_dim=32,
operator_type="driscoll-healy",
activation_function="gelu",
big_skip=False,
pos_embed=False,
use_mlp=True,
normalization_layer="none",
kernel_shape=[4],
encoder_kernel_shape=[4],
filter_basis_type="zernike"
) )
# iterate over models and train each model # iterate over models and train each model
...@@ -468,7 +489,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0): ...@@ -468,7 +489,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
# run the training # run the training
if train: if train:
run = wandb.init(project="sfno ablations spherical swe", group=model_name, name=model_name + "_" + str(time.time()), config=model_handle.keywords) run = wandb.init(project="local sno spherical swe", group=model_name, name=model_name + "_" + str(time.time()), config=model_handle.keywords)
# optimizer: # optimizer:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4) optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
...@@ -478,7 +499,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0): ...@@ -478,7 +499,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
start_time = time.time() start_time = time.time()
print(f"Training {model_name}, single step") print(f"Training {model_name}, single step")
train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=20, loss_fn="l2", enable_amp=enable_amp, log_grads=log_grads) train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=100, loss_fn="l2", enable_amp=enable_amp, log_grads=log_grads)
if nfuture > 0: if nfuture > 0:
print(f'Training {model_name}, {nfuture} step') print(f'Training {model_name}, {nfuture} step')
......
...@@ -70,7 +70,7 @@ class DiscreteContinuousEncoder(nn.Module): ...@@ -70,7 +70,7 @@ class DiscreteContinuousEncoder(nn.Module):
grid_out=grid_out, grid_out=grid_out,
groups=groups, groups=groups,
bias=bias, bias=bias,
theta_cutoff=math.sqrt(2.0) * torch.pi / float(out_shape[0] - 1), theta_cutoff=1.0 * torch.pi / float(out_shape[0] - 1),
) )
def forward(self, x): def forward(self, x):
...@@ -103,7 +103,7 @@ class DiscreteContinuousDecoder(nn.Module): ...@@ -103,7 +103,7 @@ class DiscreteContinuousDecoder(nn.Module):
# # set up # # set up
self.sht = RealSHT(*in_shape, grid=grid_in).float() self.sht = RealSHT(*in_shape, grid=grid_in).float()
self.isht = InverseRealSHT(*out_shape, lmax=self.sht.lmax, mmax=self.sht.mmax, grid=grid_out).float() self.isht = InverseRealSHT(*out_shape, lmax=self.sht.lmax, mmax=self.sht.mmax, grid=grid_out).float()
# self.upscale = ResampleS2(*in_shape, *out_shape, grid_in=grid_in, grid_out=grid_out) self.upscale = ResampleS2(*in_shape, *out_shape, grid_in=grid_in, grid_out=grid_out)
# set up DISCO convolution # set up DISCO convolution
self.conv = DiscreteContinuousConvS2( self.conv = DiscreteContinuousConvS2(
...@@ -117,28 +117,25 @@ class DiscreteContinuousDecoder(nn.Module): ...@@ -117,28 +117,25 @@ class DiscreteContinuousDecoder(nn.Module):
grid_out=grid_out, grid_out=grid_out,
groups=groups, groups=groups,
bias=False, bias=False,
theta_cutoff=math.sqrt(2.0) * torch.pi / float(in_shape[0] - 1), theta_cutoff=1.0 * torch.pi / float(in_shape[0] - 1),
) )
# self.convt = nn.Conv2d(inp_chans, out_chans, 1, bias=False)
def upscale_sht(self, x: torch.Tensor): def upscale_sht(self, x: torch.Tensor):
return self.isht(self.sht(x)) return self.isht(self.sht(x))
def forward(self, x): def forward(self, x):
dtype = x.dtype dtype = x.dtype
# x = self.upscale(x) x = self.upscale(x)
with amp.autocast(device_type="cuda", enabled=False): with amp.autocast(device_type="cuda", enabled=False):
x = x.float() x = x.float()
x = self.upscale_sht(x) # x = self.upscale_sht(x)
x = self.conv(x) x = self.conv(x)
x = x.to(dtype=dtype) x = x.to(dtype=dtype)
return x return x
class SphericalNeuralOperatorBlock(nn.Module): class SphericalNeuralOperatorBlock(nn.Module):
""" """
Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks. Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
...@@ -160,7 +157,7 @@ class SphericalNeuralOperatorBlock(nn.Module): ...@@ -160,7 +157,7 @@ class SphericalNeuralOperatorBlock(nn.Module):
inner_skip="None", inner_skip="None",
outer_skip="linear", outer_skip="linear",
use_mlp=True, use_mlp=True,
disco_kernel_shape=[2, 4], disco_kernel_shape=[3, 4],
disco_basis_type="piecewise linear", disco_basis_type="piecewise linear",
): ):
super().__init__() super().__init__()
...@@ -185,10 +182,10 @@ class SphericalNeuralOperatorBlock(nn.Module): ...@@ -185,10 +182,10 @@ class SphericalNeuralOperatorBlock(nn.Module):
grid_in=forward_transform.grid, grid_in=forward_transform.grid,
grid_out=inverse_transform.grid, grid_out=inverse_transform.grid,
bias=False, bias=False,
theta_cutoff=4 * math.sqrt(2.0) * torch.pi / float(inverse_transform.nlat - 1), theta_cutoff=1.0 * (disco_kernel_shape[0] + 1) * torch.pi / float(inverse_transform.nlat - 1),
) )
elif conv_type == "global": elif conv_type == "global":
self.global_conv = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain_factor, operator_type=operator_type, bias=False) self.global_conv = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain_factor, operator_type=operator_type, bias=False)
else: else:
raise ValueError(f"Unknown convolution type {conv_type}") raise ValueError(f"Unknown convolution type {conv_type}")
...@@ -294,6 +291,8 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -294,6 +291,8 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
Activation function to use, by default "gelu" Activation function to use, by default "gelu"
encoder_kernel_shape : int, optional encoder_kernel_shape : int, optional
size of the encoder kernel size of the encoder kernel
filter_basis_type: Optional[str]: str, optional
filter basis type
use_mlp : int, optional use_mlp : int, optional
Whether to use MLPs in the SFNO blocks, by default True Whether to use MLPs in the SFNO blocks, by default True
mlp_ratio : int, optional mlp_ratio : int, optional
...@@ -350,7 +349,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -350,7 +349,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
activation_function="relu", activation_function="relu",
kernel_shape=[3, 4], kernel_shape=[3, 4],
encoder_kernel_shape=[3, 4], encoder_kernel_shape=[3, 4],
disco_basis_type="piecewise linear", filter_basis_type="piecewise linear",
use_mlp=True, use_mlp=True,
mlp_ratio=2.0, mlp_ratio=2.0,
drop_rate=0.0, drop_rate=0.0,
...@@ -423,18 +422,17 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -423,18 +422,17 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
self.pos_embed = None self.pos_embed = None
# encoder # encoder
self.encoder = DiscreteContinuousConvS2( self.encoder = DiscreteContinuousEncoder(
self.in_chans, in_shape=self.img_size,
self.embed_dim, out_shape=(self.h, self.w),
self.img_size,
(self.h, self.w),
self.encoder_kernel_shape,
basis_type=disco_basis_type,
groups=1,
grid_in=grid, grid_in=grid,
grid_out=grid_internal, grid_out=grid_internal,
inp_chans=self.in_chans,
out_chans=self.embed_dim,
kernel_shape=self.encoder_kernel_shape,
basis_type=filter_basis_type,
groups=1,
bias=False, bias=False,
theta_cutoff=math.sqrt(2) * torch.pi / float(self.h - 1),
) )
# prepare the SHT # prepare the SHT
...@@ -476,7 +474,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -476,7 +474,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
outer_skip=outer_skip, outer_skip=outer_skip,
use_mlp=use_mlp, use_mlp=use_mlp,
disco_kernel_shape=kernel_shape, disco_kernel_shape=kernel_shape,
disco_basis_type=disco_basis_type, disco_basis_type=filter_basis_type,
) )
self.blocks.append(block) self.blocks.append(block)
...@@ -490,7 +488,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -490,7 +488,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
inp_chans=self.embed_dim, inp_chans=self.embed_dim,
out_chans=self.out_chans, out_chans=self.out_chans,
kernel_shape=self.encoder_kernel_shape, kernel_shape=self.encoder_kernel_shape,
basis_type=disco_basis_type, basis_type=filter_basis_type,
groups=1, groups=1,
bias=False, bias=False,
) )
...@@ -503,7 +501,6 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -503,7 +501,6 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
# scale = math.sqrt(0.5 / self.in_chans) # scale = math.sqrt(0.5 / self.in_chans)
# nn.init.normal_(self.residual_transform.weight, mean=0.0, std=scale) # nn.init.normal_(self.residual_transform.weight, mean=0.0, std=scale)
@torch.jit.ignore @torch.jit.ignore
def no_weight_decay(self): def no_weight_decay(self):
return {"pos_embed", "cls_token"} return {"pos_embed", "cls_token"}
......
...@@ -44,7 +44,7 @@ def get_filter_basis(kernel_shape: Union[int, List[int], Tuple[int, int]], basis ...@@ -44,7 +44,7 @@ def get_filter_basis(kernel_shape: Union[int, List[int], Tuple[int, int]], basis
elif basis_type == "morlet": elif basis_type == "morlet":
return MorletFilterBasis(kernel_shape=kernel_shape) return MorletFilterBasis(kernel_shape=kernel_shape)
elif basis_type == "zernike": elif basis_type == "zernike":
raise NotImplementedError() return ZernikeFilterBasis(kernel_shape=kernel_shape)
else: else:
raise ValueError(f"Unknown basis_type {basis_type}") raise ValueError(f"Unknown basis_type {basis_type}")
...@@ -54,6 +54,16 @@ def _circle_dist(x1: torch.Tensor, x2: torch.Tensor): ...@@ -54,6 +54,16 @@ def _circle_dist(x1: torch.Tensor, x2: torch.Tensor):
return torch.minimum(torch.abs(x1 - x2), torch.abs(2 * math.pi - torch.abs(x1 - x2))) return torch.minimum(torch.abs(x1 - x2), torch.abs(2 * math.pi - torch.abs(x1 - x2)))
def _log_factorial(x: torch.Tensor):
"""Helper function to compute the log factorial on a torch tensor"""
return torch.lgamma(x + 1)
def _factorial(x: torch.Tensor):
"""Helper function to compute the factorial on a torch tensor"""
return torch.exp(_log_factorial(x))
class FilterBasis(metaclass=abc.ABCMeta): class FilterBasis(metaclass=abc.ABCMeta):
""" """
Abstract base class for a filter basis Abstract base class for a filter basis
...@@ -226,7 +236,7 @@ class MorletFilterBasis(FilterBasis): ...@@ -226,7 +236,7 @@ class MorletFilterBasis(FilterBasis):
def kernel_size(self): def kernel_size(self):
return self.kernel_shape[0] * self.kernel_shape[1] return self.kernel_shape[0] * self.kernel_shape[1]
def _gaussian_window(self, r: torch.Tensor, width: float = 1.0): def gaussian_window(self, r: torch.Tensor, width: float = 1.0):
return 1 / (2 * math.pi * width**2) * torch.exp(-0.5 * r**2 / (width**2)) return 1 / (2 * math.pi * width**2) * torch.exp(-0.5 * r**2 / (width**2))
def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float, width: float = 0.25): def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float, width: float = 0.25):
...@@ -254,6 +264,76 @@ class MorletFilterBasis(FilterBasis): ...@@ -254,6 +264,76 @@ class MorletFilterBasis(FilterBasis):
disk_area = 1.0 disk_area = 1.0
# computes the Gaussian envelope. To ensure that the curve is roughly 0 at the boundary, we rescale the Gaussian by 0.25 # computes the Gaussian envelope. To ensure that the curve is roughly 0 at the boundary, we rescale the Gaussian by 0.25
vals = self._gaussian_window(r[iidx[:, 1], iidx[:, 2]] / r_cutoff, width=width) * harmonic[iidx[:, 0], iidx[:, 1], iidx[:, 2]] / disk_area vals = self.gaussian_window(r[iidx[:, 1], iidx[:, 2]] / r_cutoff, width=width) * harmonic[iidx[:, 0], iidx[:, 1], iidx[:, 2]] / disk_area
return iidx, vals
class ZernikeFilterBasis(FilterBasis):
"""
Zernike polynomials which are defined on the disk. See https://en.wikipedia.org/wiki/Zernike_polynomials
"""
def __init__(
self,
kernel_shape: Union[int, Tuple[int], List[int]],
):
if isinstance(kernel_shape, tuple) or isinstance(kernel_shape, list):
kernel_shape = kernel_shape[0]
if not isinstance(kernel_shape, int):
raise ValueError(f"expected kernel_shape to be an integer but got {kernel_shape} instead.")
super().__init__(kernel_shape=kernel_shape)
@property
def kernel_size(self):
return (self.kernel_shape * (self.kernel_shape + 1)) // 2
def zernikeradial(self, r: torch.Tensor, n: torch.Tensor, m: torch.Tensor):
out = torch.zeros_like(r)
bound = (n - m) // 2 + 1
max_bound = bound.max().item()
for k in range(max_bound):
inc = (-1) ** k * _factorial(n - k) * r ** (n - 2 * k) / (math.factorial(k) * _factorial((n + m) // 2 - k) * _factorial((n - m) // 2 - k))
out += torch.where(k < bound, inc, 0.0)
return out
def zernikepoly(self, r: torch.Tensor, phi: torch.Tensor, n: torch.Tensor, l: torch.Tensor):
m = 2 * l - n
return torch.where(m < 0, self.zernikeradial(r, n, -m) * torch.sin(m * phi), self.zernikeradial(r, n, m) * torch.cos(m * phi))
def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float, width: float = 0.25):
"""
Computes the index set that falls into the isotropic kernel's support and returns both indices and values.
"""
# enumerator for basis function
ikernel = torch.arange(self.kernel_size).reshape(-1, 1, 1)
# get relevant indices
iidx = torch.argwhere((r <= r_cutoff) & torch.full_like(ikernel, True, dtype=torch.bool))
# indexing logic for zernike polynomials
# the total index is given by (n * (n + 2) + l ) // 2 which needs to be reversed
# precompute shifts in the level of the "pyramid"
nshifts = torch.arange(self.kernel_shape)
nshifts = (nshifts + 1) * nshifts // 2
# find the level and position within the pyramid
nkernel = torch.searchsorted(nshifts, ikernel, right=True) - 1
lkernel = ikernel - nshifts[nkernel]
# mkernel = 2 * ikernel - nkernel * (nkernel + 2)
# get corresponding coordinates and n and l indices
r = r[iidx[:, 1], iidx[:, 2]] / r_cutoff
phi = phi[iidx[:, 1], iidx[:, 2]]
n = nkernel[iidx[:, 0], 0, 0]
l = lkernel[iidx[:, 0], 0, 0]
# computes the Zernike polynomials using helper functions
vals = self.zernikepoly(r, phi, n, l)
return iidx, vals return iidx, vals
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment