Commit 6a845fd3 authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

adding spherical attention

parent b3816ebc
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
from typing import Optional, Tuple
import torch
import torch.nn as nn
from torch_harmonics.quadrature import _precompute_latitudes
from .losses import get_quadrature_weights
# routine to compute multiclass labels on the sphere
# the routine follows the implementation in
# https://github.com/qubvel-org/segmentation_models.pytorch/blob/4aa36c6ad13f8a12552e4ea4131af2a86e564962/segmentation_models_pytorch/metrics/functional.py
# but uses quadrature weights
def _get_stats_multiclass(
output: torch.LongTensor,
target: torch.LongTensor,
num_classes: int,
quad_weights: torch.Tensor,
ignore_index: Optional[int],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size, *dims = output.shape
num_elements = torch.prod(torch.tensor(dims)).long()
if ignore_index is not None:
ignore = target == ignore_index
output = torch.where(ignore, -1, output)
target = torch.where(ignore, -1, target)
ignore_per_sample = ignore.view(batch_size, -1).sum(1)
tp_count = torch.zeros(batch_size, num_classes, dtype=torch.float32, device=output.device)
fp_count = torch.zeros(batch_size, num_classes, dtype=torch.float32, device=output.device)
fn_count = torch.zeros(batch_size, num_classes, dtype=torch.float32, device=output.device)
tn_count = torch.zeros(batch_size, num_classes, dtype=torch.float32, device=output.device)
matched = target == output
not_matched = target != output
for i in range(batch_size):
matched_i = matched[i, ...]
not_matched_i = not_matched[i, ...]
target_i = target[i, ...]
output_i = output[i, ...]
for c in range(num_classes):
# compute weights
qwt_c = quad_weights[target_i == c]
qwo_c = quad_weights[output_i == c]
# true positives
tp_count[i, c] = torch.sum(matched_i[target_i == c] * qwt_c)
# false positives
fp_count[i, c] = torch.sum(not_matched_i[output_i == c] * qwo_c)
# false negatives
fn_count[i, c] = torch.sum(not_matched_i[target_i == c] * qwt_c)
# true negatives is the leftovers
tn_count = torch.sum(quad_weights) - tp_count - fp_count - fn_count
return tp_count, fp_count, fn_count, tn_count
def _predict_classes(logits: torch.Tensor) -> torch.Tensor:
return torch.argmax(torch.softmax(logits, dim=1), dim=1, keepdim=False)
class BaseMetricS2(nn.Module):
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular", weight: torch.Tensor = None, ignore_index: int = -100, mode: str = "micro"):
super().__init__()
self.ignore_index = ignore_index
self.mode = mode
# area weights
q = get_quadrature_weights(nlat=nlat, nlon=nlon, grid=grid, tile=True)
self.register_buffer("quad_weights", q)
if weight is None:
self.weight = None
else:
self.register_buffer("weight", weight.unsqueeze(0))
def _forward(self, pred: torch.Tensor, truth: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# convert logits to class predictions
pred_class = _predict_classes(pred)
# get true positive, false positive, etc
tp, fp, fn, tn = _get_stats_multiclass(pred_class, truth, pred.shape[1], self.quad_weights, self.ignore_index)
# compute averages:
if self.mode == "micro":
if self.weight is not None:
# weighted average
tp = torch.sum(tp * self.weight)
fp = torch.sum(fp * self.weight)
fn = torch.sum(fn * self.weight)
tn = torch.sum(tn * self.weight)
else:
# normal average
tp = torch.mean(tp)
fp = torch.mean(fp)
fn = torch.mean(fn)
tn = torch.mean(tn)
else:
tp = torch.mean(tp, dim=0)
fp = torch.mean(fp, dim=0)
fn = torch.mean(fn, dim=0)
tn = torch.mean(tn, dim=0)
return tp, fp, fn, tn
class IntersectionOverUnionS2(BaseMetricS2):
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular", weight: torch.Tensor = None, ignore_index: int = -100, mode: str = "micro"):
super().__init__(nlat, nlon, grid, weight, ignore_index, mode)
def forward(self, pred: torch.Tensor, truth: torch.Tensor) -> torch.Tensor:
tp, fp, fn, tn = self._forward(pred, truth)
# compute score
score = tp / (tp + fp + fn)
if self.mode == "macro":
# we need to do some averaging still:
# be careful with zeros
score = torch.where(torch.isnan(score), 0.0, score)
if self.weight is not None:
score = torch.sum(score * self.weight)
else:
score = torch.mean(score)
return score
class AccuracyS2(BaseMetricS2):
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular", weight: torch.Tensor = None, ignore_index: int = -100, mode: str = "micro"):
super().__init__(nlat, nlon, grid, weight, ignore_index, mode)
def forward(self, pred: torch.Tensor, truth: torch.Tensor) -> torch.Tensor:
tp, fp, fn, tn = self._forward(pred, truth)
# compute score
score = (tp + tn) / (tp + fp + fn + tn)
if self.mode == "macro":
# we need to do some averaging still:
# be careful with zeros
score = torch.where(torch.isnan(score), 0.0, score)
if self.weight is not None:
score = torch.sum(score * self.weight)
else:
score = torch.mean(score)
return score
......@@ -29,5 +29,8 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
from .sfno import SphericalFourierNeuralOperatorNet
from .lsno import LocalSphericalNeuralOperatorNet
from .sfno import SphericalFourierNeuralOperator
from .lsno import LocalSphericalNeuralOperator
from .s2unet import SphericalUNet
from .s2transformer import SphericalTransformer
from .s2segformer import SphericalSegformer
......@@ -29,26 +29,26 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import abc
import math
import torch
import torch.nn as nn
import torch.fft
from torch.utils.checkpoint import checkpoint
import math
from torch_harmonics import *
from ._activations import *
from torch_harmonics import InverseRealSHT
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2)
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " "The distribution of values may be incorrect.", stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
......@@ -66,7 +66,7 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
......@@ -74,7 +74,7 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
......@@ -95,7 +95,7 @@ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
@torch.jit.script
def drop_path(x: torch.Tensor, drop_prob: float = 0., training: bool = False) -> torch.Tensor:
def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
......@@ -103,9 +103,9 @@ def drop_path(x: torch.Tensor, drop_prob: float = 0., training: bool = False) ->
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
if drop_prob == 0.0 or not training:
return x
keep_prob = 1. - drop_prob
keep_prob = 1.0 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2d ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
......@@ -114,8 +114,8 @@ def drop_path(x: torch.Tensor, drop_prob: float = 0., training: bool = False) ->
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
......@@ -123,16 +123,30 @@ class DropPath(nn.Module):
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class PatchEmbed(nn.Module):
def __init__(self, img_size=(224, 224), patch_size=(16, 16), in_chans=3, embed_dim=768):
super(PatchEmbed, self).__init__()
self.red_img_size = ((img_size[0] // patch_size[0]), (img_size[1] // patch_size[1]))
num_patches = self.red_img_size[0] * self.red_img_size[1]
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True)
self.proj.weight.is_shared_mp = ["spatial"]
self.proj.bias.is_shared_mp = ["spatial"]
def forward(self, x):
# gather input
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
# new: B, C, H*W
x = self.proj(x).flatten(2)
return x
class MLP(nn.Module):
def __init__(self,
in_features,
hidden_features = None,
out_features = None,
act_layer = nn.ReLU,
output_bias = False,
drop_rate = 0.,
checkpointing = False,
gain = 1.0):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, output_bias=False, drop_rate=0.0, checkpointing=False, gain=1.0):
super(MLP, self).__init__()
self.checkpointing = checkpointing
out_features = out_features or in_features
......@@ -142,7 +156,7 @@ class MLP(nn.Module):
fc1 = nn.Conv2d(in_features, hidden_features, 1, bias=True)
# initialize the weights correctly
scale = math.sqrt(2.0 / in_features)
nn.init.normal_(fc1.weight, mean=0., std=scale)
nn.init.normal_(fc1.weight, mean=0.0, std=scale)
if fc1.bias is not None:
nn.init.constant_(fc1.bias, 0.0)
......@@ -153,11 +167,11 @@ class MLP(nn.Module):
fc2 = nn.Conv2d(hidden_features, out_features, 1, bias=output_bias)
# gain factor for the output determines the scaling of the output init
scale = math.sqrt(gain / hidden_features)
nn.init.normal_(fc2.weight, mean=0., std=scale)
nn.init.normal_(fc2.weight, mean=0.0, std=scale)
if fc2.bias is not None:
nn.init.constant_(fc2.bias, 0.0)
if drop_rate > 0.:
if drop_rate > 0.0:
drop = nn.Dropout2d(drop_rate)
self.fwd = nn.Sequential(fc1, act, drop, fc2, drop)
else:
......@@ -173,15 +187,13 @@ class MLP(nn.Module):
else:
return self.fwd(x)
class RealFFT2(nn.Module):
"""
Helper routine to wrap FFT similarly to the SHT
"""
def __init__(self,
nlat,
nlon,
lmax = None,
mmax = None):
def __init__(self, nlat, nlon, lmax=None, mmax=None):
super(RealFFT2, self).__init__()
self.nlat = nlat
......@@ -191,18 +203,16 @@ class RealFFT2(nn.Module):
def forward(self, x):
y = torch.fft.rfft2(x, dim=(-2, -1), norm="ortho")
y = torch.cat((y[..., :math.ceil(self.lmax/2), :self.mmax], y[..., -math.floor(self.lmax/2):, :self.mmax]), dim=-2)
y = torch.cat((y[..., : math.ceil(self.lmax / 2), : self.mmax], y[..., -math.floor(self.lmax / 2) :, : self.mmax]), dim=-2)
return y
class InverseRealFFT2(nn.Module):
"""
Helper routine to wrap FFT similarly to the SHT
"""
def __init__(self,
nlat,
nlon,
lmax = None,
mmax = None):
def __init__(self, nlat, nlon, lmax=None, mmax=None):
super(InverseRealFFT2, self).__init__()
self.nlat = nlat
......@@ -213,6 +223,24 @@ class InverseRealFFT2(nn.Module):
def forward(self, x):
return torch.fft.irfft2(x, dim=(-2, -1), s=(self.nlat, self.nlon), norm="ortho")
class LayerNorm(nn.Module):
"""
Wrapper class that moves the channel dimension to the end
"""
def __init__(self, in_channels, eps=1e-05, elementwise_affine=True, bias=True, device=None, dtype=None):
super().__init__()
self.channel_dim = -3
self.norm = nn.LayerNorm(normalized_shape=in_channels, eps=1e-6, elementwise_affine=elementwise_affine, bias=bias, device=device, dtype=dtype)
def forward(self, x):
return self.norm(x.transpose(self.channel_dim, -1)).transpose(-1, self.channel_dim)
class SpectralConvS2(nn.Module):
"""
Spectral Convolution according to Driscoll & Healy. Designed for convolutions on the two-sphere S2
......@@ -220,15 +248,7 @@ class SpectralConvS2(nn.Module):
domain via the RealFFT2 and InverseRealFFT2 wrappers.
"""
def __init__(self,
forward_transform,
inverse_transform,
in_channels,
out_channels,
gain = 2.,
operator_type = "driscoll-healy",
lr_scale_exponent = 0,
bias = False):
def __init__(self, forward_transform, inverse_transform, in_channels, out_channels, gain=2.0, operator_type="driscoll-healy", lr_scale_exponent=0, bias=False):
super().__init__()
self.forward_transform = forward_transform
......@@ -237,8 +257,7 @@ class SpectralConvS2(nn.Module):
self.modes_lat = self.inverse_transform.lmax
self.modes_lon = self.inverse_transform.mmax
self.scale_residual = (self.forward_transform.nlat != self.inverse_transform.nlat) \
or (self.forward_transform.nlon != self.inverse_transform.nlon)
self.scale_residual = (self.forward_transform.nlat != self.inverse_transform.nlat) or (self.forward_transform.nlon != self.inverse_transform.nlon)
# remember factorization details
self.operator_type = operator_type
......@@ -266,7 +285,6 @@ class SpectralConvS2(nn.Module):
if bias:
self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
def forward(self, x):
dtype = x.dtype
......@@ -287,4 +305,117 @@ class SpectralConvS2(nn.Module):
x = x + self.bias
x = x.type(dtype)
return x, residual
\ No newline at end of file
return x, residual
class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta):
"""
Returns standard sequence based position embedding
"""
def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1):
super().__init__()
self.img_shape = img_shape
self.num_chans = num_chans
def forward(self, x: torch.Tensor):
return x + self.position_embeddings
class SequencePositionEmbedding(PositionEmbedding):
"""
Returns standard sequence based position embedding
"""
def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1):
super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)
with torch.no_grad():
# alternating custom position embeddings
pos = torch.arange(self.img_shape[0] * self.img_shape[1]).reshape(1, 1, *self.img_shape).repeat(1, self.num_chans, 1, 1)
k = torch.arange(self.num_chans).reshape(1, self.num_chans, 1, 1)
denom = torch.pow(10000, 2 * k / self.num_chans)
pos_embed = torch.where(k % 2 == 0, torch.sin(pos / denom), torch.cos(pos / denom))
# register tensor
self.register_buffer("position_embeddings", pos_embed.float())
class SpectralPositionEmbedding(PositionEmbedding):
"""
Returns position embeddings for the spherical transformer
"""
def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1):
super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)
# compute maximum required frequency and prepare isht
lmax = math.floor(math.sqrt(self.num_chans)) + 1
isht = InverseRealSHT(nlat=self.img_shape[0], nlon=self.img_shape[1], lmax=lmax, mmax=lmax, grid=grid)
# fill position embedding
with torch.no_grad():
pos_embed_freq = torch.zeros(1, self.num_chans, isht.lmax, isht.mmax, dtype=torch.complex64)
for i in range(self.num_chans):
l = math.floor(math.sqrt(i))
m = i - l**2 - l
if m < 0:
pos_embed_freq[0, i, l, -m] = 1.0j
else:
pos_embed_freq[0, i, l, m] = 1.0
# compute spatial position embeddings
pos_embed = isht(pos_embed_freq)
# normalization
pos_embed = pos_embed / torch.amax(pos_embed.abs(), dim=(-1, -2), keepdim=True)
# register tensor
self.register_buffer("position_embeddings", pos_embed)
class LearnablePositionEmbedding(PositionEmbedding):
"""
Returns position embeddings for the spherical transformer
"""
def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1, embed_type="lat"):
super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)
if embed_type == "latlon":
self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_chans, self.img_shape[0], self.img_shape[1]))
elif embed_type == "lat":
self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_chans, self.img_shape[0], 1))
else:
raise ValueError(f"Unknown learnable position embedding type {embed_type}")
# class SpiralPositionEmbedding(PositionEmbedding):
# """
# Returns position embeddings on the torus
# """
# def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1):
# super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)
# with torch.no_grad():
# # alternating custom position embeddings
# lats, _ = _precompute_latitudes(img_shape[0], grid=grid)
# lats = lats.reshape(-1, 1)
# lons = torch.linspace(0, 2 * math.pi, img_shape[1] + 1)[:-1]
# lons = lons.reshape(1, -1)
# # channel index
# k = torch.arange(self.num_chans).reshape(1, -1, 1, 1)
# pos_embed = torch.where(k % 2 == 0, torch.sin(k * (lons + lats)), torch.cos(k * (lons - lats)))
# # register tensor
# self.register_buffer("position_embeddings", pos_embed.float())
\ No newline at end of file
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
......@@ -29,6 +29,8 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import math
import torch
import torch.nn as nn
import torch.amp as amp
......@@ -37,10 +39,15 @@ from torch_harmonics import RealSHT, InverseRealSHT
from torch_harmonics import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
from torch_harmonics import ResampleS2
from ._layers import *
from torch_harmonics.examples.models._layers import MLP, SpectralConvS2, SequencePositionEmbedding, SpectralPositionEmbedding, LearnablePositionEmbedding
from functools import partial
# heuristic for finding theta_cutoff
def _compute_cutoff_radius(nlat, kernel_shape, basis_type):
theta_cutoff_factor = {"piecewise linear": 0.5, "morlet": 0.5, "zernike": math.sqrt(2.0)}
return (kernel_shape[0] + 1) * theta_cutoff_factor[basis_type] * math.pi / float(nlat - 1)
class DiscreteContinuousEncoder(nn.Module):
def __init__(
......@@ -51,8 +58,8 @@ class DiscreteContinuousEncoder(nn.Module):
grid_out="equiangular",
inp_chans=2,
out_chans=2,
kernel_shape=[3, 4],
basis_type="piecewise linear",
kernel_shape=(3, 3),
basis_type="morlet",
groups=1,
bias=False,
):
......@@ -70,7 +77,7 @@ class DiscreteContinuousEncoder(nn.Module):
grid_out=grid_out,
groups=groups,
bias=bias,
theta_cutoff=4.0 * torch.pi / float(out_shape[0] - 1),
theta_cutoff=_compute_cutoff_radius(in_shape[0], kernel_shape, basis_type),
)
def forward(self, x):
......@@ -93,11 +100,11 @@ class DiscreteContinuousDecoder(nn.Module):
grid_out="equiangular",
inp_chans=2,
out_chans=2,
kernel_shape=[3, 4],
basis_type="piecewise linear",
kernel_shape=(3, 3),
basis_type="morlet",
groups=1,
bias=False,
upsample_sht=False
upsample_sht=False,
):
super().__init__()
......@@ -121,7 +128,7 @@ class DiscreteContinuousDecoder(nn.Module):
grid_out=grid_out,
groups=groups,
bias=False,
theta_cutoff=4.0 * torch.pi / float(in_shape[0] - 1),
theta_cutoff=_compute_cutoff_radius(in_shape[0], kernel_shape, basis_type),
)
def forward(self, x):
......@@ -152,12 +159,13 @@ class SphericalNeuralOperatorBlock(nn.Module):
drop_rate=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.Identity,
norm_layer="none",
inner_skip="none",
outer_skip="identity",
use_mlp=True,
disco_kernel_shape=[3, 4],
disco_basis_type="piecewise linear",
disco_kernel_shape=(3, 3),
disco_basis_type="morlet",
bias=False,
):
super().__init__()
......@@ -171,6 +179,7 @@ class SphericalNeuralOperatorBlock(nn.Module):
# convolution layer
if conv_type == "local":
theta_cutoff = 2.0 * _compute_cutoff_radius(forward_transform.nlat, disco_kernel_shape, disco_basis_type)
self.local_conv = DiscreteContinuousConvS2(
input_dim,
output_dim,
......@@ -180,11 +189,11 @@ class SphericalNeuralOperatorBlock(nn.Module):
basis_type=disco_basis_type,
grid_in=forward_transform.grid,
grid_out=inverse_transform.grid,
bias=False,
theta_cutoff=4.0 * (disco_kernel_shape[0] + 1) * torch.pi / float(inverse_transform.nlat - 1),
bias=bias,
theta_cutoff=theta_cutoff,
)
elif conv_type == "global":
self.global_conv = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain_factor, bias=False)
self.global_conv = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain_factor, bias=bias)
else:
raise ValueError(f"Unknown convolution type {conv_type}")
......@@ -199,8 +208,15 @@ class SphericalNeuralOperatorBlock(nn.Module):
else:
raise ValueError(f"Unknown skip connection type {inner_skip}")
# first normalisation layer
self.norm0 = norm_layer()
# normalisation layer
if norm_layer == "layer_norm":
self.norm = nn.LayerNorm(normalized_shape=(inverse_transform.nlat, inverse_transform.nlon), eps=1e-6)
elif norm_layer == "instance_norm":
self.norm = nn.InstanceNorm2d(num_features=output_dim, eps=1e-6, affine=True, track_running_stats=False)
elif norm_layer == "none":
self.norm = nn.Identity()
else:
raise NotImplementedError(f"Error, normalization {norm_layer} not implemented.")
# dropout
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
......@@ -232,9 +248,6 @@ class SphericalNeuralOperatorBlock(nn.Module):
else:
raise ValueError(f"Unknown skip connection type {outer_skip}")
# second normalisation layer
self.norm1 = norm_layer()
def forward(self, x):
residual = x
......@@ -244,7 +257,7 @@ class SphericalNeuralOperatorBlock(nn.Module):
elif hasattr(self, "local_conv"):
x = self.local_conv(x)
x = self.norm0(x)
x = self.norm(x)
if hasattr(self, "inner_skip"):
x = x + self.inner_skip(residual)
......@@ -252,8 +265,6 @@ class SphericalNeuralOperatorBlock(nn.Module):
if hasattr(self, "mlp"):
x = self.mlp(x)
x = self.norm1(x)
x = self.drop_path(x)
if hasattr(self, "outer_skip"):
......@@ -262,7 +273,7 @@ class SphericalNeuralOperatorBlock(nn.Module):
return x
class LocalSphericalNeuralOperatorNet(nn.Module):
class LocalSphericalNeuralOperator(nn.Module):
"""
LocalSphericalNeuralOperator module. A spherical neural operator which uses both local and global integral
operators to accureately model both types of solution operators [1]. The architecture is based on the Spherical
......@@ -300,6 +311,8 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
Dropout path rate, by default 0.0
normalization_layer : str, optional
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
sfno_block_frequency : int, optional
Hopw often a (global) SFNO block is used, by default 2
hard_thresholding_fraction : float, optional
Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0
big_skip : bool, optional
......@@ -308,6 +321,8 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
Whether to use positional embedding, by default True
upsample_sht : bool, optional
Use SHT upsampling if true, else linear interpolation
bias : bool, optional
Whether to use a bias, by default False
Example
-----------
......@@ -345,19 +360,20 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
embed_dim=256,
num_layers=4,
activation_function="gelu",
kernel_shape=[3, 4],
encoder_kernel_shape=[3, 4],
filter_basis_type="piecewise linear",
kernel_shape=(3, 3),
encoder_kernel_shape=(3, 3),
filter_basis_type="morlet",
use_mlp=True,
mlp_ratio=2.0,
drop_rate=0.0,
drop_path_rate=0.0,
normalization_layer="none",
sfno_block_frequency=2,
hard_thresholding_fraction=1.0,
use_complex_kernels=True,
big_skip=False,
pos_embed=False,
residual_prediction=False,
pos_embed="none",
upsample_sht=False,
bias=False,
):
super().__init__()
......@@ -373,7 +389,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
self.hard_thresholding_fraction = hard_thresholding_fraction
self.normalization_layer = normalization_layer
self.use_mlp = use_mlp
self.big_skip = big_skip
self.residual_prediction = residual_prediction
# activation function
if activation_function == "relu":
......@@ -394,30 +410,18 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0.0 else nn.Identity()
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_layers)]
# pick norm layer
if self.normalization_layer == "layer_norm":
norm_layer0 = partial(nn.LayerNorm, normalized_shape=(self.img_size[0], self.img_size[1]), eps=1e-6)
norm_layer1 = partial(nn.LayerNorm, normalized_shape=(self.h, self.w), eps=1e-6)
elif self.normalization_layer == "instance_norm":
norm_layer0 = partial(nn.InstanceNorm2d, num_features=self.embed_dim, eps=1e-6, affine=True, track_running_stats=False)
norm_layer1 = partial(nn.InstanceNorm2d, num_features=self.embed_dim, eps=1e-6, affine=True, track_running_stats=False)
elif self.normalization_layer == "none":
norm_layer0 = nn.Identity
norm_layer1 = norm_layer0
else:
raise NotImplementedError(f"Error, normalization {self.normalization_layer} not implemented.")
if pos_embed == "latlon" or pos_embed == True:
self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, self.h, self.w))
nn.init.constant_(self.pos_embed, 0.0)
elif pos_embed == "lat":
self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, self.h, 1))
nn.init.constant_(self.pos_embed, 0.0)
elif pos_embed == "const":
self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, 1, 1))
nn.init.constant_(self.pos_embed, 0.0)
if pos_embed == "sequence":
self.pos_embed = SequencePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal)
elif pos_embed == "spectral":
self.pos_embed = SpectralPositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal)
elif pos_embed == "learnable lat":
self.pos_embed = LearnablePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal, embed_type="lat")
elif pos_embed == "learnable latlon":
self.pos_embed = LearnablePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal, embed_type="latlon")
elif pos_embed == "none":
self.pos_embed = nn.Identity()
else:
self.pos_embed = None
raise ValueError(f"Unknown position embedding type {pos_embed}")
# encoder
self.encoder = DiscreteContinuousEncoder(
......@@ -445,30 +449,22 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
self.blocks = nn.ModuleList([])
for i in range(self.num_layers):
first_layer = i == 0
last_layer = i == self.num_layers - 1
if first_layer:
norm_layer = norm_layer1
elif last_layer:
norm_layer = norm_layer0
else:
norm_layer = norm_layer1
block = SphericalNeuralOperatorBlock(
self.trans,
self.itrans,
self.embed_dim,
self.embed_dim,
conv_type="global" if i % 2 == 0 else "local",
conv_type="global" if i % sfno_block_frequency == (sfno_block_frequency-1) else "local",
mlp_ratio=mlp_ratio,
drop_rate=drop_rate,
drop_path=dpr[i],
act_layer=self.activation_function,
norm_layer=norm_layer,
norm_layer=self.normalization_layer,
use_mlp=use_mlp,
disco_kernel_shape=kernel_shape,
disco_basis_type=filter_basis_type,
bias=bias,
)
self.blocks.append(block)
......@@ -485,17 +481,9 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
basis_type=filter_basis_type,
groups=1,
bias=False,
upsample_sht=upsample_sht
upsample_sht=upsample_sht,
)
# # residual prediction
# if self.big_skip:
# self.residual_transform = nn.Conv2d(self.out_chans, self.in_chans, 1, bias=False)
# self.residual_transform.weight.is_shared_mp = ["spatial"]
# self.residual_transform.weight.sharded_dims_mp = [None, None, None, None]
# scale = math.sqrt(0.5 / self.in_chans)
# nn.init.normal_(self.residual_transform.weight, mean=0.0, std=scale)
@torch.jit.ignore
def no_weight_decay(self):
return {"pos_embed", "cls_token"}
......@@ -509,20 +497,19 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
return x
def forward(self, x):
if self.big_skip:
if self.residual_prediction:
residual = x
x = self.encoder(x)
if self.pos_embed is not None:
x = x + self.pos_embed
x = self.pos_embed(x)
x = self.forward_features(x)
x = self.decoder(x)
if self.big_skip:
# x = x + self.residual_transform(residual)
if self.residual_prediction:
x = x + residual
return x
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import math
import torch
import torch.nn as nn
import torch.amp as amp
from torch_harmonics import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
from torch_harmonics import AttentionS2, NeighborhoodAttentionS2
from torch_harmonics import ResampleS2
from torch_harmonics import RealSHT, InverseRealSHT
from torch_harmonics.quadrature import _precompute_latitudes
from torch_harmonics.examples.models._layers import MLP, LayerNorm, DropPath
from functools import partial
# heuristic for finding theta_cutoff
def _compute_cutoff_radius(nlat, kernel_shape, basis_type):
theta_cutoff_factor = {"piecewise linear": 0.5, "morlet": 0.5, "zernike": math.sqrt(2.0)}
return (kernel_shape[0] + 1) * theta_cutoff_factor[basis_type] * math.pi / float(nlat - 1)
class OverlapPatchMerging(nn.Module):
def __init__(
self,
in_shape=(721, 1440),
out_shape=(481, 960),
grid_in="equiangular",
grid_out="equiangular",
in_channels=3,
out_channels=64,
kernel_shape=(3, 3),
basis_type="morlet",
bias=False,
):
super().__init__()
# convolution for patches, curtoff radius inferred from kernel shape
theta_cutoff = _compute_cutoff_radius(out_shape[0], kernel_shape, basis_type)
self.conv = DiscreteContinuousConvS2(
in_channels,
out_channels,
in_shape=in_shape,
out_shape=out_shape,
kernel_shape=kernel_shape,
basis_type=basis_type,
grid_in=grid_in,
grid_out=grid_out,
bias=bias,
theta_cutoff=theta_cutoff,
)
# layer norm
self.norm = nn.LayerNorm((out_channels), eps=1e-05, elementwise_affine=True, bias=True)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x):
dtype = x.dtype
with amp.autocast(device_type="cuda", enabled=False):
x = x.float()
x = self.conv(x).to(dtype=dtype)
# permute
x = x.permute(0, 2, 3, 1)
x = self.norm(x)
out = x.permute(0, 3, 1, 2)
return out
class MixFFN(nn.Module):
def __init__(
self,
shape,
inout_channels,
hidden_channels,
mlp_bias=True,
grid="equiangular",
kernel_shape=(3, 3),
basis_type="morlet",
conv_bias=False,
activation=nn.GELU,
use_mlp=False,
drop_path=0.0,
):
super().__init__()
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm = nn.LayerNorm((inout_channels), eps=1e-05, elementwise_affine=True, bias=True)
if use_mlp:
# although the paper says MLP, it uses a single linear layer
self.mlp_in = MLP(inout_channels, hidden_features=hidden_channels, out_features=inout_channels, act_layer=activation, output_bias=False, drop_rate=0.0)
else:
self.mlp_in = nn.Conv2d(in_channels=inout_channels, out_channels=inout_channels, kernel_size=1, bias=True)
# convolution for patches, curtoff radius inferred from kernel shape
theta_cutoff = _compute_cutoff_radius(shape[0], kernel_shape, basis_type)
self.conv = DiscreteContinuousConvS2(
inout_channels,
inout_channels,
in_shape=shape,
out_shape=shape,
kernel_shape=kernel_shape,
basis_type=basis_type,
grid_in=grid,
grid_out=grid,
groups=inout_channels,
bias=conv_bias,
theta_cutoff=theta_cutoff,
)
if use_mlp:
self.mlp_out = MLP(inout_channels, hidden_features=hidden_channels, out_features=inout_channels, act_layer=activation, output_bias=False, drop_rate=0.0)
else:
self.mlp_out = nn.Conv2d(in_channels=inout_channels, out_channels=inout_channels, kernel_size=1, bias=True)
self.act = activation()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
# norm
x = x.permute(0, 2, 3, 1)
x = self.norm(x)
x = x.permute(0, 3, 1, 2)
# NOTE: we add another activation here
# because in the paper they only use depthwise conv,
# but without this activation it would just be a fused MM
# with the disco conv
x = self.mlp_in(x)
# conv parth
x = self.act(self.conv(x))
# second linear
x = self.mlp_out(x)
return residual + self.drop_path(x)
class AttentionWrapper(nn.Module):
def __init__(
self,
channels,
shape,
grid,
heads,
pre_norm=False,
attention_drop_rate=0.0,
drop_path=0.0,
attention_mode="neighborhood",
theta_cutoff=None,
bias=True
):
super().__init__()
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.attention_mode = attention_mode
if attention_mode == "neighborhood":
if theta_cutoff is None:
theta_cutoff = (7.0 / math.sqrt(math.pi)) * math.pi / (shape[0] - 1)
self.att = NeighborhoodAttentionS2(
in_channels=channels,
in_shape=shape,
out_shape=shape,
grid_in=grid,
grid_out=grid,
theta_cutoff=theta_cutoff,
out_channels=channels,
num_heads=heads,
bias=bias
# drop_rate=attention_drop_rate,
)
else:
self.att = AttentionS2(
in_channels=channels,
num_heads=heads,
in_shape=shape,
out_shape=shape,
grid_in=grid,
grid_out=grid,
out_channels=channels,
drop_rate=attention_drop_rate,
bias=bias
)
self.norm = None
if pre_norm:
self.norm = nn.LayerNorm((channels), eps=1e-05, elementwise_affine=True, bias=True)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
if self.norm is not None:
x = x.permute(0, 2, 3, 1)
x = self.norm(x)
x = x.permute(0, 3, 1, 2)
if self.attention_mode == "neighborhood":
dtype = x.dtype
with amp.autocast(device_type="cuda", enabled=False):
x = x.float()
x = self.att(x).to(dtype=dtype)
else:
x = self.att(x)
return residual + self.drop_path(x)
class TransformerBlock(nn.Module):
def __init__(
self,
in_shape,
out_shape,
in_channels,
out_channels,
mlp_hidden_channels,
grid_in="equiangular",
grid_out="equiangular",
nrep=1,
heads=1,
kernel_shape=(3, 3),
basis_type="morlet",
activation=nn.GELU,
att_drop_rate=0.0,
drop_path_rates=0.0,
attention_mode="neighborhood",
theta_cutoff=None,
bias=True
):
super().__init__()
self.in_shape = in_shape
self.out_shape = out_shape
self.in_channels = in_channels
self.out_channels = out_channels
if isinstance(drop_path_rates, float):
drop_path_rates = [x.item() for x in torch.linspace(0, drop_path_rates, nrep)]
assert len(drop_path_rates) == nrep
self.fwd = [
OverlapPatchMerging(
in_shape=in_shape,
out_shape=out_shape,
grid_in=grid_in,
grid_out=grid_out,
in_channels=in_channels,
out_channels=out_channels,
kernel_shape=kernel_shape,
basis_type=basis_type,
bias=False,
)
]
for i in range(nrep):
self.fwd.append(
AttentionWrapper(
channels=out_channels,
shape=out_shape,
grid=grid_out,
heads=heads,
pre_norm=True,
attention_drop_rate=att_drop_rate,
drop_path=drop_path_rates[i],
attention_mode=attention_mode,
theta_cutoff=theta_cutoff,
bias=bias
)
)
self.fwd.append(
MixFFN(
out_shape,
inout_channels=out_channels,
hidden_channels=mlp_hidden_channels,
mlp_bias=True,
grid=grid_out,
kernel_shape=kernel_shape,
basis_type=basis_type,
conv_bias=False,
activation=activation,
use_mlp=False,
drop_path=drop_path_rates[i],
)
)
# make sequential
self.fwd = nn.Sequential(*self.fwd)
# final norm
self.norm = nn.LayerNorm((out_channels), eps=1e-05, elementwise_affine=True, bias=True)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fwd(x)
# apply norm
x = x.permute(0, 2, 3, 1)
x = self.norm(x)
x = x.permute(0, 3, 1, 2)
return x
class Upsampling(nn.Module):
def __init__(
self,
in_shape,
out_shape,
in_channels,
out_channels,
hidden_channels,
mlp_bias=True,
grid_in="equiangular",
grid_out="equiangular",
kernel_shape=(3, 3),
basis_type="morlet",
conv_bias=False,
activation=nn.GELU,
use_mlp=False,
upsampling_method="conv"
):
super().__init__()
if use_mlp:
self.mlp = MLP(in_channels, hidden_features=hidden_channels, out_features=out_channels, act_layer=activation, output_bias=False, drop_rate=0.0)
else:
self.mlp = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=True)
if upsampling_method == "conv":
theta_cutoff = _compute_cutoff_radius(in_shape[0], kernel_shape, basis_type)
self.upsample = DiscreteContinuousConvTransposeS2(
out_channels,
out_channels,
in_shape=in_shape,
out_shape=out_shape,
kernel_shape=kernel_shape,
basis_type=basis_type,
grid_in=grid_in,
grid_out=grid_out,
bias=conv_bias,
theta_cutoff=theta_cutoff,
)
elif upsampling_method == "bilinear":
self.upsample = ResampleS2(*in_shape, *out_shape, grid_in=grid_in, grid_out=grid_out)
else:
raise ValueError(f"Unknown upsampling method {upsampling_method}")
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.upsample(self.mlp(x))
return x
class SphericalSegformer(nn.Module):
"""
Spherical segformer model designed to approximate mappings from spherical signals to spherical segmentation masks
Parameters
-----------
img_shape : tuple, optional
Shape of the input channels, by default (128, 256)
kernel_shape: tuple, int
scale_factor: int, optional
Scale factor to use, by default 2
in_chans : int, optional
Number of input channels, by default 3
out_chans : int, optional
Number of classes, by default 3
embed_dims : List[int], optional
Dimension of the embeddings for each block, has to be the same length as heads
heads : List[int], optional
Number of heads for each block in the network, has to be the same length as embed_dims
depths: List[in], optional
Number of repetitions of attentions blocks and ffn mixers per layer. Has to be the same length as embed_dims and heads
activation_function : str, optional
Activation function to use, by default "gelu"
embedder_kernel_shape : int, optional
size of the encoder kernel
filter_basis_type: Optional[str]: str, optional
filter basis type
use_mlp : int, optional
Whether to use MLPs in the SFNO blocks, by default True
mlp_ratio : int, optional
Ratio of MLP to use, by default 2.0
drop_rate : float, optional
Dropout rate, by default 0.0
drop_path_rate : float, optional
Dropout path rate, by default 0.0
normalization_layer : str, optional
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
hard_thresholding_fraction : float, optional
Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0
upsampling_method : str
Conv, bilinear
Example
-----------
>>> model = SphericalTransformer(
... img_shape=(128, 256),
... scale_factor=4,
... in_chans=2,
... out_chans=2,
... embed_dim=16,
... num_layers=4,
... use_mlp=True,)
>>> model(torch.randn(1, 2, 128, 256)).shape
torch.Size([1, 2, 128, 256])
"""
def __init__(
self,
img_size=(128, 256),
grid="equiangular",
grid_internal="legendre-gauss",
in_chans=3,
out_chans=3,
embed_dims=[64, 128, 256, 512],
heads=[1, 2, 4, 8],
depths=[3, 4, 6, 3],
scale_factor=2,
activation_function="gelu",
kernel_shape=(3, 3),
filter_basis_type="morlet",
mlp_ratio=2.0,
att_drop_rate=0.0,
drop_path_rate=0.1,
attention_mode="neighborhood",
theta_cutoff=None,
upsampling_method="bilinear",
bias=True
):
super().__init__()
self.img_size = img_size
self.grid = grid
self.grid_internal = grid_internal
self.in_chans = in_chans
self.out_chans = out_chans
self.embed_dims = embed_dims
self.heads = heads
self.num_blocks = len(self.embed_dims)
self.depths = depths
self.kernel_shape = kernel_shape
assert len(self.heads) == self.num_blocks
assert len(self.depths) == self.num_blocks
# activation function
if activation_function == "relu":
self.activation_function = nn.ReLU
elif activation_function == "gelu":
self.activation_function = nn.GELU
# for debugging purposes
elif activation_function == "identity":
self.activation_function = nn.Identity
else:
raise ValueError(f"Unknown activation function {activation_function}")
# set up drop path rates
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
self.blocks = nn.ModuleList([])
out_shape = img_size
grid_in = grid
grid_out = grid_internal
in_channels = in_chans
cur = 0
for i in range(self.num_blocks):
out_shape_new = (out_shape[0] // scale_factor, out_shape[1] // scale_factor)
out_channels = self.embed_dims[i]
self.blocks.append(
TransformerBlock(
in_shape=out_shape,
out_shape=out_shape_new,
in_channels=in_channels,
out_channels=out_channels,
mlp_hidden_channels=int(mlp_ratio * out_channels),
grid_in=grid_in,
grid_out=grid_out,
nrep=self.depths[i],
heads=self.heads[i],
kernel_shape=kernel_shape,
basis_type=filter_basis_type,
activation=self.activation_function,
att_drop_rate=att_drop_rate,
drop_path_rates=dpr[cur : cur + self.depths[i]],
attention_mode=attention_mode,
theta_cutoff=theta_cutoff,
bias=bias
)
)
cur += self.depths[i]
out_shape = out_shape_new
grid_in = grid_internal
in_channels = out_channels
self.upsamplers = nn.ModuleList([])
out_shape = img_size
grid_out = grid
for i in range(self.num_blocks):
in_shape = self.blocks[i].out_shape
self.upsamplers.append(
Upsampling(
in_shape=in_shape,
out_shape=out_shape,
in_channels=self.embed_dims[i],
out_channels=self.embed_dims[i],
hidden_channels=int(mlp_ratio * self.embed_dims[i]),
mlp_bias=True,
grid_in=grid_internal,
grid_out=grid,
kernel_shape=kernel_shape,
basis_type=filter_basis_type,
conv_bias=False,
activation=nn.GELU,
upsampling_method=upsampling_method
)
)
segmentation_head_dim = sum(self.embed_dims)
self.segmentation_head = nn.Conv2d(in_channels=segmentation_head_dim, out_channels=out_chans, kernel_size=1, bias=True)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x):
# encoder:
features = []
feat = x
for block in self.blocks:
feat = block(feat)
features.append(feat)
# perform upsample
upfeats = []
for feat, upsampler in zip(features, self.upsamplers):
upfeats.append(upsampler(feat))
# perform concatenation
upfeats = torch.cat(upfeats, dim=1)
# final upsampling and prediction
out = self.segmentation_head(upfeats)
return out
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import math
import torch
import torch.nn as nn
import torch.amp as amp
from torch_harmonics import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
from torch_harmonics import NeighborhoodAttentionS2, AttentionS2
from torch_harmonics import ResampleS2
from torch_harmonics import RealSHT, InverseRealSHT
from torch_harmonics.quadrature import _precompute_latitudes
from torch_harmonics.examples.models._layers import MLP, DropPath, LayerNorm, SequencePositionEmbedding, SpectralPositionEmbedding, LearnablePositionEmbedding
from functools import partial
# heuristic for finding theta_cutoff
def _compute_cutoff_radius(nlat, kernel_shape, basis_type):
theta_cutoff_factor = {"piecewise linear": 0.5, "morlet": 0.5, "zernike": math.sqrt(2.0)}
return (kernel_shape[0] + 1) * theta_cutoff_factor[basis_type] * math.pi / float(nlat - 1)
class DiscreteContinuousEncoder(nn.Module):
def __init__(
self,
in_shape=(721, 1440),
out_shape=(480, 960),
grid_in="equiangular",
grid_out="equiangular",
in_chans=2,
out_chans=2,
kernel_shape=(3, 3),
basis_type="morlet",
groups=1,
bias=False,
):
super().__init__()
# set up local convolution
self.conv = DiscreteContinuousConvS2(
in_chans,
out_chans,
in_shape=in_shape,
out_shape=out_shape,
kernel_shape=kernel_shape,
basis_type=basis_type,
grid_in=grid_in,
grid_out=grid_out,
groups=groups,
bias=bias,
theta_cutoff=_compute_cutoff_radius(in_shape[0], kernel_shape, basis_type),
)
def forward(self, x):
dtype = x.dtype
with amp.autocast(device_type="cuda", enabled=False):
x = x.float()
x = self.conv(x)
x = x.to(dtype=dtype)
return x
class DiscreteContinuousDecoder(nn.Module):
def __init__(
self,
in_shape=(480, 960),
out_shape=(721, 1440),
grid_in="equiangular",
grid_out="equiangular",
in_chans=2,
out_chans=2,
kernel_shape=(3, 3),
basis_type="morlet",
groups=1,
bias=False,
upsample_sht=False,
):
super().__init__()
# set up upsampling
if upsample_sht:
self.sht = RealSHT(*in_shape, grid=grid_in).float()
self.isht = InverseRealSHT(*out_shape, lmax=self.sht.lmax, mmax=self.sht.mmax, grid=grid_out).float()
self.upsample = nn.Sequential(self.sht, self.isht)
else:
self.upsample = ResampleS2(*in_shape, *out_shape, grid_in=grid_in, grid_out=grid_out)
# set up DISCO convolution
self.conv = DiscreteContinuousConvS2(
in_chans,
out_chans,
in_shape=out_shape,
out_shape=out_shape,
kernel_shape=kernel_shape,
basis_type=basis_type,
grid_in=grid_out,
grid_out=grid_out,
groups=groups,
bias=False,
theta_cutoff=_compute_cutoff_radius(in_shape[0], kernel_shape, basis_type),
)
def forward(self, x):
dtype = x.dtype
with amp.autocast(device_type="cuda", enabled=False):
x = x.float()
x = self.upsample(x)
x = self.conv(x)
x = x.to(dtype=dtype)
return x
class SphericalAttentionBlock(nn.Module):
"""
Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
"""
def __init__(
self,
in_shape=(480, 960),
out_shape=(480, 960),
grid_in="equiangular",
grid_out="equiangular",
in_chans=2,
out_chans=2,
num_heads=1,
mlp_ratio=2.0,
drop_rate=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer="none",
use_mlp=True,
bias=False,
attention_mode="neighborhood",
theta_cutoff=None,
):
super().__init__()
# normalisation layer
if norm_layer == "layer_norm":
self.norm0 = LayerNorm(in_channels=in_chans, eps=1e-6)
self.norm1 = LayerNorm(in_channels=out_chans, eps=1e-6)
elif norm_layer == "instance_norm":
self.norm0 = nn.InstanceNorm2d(num_features=in_chans, eps=1e-6, affine=True, track_running_stats=False)
self.norm1 = nn.InstanceNorm2d(num_features=out_chans, eps=1e-6, affine=True, track_running_stats=False)
elif norm_layer == "none":
self.norm0 = nn.Identity()
self.norm1 = nn.Identity()
else:
raise NotImplementedError(f"Error, normalization {norm_layer} not implemented.")
# determine radius for neighborhood attention
self.attention_mode = attention_mode
if attention_mode == "neighborhood":
if theta_cutoff is None:
theta_cutoff = (7.0 / math.sqrt(math.pi)) * math.pi / (in_shape[0] - 1)
self.self_attn = NeighborhoodAttentionS2(
in_channels=in_chans,
in_shape=in_shape,
out_shape=out_shape,
grid_in=grid_in,
grid_out=grid_out,
num_heads=num_heads,
theta_cutoff=theta_cutoff,
k_channels=None,
out_channels=out_chans,
bias=bias,
)
else:
self.self_attn = AttentionS2(
in_channels=in_chans,
num_heads=num_heads,
in_shape=in_shape,
out_shape=out_shape,
grid_in=grid_in,
grid_out=grid_out,
out_channels=out_chans,
drop_rate=drop_rate,
bias=bias,
)
self.skip0 = nn.Identity()
# dropout
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
if use_mlp == True:
mlp_hidden_dim = int(out_chans * mlp_ratio)
self.mlp = MLP(
in_features=out_chans,
out_features=out_chans,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop_rate=drop_rate,
checkpointing=False,
gain=0.5,
)
self.skip1 = nn.Identity()
def forward(self, x):
residual = x
x = self.norm0(x)
if self.attention_mode == "neighborhood":
dtype = x.dtype
with amp.autocast(device_type="cuda", enabled=False):
x = x.float()
x = self.self_attn(x).to(dtype=dtype)
else:
x = self.self_attn(x)
if hasattr(self, "skip0"):
x = x + self.skip0(residual)
residual = x
x = self.norm1(x)
if hasattr(self, "mlp"):
x = self.mlp(x)
x = self.drop_path(x)
if hasattr(self, "skip1"):
x = x + self.skip1(residual)
return x
class SphericalTransformer(nn.Module):
"""
Spherical transformer model designed to approximate mappings from spherical signals to spherical signals
Parameters
-----------
img_shape : tuple, optional
Shape of the input channels, by default (128, 256)
kernel_shape: tuple, int
scale_factor : int, optional
Scale factor to use, by default 3
in_chans : int, optional
Number of input channels, by default 3
out_chans : int, optional
Number of output channels, by default 3
embed_dim : int, optional
Dimension of the embeddings, by default 256
num_layers : int, optional
Number of layers in the network, by default 4
activation_function : str, optional
Activation function to use, by default "gelu"
encoder_kernel_shape : int, optional
size of the encoder kernel
filter_basis_type: str, optional
filter basis type
num_heads: int, optional
number of attention heads
use_mlp : int, optional
Whether to use MLPs in the SFNO blocks, by default True
mlp_ratio : int, optional
Ratio of MLP to use, by default 2.0
drop_rate : float, optional
Dropout rate, by default 0.0
drop_path_rate : float, optional
Dropout path rate, by default 0.0
normalization_layer : str, optional
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
hard_thresholding_fraction : float, optional
Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0
residual_prediction : bool, optional
Whether to add a single large skip connection, by default True
pos_embed : bool, optional
Whether to use positional embedding, by default True
upsample_sht : bool, optional
Use SHT upsampling if true, else linear interpolation
bias : bool, optional
Whether to use a bias, by default False
Example
-----------
>>> model = SphericalTransformer(
... img_shape=(128, 256),
... scale_factor=4,
... in_chans=2,
... out_chans=2,
... embed_dim=16,
... num_layers=4,
... use_mlp=True,)
>>> model(torch.randn(1, 2, 128, 256)).shape
torch.Size([1, 2, 128, 256])
"""
def __init__(
self,
img_size=(128, 256),
grid="equiangular",
grid_internal="legendre-gauss",
scale_factor=3,
in_chans=3,
out_chans=3,
embed_dim=256,
num_layers=4,
activation_function="gelu",
encoder_kernel_shape=(3, 3),
filter_basis_type="morlet",
num_heads=1,
use_mlp=True,
mlp_ratio=2.0,
drop_rate=0.0,
drop_path_rate=0.0,
normalization_layer="none",
hard_thresholding_fraction=1.0,
residual_prediction=False,
pos_embed="spectral",
upsample_sht=False,
attention_mode="neighborhood",
bias=False,
theta_cutoff=None,
):
super().__init__()
self.img_size = img_size
self.grid = grid
self.grid_internal = grid_internal
self.scale_factor = scale_factor
self.in_chans = in_chans
self.out_chans = out_chans
self.embed_dim = embed_dim
self.num_layers = num_layers
self.encoder_kernel_shape = encoder_kernel_shape
self.hard_thresholding_fraction = hard_thresholding_fraction
self.normalization_layer = normalization_layer
self.use_mlp = use_mlp
self.residual_prediction = residual_prediction
# activation function
if activation_function == "relu":
self.activation_function = nn.ReLU
elif activation_function == "gelu":
self.activation_function = nn.GELU
# for debugging purposes
elif activation_function == "identity":
self.activation_function = nn.Identity
else:
raise ValueError(f"Unknown activation function {activation_function}")
# compute downsampled image size. We assume that the latitude-grid includes both poles
self.h = (self.img_size[0] - 1) // scale_factor + 1
self.w = self.img_size[1] // scale_factor
# dropout
self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0.0 else nn.Identity()
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_layers)]
if pos_embed == "sequence":
self.pos_embed = SequencePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal)
elif pos_embed == "spectral":
self.pos_embed = SpectralPositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal)
elif pos_embed == "learnable lat":
self.pos_embed = LearnablePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal, embed_type="lat")
elif pos_embed == "learnable latlon":
self.pos_embed = LearnablePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal, embed_type="latlon")
elif pos_embed == "none":
self.pos_embed = nn.Identity()
else:
raise ValueError(f"Unknown position embedding type {pos_embed}")
# maybe keep for now becuase tr
# encoder
self.encoder = DiscreteContinuousEncoder(
in_shape=self.img_size,
out_shape=(self.h, self.w),
grid_in=grid,
grid_out=grid_internal,
in_chans=self.in_chans,
out_chans=self.embed_dim,
kernel_shape=self.encoder_kernel_shape,
basis_type=filter_basis_type,
groups=1,
bias=False,
)
self.blocks = nn.ModuleList([])
for i in range(self.num_layers):
block = SphericalAttentionBlock(
in_shape=(self.h, self.w),
out_shape=(self.h, self.w),
grid_in=grid_internal,
grid_out=grid_internal,
in_chans=self.embed_dim,
out_chans=self.embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
drop_rate=drop_rate,
drop_path=dpr[i],
act_layer=self.activation_function,
norm_layer=self.normalization_layer,
use_mlp=use_mlp,
bias=bias,
attention_mode=attention_mode,
theta_cutoff=theta_cutoff,
)
self.blocks.append(block)
# decoder
self.decoder = DiscreteContinuousDecoder(
in_shape=(self.h, self.w),
out_shape=self.img_size,
grid_in=grid_internal,
grid_out=grid,
in_chans=self.embed_dim,
out_chans=self.out_chans,
kernel_shape=self.encoder_kernel_shape,
basis_type=filter_basis_type,
groups=1,
bias=False,
upsample_sht=upsample_sht,
)
@torch.jit.ignore
def no_weight_decay(self):
return {"pos_embed", "cls_token"}
def forward_features(self, x):
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
return x
def forward(self, x):
if self.residual_prediction:
residual = x
x = self.encoder(x)
if self.pos_embed is not None:
# x = x + self.pos_embed
x = self.pos_embed(x)
x = self.forward_features(x)
x = self.decoder(x)
if self.residual_prediction:
x = x + residual
return x
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import math
import torch
import torch.nn as nn
import torch.amp as amp
from torch_harmonics import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
from torch_harmonics import NeighborhoodAttentionS2
from torch_harmonics import ResampleS2
from torch_harmonics import RealSHT, InverseRealSHT
from torch_harmonics.quadrature import _precompute_latitudes
from torch_harmonics.examples.models._layers import MLP, DropPath
from functools import partial
# heuristic for finding theta_cutoff
def _compute_cutoff_radius(nlat, kernel_shape, basis_type):
theta_cutoff_factor = {"piecewise linear": 0.5, "morlet": 0.5, "zernike": math.sqrt(2.0)}
return (kernel_shape[0] + 1) * theta_cutoff_factor[basis_type] * math.pi / float(nlat - 1)
class DownsamplingBlock(nn.Module):
def __init__(
self,
in_shape,
out_shape,
in_channels,
out_channels,
grid_in="equiangular",
grid_out="equiangular",
nrep=1,
kernel_shape=(3, 3),
basis_type="morlet",
activation=nn.ReLU,
transform_skip=False,
drop_conv_rate=0.0,
drop_path_rate=0.0,
drop_dense_rate=0.0,
downsampling_mode="bilinear",
):
super().__init__()
self.in_shape = in_shape
self.out_shape = out_shape
self.in_channels = in_channels
self.out_channels = out_channels
self.grid_in = grid_in
self.grid_out = grid_out
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
self.fwd = []
for i in range(nrep):
# conv
theta_cutoff = _compute_cutoff_radius(in_shape[0], kernel_shape, basis_type)
self.fwd.append(
DiscreteContinuousConvS2(
in_channels=(in_channels if i == 0 else out_channels),
out_channels=out_channels,
in_shape=in_shape,
out_shape=in_shape,
kernel_shape=kernel_shape,
basis_type=basis_type,
grid_in=grid_out,
grid_out=grid_out,
bias=False,
theta_cutoff=theta_cutoff,
)
)
if drop_conv_rate > 0.0:
self.fwd.append(nn.Dropout2d(p=drop_conv_rate))
# batchnorm
self.fwd.append(nn.BatchNorm2d(out_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))
# activation
self.fwd.append(
activation(),
)
if downsampling_mode == "conv":
theta_cutoff = _compute_cutoff_radius(out_shape[0], kernel_shape, basis_type)
self.downsample = DiscreteContinuousConvS2(
out_channels,
out_channels,
in_shape=in_shape,
out_shape=out_shape,
kernel_shape=kernel_shape,
basis_type=basis_type,
grid_in=grid_in,
grid_out=grid_out,
bias=False,
theta_cutoff=theta_cutoff,
)
else:
self.downsample = ResampleS2(
nlat_in=in_shape[0],
nlon_in=in_shape[1],
nlat_out=out_shape[0],
nlon_out=out_shape[1],
grid_in=grid_in,
grid_out=grid_out,
mode=downsampling_mode,
)
# make sequential
self.fwd = nn.Sequential(*self.fwd)
# final norm
if transform_skip or (in_channels != out_channels):
self.transform_skip = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True)
if drop_dense_rate > 0.0:
self.transform_skip = nn.Sequential(
self.transform_skip,
nn.Dropout2d(p=drop_dense_rate),
)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# skip connection
residual = x
if hasattr(self, "transform_skip"):
residual = self.transform_skip(residual)
# main path
x = self.fwd(x)
# add residual connection
x = residual + self.drop_path(x)
# downsample
x = self.downsample(x)
return x
class UpsamplingBlock(nn.Module):
def __init__(
self,
in_shape,
out_shape,
in_channels,
out_channels,
grid_in="equiangular",
grid_out="equiangular",
nrep=1,
kernel_shape=(3, 3),
basis_type="morlet",
activation=nn.ReLU,
transform_skip=False,
drop_conv_rate=0.0,
drop_path_rate=0.0,
drop_dense_rate=0.0,
upsampling_mode="bilinear",
):
super().__init__()
self.in_shape = in_shape
self.out_shape = out_shape
self.in_channels = in_channels
self.out_channels = out_channels
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
if in_shape != out_shape:
if upsampling_mode == "conv":
theta_cutoff = _compute_cutoff_radius(in_shape[0], kernel_shape, basis_type)
self.upsample = nn.Sequential(
DiscreteContinuousConvTransposeS2(
in_channels=out_channels,
out_channels=out_channels,
in_shape=in_shape,
out_shape=out_shape,
kernel_shape=kernel_shape,
basis_type=basis_type,
grid_in=grid_in,
grid_out=grid_out,
bias=False,
theta_cutoff=theta_cutoff,
),
nn.BatchNorm2d(out_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
activation(),
DiscreteContinuousConvS2(
in_channels=out_channels,
out_channels=out_channels,
in_shape=out_shape,
out_shape=out_shape,
kernel_shape=kernel_shape,
basis_type=basis_type,
grid_in=grid_in,
grid_out=grid_out,
bias=False,
theta_cutoff=theta_cutoff,
),
)
else:
self.upsample = ResampleS2(
nlat_in=in_shape[0],
nlon_in=in_shape[1],
nlat_out=out_shape[0],
nlon_out=out_shape[1],
grid_in=grid_in,
grid_out=grid_out,
mode=upsampling_mode,
)
else:
theta_cutoff = _compute_cutoff_radius(in_shape[0], kernel_shape, basis_type)
self.upsample = DiscreteContinuousConvS2(
in_channels=out_channels,
out_channels=out_channels,
in_shape=in_shape,
out_shape=in_shape,
kernel_shape=kernel_shape,
basis_type=basis_type,
grid_in=grid_in,
grid_out=grid_out,
bias=False,
theta_cutoff=theta_cutoff,
)
self.fwd = []
for i in range(nrep):
# conv
theta_cutoff = _compute_cutoff_radius(in_shape[0], kernel_shape, basis_type)
self.fwd.append(
DiscreteContinuousConvS2(
in_channels=in_channels,
out_channels=(out_channels if i == nrep - 1 else in_channels),
in_shape=in_shape,
out_shape=in_shape,
kernel_shape=kernel_shape,
basis_type=basis_type,
grid_in=grid_in,
grid_out=grid_in,
bias=False,
theta_cutoff=theta_cutoff,
)
)
if drop_conv_rate > 0.0:
self.fwd.append(nn.Dropout2d(p=drop_conv_rate))
# batchnorm
self.fwd.append(nn.BatchNorm2d((out_channels if i == nrep - 1 else in_channels), eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))
# activation
self.fwd.append(
activation(),
)
# make sequential
self.fwd = nn.Sequential(*self.fwd)
# final norm
if transform_skip or (in_channels != out_channels):
self.transform_skip = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True)
if drop_dense_rate > 0.0:
self.transform_skip = nn.Sequential(
self.transform_skip,
nn.Dropout2d(p=drop_dense_rate),
)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# skip connection
residual = x
if hasattr(self, "transform_skip"):
residual = self.transform_skip(residual)
# main path
x = residual + self.drop_path(self.fwd(x))
# upsampling
x = self.upsample(x)
return x
class SphericalUNet(nn.Module):
"""
Spherical segformer model designed to approximate mappings from spherical signals to spherical segmentation masks
Parameters
-----------
img_shape : tuple, optional
Shape of the input channels, by default (128, 256)
kernel_shape: tuple, int
scale_factor: int, optional
Scale factor to use, by default 2
in_chans : int, optional
Number of input channels, by default 3
out_chans : int, optional
Number of classes, by default 3
embed_dims : List[int], optional
Dimension of the embeddings for each block, has to be the same length as depths
depths: List[in], optional
Number of repetitions of conv blocks and ffn mixers per layer. Has to be the same length as embed_dims
activation_function : str, optional
Activation function to use, by default "relu"
embedder_kernel_shape : int, optional
size of the encoder kernel
filter_basis_type: Optional[str]: str, optional
filter basis type
use_mlp : int, optional
Whether to use MLPs in the SFNO blocks, by default True
mlp_ratio : int, optional
Ratio of MLP to use, by default 2.0
drop_rate : float, optional
Dropout rate, by default 0.0
drop_path_rate : float, optional
Dropout path rate, by default 0.0
normalization_layer : str, optional
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
hard_thresholding_fraction : float, optional
Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0
upsample_sht : bool, optional
Use SHT upsampling if true, else linear interpolation
Example
-----------
>>> model = SphericalTransformer(
... img_shape=(128, 256),
... scale_factor=4,
... in_chans=2,
... out_chans=2,
... embed_dim=16,
... num_layers=4,
... use_mlp=True,)
>>> model(torch.randn(1, 2, 128, 256)).shape
torch.Size([1, 2, 128, 256])
"""
def __init__(
self,
img_size=(128, 256),
grid="equiangular",
grid_internal="legendre-gauss",
in_chans=3,
out_chans=3,
embed_dims=[64, 128, 256, 512],
depths=[2, 2, 2, 2],
scale_factor=2,
activation_function="relu",
kernel_shape=(3, 3),
filter_basis_type="morlet",
transform_skip=False,
drop_conv_rate=0.1,
drop_path_rate=0.1,
drop_dense_rate=0.5,
downsampling_mode="bilinear",
upsampling_mode="bilinear",
):
super().__init__()
self.img_size = img_size
self.grid = grid
self.grid_internal = grid_internal
self.in_chans = in_chans
self.out_chans = out_chans
self.embed_dims = embed_dims
self.num_blocks = len(self.embed_dims)
self.depths = depths
self.kernel_shape = kernel_shape
assert len(self.depths) == self.num_blocks
# activation function
if activation_function == "relu":
self.activation_function = nn.ReLU
elif activation_function == "gelu":
self.activation_function = nn.GELU
# for debugging purposes
elif activation_function == "identity":
self.activation_function = nn.Identity
else:
raise ValueError(f"Unknown activation function {activation_function}")
# set up drop path rates
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_blocks)]
self.dblocks = nn.ModuleList([])
out_shape = img_size
grid_in = grid
grid_out = grid_internal
in_channels = in_chans
for i in range(self.num_blocks):
out_shape_new = (out_shape[0] // scale_factor, out_shape[1] // scale_factor)
out_channels = self.embed_dims[i]
self.dblocks.append(
DownsamplingBlock(
in_shape=out_shape,
out_shape=out_shape_new,
in_channels=in_channels,
out_channels=out_channels,
grid_in=grid_in,
grid_out=grid_out,
nrep=self.depths[i],
kernel_shape=kernel_shape,
basis_type=filter_basis_type,
activation=self.activation_function,
drop_conv_rate=drop_conv_rate,
drop_path_rate=dpr[i],
drop_dense_rate=drop_dense_rate,
transform_skip=transform_skip,
downsampling_mode=downsampling_mode,
)
)
out_shape = out_shape_new
grid_in = grid_internal
in_channels = out_channels
self.ublocks = nn.ModuleList([])
for i in range(self.num_blocks - 1, -1, -1):
in_shape = self.dblocks[i].out_shape
out_shape = self.dblocks[i].in_shape
in_channels = self.dblocks[i].out_channels
if i != self.num_blocks - 1:
in_channels = 2 * in_channels
out_channels = self.dblocks[i].in_channels
if i == 0:
out_channels = self.embed_dims[0]
grid_in = self.dblocks[i].grid_out
grid_out = self.dblocks[i].grid_in
self.ublocks.append(
UpsamplingBlock(
in_shape=in_shape,
out_shape=out_shape,
in_channels=in_channels,
out_channels=out_channels,
grid_in=grid_in,
grid_out=grid_out,
kernel_shape=kernel_shape,
basis_type=filter_basis_type,
activation=self.activation_function,
drop_conv_rate=drop_conv_rate,
drop_path_rate=0.0,
drop_dense_rate=drop_dense_rate,
transform_skip=transform_skip,
upsampling_mode=upsampling_mode,
)
)
self.head = nn.Conv2d(self.embed_dims[0], self.out_chans, kernel_size=1, bias=True)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x):
# encoder:
features = []
feat = x
for dblock in self.dblocks:
feat = dblock(feat)
features.append(feat)
# reverse list
features = features[::-1]
# perform upsample
ufeat = self.ublocks[0](features[0])
for feat, ublock in zip(features[1:], self.ublocks[1:]):
ufeat = ublock(torch.cat([feat, ufeat], dim=1))
# last layer
out = self.head(ufeat)
return out
......@@ -28,12 +28,14 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import math
import torch
import torch.nn as nn
from torch_harmonics import RealSHT, InverseRealSHT
from ._layers import *
from torch_harmonics.examples.models._layers import MLP, SpectralConvS2, SequencePositionEmbedding, SpectralPositionEmbedding, LearnablePositionEmbedding
from functools import partial
......@@ -53,10 +55,11 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
drop_rate=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.Identity,
norm_layer="none",
inner_skip="none",
outer_skip="identity",
use_mlp=True,
bias=False,
):
super().__init__()
......@@ -68,7 +71,7 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
if inner_skip == "linear" or inner_skip == "identity":
gain_factor /= 2.0
self.global_conv = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain_factor, bias=False)
self.global_conv = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain_factor, bias=bias)
if inner_skip == "linear":
self.inner_skip = nn.Conv2d(input_dim, output_dim, 1, 1)
......@@ -81,8 +84,15 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
else:
raise ValueError(f"Unknown skip connection type {inner_skip}")
# first normalisation layer
self.norm0 = norm_layer()
# normalisation layer
if norm_layer == "layer_norm":
self.norm = nn.LayerNorm(normalized_shape=(inverse_transform.nlat, inverse_transform.nlon), eps=1e-6)
elif norm_layer == "instance_norm":
self.norm = nn.InstanceNorm2d(num_features=output_dim, eps=1e-6, affine=True, track_running_stats=False)
elif norm_layer == "none":
self.norm = nn.Identity()
else:
raise NotImplementedError(f"Error, normalization {self.norm_layer} not implemented.")
# dropout
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
......@@ -108,14 +118,12 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
else:
raise ValueError(f"Unknown skip connection type {outer_skip}")
# second normalisation layer
self.norm1 = norm_layer()
def forward(self, x):
x, residual = self.global_conv(x)
x = self.norm0(x)
x = self.norm(x)
if hasattr(self, "inner_skip"):
x = x + self.inner_skip(residual)
......@@ -123,8 +131,6 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
if hasattr(self, "mlp"):
x = self.mlp(x)
x = self.norm1(x)
x = self.drop_path(x)
if hasattr(self, "outer_skip"):
......@@ -133,7 +139,7 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
return x
class SphericalFourierNeuralOperatorNet(nn.Module):
class SphericalFourierNeuralOperator(nn.Module):
"""
SphericalFourierNeuralOperator module. Implements the 'linear' variant of the Spherical Fourier Neural Operator
as presented in [1]. Spherical convolutions are applied via spectral transforms to apply a geometrically consistent
......@@ -169,14 +175,16 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
hard_thresholding_fraction : float, optional
Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0
big_skip : bool, optional
residual_prediction : bool, optional
Whether to add a single large skip connection, by default True
pos_embed : bool, optional
Whether to use positional embedding, by default True
bias : bool, optional
Whether to use a bias, by default False
Example:
--------
>>> model = SphericalFourierNeuralOperatorNet(
>>> model = SphericalFourierNeuralOperator(
... img_shape=(128, 256),
... scale_factor=4,
... in_chans=2,
......@@ -212,9 +220,9 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
drop_path_rate=0.0,
normalization_layer="none",
hard_thresholding_fraction=1.0,
use_complex_kernels=True,
big_skip=False,
pos_embed=False,
residual_prediction=False,
pos_embed="none",
bias=False,
):
super().__init__()
......@@ -231,7 +239,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
self.normalization_layer = normalization_layer
self.use_mlp = use_mlp
self.encoder_layers = encoder_layers
self.big_skip = big_skip
self.residual_prediction = residual_prediction
# activation function
if activation_function == "relu":
......@@ -252,30 +260,18 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0.0 else nn.Identity()
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_layers)]
# pick norm layer
if self.normalization_layer == "layer_norm":
norm_layer0 = partial(nn.LayerNorm, normalized_shape=(self.img_size[0], self.img_size[1]), eps=1e-6)
norm_layer1 = partial(nn.LayerNorm, normalized_shape=(self.h, self.w), eps=1e-6)
elif self.normalization_layer == "instance_norm":
norm_layer0 = partial(nn.InstanceNorm2d, num_features=self.embed_dim, eps=1e-6, affine=True, track_running_stats=False)
norm_layer1 = partial(nn.InstanceNorm2d, num_features=self.embed_dim, eps=1e-6, affine=True, track_running_stats=False)
elif self.normalization_layer == "none":
norm_layer0 = nn.Identity
norm_layer1 = norm_layer0
else:
raise NotImplementedError(f"Error, normalization {self.normalization_layer} not implemented.")
if pos_embed == "latlon" or pos_embed == True:
self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, self.img_size[0], self.img_size[1]))
nn.init.constant_(self.pos_embed, 0.0)
elif pos_embed == "lat":
self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, self.img_size[0], 1))
nn.init.constant_(self.pos_embed, 0.0)
elif pos_embed == "const":
self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, 1, 1))
nn.init.constant_(self.pos_embed, 0.0)
if pos_embed == "sequence":
self.pos_embed = SequencePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal)
elif pos_embed == "spectral":
self.pos_embed = SpectralPositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal)
elif pos_embed == "learnable lat":
self.pos_embed = LearnablePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal, embed_type="lat")
elif pos_embed == "learnable latlon":
self.pos_embed = LearnablePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal, embed_type="latlon")
elif pos_embed == "none":
self.pos_embed = nn.Identity()
else:
self.pos_embed = None
raise ValueError(f"Unknown position embedding type {pos_embed}")
# construct an encoder with num_encoder_layers
num_encoder_layers = 1
......@@ -292,7 +288,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
encoder_layers.append(fc)
encoder_layers.append(self.activation_function())
current_dim = encoder_hidden_dim
fc = nn.Conv2d(current_dim, self.embed_dim, 1, bias=False)
fc = nn.Conv2d(current_dim, self.embed_dim, 1, bias=bias)
scale = math.sqrt(1.0 / current_dim)
nn.init.normal_(fc.weight, mean=0.0, std=scale)
if fc.bias is not None:
......@@ -318,13 +314,6 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
first_layer = i == 0
last_layer = i == self.num_layers - 1
if first_layer:
norm_layer = norm_layer1
elif last_layer:
norm_layer = norm_layer0
else:
norm_layer = norm_layer1
block = SphericalFourierNeuralOperatorBlock(
self.trans_down if first_layer else self.trans,
self.itrans_up if last_layer else self.itrans,
......@@ -334,8 +323,9 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
drop_rate=drop_rate,
drop_path=dpr[i],
act_layer=self.activation_function,
norm_layer=norm_layer,
norm_layer=self.normalization_layer,
use_mlp=use_mlp,
bias=bias,
)
self.blocks.append(block)
......@@ -343,7 +333,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
# construct an decoder with num_decoder_layers
num_decoder_layers = 1
decoder_hidden_dim = int(self.embed_dim * mlp_ratio)
current_dim = self.embed_dim + self.big_skip * self.in_chans
current_dim = self.embed_dim
decoder_layers = []
for l in range(num_decoder_layers - 1):
fc = nn.Conv2d(current_dim, decoder_hidden_dim, 1, bias=True)
......@@ -355,7 +345,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
decoder_layers.append(fc)
decoder_layers.append(self.activation_function())
current_dim = decoder_hidden_dim
fc = nn.Conv2d(current_dim, self.out_chans, 1, bias=False)
fc = nn.Conv2d(current_dim, self.out_chans, 1, bias=bias)
scale = math.sqrt(1.0 / current_dim)
nn.init.normal_(fc.weight, mean=0.0, std=scale)
if fc.bias is not None:
......@@ -378,19 +368,19 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
def forward(self, x):
if self.big_skip:
if self.residual_prediction:
residual = x
x = self.encoder(x)
if self.pos_embed is not None:
x = x + self.pos_embed
x = self.pos_embed(x)
x = self.forward_features(x)
if self.big_skip:
x = torch.cat((x, residual), dim=1)
x = self.decoder(x)
if self.residual_prediction:
x = x + residual
return x
......@@ -32,7 +32,7 @@
import torch
import torch.nn as nn
import torch_harmonics as harmonics
import torch_harmonics as th
from torch_harmonics.quadrature import _precompute_longitudes
import math
......@@ -61,19 +61,19 @@ class SphereSolver(nn.Module):
self.register_buffer('coeff', torch.as_tensor(coeff, dtype=torch.float64))
# SHT
self.sht = harmonics.RealSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False)
self.isht = harmonics.InverseRealSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False)
self.sht = th.RealSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False)
self.isht = th.InverseRealSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False)
self.lmax = lmax or self.sht.lmax
self.mmax = lmax or self.sht.mmax
# compute gridpoints
if self.grid == "legendre-gauss":
cost, _ = harmonics.quadrature.legendre_gauss_weights(self.nlat, -1, 1)
cost, _ = th.quadrature.legendre_gauss_weights(self.nlat, -1, 1)
elif self.grid == "lobatto":
cost, _ = harmonics.quadrature.lobatto_weights(self.nlat, -1, 1)
cost, _ = th.quadrature.lobatto_weights(self.nlat, -1, 1)
elif self.grid == "equiangular":
cost, _ = harmonics.quadrature.clenshaw_curtiss_weights(self.nlat, -1, 1)
cost, _ = th.quadrature.clenshaw_curtiss_weights(self.nlat, -1, 1)
# apply cosine transform and flip them
lats = -torch.arcsin(cost)
......
......@@ -32,7 +32,7 @@
import torch
import torch.nn as nn
import torch_harmonics as harmonics
import torch_harmonics as th
from torch_harmonics.quadrature import _precompute_longitudes
import math
......@@ -64,21 +64,21 @@ class ShallowWaterSolver(nn.Module):
self.register_buffer('hamp', torch.as_tensor(hamp, dtype=torch.float64))
# SHT
self.sht = harmonics.RealSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False)
self.isht = harmonics.InverseRealSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False)
self.vsht = harmonics.RealVectorSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False)
self.ivsht = harmonics.InverseRealVectorSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False)
self.sht = th.RealSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False)
self.isht = th.InverseRealSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False)
self.vsht = th.RealVectorSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False)
self.ivsht = th.InverseRealVectorSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False)
self.lmax = lmax or self.sht.lmax
self.mmax = lmax or self.sht.mmax
# compute gridpoints
if self.grid == "legendre-gauss":
cost, quad_weights = harmonics.quadrature.legendre_gauss_weights(self.nlat, -1, 1)
cost, quad_weights = th.quadrature.legendre_gauss_weights(self.nlat, -1, 1)
elif self.grid == "lobatto":
cost, quad_weights = harmonics.quadrature.lobatto_weights(self.nlat, -1, 1)
cost, quad_weights = th.quadrature.lobatto_weights(self.nlat, -1, 1)
elif self.grid == "equiangular":
cost, quad_weights = harmonics.quadrature.clenshaw_curtiss_weights(self.nlat, -1, 1)
cost, quad_weights = th.quadrature.clenshaw_curtiss_weights(self.nlat, -1, 1)
quad_weights = quad_weights.reshape(-1, 1)
......
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import os
import math
import torch
from torch.utils.data import Dataset, DataLoader, Subset
import numpy as np
from torch_harmonics.quadrature import _precompute_latitudes
from torch_harmonics.examples.losses import get_quadrature_weights
# some specifiers where to find the dataset
DEFAULT_BASE_URL = "https://cvg-data.inf.ethz.ch/2d3ds/no_xyz/"
DEFAULT_TAR_FILE_PAIRS = [
("area_1_no_xyz.tar", "area_1"),
("area_2_no_xyz.tar", "area_2"),
("area_3_no_xyz.tar", "area_3"),
("area_4_no_xyz.tar", "area_4"),
("area_5a_no_xyz.tar", "area_5a"),
("area_5b_no_xyz.tar", "area_5b"),
("area_6_no_xyz.tar", "area_6"),
]
DEFAULT_LABELS_URL = "https://raw.githubusercontent.com/alexsax/2D-3D-Semantics/refs/heads/master/assets/semantic_labels.json"
class Stanford2D3DSDownloader:
"""
Convenience class for downloading the 2d3ds dataset [1].
References
-----------
.. [1] Armeni, I., Sax, S., Zamir, A. R., Savarese, S.;
"Joint 2D-3D-Semantic Data for Indoor Scene Understanding" (2017).
https://arxiv.org/abs/1702.01105.
"""
def __init__(self, base_url: str = DEFAULT_BASE_URL, local_dir: str = "data"):
self.base_url = base_url
self.local_dir = local_dir
os.makedirs(self.local_dir, exist_ok=True)
def _download_file(self, filename):
import requests
from tqdm import tqdm
url = f"{self.base_url}/{filename}"
local_path = os.path.join(self.local_dir, filename)
if os.path.exists(local_path):
print(f"Note: Skipping download for {filename}, because it already exists")
return local_path
print(f"Downloading {filename}...")
temp_path = local_path.split(".")[0] + ".part"
# Resume logic
headers = {}
if os.path.exists(temp_path):
headers = {"Range": f"bytes={os.stat(temp_path).st_size}-"}
response = requests.get(url, headers=headers, stream=True, timeout=30)
if os.path.exists(temp_path):
total_size = int(response.headers.get("content-length", 0)) + os.stat(temp_path).st_size
else:
total_size = int(response.headers.get("content-length", 0))
with open(temp_path, "ab") as f, tqdm(desc=filename, total=total_size, unit="B", unit_scale=True, unit_divisor=1024, initial=os.stat(temp_path).st_size) as pbar:
for chunk in response.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
pbar.update(len(chunk))
os.rename(temp_path, local_path)
return local_path
def _extract_tar(self, tar_path):
import tarfile
with tarfile.open(tar_path) as tar:
tar.extractall(path=self.local_dir)
tar_filenames = tar.getnames()
extracted_dir = tar_filenames[0]
os.remove(tar_path)
return extracted_dir
def download_dataset(self, file_extracted_directory_pairs=DEFAULT_TAR_FILE_PAIRS):
import requests
data_folders = []
for file, extracted_folder_name in file_extracted_directory_pairs:
if not os.path.exists(os.path.join(self.local_dir, extracted_folder_name)):
downloaded_file = self._download_file(file)
data_folders.append(self._extract_tar(downloaded_file))
else:
print(f"Warning: Skipping D/L for '{file}' because folder '{extracted_folder_name}' already exists")
data_folders.append(extracted_folder_name)
labels_json_url = DEFAULT_LABELS_URL
class_labels = requests.get(labels_json_url).json()
return data_folders, class_labels
def _rgb_to_id(self, img, class_labels_map, class_labels_indices):
# Convert to int32 first to avoid overflow
r = img[..., 0].astype(np.int32)
g = img[..., 1].astype(np.int32)
b = img[..., 2].astype(np.int32)
lookup_indices = r * 256 * 256 + g * 256 + b
def _convert(lookup: int) -> int:
# the dataset has a bad label for clutter, so we need to fix it
# clutter is 855309, but the labels file has it as 3341
# The original conversion used uint8, which overflowed the clutter label to 3341
# this is a fix to handle that accidental usage of undefined overflow behavior
if lookup == 855309:
label = class_labels_map[3341] # clutter
else:
label = class_labels_map[lookup]
class_index = class_labels_indices.index(label)
return class_index
lookup_fn = np.vectorize(_convert)
return lookup_fn(lookup_indices)
def convert_dataset(
self,
data_folders,
class_labels,
rgb_path: str = "pano/rgb",
semantic_path: str = "pano/semantic",
depth_path: str = "pano/depth",
output_filename="semantic",
dataset_file: str = "stanford_2d3ds_dataset.h5",
downsampling_factor: int = 16,
remove_alpha_channel: bool = True,
):
converted_dataset_path = os.path.join(self.local_dir, dataset_file)
from PIL import Image
from tqdm import tqdm
import h5py as h5
file_paths = []
min_vals = None
max_vals = None
# condition class labels first:
class_labels_map = [label.split("_")[0] for label in class_labels]
class_labels_indices = sorted(list(set(class_labels_map)))
# get all the file path input, output pairs
for base_path in data_folders:
rgb_dir = os.path.join(self.local_dir, base_path, rgb_path)
semantic_dir = os.path.join(self.local_dir, base_path, semantic_path)
depth_dir = os.path.join(self.local_dir, base_path, depth_path)
if os.path.exists(rgb_dir) and os.path.exists(semantic_dir) and os.path.exists(depth_dir):
for file_input in os.listdir(rgb_dir):
if not file_input.endswith(".png"):
continue
rgb_filepath = os.path.join(rgb_dir, file_input)
semantic_filepath = "_".join(os.path.splitext(os.path.basename(rgb_filepath))[0].split("_")[:-1]) + f"_{output_filename}.png"
semantic_filepath = os.path.join(semantic_dir, semantic_filepath)
depth_filepath = "_".join(os.path.splitext(os.path.basename(rgb_filepath))[0].split("_")[:-1]) + f"_depth.png"
depth_filepath = os.path.join(depth_dir, depth_filepath)
if not os.path.exists(semantic_filepath):
print(f"Warning: Couldn't find output file in pair: ({rgb_filepath},{semantic_filepath})")
continue
if not os.path.exists(depth_filepath):
print(f"Warning: Couldn't find depth file in pair: ({rgb_filepath},{depth_filepath})")
continue
file_paths.append((rgb_filepath, semantic_filepath, depth_filepath))
elif not os.path.exists(rgb_dir):
print("Warning: RGB dir doesn't exist: ", rgb_dir)
continue
elif not os.path.exists(semantic_dir):
print("Warning: Semantic dir doesn't exist: ", semantic_dir)
continue
elif not os.path.exists(depth_dir):
print("Warning: Depth dir doesn't exist: ", depth_dir)
continue
num_samples = len(file_paths)
if num_samples > 0:
first_rgb, first_semantic, first_depth = file_paths[0]
first_rgb = np.array(Image.open(first_rgb))
# first_semantic = np.array(Image.open(first_semantic))
# first_depth = np.array(Image.open(first_depth))
rgb_shape = first_rgb.shape
img_shape = (rgb_shape[0] // downsampling_factor, rgb_shape[1] // downsampling_factor)
rgb_channels = rgb_shape[2]
if remove_alpha_channel:
rgb_channels = 3
else:
raise ValueError(f"No samples found")
# create the dataset file
with h5.File(converted_dataset_path, "w") as h5file:
rgb_data = h5file.create_dataset("rgb", (num_samples, rgb_channels, *img_shape), "f4")
semantic_data = h5file.create_dataset("semantic", (num_samples, *img_shape), "i8")
depth_data = h5file.create_dataset("depth", (num_samples, *img_shape), "f4")
classes = h5file.create_dataset("class_labels", data=class_labels_indices)
num_classes = len(set(class_labels_indices))
data_source_path = h5file.create_dataset("data_source_path", (num_samples,), dtype=h5.string_dtype(encoding="utf-8"))
data_target_path = h5file.create_dataset("data_target_path", (num_samples,), dtype=h5.string_dtype(encoding="utf-8"))
# prepare computation of the class histogram
class_histogram = np.zeros(num_classes)
_, quad_weights = _precompute_latitudes(nlat=img_shape[0], grid="equiangular")
quad_weights = quad_weights.reshape(-1, 1) * 2 * torch.pi / float(img_shape[1])
quad_weights = quad_weights.tile(1, img_shape[1])
quad_weights /= torch.sum(quad_weights)
quad_weights = quad_weights.numpy()
for count in tqdm(range(num_samples), desc="preparing dataset"):
# open image
img = Image.open(file_paths[count][0])
# downsampling
if downsampling_factor != 1:
# first width, then weight, weird
img = img.resize(size=(img_shape[1], img_shape[0]), resample=Image.BILINEAR)
# remove alpha channel if requested
if remove_alpha_channel:
img = img.convert("RGBA")
background = Image.new("RGBA", img.size, (255, 255, 255))
# compoe foreground and background and remove alpha channel
img = np.array(Image.alpha_composite(background, img))
r_data = img[:, :, :3]
else:
r_data = np.array(img)
# transpose to channels first
r_data = np.transpose(r_data / 255.0, axes=(2, 0, 1))
# write to disk
rgb_data[count, ...] = r_data[...]
data_source_path[count] = file_paths[count][0]
# compute stats -> segmentation
# min/max
tmp_min = np.min(r_data, axis=(1, 2))
tmp_max = np.max(r_data, axis=(1, 2))
# mean/var
tmp_mean = np.sum(r_data * quad_weights[np.newaxis, :, :], axis=(1, 2))
tmp_m2 = np.sum(np.square(r_data - tmp_mean[:, np.newaxis, np.newaxis]) * quad_weights[np.newaxis, :, :])
if count == 0:
# min/max
min_vals = tmp_min
max_vals = tmp_max
# mean/var
mean_vals = tmp_mean
m2_vals = tmp_m2
else:
# min/max
min_vals = np.minimum(min_vals, tmp_min)
max_vals = np.minimum(max_vals, tmp_max)
# mean/var
delta = tmp_mean - mean_vals
mean_vals += delta / float(count + 1)
m2_vals += tmp_m2 + delta * delta * float(count / (count + 1))
# get the target
sem = Image.open(file_paths[count][1])
# downsampling
if downsampling_factor != 1:
sem = sem.resize(size=(img_shape[1], img_shape[0]), resample=Image.NEAREST)
sem_data = np.array(sem, dtype=np.uint32)
# map to classes
sem_data = self._rgb_to_id(sem_data, class_labels_map, class_labels_indices)
# write to file
semantic_data[count, ...] = sem_data[...]
data_target_path[count] = file_paths[count][1]
# Here we want depth
dep = Image.open(file_paths[count][2])
if downsampling_factor != 1:
dep = dep.resize(size=(img_shape[1], img_shape[0]), resample=Image.NEAREST)
dep_data = np.array(dep)
depth_data[count, ...] = dep_data[...] / 65536.0
# compute stats -> depth
# min/max
tmp_min_depth = np.min(dep_data, axis=(0, 1))
tmp_max_depth = np.max(dep_data, axis=(0, 1))
# mean/var
tmp_mean_depth = np.sum(dep_data * quad_weights[:, :])
tmp_m2_depth = np.sum(np.square(dep_data - tmp_mean_depth) * quad_weights[:, :])
if count == 0:
min_vals_depth = tmp_min_depth
max_vals_depth = tmp_max_depth
mean_vals_depth = tmp_mean_depth
m2_vals_depth = tmp_m2_depth
else:
min_vals_depth = np.minimum(min_vals_depth, tmp_min_depth)
max_vals_depth = np.minimum(max_vals_depth, tmp_max_depth)
delta = tmp_mean_depth - mean_vals_depth
mean_vals_depth += delta / float(count + 1)
m2_vals_depth += tmp_m2_depth + delta * delta * float(count / (count + 1))
# update the class histogram
for c in range(num_classes):
class_histogram[c] += quad_weights[sem_data == c].sum()
# record min/max
h5file.create_dataset("min_rgb", data=min_vals.astype(np.float32))
h5file.create_dataset("max_rgb", data=max_vals.astype(np.float32))
h5file.create_dataset("mean_rgb", data=mean_vals.astype(np.float32))
std_vals = np.sqrt(m2_vals / float(num_samples - 1))
h5file.create_dataset("std_rgb", data=std_vals.astype(np.float32))
# record min/max
h5file.create_dataset("min_depth", data=min_vals_depth.astype(np.float32))
h5file.create_dataset("max_depth", data=max_vals_depth.astype(np.float32))
h5file.create_dataset("mean_depth", data=mean_vals_depth.astype(np.float32))
std_vals_depth = np.sqrt(m2_vals_depth / float(num_samples - 1))
h5file.create_dataset("std_depth", data=std_vals_depth.astype(np.float32))
# record class histogram
class_histogram = class_histogram / num_samples
h5file.create_dataset("class_histogram", data=class_histogram.astype(np.float32))
return converted_dataset_path
def prepare_dataset(self, file_extracted_directory_pairs=DEFAULT_TAR_FILE_PAIRS, dataset_file: str = "stanford_2d3ds_dataset.h5", downsampling_factor: int = 16):
converted_dataset_path = os.path.join(self.local_dir, dataset_file)
if os.path.exists(converted_dataset_path):
print(
f"Dataset file at {converted_dataset_path} already exists. Skipping download and conversion. If you want to create a new dataset file, delete or rename the existing file."
)
return converted_dataset_path
data_folders, class_labels = self.download_dataset(file_extracted_directory_pairs=file_extracted_directory_pairs)
converted_dataset_path = self.convert_dataset(data_folders=data_folders, class_labels=class_labels, dataset_file=dataset_file, downsampling_factor=downsampling_factor)
self.converted_dataset_path = converted_dataset_path
return self.converted_dataset_path
class StanfordSegmentationDataset(Dataset):
"""
Spherical segmentation dataset from [1].
References
-----------
.. [1] Armeni, I., Sax, S., Zamir, A. R., Savarese, S.;
"Joint 2D-3D-Semantic Data for Indoor Scene Understanding" (2017).
https://arxiv.org/abs/1702.01105.
"""
def __init__(
self,
dataset_file,
ignore_alpha_channel=True,
exclude_polar_fraction=0,
):
import h5py as h5
self.dataset_file = dataset_file
self.exclude_polar_fraction = exclude_polar_fraction
with h5.File(self.dataset_file, "r") as h5file:
self.img_rgb = h5file["rgb"][0].shape
self.img_seg = h5file["semantic"][0].shape
self.num_samples = h5file["rgb"].shape[0]
self.num_classes = h5file["class_labels"].shape[0]
self.class_labels = [class_name.decode("utf-8") for class_name in h5file["class_labels"][...].tolist()]
self.class_histogram = np.array(h5file["class_histogram"][...])
self.class_histogram = self.class_histogram / self.class_histogram.sum()
self.mean = h5file["mean_rgb"][...]
self.std = h5file["std_rgb"][...]
self.min = h5file["min_rgb"][...]
self.max = h5file["max_rgb"][...]
self.img_filepath = h5file["data_source_path"][...]
self.tar_filepath = h5file["data_target_path"][...]
if ignore_alpha_channel:
self.img_rgb = (3, self.img_rgb[1], self.img_rgb[2])
# open file and check for
self.h5file = None
self.rgb = None
self.semantic = None
# return index set to false by default
# when true, the __getitem__ method will return the index of the input,target pair
self.return_index = False
@property
def target_shape(self):
return self.img_seg
@property
def input_shape(self):
return self.img_rgb
def set_return_index(self, return_index: bool):
self.return_index = return_index
def get_img_filepath(self, idx: int):
return self.img_filepath[idx]
def get_tar_filepath(self, idx: int):
return self.tar_filepath[idx]
def _id_to_class(self, class_id):
if class_id > self.num_classes:
print("WARNING: ID > number of classes!")
return None
return self.segmentation_classes[class_id]
def _mask_invalid(self, tar):
return np.where(tar >= self.num_classes, -100, tar)
def __len__(self):
return self.num_samples
def _init_files(self):
import h5py as h5
self.h5file = h5.File(self.dataset_file, "r")
self.rgb = self.h5file["rgb"]
self.semantic = self.h5file["semantic"]
def reset(self):
self.rgb = None
self.semantic = None
if self.h5file is not None:
self.h5file.close()
del self.h5file
self.h5file = None
def __getitem__(self, idx, mask_invalid=True):
if self.h5file is None:
# init files
self._init_files()
rgb = self.rgb[idx, 0 : self.img_rgb[0], 0 : self.img_rgb[1], 0 : self.img_rgb[2]]
sem = self.semantic[idx, 0 : self.img_seg[0], 0 : self.img_seg[1]]
if mask_invalid:
sem = self._mask_invalid(sem)
if self.exclude_polar_fraction > 0:
hcut = int(self.exclude_polar_fraction * sem.shape[0])
if hcut > 0:
sem[0:hcut, :] = -100
sem[-hcut:, :] = -100
return rgb, sem
class StanfordDatasetSubset(Subset):
def __init__(self, dataset, indices, return_index=False):
super().__init__(dataset, indices)
self.return_index = return_index
self.dataset = dataset
def set_return_index(self, value):
self.return_index = value
def __getitem__(self, index):
real_index = self.indices[index]
data = self.dataset[real_index]
if self.return_index:
return data[0], data[1], real_index
else:
# Otherwise, return only (data, target)
return data[0], data[1]
class StanfordDepthDataset(Dataset):
"""
Spherical segmentation dataset from [1].
References
-----------
.. [1] Armeni, I., Sax, S., Zamir, A. R., Savarese, S.;
"Joint 2D-3D-Semantic Data for Indoor Scene Understanding" (2017).
https://arxiv.org/abs/1702.01105.
"""
def __init__(self, dataset_file, ignore_alpha_channel=True, log_depth=False, exclude_polar_fraction=0.0):
import h5py as h5
self.dataset_file = dataset_file
self.log_depth = log_depth
self.exclude_polar_fraction = exclude_polar_fraction
with h5.File(self.dataset_file, "r") as h5file:
self.img_rgb = h5file["rgb"][0].shape
self.img_depth = h5file["depth"][0].shape
self.num_samples = h5file["rgb"].shape[0]
self.mean_in = h5file["mean_rgb"][...]
self.std_in = h5file["std_rgb"][...]
self.min_in = h5file["min_rgb"][...]
self.max_in = h5file["max_rgb"][...]
self.mean_out = h5file["mean_depth"][...]
self.std_out = h5file["std_depth"][...]
self.min_out = h5file["min_depth"][...]
self.max_out = h5file["max_depth"][...]
if ignore_alpha_channel:
self.img_rgb = (3, self.img_rgb[1], self.img_rgb[2])
# open file and check for
self.h5file = None
self.rgb = None
self.depth = None
@property
def target_shape(self):
return self.img_depth
@property
def input_shape(self):
return self.img_rgb
def __len__(self):
return self.num_samples
def _init_files(self):
import h5py as h5
self.h5file = h5.File(self.dataset_file, "r")
self.rgb = self.h5file["rgb"]
self.depth = self.h5file["depth"]
def reset(self):
self.rgb = None
self.depth = None
if self.h5file is not None:
self.h5file.close()
del self.h5file
self.h5file = None
def _mask_invalid(self, tar):
return tar * np.where(tar == tar.max(), 0, 1)
def __getitem__(self, idx, mask_invalid=True):
if self.h5file is None:
# init files
self._init_files()
rgb = self.rgb[idx, 0 : self.img_rgb[0], 0 : self.img_rgb[1], 0 : self.img_rgb[2]]
depth = self.depth[idx, 0 : self.img_depth[0], 0 : self.img_depth[1]]
if mask_invalid:
depth = self._mask_invalid(depth)
if self.exclude_polar_fraction > 0:
hcut = int(self.exclude_polar_fraction * depth.shape[0])
if hcut > 0:
depth[0:hcut, :] = 0
depth[-hcut:, :] = 0
if self.log_depth:
depth = np.log(1 + depth)
return rgb, depth
def compute_stats_s2(dataset: Dataset, normalize_target: bool = False):
"""
Compute stats using parallel welford reduction and quadrature on the sphere. The parallel welford reduction follows this article (parallel algorithm): https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
"""
nexamples = len(dataset)
count = 0
for isample in range(nexamples):
token = dataset[isample]
# dimension of inp and tar are (3, nlat, nlon)
inp, tar = token
nlat = tar.shape[-2]
nlon = tar.shape[-1]
# pre-compute quadrature weights
if isample == 0:
quad_weights = get_quadrature_weights(nlat=inp.shape[1], nlon=inp.shape[2], grid="equiangular", tile=True).numpy().astype(np.float64)
# this is a special case for the depth dataset
# TODO: maybe make this an argument
if normalize_target:
mask = np.where(tar == 0, 0, 1)
masked_area = np.sum(mask * quad_weights[np.newaxis, :, :], axis=(-2, -1))
# get initial welford values
if isample == 0:
# input
inp_means = np.sum(inp * quad_weights[np.newaxis, :, :], axis=(-2, -1))
inp_m2s = np.sum(np.square(inp - inp_means[:, np.newaxis, np.newaxis]) * quad_weights[np.newaxis, :, :], axis=(-2, -1))
# target
if normalize_target:
tar_means = np.sum(mask * tar * quad_weights[np.newaxis, :, :], axis=(-2, -1)) / masked_area
tar_m2s = np.sum(mask * np.square(tar - tar_means[:, np.newaxis, np.newaxis]) * quad_weights[np.newaxis, :, :], axis=(-2, -1)) / masked_area
# update count
count = 1
# do welford update
else:
# input
# get new mean and m2
inp_mean = np.sum(inp * quad_weights[np.newaxis, :, :], axis=(-2, -1))
inp_m2 = np.sum(np.square(inp - inp_mean[:, np.newaxis, np.newaxis]) * quad_weights[np.newaxis, :, :], axis=(-2, -1))
# update welford values
inp_delta = inp_mean - inp_means
inp_m2s = inp_m2s + inp_m2 + inp_delta**2 * count / float(count + 1)
inp_means = inp_means + inp_delta / float(count + 1)
# target
if normalize_target:
# get new mean and m2
tar_mean = np.sum(mask * tar * quad_weights[np.newaxis, :, :], axis=(-2, -1)) / masked_area
tar_m2 = np.sum(mask * np.square(tar - tar_mean[:, np.newaxis, np.newaxis]) * quad_weights[np.newaxis, :, :], axis=(-2, -1)) / masked_area
# update welford values
tar_delta = tar_mean - tar_means
tar_m2s = tar_m2s + tar_m2 + tar_delta**2 * count / float(count + 1)
tar_means = tar_means + tar_delta / float(count + 1)
# update count
count += 1
# finalize
inp_stds = np.sqrt(inp_m2s / float(count))
result = (inp_means.astype(np.float32), inp_stds.astype(np.float32))
if normalize_target:
tar_stds = np.sqrt(tar_m2s / float(count))
result += (tar_means.astype(np.float32), tar_stds.astype(np.float32))
return result
......@@ -30,70 +30,158 @@
#
import numpy as np
import matplotlib.pyplot as plt
import cartopy
import cartopy.crs as ccrs
import os
# guarded imports
try:
import matplotlib.pyplot as plt
except ImportError as err:
plt = None
try:
import cartopy
import cartopy.crs as ccrs
except ImportError as err:
cartopy = None
ccrs = None
def check_plotting_dependencies():
if plt is None:
raise ImportError("matplotlib is required for plotting functions. Install it with 'pip install matplotlib'")
if cartopy is None:
raise ImportError("cartopy is required for map plotting. Install it with 'pip install cartopy'")
def get_projection(
projection,
central_latitude=0,
central_longitude=0,
):
if projection == "orthographic":
proj = ccrs.Orthographic(central_latitude=central_latitude, central_longitude=central_longitude)
elif projection == "robinson":
proj = ccrs.Robinson(central_longitude=central_longitude)
elif projection == "platecarree":
proj = ccrs.PlateCarree(central_longitude=central_longitude)
elif projection == "mollweide":
proj = ccrs.Mollweide(central_longitude=central_longitude)
else:
raise ValueError(f"Unknown projection mode {projection}")
return proj
def plot_sphere(
data, fig=None, projection="robinson", cmap="RdBu", title=None, colorbar=False, coastlines=False, gridlines=False, central_latitude=0, central_longitude=0, lon=None, lat=None, **kwargs
):
"""
Plots a function defined on the sphere using pcolormesh
"""
# make sure cartopy exist
check_plotting_dependencies()
def plot_sphere(data, fig=None, cmap="RdBu", title=None, colorbar=False, coastlines=False, gridlines=False, central_latitude=0, central_longitude=0, lon=None, lat=None, **kwargs):
if fig == None:
fig = plt.figure()
nlat = data.shape[-2]
nlon = data.shape[-1]
if lon is None:
lon = np.linspace(0, 2 * np.pi, nlon)
lon = np.linspace(0, 2 * np.pi, nlon + 1)[:-1]
if lat is None:
lat = np.linspace(np.pi / 2.0, -np.pi / 2.0, nlat)
Lon, Lat = np.meshgrid(lon, lat)
proj = ccrs.Orthographic(central_longitude=central_longitude, central_latitude=central_latitude)
# proj = ccrs.Mollweide(central_longitude=central_longitude)
ax = fig.add_subplot(projection=proj)
# convert radians to degrees
Lon = Lon * 180 / np.pi
Lat = Lat * 180 / np.pi
# get the projection. Latitude is shifted to match plot_sphere
proj = get_projection(projection, central_latitude=central_latitude, central_longitude=central_longitude)
ax = fig.add_subplot(projection=proj)
# contour data over the map.
im = ax.pcolormesh(Lon, Lat, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=False, **kwargs)
# add features if requested
if coastlines:
ax.add_feature(cartopy.feature.COASTLINE, edgecolor="white", facecolor="none", linewidth=1.5)
if gridlines:
gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=False, linewidth=1.5, color="gray", alpha=0.6, linestyle="--")
# add colorbar if requested
if colorbar:
plt.colorbar(im, extend="both")
plt.title(title, y=1.05)
plt.colorbar(im)
# add gridlines
if gridlines:
gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=False, linewidth=1, color="gray", alpha=0.6, linestyle="--")
# add title with smaller font
plt.title(title, y=1.05, fontsize=8)
return im
def plot_data(data, fig=None, cmap="RdBu", title=None, colorbar=False, coastlines=False, gridlines=False, central_longitude=0, lon=None, lat=None, **kwargs):
def imshow_sphere(data, fig=None, projection="robinson", title=None, central_latitude=0, central_longitude=0, **kwargs):
"""
Displays an image on the sphere
"""
# make sure cartopy exist
check_plotting_dependencies()
if fig == None:
fig = plt.figure()
nlat = data.shape[-2]
nlon = data.shape[-1]
if lon is None:
lon = np.linspace(0, 2 * np.pi, nlon)
if lat is None:
lat = np.linspace(np.pi / 2.0, -np.pi / 2.0, nlat)
Lon, Lat = np.meshgrid(lon, lat)
proj = ccrs.PlateCarree(central_longitude=central_longitude)
# proj = ccrs.Mollweide(central_longitude=central_longitude)
# get the projection. Latitude is shifted to match plot_sphere
proj = get_projection(projection, central_latitude=central_latitude, central_longitude=central_longitude + 180)
ax = fig.add_subplot(projection=proj)
Lon = Lon * 180 / np.pi
Lat = Lat * 180 / np.pi
# contour data over the map.
im = ax.pcolormesh(Lon, Lat, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=False, **kwargs)
if coastlines:
ax.add_feature(cartopy.feature.COASTLINE, edgecolor="white", facecolor="none", linewidth=1.5)
if gridlines:
gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=False, linewidth=1.5, color="gray", alpha=0.6, linestyle="--")
if colorbar:
plt.colorbar(im, extend="both")
im = ax.imshow(data, transform=ccrs.PlateCarree(), **kwargs)
# add title
plt.title(title, y=1.05)
return im
# def plot_data(data,
# fig=None,
# cmap="RdBu",
# title=None,
# colorbar=False,
# coastlines=False,
# central_longitude=0,
# lon=None,
# lat=None,
# **kwargs):
# if fig == None:
# fig = plt.figure()
# nlat = data.shape[-2]
# nlon = data.shape[-1]
# if lon is None:
# lon = np.linspace(0, 2*np.pi, nlon+1)[:-1]
# if lat is None:
# lat = np.linspace(np.pi/2., -np.pi/2., nlat)
# Lon, Lat = np.meshgrid(lon, lat)
# proj = ccrs.Robinson(central_longitude=central_longitude)
# # proj = ccrs.Mollweide(central_longitude=central_longitude)
# ax = fig.add_subplot(projection=proj)
# Lon = Lon*180/np.pi
# Lat = Lat*180/np.pi
# # contour data over the map.
# im = ax.pcolormesh(Lon, Lat, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=False, **kwargs)
# if coastlines:
# ax.add_feature(cartopy.feature.COASTLINE, edgecolor='white', facecolor='none', linewidth=1.5)
# if colorbar:
# plt.colorbar(im)
# plt.title(title, y=1.05)
# return im
......@@ -121,14 +121,15 @@ class ResampleS2(nn.Module):
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}"
def _upscale_longitudes(self, x: torch.Tensor):
# do the interpolation
# do the interpolation in precision of x
lwgt = self.lon_weights.to(x.dtype)
if self.mode == "bilinear":
x = torch.lerp(x[..., self.lon_idx_left], x[..., self.lon_idx_right], self.lon_weights)
x = torch.lerp(x[..., self.lon_idx_left], x[..., self.lon_idx_right], lwgt)
else:
omega = x[..., self.lon_idx_right] - x[..., self.lon_idx_left]
somega = torch.sin(omega)
start_prefac = torch.where(somega > 1e-4, torch.sin((1.0 - self.lon_weights) * omega) / somega, (1.0 - self.lon_weights))
end_prefac = torch.where(somega > 1e-4, torch.sin(self.lon_weights * omega) / somega, self.lon_weights)
start_prefac = torch.where(somega > 1e-4, torch.sin((1.0 - lwgt) * omega) / somega, (1.0 - lwgt))
end_prefac = torch.where(somega > 1e-4, torch.sin(lwgt * omega) / somega, lwgt)
x = start_prefac * x[..., self.lon_idx_left] + end_prefac * x[..., self.lon_idx_right]
return x
......@@ -142,14 +143,15 @@ class ResampleS2(nn.Module):
return x
def _upscale_latitudes(self, x: torch.Tensor):
# do the interpolation
# do the interpolation in precision of x
lwgt = self.lat_weights.to(x.dtype)
if self.mode == "bilinear":
x = torch.lerp(x[..., self.lat_idx, :], x[..., self.lat_idx + 1, :], self.lat_weights)
x = torch.lerp(x[..., self.lat_idx, :], x[..., self.lat_idx + 1, :], lwgt)
else:
omega = x[..., self.lat_idx + 1, :] - x[..., self.lat_idx, :]
somega = torch.sin(omega)
start_prefac = torch.where(somega > 1e-4, torch.sin((1.0 - self.lat_weights) * omega) / somega, (1.0 - self.lat_weights))
end_prefac = torch.where(somega > 1e-4, torch.sin(self.lat_weights * omega) / somega, self.lat_weights)
start_prefac = torch.where(somega > 1e-4, torch.sin((1.0 - lwgt) * omega) / somega, (1.0 - lwgt))
end_prefac = torch.where(somega > 1e-4, torch.sin(lwgt * omega) / somega, lwgt)
x = start_prefac * x[..., self.lat_idx, :] + end_prefac * x[..., self.lat_idx + 1, :]
return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment