Commit 6a845fd3 authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

adding spherical attention

parent b3816ebc
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
from typing import Optional, Tuple
import torch
import torch.nn as nn
from torch_harmonics.quadrature import _precompute_latitudes
from .losses import get_quadrature_weights
# routine to compute multiclass labels on the sphere
# the routine follows the implementation in
# https://github.com/qubvel-org/segmentation_models.pytorch/blob/4aa36c6ad13f8a12552e4ea4131af2a86e564962/segmentation_models_pytorch/metrics/functional.py
# but uses quadrature weights
def _get_stats_multiclass(
output: torch.LongTensor,
target: torch.LongTensor,
num_classes: int,
quad_weights: torch.Tensor,
ignore_index: Optional[int],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size, *dims = output.shape
num_elements = torch.prod(torch.tensor(dims)).long()
if ignore_index is not None:
ignore = target == ignore_index
output = torch.where(ignore, -1, output)
target = torch.where(ignore, -1, target)
ignore_per_sample = ignore.view(batch_size, -1).sum(1)
tp_count = torch.zeros(batch_size, num_classes, dtype=torch.float32, device=output.device)
fp_count = torch.zeros(batch_size, num_classes, dtype=torch.float32, device=output.device)
fn_count = torch.zeros(batch_size, num_classes, dtype=torch.float32, device=output.device)
tn_count = torch.zeros(batch_size, num_classes, dtype=torch.float32, device=output.device)
matched = target == output
not_matched = target != output
for i in range(batch_size):
matched_i = matched[i, ...]
not_matched_i = not_matched[i, ...]
target_i = target[i, ...]
output_i = output[i, ...]
for c in range(num_classes):
# compute weights
qwt_c = quad_weights[target_i == c]
qwo_c = quad_weights[output_i == c]
# true positives
tp_count[i, c] = torch.sum(matched_i[target_i == c] * qwt_c)
# false positives
fp_count[i, c] = torch.sum(not_matched_i[output_i == c] * qwo_c)
# false negatives
fn_count[i, c] = torch.sum(not_matched_i[target_i == c] * qwt_c)
# true negatives is the leftovers
tn_count = torch.sum(quad_weights) - tp_count - fp_count - fn_count
return tp_count, fp_count, fn_count, tn_count
def _predict_classes(logits: torch.Tensor) -> torch.Tensor:
return torch.argmax(torch.softmax(logits, dim=1), dim=1, keepdim=False)
class BaseMetricS2(nn.Module):
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular", weight: torch.Tensor = None, ignore_index: int = -100, mode: str = "micro"):
super().__init__()
self.ignore_index = ignore_index
self.mode = mode
# area weights
q = get_quadrature_weights(nlat=nlat, nlon=nlon, grid=grid, tile=True)
self.register_buffer("quad_weights", q)
if weight is None:
self.weight = None
else:
self.register_buffer("weight", weight.unsqueeze(0))
def _forward(self, pred: torch.Tensor, truth: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# convert logits to class predictions
pred_class = _predict_classes(pred)
# get true positive, false positive, etc
tp, fp, fn, tn = _get_stats_multiclass(pred_class, truth, pred.shape[1], self.quad_weights, self.ignore_index)
# compute averages:
if self.mode == "micro":
if self.weight is not None:
# weighted average
tp = torch.sum(tp * self.weight)
fp = torch.sum(fp * self.weight)
fn = torch.sum(fn * self.weight)
tn = torch.sum(tn * self.weight)
else:
# normal average
tp = torch.mean(tp)
fp = torch.mean(fp)
fn = torch.mean(fn)
tn = torch.mean(tn)
else:
tp = torch.mean(tp, dim=0)
fp = torch.mean(fp, dim=0)
fn = torch.mean(fn, dim=0)
tn = torch.mean(tn, dim=0)
return tp, fp, fn, tn
class IntersectionOverUnionS2(BaseMetricS2):
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular", weight: torch.Tensor = None, ignore_index: int = -100, mode: str = "micro"):
super().__init__(nlat, nlon, grid, weight, ignore_index, mode)
def forward(self, pred: torch.Tensor, truth: torch.Tensor) -> torch.Tensor:
tp, fp, fn, tn = self._forward(pred, truth)
# compute score
score = tp / (tp + fp + fn)
if self.mode == "macro":
# we need to do some averaging still:
# be careful with zeros
score = torch.where(torch.isnan(score), 0.0, score)
if self.weight is not None:
score = torch.sum(score * self.weight)
else:
score = torch.mean(score)
return score
class AccuracyS2(BaseMetricS2):
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular", weight: torch.Tensor = None, ignore_index: int = -100, mode: str = "micro"):
super().__init__(nlat, nlon, grid, weight, ignore_index, mode)
def forward(self, pred: torch.Tensor, truth: torch.Tensor) -> torch.Tensor:
tp, fp, fn, tn = self._forward(pred, truth)
# compute score
score = (tp + tn) / (tp + fp + fn + tn)
if self.mode == "macro":
# we need to do some averaging still:
# be careful with zeros
score = torch.where(torch.isnan(score), 0.0, score)
if self.weight is not None:
score = torch.sum(score * self.weight)
else:
score = torch.mean(score)
return score
...@@ -29,5 +29,8 @@ ...@@ -29,5 +29,8 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# #
from .sfno import SphericalFourierNeuralOperatorNet from .sfno import SphericalFourierNeuralOperator
from .lsno import LocalSphericalNeuralOperatorNet from .lsno import LocalSphericalNeuralOperator
from .s2unet import SphericalUNet
from .s2transformer import SphericalTransformer
from .s2segformer import SphericalSegformer
...@@ -29,26 +29,26 @@ ...@@ -29,26 +29,26 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# #
import abc
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.fft import torch.fft
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
import math
from torch_harmonics import * from torch_harmonics import InverseRealSHT
from ._activations import *
def _no_grad_trunc_normal_(tensor, mean, std, a, b): def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW # Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x): def norm_cdf(x):
# Computes standard normal cumulative distribution function # Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2. return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
if (mean < a - 2 * std) or (mean > b + 2 * std): if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " "The distribution of values may be incorrect.", stacklevel=2)
"The distribution of values may be incorrect.",
stacklevel=2)
with torch.no_grad(): with torch.no_grad():
# Values are generated by using a truncated uniform distribution and # Values are generated by using a truncated uniform distribution and
...@@ -66,7 +66,7 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b): ...@@ -66,7 +66,7 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
tensor.erfinv_() tensor.erfinv_()
# Transform to proper mean, std # Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.)) tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean) tensor.add_(mean)
# Clamp to ensure it's in the proper range # Clamp to ensure it's in the proper range
...@@ -74,7 +74,7 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b): ...@@ -74,7 +74,7 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
return tensor return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
r"""Fills the input Tensor with values drawn from a truncated r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
...@@ -95,7 +95,7 @@ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): ...@@ -95,7 +95,7 @@ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
@torch.jit.script @torch.jit.script
def drop_path(x: torch.Tensor, drop_prob: float = 0., training: bool = False) -> torch.Tensor: def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
...@@ -103,9 +103,9 @@ def drop_path(x: torch.Tensor, drop_prob: float = 0., training: bool = False) -> ...@@ -103,9 +103,9 @@ def drop_path(x: torch.Tensor, drop_prob: float = 0., training: bool = False) ->
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument. 'survival rate' as the argument.
""" """
if drop_prob == 0. or not training: if drop_prob == 0.0 or not training:
return x return x
keep_prob = 1. - drop_prob keep_prob = 1.0 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2d ConvNets shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2d ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize random_tensor.floor_() # binarize
...@@ -114,8 +114,8 @@ def drop_path(x: torch.Tensor, drop_prob: float = 0., training: bool = False) -> ...@@ -114,8 +114,8 @@ def drop_path(x: torch.Tensor, drop_prob: float = 0., training: bool = False) ->
class DropPath(nn.Module): class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
"""
def __init__(self, drop_prob=None): def __init__(self, drop_prob=None):
super(DropPath, self).__init__() super(DropPath, self).__init__()
self.drop_prob = drop_prob self.drop_prob = drop_prob
...@@ -123,16 +123,30 @@ class DropPath(nn.Module): ...@@ -123,16 +123,30 @@ class DropPath(nn.Module):
def forward(self, x): def forward(self, x):
return drop_path(x, self.drop_prob, self.training) return drop_path(x, self.drop_prob, self.training)
class PatchEmbed(nn.Module):
def __init__(self, img_size=(224, 224), patch_size=(16, 16), in_chans=3, embed_dim=768):
super(PatchEmbed, self).__init__()
self.red_img_size = ((img_size[0] // patch_size[0]), (img_size[1] // patch_size[1]))
num_patches = self.red_img_size[0] * self.red_img_size[1]
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True)
self.proj.weight.is_shared_mp = ["spatial"]
self.proj.bias.is_shared_mp = ["spatial"]
def forward(self, x):
# gather input
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
# new: B, C, H*W
x = self.proj(x).flatten(2)
return x
class MLP(nn.Module): class MLP(nn.Module):
def __init__(self, def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, output_bias=False, drop_rate=0.0, checkpointing=False, gain=1.0):
in_features,
hidden_features = None,
out_features = None,
act_layer = nn.ReLU,
output_bias = False,
drop_rate = 0.,
checkpointing = False,
gain = 1.0):
super(MLP, self).__init__() super(MLP, self).__init__()
self.checkpointing = checkpointing self.checkpointing = checkpointing
out_features = out_features or in_features out_features = out_features or in_features
...@@ -142,7 +156,7 @@ class MLP(nn.Module): ...@@ -142,7 +156,7 @@ class MLP(nn.Module):
fc1 = nn.Conv2d(in_features, hidden_features, 1, bias=True) fc1 = nn.Conv2d(in_features, hidden_features, 1, bias=True)
# initialize the weights correctly # initialize the weights correctly
scale = math.sqrt(2.0 / in_features) scale = math.sqrt(2.0 / in_features)
nn.init.normal_(fc1.weight, mean=0., std=scale) nn.init.normal_(fc1.weight, mean=0.0, std=scale)
if fc1.bias is not None: if fc1.bias is not None:
nn.init.constant_(fc1.bias, 0.0) nn.init.constant_(fc1.bias, 0.0)
...@@ -153,11 +167,11 @@ class MLP(nn.Module): ...@@ -153,11 +167,11 @@ class MLP(nn.Module):
fc2 = nn.Conv2d(hidden_features, out_features, 1, bias=output_bias) fc2 = nn.Conv2d(hidden_features, out_features, 1, bias=output_bias)
# gain factor for the output determines the scaling of the output init # gain factor for the output determines the scaling of the output init
scale = math.sqrt(gain / hidden_features) scale = math.sqrt(gain / hidden_features)
nn.init.normal_(fc2.weight, mean=0., std=scale) nn.init.normal_(fc2.weight, mean=0.0, std=scale)
if fc2.bias is not None: if fc2.bias is not None:
nn.init.constant_(fc2.bias, 0.0) nn.init.constant_(fc2.bias, 0.0)
if drop_rate > 0.: if drop_rate > 0.0:
drop = nn.Dropout2d(drop_rate) drop = nn.Dropout2d(drop_rate)
self.fwd = nn.Sequential(fc1, act, drop, fc2, drop) self.fwd = nn.Sequential(fc1, act, drop, fc2, drop)
else: else:
...@@ -173,15 +187,13 @@ class MLP(nn.Module): ...@@ -173,15 +187,13 @@ class MLP(nn.Module):
else: else:
return self.fwd(x) return self.fwd(x)
class RealFFT2(nn.Module): class RealFFT2(nn.Module):
""" """
Helper routine to wrap FFT similarly to the SHT Helper routine to wrap FFT similarly to the SHT
""" """
def __init__(self,
nlat, def __init__(self, nlat, nlon, lmax=None, mmax=None):
nlon,
lmax = None,
mmax = None):
super(RealFFT2, self).__init__() super(RealFFT2, self).__init__()
self.nlat = nlat self.nlat = nlat
...@@ -191,18 +203,16 @@ class RealFFT2(nn.Module): ...@@ -191,18 +203,16 @@ class RealFFT2(nn.Module):
def forward(self, x): def forward(self, x):
y = torch.fft.rfft2(x, dim=(-2, -1), norm="ortho") y = torch.fft.rfft2(x, dim=(-2, -1), norm="ortho")
y = torch.cat((y[..., :math.ceil(self.lmax/2), :self.mmax], y[..., -math.floor(self.lmax/2):, :self.mmax]), dim=-2) y = torch.cat((y[..., : math.ceil(self.lmax / 2), : self.mmax], y[..., -math.floor(self.lmax / 2) :, : self.mmax]), dim=-2)
return y return y
class InverseRealFFT2(nn.Module): class InverseRealFFT2(nn.Module):
""" """
Helper routine to wrap FFT similarly to the SHT Helper routine to wrap FFT similarly to the SHT
""" """
def __init__(self,
nlat, def __init__(self, nlat, nlon, lmax=None, mmax=None):
nlon,
lmax = None,
mmax = None):
super(InverseRealFFT2, self).__init__() super(InverseRealFFT2, self).__init__()
self.nlat = nlat self.nlat = nlat
...@@ -213,6 +223,24 @@ class InverseRealFFT2(nn.Module): ...@@ -213,6 +223,24 @@ class InverseRealFFT2(nn.Module):
def forward(self, x): def forward(self, x):
return torch.fft.irfft2(x, dim=(-2, -1), s=(self.nlat, self.nlon), norm="ortho") return torch.fft.irfft2(x, dim=(-2, -1), s=(self.nlat, self.nlon), norm="ortho")
class LayerNorm(nn.Module):
"""
Wrapper class that moves the channel dimension to the end
"""
def __init__(self, in_channels, eps=1e-05, elementwise_affine=True, bias=True, device=None, dtype=None):
super().__init__()
self.channel_dim = -3
self.norm = nn.LayerNorm(normalized_shape=in_channels, eps=1e-6, elementwise_affine=elementwise_affine, bias=bias, device=device, dtype=dtype)
def forward(self, x):
return self.norm(x.transpose(self.channel_dim, -1)).transpose(-1, self.channel_dim)
class SpectralConvS2(nn.Module): class SpectralConvS2(nn.Module):
""" """
Spectral Convolution according to Driscoll & Healy. Designed for convolutions on the two-sphere S2 Spectral Convolution according to Driscoll & Healy. Designed for convolutions on the two-sphere S2
...@@ -220,15 +248,7 @@ class SpectralConvS2(nn.Module): ...@@ -220,15 +248,7 @@ class SpectralConvS2(nn.Module):
domain via the RealFFT2 and InverseRealFFT2 wrappers. domain via the RealFFT2 and InverseRealFFT2 wrappers.
""" """
def __init__(self, def __init__(self, forward_transform, inverse_transform, in_channels, out_channels, gain=2.0, operator_type="driscoll-healy", lr_scale_exponent=0, bias=False):
forward_transform,
inverse_transform,
in_channels,
out_channels,
gain = 2.,
operator_type = "driscoll-healy",
lr_scale_exponent = 0,
bias = False):
super().__init__() super().__init__()
self.forward_transform = forward_transform self.forward_transform = forward_transform
...@@ -237,8 +257,7 @@ class SpectralConvS2(nn.Module): ...@@ -237,8 +257,7 @@ class SpectralConvS2(nn.Module):
self.modes_lat = self.inverse_transform.lmax self.modes_lat = self.inverse_transform.lmax
self.modes_lon = self.inverse_transform.mmax self.modes_lon = self.inverse_transform.mmax
self.scale_residual = (self.forward_transform.nlat != self.inverse_transform.nlat) \ self.scale_residual = (self.forward_transform.nlat != self.inverse_transform.nlat) or (self.forward_transform.nlon != self.inverse_transform.nlon)
or (self.forward_transform.nlon != self.inverse_transform.nlon)
# remember factorization details # remember factorization details
self.operator_type = operator_type self.operator_type = operator_type
...@@ -266,7 +285,6 @@ class SpectralConvS2(nn.Module): ...@@ -266,7 +285,6 @@ class SpectralConvS2(nn.Module):
if bias: if bias:
self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1)) self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
def forward(self, x): def forward(self, x):
dtype = x.dtype dtype = x.dtype
...@@ -287,4 +305,117 @@ class SpectralConvS2(nn.Module): ...@@ -287,4 +305,117 @@ class SpectralConvS2(nn.Module):
x = x + self.bias x = x + self.bias
x = x.type(dtype) x = x.type(dtype)
return x, residual return x, residual
\ No newline at end of file
class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta):
"""
Returns standard sequence based position embedding
"""
def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1):
super().__init__()
self.img_shape = img_shape
self.num_chans = num_chans
def forward(self, x: torch.Tensor):
return x + self.position_embeddings
class SequencePositionEmbedding(PositionEmbedding):
"""
Returns standard sequence based position embedding
"""
def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1):
super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)
with torch.no_grad():
# alternating custom position embeddings
pos = torch.arange(self.img_shape[0] * self.img_shape[1]).reshape(1, 1, *self.img_shape).repeat(1, self.num_chans, 1, 1)
k = torch.arange(self.num_chans).reshape(1, self.num_chans, 1, 1)
denom = torch.pow(10000, 2 * k / self.num_chans)
pos_embed = torch.where(k % 2 == 0, torch.sin(pos / denom), torch.cos(pos / denom))
# register tensor
self.register_buffer("position_embeddings", pos_embed.float())
class SpectralPositionEmbedding(PositionEmbedding):
"""
Returns position embeddings for the spherical transformer
"""
def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1):
super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)
# compute maximum required frequency and prepare isht
lmax = math.floor(math.sqrt(self.num_chans)) + 1
isht = InverseRealSHT(nlat=self.img_shape[0], nlon=self.img_shape[1], lmax=lmax, mmax=lmax, grid=grid)
# fill position embedding
with torch.no_grad():
pos_embed_freq = torch.zeros(1, self.num_chans, isht.lmax, isht.mmax, dtype=torch.complex64)
for i in range(self.num_chans):
l = math.floor(math.sqrt(i))
m = i - l**2 - l
if m < 0:
pos_embed_freq[0, i, l, -m] = 1.0j
else:
pos_embed_freq[0, i, l, m] = 1.0
# compute spatial position embeddings
pos_embed = isht(pos_embed_freq)
# normalization
pos_embed = pos_embed / torch.amax(pos_embed.abs(), dim=(-1, -2), keepdim=True)
# register tensor
self.register_buffer("position_embeddings", pos_embed)
class LearnablePositionEmbedding(PositionEmbedding):
"""
Returns position embeddings for the spherical transformer
"""
def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1, embed_type="lat"):
super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)
if embed_type == "latlon":
self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_chans, self.img_shape[0], self.img_shape[1]))
elif embed_type == "lat":
self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_chans, self.img_shape[0], 1))
else:
raise ValueError(f"Unknown learnable position embedding type {embed_type}")
# class SpiralPositionEmbedding(PositionEmbedding):
# """
# Returns position embeddings on the torus
# """
# def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1):
# super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)
# with torch.no_grad():
# # alternating custom position embeddings
# lats, _ = _precompute_latitudes(img_shape[0], grid=grid)
# lats = lats.reshape(-1, 1)
# lons = torch.linspace(0, 2 * math.pi, img_shape[1] + 1)[:-1]
# lons = lons.reshape(1, -1)
# # channel index
# k = torch.arange(self.num_chans).reshape(1, -1, 1, 1)
# pos_embed = torch.where(k % 2 == 0, torch.sin(k * (lons + lats)), torch.cos(k * (lons - lats)))
# # register tensor
# self.register_buffer("position_embeddings", pos_embed.float())
\ No newline at end of file
# coding=utf-8 # coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
# #
# Redistribution and use in source and binary forms, with or without # Redistribution and use in source and binary forms, with or without
...@@ -29,6 +29,8 @@ ...@@ -29,6 +29,8 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# #
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.amp as amp import torch.amp as amp
...@@ -37,10 +39,15 @@ from torch_harmonics import RealSHT, InverseRealSHT ...@@ -37,10 +39,15 @@ from torch_harmonics import RealSHT, InverseRealSHT
from torch_harmonics import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2 from torch_harmonics import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
from torch_harmonics import ResampleS2 from torch_harmonics import ResampleS2
from ._layers import * from torch_harmonics.examples.models._layers import MLP, SpectralConvS2, SequencePositionEmbedding, SpectralPositionEmbedding, LearnablePositionEmbedding
from functools import partial from functools import partial
# heuristic for finding theta_cutoff
def _compute_cutoff_radius(nlat, kernel_shape, basis_type):
theta_cutoff_factor = {"piecewise linear": 0.5, "morlet": 0.5, "zernike": math.sqrt(2.0)}
return (kernel_shape[0] + 1) * theta_cutoff_factor[basis_type] * math.pi / float(nlat - 1)
class DiscreteContinuousEncoder(nn.Module): class DiscreteContinuousEncoder(nn.Module):
def __init__( def __init__(
...@@ -51,8 +58,8 @@ class DiscreteContinuousEncoder(nn.Module): ...@@ -51,8 +58,8 @@ class DiscreteContinuousEncoder(nn.Module):
grid_out="equiangular", grid_out="equiangular",
inp_chans=2, inp_chans=2,
out_chans=2, out_chans=2,
kernel_shape=[3, 4], kernel_shape=(3, 3),
basis_type="piecewise linear", basis_type="morlet",
groups=1, groups=1,
bias=False, bias=False,
): ):
...@@ -70,7 +77,7 @@ class DiscreteContinuousEncoder(nn.Module): ...@@ -70,7 +77,7 @@ class DiscreteContinuousEncoder(nn.Module):
grid_out=grid_out, grid_out=grid_out,
groups=groups, groups=groups,
bias=bias, bias=bias,
theta_cutoff=4.0 * torch.pi / float(out_shape[0] - 1), theta_cutoff=_compute_cutoff_radius(in_shape[0], kernel_shape, basis_type),
) )
def forward(self, x): def forward(self, x):
...@@ -93,11 +100,11 @@ class DiscreteContinuousDecoder(nn.Module): ...@@ -93,11 +100,11 @@ class DiscreteContinuousDecoder(nn.Module):
grid_out="equiangular", grid_out="equiangular",
inp_chans=2, inp_chans=2,
out_chans=2, out_chans=2,
kernel_shape=[3, 4], kernel_shape=(3, 3),
basis_type="piecewise linear", basis_type="morlet",
groups=1, groups=1,
bias=False, bias=False,
upsample_sht=False upsample_sht=False,
): ):
super().__init__() super().__init__()
...@@ -121,7 +128,7 @@ class DiscreteContinuousDecoder(nn.Module): ...@@ -121,7 +128,7 @@ class DiscreteContinuousDecoder(nn.Module):
grid_out=grid_out, grid_out=grid_out,
groups=groups, groups=groups,
bias=False, bias=False,
theta_cutoff=4.0 * torch.pi / float(in_shape[0] - 1), theta_cutoff=_compute_cutoff_radius(in_shape[0], kernel_shape, basis_type),
) )
def forward(self, x): def forward(self, x):
...@@ -152,12 +159,13 @@ class SphericalNeuralOperatorBlock(nn.Module): ...@@ -152,12 +159,13 @@ class SphericalNeuralOperatorBlock(nn.Module):
drop_rate=0.0, drop_rate=0.0,
drop_path=0.0, drop_path=0.0,
act_layer=nn.GELU, act_layer=nn.GELU,
norm_layer=nn.Identity, norm_layer="none",
inner_skip="none", inner_skip="none",
outer_skip="identity", outer_skip="identity",
use_mlp=True, use_mlp=True,
disco_kernel_shape=[3, 4], disco_kernel_shape=(3, 3),
disco_basis_type="piecewise linear", disco_basis_type="morlet",
bias=False,
): ):
super().__init__() super().__init__()
...@@ -171,6 +179,7 @@ class SphericalNeuralOperatorBlock(nn.Module): ...@@ -171,6 +179,7 @@ class SphericalNeuralOperatorBlock(nn.Module):
# convolution layer # convolution layer
if conv_type == "local": if conv_type == "local":
theta_cutoff = 2.0 * _compute_cutoff_radius(forward_transform.nlat, disco_kernel_shape, disco_basis_type)
self.local_conv = DiscreteContinuousConvS2( self.local_conv = DiscreteContinuousConvS2(
input_dim, input_dim,
output_dim, output_dim,
...@@ -180,11 +189,11 @@ class SphericalNeuralOperatorBlock(nn.Module): ...@@ -180,11 +189,11 @@ class SphericalNeuralOperatorBlock(nn.Module):
basis_type=disco_basis_type, basis_type=disco_basis_type,
grid_in=forward_transform.grid, grid_in=forward_transform.grid,
grid_out=inverse_transform.grid, grid_out=inverse_transform.grid,
bias=False, bias=bias,
theta_cutoff=4.0 * (disco_kernel_shape[0] + 1) * torch.pi / float(inverse_transform.nlat - 1), theta_cutoff=theta_cutoff,
) )
elif conv_type == "global": elif conv_type == "global":
self.global_conv = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain_factor, bias=False) self.global_conv = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain_factor, bias=bias)
else: else:
raise ValueError(f"Unknown convolution type {conv_type}") raise ValueError(f"Unknown convolution type {conv_type}")
...@@ -199,8 +208,15 @@ class SphericalNeuralOperatorBlock(nn.Module): ...@@ -199,8 +208,15 @@ class SphericalNeuralOperatorBlock(nn.Module):
else: else:
raise ValueError(f"Unknown skip connection type {inner_skip}") raise ValueError(f"Unknown skip connection type {inner_skip}")
# first normalisation layer # normalisation layer
self.norm0 = norm_layer() if norm_layer == "layer_norm":
self.norm = nn.LayerNorm(normalized_shape=(inverse_transform.nlat, inverse_transform.nlon), eps=1e-6)
elif norm_layer == "instance_norm":
self.norm = nn.InstanceNorm2d(num_features=output_dim, eps=1e-6, affine=True, track_running_stats=False)
elif norm_layer == "none":
self.norm = nn.Identity()
else:
raise NotImplementedError(f"Error, normalization {norm_layer} not implemented.")
# dropout # dropout
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
...@@ -232,9 +248,6 @@ class SphericalNeuralOperatorBlock(nn.Module): ...@@ -232,9 +248,6 @@ class SphericalNeuralOperatorBlock(nn.Module):
else: else:
raise ValueError(f"Unknown skip connection type {outer_skip}") raise ValueError(f"Unknown skip connection type {outer_skip}")
# second normalisation layer
self.norm1 = norm_layer()
def forward(self, x): def forward(self, x):
residual = x residual = x
...@@ -244,7 +257,7 @@ class SphericalNeuralOperatorBlock(nn.Module): ...@@ -244,7 +257,7 @@ class SphericalNeuralOperatorBlock(nn.Module):
elif hasattr(self, "local_conv"): elif hasattr(self, "local_conv"):
x = self.local_conv(x) x = self.local_conv(x)
x = self.norm0(x) x = self.norm(x)
if hasattr(self, "inner_skip"): if hasattr(self, "inner_skip"):
x = x + self.inner_skip(residual) x = x + self.inner_skip(residual)
...@@ -252,8 +265,6 @@ class SphericalNeuralOperatorBlock(nn.Module): ...@@ -252,8 +265,6 @@ class SphericalNeuralOperatorBlock(nn.Module):
if hasattr(self, "mlp"): if hasattr(self, "mlp"):
x = self.mlp(x) x = self.mlp(x)
x = self.norm1(x)
x = self.drop_path(x) x = self.drop_path(x)
if hasattr(self, "outer_skip"): if hasattr(self, "outer_skip"):
...@@ -262,7 +273,7 @@ class SphericalNeuralOperatorBlock(nn.Module): ...@@ -262,7 +273,7 @@ class SphericalNeuralOperatorBlock(nn.Module):
return x return x
class LocalSphericalNeuralOperatorNet(nn.Module): class LocalSphericalNeuralOperator(nn.Module):
""" """
LocalSphericalNeuralOperator module. A spherical neural operator which uses both local and global integral LocalSphericalNeuralOperator module. A spherical neural operator which uses both local and global integral
operators to accureately model both types of solution operators [1]. The architecture is based on the Spherical operators to accureately model both types of solution operators [1]. The architecture is based on the Spherical
...@@ -300,6 +311,8 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -300,6 +311,8 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
Dropout path rate, by default 0.0 Dropout path rate, by default 0.0
normalization_layer : str, optional normalization_layer : str, optional
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm" Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
sfno_block_frequency : int, optional
Hopw often a (global) SFNO block is used, by default 2
hard_thresholding_fraction : float, optional hard_thresholding_fraction : float, optional
Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0 Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0
big_skip : bool, optional big_skip : bool, optional
...@@ -308,6 +321,8 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -308,6 +321,8 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
Whether to use positional embedding, by default True Whether to use positional embedding, by default True
upsample_sht : bool, optional upsample_sht : bool, optional
Use SHT upsampling if true, else linear interpolation Use SHT upsampling if true, else linear interpolation
bias : bool, optional
Whether to use a bias, by default False
Example Example
----------- -----------
...@@ -345,19 +360,20 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -345,19 +360,20 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
embed_dim=256, embed_dim=256,
num_layers=4, num_layers=4,
activation_function="gelu", activation_function="gelu",
kernel_shape=[3, 4], kernel_shape=(3, 3),
encoder_kernel_shape=[3, 4], encoder_kernel_shape=(3, 3),
filter_basis_type="piecewise linear", filter_basis_type="morlet",
use_mlp=True, use_mlp=True,
mlp_ratio=2.0, mlp_ratio=2.0,
drop_rate=0.0, drop_rate=0.0,
drop_path_rate=0.0, drop_path_rate=0.0,
normalization_layer="none", normalization_layer="none",
sfno_block_frequency=2,
hard_thresholding_fraction=1.0, hard_thresholding_fraction=1.0,
use_complex_kernels=True, residual_prediction=False,
big_skip=False, pos_embed="none",
pos_embed=False,
upsample_sht=False, upsample_sht=False,
bias=False,
): ):
super().__init__() super().__init__()
...@@ -373,7 +389,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -373,7 +389,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
self.hard_thresholding_fraction = hard_thresholding_fraction self.hard_thresholding_fraction = hard_thresholding_fraction
self.normalization_layer = normalization_layer self.normalization_layer = normalization_layer
self.use_mlp = use_mlp self.use_mlp = use_mlp
self.big_skip = big_skip self.residual_prediction = residual_prediction
# activation function # activation function
if activation_function == "relu": if activation_function == "relu":
...@@ -394,30 +410,18 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -394,30 +410,18 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0.0 else nn.Identity() self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0.0 else nn.Identity()
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_layers)] dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_layers)]
# pick norm layer if pos_embed == "sequence":
if self.normalization_layer == "layer_norm": self.pos_embed = SequencePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal)
norm_layer0 = partial(nn.LayerNorm, normalized_shape=(self.img_size[0], self.img_size[1]), eps=1e-6) elif pos_embed == "spectral":
norm_layer1 = partial(nn.LayerNorm, normalized_shape=(self.h, self.w), eps=1e-6) self.pos_embed = SpectralPositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal)
elif self.normalization_layer == "instance_norm": elif pos_embed == "learnable lat":
norm_layer0 = partial(nn.InstanceNorm2d, num_features=self.embed_dim, eps=1e-6, affine=True, track_running_stats=False) self.pos_embed = LearnablePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal, embed_type="lat")
norm_layer1 = partial(nn.InstanceNorm2d, num_features=self.embed_dim, eps=1e-6, affine=True, track_running_stats=False) elif pos_embed == "learnable latlon":
elif self.normalization_layer == "none": self.pos_embed = LearnablePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal, embed_type="latlon")
norm_layer0 = nn.Identity elif pos_embed == "none":
norm_layer1 = norm_layer0 self.pos_embed = nn.Identity()
else:
raise NotImplementedError(f"Error, normalization {self.normalization_layer} not implemented.")
if pos_embed == "latlon" or pos_embed == True:
self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, self.h, self.w))
nn.init.constant_(self.pos_embed, 0.0)
elif pos_embed == "lat":
self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, self.h, 1))
nn.init.constant_(self.pos_embed, 0.0)
elif pos_embed == "const":
self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, 1, 1))
nn.init.constant_(self.pos_embed, 0.0)
else: else:
self.pos_embed = None raise ValueError(f"Unknown position embedding type {pos_embed}")
# encoder # encoder
self.encoder = DiscreteContinuousEncoder( self.encoder = DiscreteContinuousEncoder(
...@@ -445,30 +449,22 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -445,30 +449,22 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
self.blocks = nn.ModuleList([]) self.blocks = nn.ModuleList([])
for i in range(self.num_layers): for i in range(self.num_layers):
first_layer = i == 0
last_layer = i == self.num_layers - 1
if first_layer:
norm_layer = norm_layer1
elif last_layer:
norm_layer = norm_layer0
else:
norm_layer = norm_layer1
block = SphericalNeuralOperatorBlock( block = SphericalNeuralOperatorBlock(
self.trans, self.trans,
self.itrans, self.itrans,
self.embed_dim, self.embed_dim,
self.embed_dim, self.embed_dim,
conv_type="global" if i % 2 == 0 else "local", conv_type="global" if i % sfno_block_frequency == (sfno_block_frequency-1) else "local",
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
drop_rate=drop_rate, drop_rate=drop_rate,
drop_path=dpr[i], drop_path=dpr[i],
act_layer=self.activation_function, act_layer=self.activation_function,
norm_layer=norm_layer, norm_layer=self.normalization_layer,
use_mlp=use_mlp, use_mlp=use_mlp,
disco_kernel_shape=kernel_shape, disco_kernel_shape=kernel_shape,
disco_basis_type=filter_basis_type, disco_basis_type=filter_basis_type,
bias=bias,
) )
self.blocks.append(block) self.blocks.append(block)
...@@ -485,17 +481,9 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -485,17 +481,9 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
basis_type=filter_basis_type, basis_type=filter_basis_type,
groups=1, groups=1,
bias=False, bias=False,
upsample_sht=upsample_sht upsample_sht=upsample_sht,
) )
# # residual prediction
# if self.big_skip:
# self.residual_transform = nn.Conv2d(self.out_chans, self.in_chans, 1, bias=False)
# self.residual_transform.weight.is_shared_mp = ["spatial"]
# self.residual_transform.weight.sharded_dims_mp = [None, None, None, None]
# scale = math.sqrt(0.5 / self.in_chans)
# nn.init.normal_(self.residual_transform.weight, mean=0.0, std=scale)
@torch.jit.ignore @torch.jit.ignore
def no_weight_decay(self): def no_weight_decay(self):
return {"pos_embed", "cls_token"} return {"pos_embed", "cls_token"}
...@@ -509,20 +497,19 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -509,20 +497,19 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
return x return x
def forward(self, x): def forward(self, x):
if self.big_skip: if self.residual_prediction:
residual = x residual = x
x = self.encoder(x) x = self.encoder(x)
if self.pos_embed is not None: if self.pos_embed is not None:
x = x + self.pos_embed x = self.pos_embed(x)
x = self.forward_features(x) x = self.forward_features(x)
x = self.decoder(x) x = self.decoder(x)
if self.big_skip: if self.residual_prediction:
# x = x + self.residual_transform(residual)
x = x + residual x = x + residual
return x return x
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import math
import torch
import torch.nn as nn
import torch.amp as amp
from torch_harmonics import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
from torch_harmonics import AttentionS2, NeighborhoodAttentionS2
from torch_harmonics import ResampleS2
from torch_harmonics import RealSHT, InverseRealSHT
from torch_harmonics.quadrature import _precompute_latitudes
from torch_harmonics.examples.models._layers import MLP, LayerNorm, DropPath
from functools import partial
# heuristic for finding theta_cutoff
def _compute_cutoff_radius(nlat, kernel_shape, basis_type):
theta_cutoff_factor = {"piecewise linear": 0.5, "morlet": 0.5, "zernike": math.sqrt(2.0)}
return (kernel_shape[0] + 1) * theta_cutoff_factor[basis_type] * math.pi / float(nlat - 1)
class OverlapPatchMerging(nn.Module):
def __init__(
self,
in_shape=(721, 1440),
out_shape=(481, 960),
grid_in="equiangular",
grid_out="equiangular",
in_channels=3,
out_channels=64,
kernel_shape=(3, 3),
basis_type="morlet",
bias=False,
):
super().__init__()
# convolution for patches, curtoff radius inferred from kernel shape
theta_cutoff = _compute_cutoff_radius(out_shape[0], kernel_shape, basis_type)
self.conv = DiscreteContinuousConvS2(
in_channels,
out_channels,
in_shape=in_shape,
out_shape=out_shape,
kernel_shape=kernel_shape,
basis_type=basis_type,
grid_in=grid_in,
grid_out=grid_out,
bias=bias,
theta_cutoff=theta_cutoff,
)
# layer norm
self.norm = nn.LayerNorm((out_channels), eps=1e-05, elementwise_affine=True, bias=True)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x):
dtype = x.dtype
with amp.autocast(device_type="cuda", enabled=False):
x = x.float()
x = self.conv(x).to(dtype=dtype)
# permute
x = x.permute(0, 2, 3, 1)
x = self.norm(x)
out = x.permute(0, 3, 1, 2)
return out
class MixFFN(nn.Module):
def __init__(
self,
shape,
inout_channels,
hidden_channels,
mlp_bias=True,
grid="equiangular",
kernel_shape=(3, 3),
basis_type="morlet",
conv_bias=False,
activation=nn.GELU,
use_mlp=False,
drop_path=0.0,
):
super().__init__()
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm = nn.LayerNorm((inout_channels), eps=1e-05, elementwise_affine=True, bias=True)
if use_mlp:
# although the paper says MLP, it uses a single linear layer
self.mlp_in = MLP(inout_channels, hidden_features=hidden_channels, out_features=inout_channels, act_layer=activation, output_bias=False, drop_rate=0.0)
else:
self.mlp_in = nn.Conv2d(in_channels=inout_channels, out_channels=inout_channels, kernel_size=1, bias=True)
# convolution for patches, curtoff radius inferred from kernel shape
theta_cutoff = _compute_cutoff_radius(shape[0], kernel_shape, basis_type)
self.conv = DiscreteContinuousConvS2(
inout_channels,
inout_channels,
in_shape=shape,
out_shape=shape,
kernel_shape=kernel_shape,
basis_type=basis_type,
grid_in=grid,
grid_out=grid,
groups=inout_channels,
bias=conv_bias,
theta_cutoff=theta_cutoff,
)
if use_mlp:
self.mlp_out = MLP(inout_channels, hidden_features=hidden_channels, out_features=inout_channels, act_layer=activation, output_bias=False, drop_rate=0.0)
else:
self.mlp_out = nn.Conv2d(in_channels=inout_channels, out_channels=inout_channels, kernel_size=1, bias=True)
self.act = activation()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
# norm
x = x.permute(0, 2, 3, 1)
x = self.norm(x)
x = x.permute(0, 3, 1, 2)
# NOTE: we add another activation here
# because in the paper they only use depthwise conv,
# but without this activation it would just be a fused MM
# with the disco conv
x = self.mlp_in(x)
# conv parth
x = self.act(self.conv(x))
# second linear
x = self.mlp_out(x)
return residual + self.drop_path(x)
class AttentionWrapper(nn.Module):
def __init__(
self,
channels,
shape,
grid,
heads,
pre_norm=False,
attention_drop_rate=0.0,
drop_path=0.0,
attention_mode="neighborhood",
theta_cutoff=None,
bias=True
):
super().__init__()
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.attention_mode = attention_mode
if attention_mode == "neighborhood":
if theta_cutoff is None:
theta_cutoff = (7.0 / math.sqrt(math.pi)) * math.pi / (shape[0] - 1)
self.att = NeighborhoodAttentionS2(
in_channels=channels,
in_shape=shape,
out_shape=shape,
grid_in=grid,
grid_out=grid,
theta_cutoff=theta_cutoff,
out_channels=channels,
num_heads=heads,
bias=bias
# drop_rate=attention_drop_rate,
)
else:
self.att = AttentionS2(
in_channels=channels,
num_heads=heads,
in_shape=shape,
out_shape=shape,
grid_in=grid,
grid_out=grid,
out_channels=channels,
drop_rate=attention_drop_rate,
bias=bias
)
self.norm = None
if pre_norm:
self.norm = nn.LayerNorm((channels), eps=1e-05, elementwise_affine=True, bias=True)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
if self.norm is not None:
x = x.permute(0, 2, 3, 1)
x = self.norm(x)
x = x.permute(0, 3, 1, 2)
if self.attention_mode == "neighborhood":
dtype = x.dtype
with amp.autocast(device_type="cuda", enabled=False):
x = x.float()
x = self.att(x).to(dtype=dtype)
else:
x = self.att(x)
return residual + self.drop_path(x)
class TransformerBlock(nn.Module):
def __init__(
self,
in_shape,
out_shape,
in_channels,
out_channels,
mlp_hidden_channels,
grid_in="equiangular",
grid_out="equiangular",
nrep=1,
heads=1,
kernel_shape=(3, 3),
basis_type="morlet",
activation=nn.GELU,
att_drop_rate=0.0,
drop_path_rates=0.0,
attention_mode="neighborhood",
theta_cutoff=None,
bias=True
):
super().__init__()
self.in_shape = in_shape
self.out_shape = out_shape
self.in_channels = in_channels
self.out_channels = out_channels
if isinstance(drop_path_rates, float):
drop_path_rates = [x.item() for x in torch.linspace(0, drop_path_rates, nrep)]
assert len(drop_path_rates) == nrep
self.fwd = [
OverlapPatchMerging(
in_shape=in_shape,
out_shape=out_shape,
grid_in=grid_in,
grid_out=grid_out,
in_channels=in_channels,
out_channels=out_channels,
kernel_shape=kernel_shape,
basis_type=basis_type,
bias=False,
)
]
for i in range(nrep):
self.fwd.append(
AttentionWrapper(
channels=out_channels,
shape=out_shape,
grid=grid_out,
heads=heads,
pre_norm=True,
attention_drop_rate=att_drop_rate,
drop_path=drop_path_rates[i],
attention_mode=attention_mode,
theta_cutoff=theta_cutoff,
bias=bias
)
)
self.fwd.append(
MixFFN(
out_shape,
inout_channels=out_channels,
hidden_channels=mlp_hidden_channels,
mlp_bias=True,
grid=grid_out,
kernel_shape=kernel_shape,
basis_type=basis_type,
conv_bias=False,
activation=activation,
use_mlp=False,
drop_path=drop_path_rates[i],
)
)
# make sequential
self.fwd = nn.Sequential(*self.fwd)
# final norm
self.norm = nn.LayerNorm((out_channels), eps=1e-05, elementwise_affine=True, bias=True)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fwd(x)
# apply norm
x = x.permute(0, 2, 3, 1)
x = self.norm(x)
x = x.permute(0, 3, 1, 2)
return x
class Upsampling(nn.Module):
def __init__(
self,
in_shape,
out_shape,
in_channels,
out_channels,
hidden_channels,
mlp_bias=True,
grid_in="equiangular",
grid_out="equiangular",
kernel_shape=(3, 3),
basis_type="morlet",
conv_bias=False,
activation=nn.GELU,
use_mlp=False,
upsampling_method="conv"
):
super().__init__()
if use_mlp:
self.mlp = MLP(in_channels, hidden_features=hidden_channels, out_features=out_channels, act_layer=activation, output_bias=False, drop_rate=0.0)
else:
self.mlp = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=True)
if upsampling_method == "conv":
theta_cutoff = _compute_cutoff_radius(in_shape[0], kernel_shape, basis_type)
self.upsample = DiscreteContinuousConvTransposeS2(
out_channels,
out_channels,
in_shape=in_shape,
out_shape=out_shape,
kernel_shape=kernel_shape,
basis_type=basis_type,
grid_in=grid_in,
grid_out=grid_out,
bias=conv_bias,
theta_cutoff=theta_cutoff,
)
elif upsampling_method == "bilinear":
self.upsample = ResampleS2(*in_shape, *out_shape, grid_in=grid_in, grid_out=grid_out)
else:
raise ValueError(f"Unknown upsampling method {upsampling_method}")
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.upsample(self.mlp(x))
return x
class SphericalSegformer(nn.Module):
"""
Spherical segformer model designed to approximate mappings from spherical signals to spherical segmentation masks
Parameters
-----------
img_shape : tuple, optional
Shape of the input channels, by default (128, 256)
kernel_shape: tuple, int
scale_factor: int, optional
Scale factor to use, by default 2
in_chans : int, optional
Number of input channels, by default 3
out_chans : int, optional
Number of classes, by default 3
embed_dims : List[int], optional
Dimension of the embeddings for each block, has to be the same length as heads
heads : List[int], optional
Number of heads for each block in the network, has to be the same length as embed_dims
depths: List[in], optional
Number of repetitions of attentions blocks and ffn mixers per layer. Has to be the same length as embed_dims and heads
activation_function : str, optional
Activation function to use, by default "gelu"
embedder_kernel_shape : int, optional
size of the encoder kernel
filter_basis_type: Optional[str]: str, optional
filter basis type
use_mlp : int, optional
Whether to use MLPs in the SFNO blocks, by default True
mlp_ratio : int, optional
Ratio of MLP to use, by default 2.0
drop_rate : float, optional
Dropout rate, by default 0.0
drop_path_rate : float, optional
Dropout path rate, by default 0.0
normalization_layer : str, optional
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
hard_thresholding_fraction : float, optional
Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0
upsampling_method : str
Conv, bilinear
Example
-----------
>>> model = SphericalTransformer(
... img_shape=(128, 256),
... scale_factor=4,
... in_chans=2,
... out_chans=2,
... embed_dim=16,
... num_layers=4,
... use_mlp=True,)
>>> model(torch.randn(1, 2, 128, 256)).shape
torch.Size([1, 2, 128, 256])
"""
def __init__(
self,
img_size=(128, 256),
grid="equiangular",
grid_internal="legendre-gauss",
in_chans=3,
out_chans=3,
embed_dims=[64, 128, 256, 512],
heads=[1, 2, 4, 8],
depths=[3, 4, 6, 3],
scale_factor=2,
activation_function="gelu",
kernel_shape=(3, 3),
filter_basis_type="morlet",
mlp_ratio=2.0,
att_drop_rate=0.0,
drop_path_rate=0.1,
attention_mode="neighborhood",
theta_cutoff=None,
upsampling_method="bilinear",
bias=True
):
super().__init__()
self.img_size = img_size
self.grid = grid
self.grid_internal = grid_internal
self.in_chans = in_chans
self.out_chans = out_chans
self.embed_dims = embed_dims
self.heads = heads
self.num_blocks = len(self.embed_dims)
self.depths = depths
self.kernel_shape = kernel_shape
assert len(self.heads) == self.num_blocks
assert len(self.depths) == self.num_blocks
# activation function
if activation_function == "relu":
self.activation_function = nn.ReLU
elif activation_function == "gelu":
self.activation_function = nn.GELU
# for debugging purposes
elif activation_function == "identity":
self.activation_function = nn.Identity
else:
raise ValueError(f"Unknown activation function {activation_function}")
# set up drop path rates
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
self.blocks = nn.ModuleList([])
out_shape = img_size
grid_in = grid
grid_out = grid_internal
in_channels = in_chans
cur = 0
for i in range(self.num_blocks):
out_shape_new = (out_shape[0] // scale_factor, out_shape[1] // scale_factor)
out_channels = self.embed_dims[i]
self.blocks.append(
TransformerBlock(
in_shape=out_shape,
out_shape=out_shape_new,
in_channels=in_channels,
out_channels=out_channels,
mlp_hidden_channels=int(mlp_ratio * out_channels),
grid_in=grid_in,
grid_out=grid_out,
nrep=self.depths[i],
heads=self.heads[i],
kernel_shape=kernel_shape,
basis_type=filter_basis_type,
activation=self.activation_function,
att_drop_rate=att_drop_rate,
drop_path_rates=dpr[cur : cur + self.depths[i]],
attention_mode=attention_mode,
theta_cutoff=theta_cutoff,
bias=bias
)
)
cur += self.depths[i]
out_shape = out_shape_new
grid_in = grid_internal
in_channels = out_channels
self.upsamplers = nn.ModuleList([])
out_shape = img_size
grid_out = grid
for i in range(self.num_blocks):
in_shape = self.blocks[i].out_shape
self.upsamplers.append(
Upsampling(
in_shape=in_shape,
out_shape=out_shape,
in_channels=self.embed_dims[i],
out_channels=self.embed_dims[i],
hidden_channels=int(mlp_ratio * self.embed_dims[i]),
mlp_bias=True,
grid_in=grid_internal,
grid_out=grid,
kernel_shape=kernel_shape,
basis_type=filter_basis_type,
conv_bias=False,
activation=nn.GELU,
upsampling_method=upsampling_method
)
)
segmentation_head_dim = sum(self.embed_dims)
self.segmentation_head = nn.Conv2d(in_channels=segmentation_head_dim, out_channels=out_chans, kernel_size=1, bias=True)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x):
# encoder:
features = []
feat = x
for block in self.blocks:
feat = block(feat)
features.append(feat)
# perform upsample
upfeats = []
for feat, upsampler in zip(features, self.upsamplers):
upfeats.append(upsampler(feat))
# perform concatenation
upfeats = torch.cat(upfeats, dim=1)
# final upsampling and prediction
out = self.segmentation_head(upfeats)
return out
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import math
import torch
import torch.nn as nn
import torch.amp as amp
from torch_harmonics import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
from torch_harmonics import NeighborhoodAttentionS2, AttentionS2
from torch_harmonics import ResampleS2
from torch_harmonics import RealSHT, InverseRealSHT
from torch_harmonics.quadrature import _precompute_latitudes
from torch_harmonics.examples.models._layers import MLP, DropPath, LayerNorm, SequencePositionEmbedding, SpectralPositionEmbedding, LearnablePositionEmbedding
from functools import partial
# heuristic for finding theta_cutoff
def _compute_cutoff_radius(nlat, kernel_shape, basis_type):
theta_cutoff_factor = {"piecewise linear": 0.5, "morlet": 0.5, "zernike": math.sqrt(2.0)}
return (kernel_shape[0] + 1) * theta_cutoff_factor[basis_type] * math.pi / float(nlat - 1)
class DiscreteContinuousEncoder(nn.Module):
def __init__(
self,
in_shape=(721, 1440),
out_shape=(480, 960),
grid_in="equiangular",
grid_out="equiangular",
in_chans=2,
out_chans=2,
kernel_shape=(3, 3),
basis_type="morlet",
groups=1,
bias=False,
):
super().__init__()
# set up local convolution
self.conv = DiscreteContinuousConvS2(
in_chans,
out_chans,
in_shape=in_shape,
out_shape=out_shape,
kernel_shape=kernel_shape,
basis_type=basis_type,
grid_in=grid_in,
grid_out=grid_out,
groups=groups,
bias=bias,
theta_cutoff=_compute_cutoff_radius(in_shape[0], kernel_shape, basis_type),
)
def forward(self, x):
dtype = x.dtype
with amp.autocast(device_type="cuda", enabled=False):
x = x.float()
x = self.conv(x)
x = x.to(dtype=dtype)
return x
class DiscreteContinuousDecoder(nn.Module):
def __init__(
self,
in_shape=(480, 960),
out_shape=(721, 1440),
grid_in="equiangular",
grid_out="equiangular",
in_chans=2,
out_chans=2,
kernel_shape=(3, 3),
basis_type="morlet",
groups=1,
bias=False,
upsample_sht=False,
):
super().__init__()
# set up upsampling
if upsample_sht:
self.sht = RealSHT(*in_shape, grid=grid_in).float()
self.isht = InverseRealSHT(*out_shape, lmax=self.sht.lmax, mmax=self.sht.mmax, grid=grid_out).float()
self.upsample = nn.Sequential(self.sht, self.isht)
else:
self.upsample = ResampleS2(*in_shape, *out_shape, grid_in=grid_in, grid_out=grid_out)
# set up DISCO convolution
self.conv = DiscreteContinuousConvS2(
in_chans,
out_chans,
in_shape=out_shape,
out_shape=out_shape,
kernel_shape=kernel_shape,
basis_type=basis_type,
grid_in=grid_out,
grid_out=grid_out,
groups=groups,
bias=False,
theta_cutoff=_compute_cutoff_radius(in_shape[0], kernel_shape, basis_type),
)
def forward(self, x):
dtype = x.dtype
with amp.autocast(device_type="cuda", enabled=False):
x = x.float()
x = self.upsample(x)
x = self.conv(x)
x = x.to(dtype=dtype)
return x
class SphericalAttentionBlock(nn.Module):
"""
Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
"""
def __init__(
self,
in_shape=(480, 960),
out_shape=(480, 960),
grid_in="equiangular",
grid_out="equiangular",
in_chans=2,
out_chans=2,
num_heads=1,
mlp_ratio=2.0,
drop_rate=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer="none",
use_mlp=True,
bias=False,
attention_mode="neighborhood",
theta_cutoff=None,
):
super().__init__()
# normalisation layer
if norm_layer == "layer_norm":
self.norm0 = LayerNorm(in_channels=in_chans, eps=1e-6)
self.norm1 = LayerNorm(in_channels=out_chans, eps=1e-6)
elif norm_layer == "instance_norm":
self.norm0 = nn.InstanceNorm2d(num_features=in_chans, eps=1e-6, affine=True, track_running_stats=False)
self.norm1 = nn.InstanceNorm2d(num_features=out_chans, eps=1e-6, affine=True, track_running_stats=False)
elif norm_layer == "none":
self.norm0 = nn.Identity()
self.norm1 = nn.Identity()
else:
raise NotImplementedError(f"Error, normalization {norm_layer} not implemented.")
# determine radius for neighborhood attention
self.attention_mode = attention_mode
if attention_mode == "neighborhood":
if theta_cutoff is None:
theta_cutoff = (7.0 / math.sqrt(math.pi)) * math.pi / (in_shape[0] - 1)
self.self_attn = NeighborhoodAttentionS2(
in_channels=in_chans,
in_shape=in_shape,
out_shape=out_shape,
grid_in=grid_in,
grid_out=grid_out,
num_heads=num_heads,
theta_cutoff=theta_cutoff,
k_channels=None,
out_channels=out_chans,
bias=bias,
)
else:
self.self_attn = AttentionS2(
in_channels=in_chans,
num_heads=num_heads,
in_shape=in_shape,
out_shape=out_shape,
grid_in=grid_in,
grid_out=grid_out,
out_channels=out_chans,
drop_rate=drop_rate,
bias=bias,
)
self.skip0 = nn.Identity()
# dropout
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
if use_mlp == True:
mlp_hidden_dim = int(out_chans * mlp_ratio)
self.mlp = MLP(
in_features=out_chans,
out_features=out_chans,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop_rate=drop_rate,
checkpointing=False,
gain=0.5,
)
self.skip1 = nn.Identity()
def forward(self, x):
residual = x
x = self.norm0(x)
if self.attention_mode == "neighborhood":
dtype = x.dtype
with amp.autocast(device_type="cuda", enabled=False):
x = x.float()
x = self.self_attn(x).to(dtype=dtype)
else:
x = self.self_attn(x)
if hasattr(self, "skip0"):
x = x + self.skip0(residual)
residual = x
x = self.norm1(x)
if hasattr(self, "mlp"):
x = self.mlp(x)
x = self.drop_path(x)
if hasattr(self, "skip1"):
x = x + self.skip1(residual)
return x
class SphericalTransformer(nn.Module):
"""
Spherical transformer model designed to approximate mappings from spherical signals to spherical signals
Parameters
-----------
img_shape : tuple, optional
Shape of the input channels, by default (128, 256)
kernel_shape: tuple, int
scale_factor : int, optional
Scale factor to use, by default 3
in_chans : int, optional
Number of input channels, by default 3
out_chans : int, optional
Number of output channels, by default 3
embed_dim : int, optional
Dimension of the embeddings, by default 256
num_layers : int, optional
Number of layers in the network, by default 4
activation_function : str, optional
Activation function to use, by default "gelu"
encoder_kernel_shape : int, optional
size of the encoder kernel
filter_basis_type: str, optional
filter basis type
num_heads: int, optional
number of attention heads
use_mlp : int, optional
Whether to use MLPs in the SFNO blocks, by default True
mlp_ratio : int, optional
Ratio of MLP to use, by default 2.0
drop_rate : float, optional
Dropout rate, by default 0.0
drop_path_rate : float, optional
Dropout path rate, by default 0.0
normalization_layer : str, optional
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
hard_thresholding_fraction : float, optional
Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0
residual_prediction : bool, optional
Whether to add a single large skip connection, by default True
pos_embed : bool, optional
Whether to use positional embedding, by default True
upsample_sht : bool, optional
Use SHT upsampling if true, else linear interpolation
bias : bool, optional
Whether to use a bias, by default False
Example
-----------
>>> model = SphericalTransformer(
... img_shape=(128, 256),
... scale_factor=4,
... in_chans=2,
... out_chans=2,
... embed_dim=16,
... num_layers=4,
... use_mlp=True,)
>>> model(torch.randn(1, 2, 128, 256)).shape
torch.Size([1, 2, 128, 256])
"""
def __init__(
self,
img_size=(128, 256),
grid="equiangular",
grid_internal="legendre-gauss",
scale_factor=3,
in_chans=3,
out_chans=3,
embed_dim=256,
num_layers=4,
activation_function="gelu",
encoder_kernel_shape=(3, 3),
filter_basis_type="morlet",
num_heads=1,
use_mlp=True,
mlp_ratio=2.0,
drop_rate=0.0,
drop_path_rate=0.0,
normalization_layer="none",
hard_thresholding_fraction=1.0,
residual_prediction=False,
pos_embed="spectral",
upsample_sht=False,
attention_mode="neighborhood",
bias=False,
theta_cutoff=None,
):
super().__init__()
self.img_size = img_size
self.grid = grid
self.grid_internal = grid_internal
self.scale_factor = scale_factor
self.in_chans = in_chans
self.out_chans = out_chans
self.embed_dim = embed_dim
self.num_layers = num_layers
self.encoder_kernel_shape = encoder_kernel_shape
self.hard_thresholding_fraction = hard_thresholding_fraction
self.normalization_layer = normalization_layer
self.use_mlp = use_mlp
self.residual_prediction = residual_prediction
# activation function
if activation_function == "relu":
self.activation_function = nn.ReLU
elif activation_function == "gelu":
self.activation_function = nn.GELU
# for debugging purposes
elif activation_function == "identity":
self.activation_function = nn.Identity
else:
raise ValueError(f"Unknown activation function {activation_function}")
# compute downsampled image size. We assume that the latitude-grid includes both poles
self.h = (self.img_size[0] - 1) // scale_factor + 1
self.w = self.img_size[1] // scale_factor
# dropout
self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0.0 else nn.Identity()
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_layers)]
if pos_embed == "sequence":
self.pos_embed = SequencePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal)
elif pos_embed == "spectral":
self.pos_embed = SpectralPositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal)
elif pos_embed == "learnable lat":
self.pos_embed = LearnablePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal, embed_type="lat")
elif pos_embed == "learnable latlon":
self.pos_embed = LearnablePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal, embed_type="latlon")
elif pos_embed == "none":
self.pos_embed = nn.Identity()
else:
raise ValueError(f"Unknown position embedding type {pos_embed}")
# maybe keep for now becuase tr
# encoder
self.encoder = DiscreteContinuousEncoder(
in_shape=self.img_size,
out_shape=(self.h, self.w),
grid_in=grid,
grid_out=grid_internal,
in_chans=self.in_chans,
out_chans=self.embed_dim,
kernel_shape=self.encoder_kernel_shape,
basis_type=filter_basis_type,
groups=1,
bias=False,
)
self.blocks = nn.ModuleList([])
for i in range(self.num_layers):
block = SphericalAttentionBlock(
in_shape=(self.h, self.w),
out_shape=(self.h, self.w),
grid_in=grid_internal,
grid_out=grid_internal,
in_chans=self.embed_dim,
out_chans=self.embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
drop_rate=drop_rate,
drop_path=dpr[i],
act_layer=self.activation_function,
norm_layer=self.normalization_layer,
use_mlp=use_mlp,
bias=bias,
attention_mode=attention_mode,
theta_cutoff=theta_cutoff,
)
self.blocks.append(block)
# decoder
self.decoder = DiscreteContinuousDecoder(
in_shape=(self.h, self.w),
out_shape=self.img_size,
grid_in=grid_internal,
grid_out=grid,
in_chans=self.embed_dim,
out_chans=self.out_chans,
kernel_shape=self.encoder_kernel_shape,
basis_type=filter_basis_type,
groups=1,
bias=False,
upsample_sht=upsample_sht,
)
@torch.jit.ignore
def no_weight_decay(self):
return {"pos_embed", "cls_token"}
def forward_features(self, x):
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
return x
def forward(self, x):
if self.residual_prediction:
residual = x
x = self.encoder(x)
if self.pos_embed is not None:
# x = x + self.pos_embed
x = self.pos_embed(x)
x = self.forward_features(x)
x = self.decoder(x)
if self.residual_prediction:
x = x + residual
return x
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import math
import torch
import torch.nn as nn
import torch.amp as amp
from torch_harmonics import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
from torch_harmonics import NeighborhoodAttentionS2
from torch_harmonics import ResampleS2
from torch_harmonics import RealSHT, InverseRealSHT
from torch_harmonics.quadrature import _precompute_latitudes
from torch_harmonics.examples.models._layers import MLP, DropPath
from functools import partial
# heuristic for finding theta_cutoff
def _compute_cutoff_radius(nlat, kernel_shape, basis_type):
theta_cutoff_factor = {"piecewise linear": 0.5, "morlet": 0.5, "zernike": math.sqrt(2.0)}
return (kernel_shape[0] + 1) * theta_cutoff_factor[basis_type] * math.pi / float(nlat - 1)
class DownsamplingBlock(nn.Module):
def __init__(
self,
in_shape,
out_shape,
in_channels,
out_channels,
grid_in="equiangular",
grid_out="equiangular",
nrep=1,
kernel_shape=(3, 3),
basis_type="morlet",
activation=nn.ReLU,
transform_skip=False,
drop_conv_rate=0.0,
drop_path_rate=0.0,
drop_dense_rate=0.0,
downsampling_mode="bilinear",
):
super().__init__()
self.in_shape = in_shape
self.out_shape = out_shape
self.in_channels = in_channels
self.out_channels = out_channels
self.grid_in = grid_in
self.grid_out = grid_out
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
self.fwd = []
for i in range(nrep):
# conv
theta_cutoff = _compute_cutoff_radius(in_shape[0], kernel_shape, basis_type)
self.fwd.append(
DiscreteContinuousConvS2(
in_channels=(in_channels if i == 0 else out_channels),
out_channels=out_channels,
in_shape=in_shape,
out_shape=in_shape,
kernel_shape=kernel_shape,
basis_type=basis_type,
grid_in=grid_out,
grid_out=grid_out,
bias=False,
theta_cutoff=theta_cutoff,
)
)
if drop_conv_rate > 0.0:
self.fwd.append(nn.Dropout2d(p=drop_conv_rate))
# batchnorm
self.fwd.append(nn.BatchNorm2d(out_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))
# activation
self.fwd.append(
activation(),
)
if downsampling_mode == "conv":
theta_cutoff = _compute_cutoff_radius(out_shape[0], kernel_shape, basis_type)
self.downsample = DiscreteContinuousConvS2(
out_channels,
out_channels,
in_shape=in_shape,
out_shape=out_shape,
kernel_shape=kernel_shape,
basis_type=basis_type,
grid_in=grid_in,
grid_out=grid_out,
bias=False,
theta_cutoff=theta_cutoff,
)
else:
self.downsample = ResampleS2(
nlat_in=in_shape[0],
nlon_in=in_shape[1],
nlat_out=out_shape[0],
nlon_out=out_shape[1],
grid_in=grid_in,
grid_out=grid_out,
mode=downsampling_mode,
)
# make sequential
self.fwd = nn.Sequential(*self.fwd)
# final norm
if transform_skip or (in_channels != out_channels):
self.transform_skip = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True)
if drop_dense_rate > 0.0:
self.transform_skip = nn.Sequential(
self.transform_skip,
nn.Dropout2d(p=drop_dense_rate),
)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# skip connection
residual = x
if hasattr(self, "transform_skip"):
residual = self.transform_skip(residual)
# main path
x = self.fwd(x)
# add residual connection
x = residual + self.drop_path(x)
# downsample
x = self.downsample(x)
return x
class UpsamplingBlock(nn.Module):
def __init__(
self,
in_shape,
out_shape,
in_channels,
out_channels,
grid_in="equiangular",
grid_out="equiangular",
nrep=1,
kernel_shape=(3, 3),
basis_type="morlet",
activation=nn.ReLU,
transform_skip=False,
drop_conv_rate=0.0,
drop_path_rate=0.0,
drop_dense_rate=0.0,
upsampling_mode="bilinear",
):
super().__init__()
self.in_shape = in_shape
self.out_shape = out_shape
self.in_channels = in_channels
self.out_channels = out_channels
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
if in_shape != out_shape:
if upsampling_mode == "conv":
theta_cutoff = _compute_cutoff_radius(in_shape[0], kernel_shape, basis_type)
self.upsample = nn.Sequential(
DiscreteContinuousConvTransposeS2(
in_channels=out_channels,
out_channels=out_channels,
in_shape=in_shape,
out_shape=out_shape,
kernel_shape=kernel_shape,
basis_type=basis_type,
grid_in=grid_in,
grid_out=grid_out,
bias=False,
theta_cutoff=theta_cutoff,
),
nn.BatchNorm2d(out_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
activation(),
DiscreteContinuousConvS2(
in_channels=out_channels,
out_channels=out_channels,
in_shape=out_shape,
out_shape=out_shape,
kernel_shape=kernel_shape,
basis_type=basis_type,
grid_in=grid_in,
grid_out=grid_out,
bias=False,
theta_cutoff=theta_cutoff,
),
)
else:
self.upsample = ResampleS2(
nlat_in=in_shape[0],
nlon_in=in_shape[1],
nlat_out=out_shape[0],
nlon_out=out_shape[1],
grid_in=grid_in,
grid_out=grid_out,
mode=upsampling_mode,
)
else:
theta_cutoff = _compute_cutoff_radius(in_shape[0], kernel_shape, basis_type)
self.upsample = DiscreteContinuousConvS2(
in_channels=out_channels,
out_channels=out_channels,
in_shape=in_shape,
out_shape=in_shape,
kernel_shape=kernel_shape,
basis_type=basis_type,
grid_in=grid_in,
grid_out=grid_out,
bias=False,
theta_cutoff=theta_cutoff,
)
self.fwd = []
for i in range(nrep):
# conv
theta_cutoff = _compute_cutoff_radius(in_shape[0], kernel_shape, basis_type)
self.fwd.append(
DiscreteContinuousConvS2(
in_channels=in_channels,
out_channels=(out_channels if i == nrep - 1 else in_channels),
in_shape=in_shape,
out_shape=in_shape,
kernel_shape=kernel_shape,
basis_type=basis_type,
grid_in=grid_in,
grid_out=grid_in,
bias=False,
theta_cutoff=theta_cutoff,
)
)
if drop_conv_rate > 0.0:
self.fwd.append(nn.Dropout2d(p=drop_conv_rate))
# batchnorm
self.fwd.append(nn.BatchNorm2d((out_channels if i == nrep - 1 else in_channels), eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))
# activation
self.fwd.append(
activation(),
)
# make sequential
self.fwd = nn.Sequential(*self.fwd)
# final norm
if transform_skip or (in_channels != out_channels):
self.transform_skip = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True)
if drop_dense_rate > 0.0:
self.transform_skip = nn.Sequential(
self.transform_skip,
nn.Dropout2d(p=drop_dense_rate),
)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# skip connection
residual = x
if hasattr(self, "transform_skip"):
residual = self.transform_skip(residual)
# main path
x = residual + self.drop_path(self.fwd(x))
# upsampling
x = self.upsample(x)
return x
class SphericalUNet(nn.Module):
"""
Spherical segformer model designed to approximate mappings from spherical signals to spherical segmentation masks
Parameters
-----------
img_shape : tuple, optional
Shape of the input channels, by default (128, 256)
kernel_shape: tuple, int
scale_factor: int, optional
Scale factor to use, by default 2
in_chans : int, optional
Number of input channels, by default 3
out_chans : int, optional
Number of classes, by default 3
embed_dims : List[int], optional
Dimension of the embeddings for each block, has to be the same length as depths
depths: List[in], optional
Number of repetitions of conv blocks and ffn mixers per layer. Has to be the same length as embed_dims
activation_function : str, optional
Activation function to use, by default "relu"
embedder_kernel_shape : int, optional
size of the encoder kernel
filter_basis_type: Optional[str]: str, optional
filter basis type
use_mlp : int, optional
Whether to use MLPs in the SFNO blocks, by default True
mlp_ratio : int, optional
Ratio of MLP to use, by default 2.0
drop_rate : float, optional
Dropout rate, by default 0.0
drop_path_rate : float, optional
Dropout path rate, by default 0.0
normalization_layer : str, optional
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
hard_thresholding_fraction : float, optional
Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0
upsample_sht : bool, optional
Use SHT upsampling if true, else linear interpolation
Example
-----------
>>> model = SphericalTransformer(
... img_shape=(128, 256),
... scale_factor=4,
... in_chans=2,
... out_chans=2,
... embed_dim=16,
... num_layers=4,
... use_mlp=True,)
>>> model(torch.randn(1, 2, 128, 256)).shape
torch.Size([1, 2, 128, 256])
"""
def __init__(
self,
img_size=(128, 256),
grid="equiangular",
grid_internal="legendre-gauss",
in_chans=3,
out_chans=3,
embed_dims=[64, 128, 256, 512],
depths=[2, 2, 2, 2],
scale_factor=2,
activation_function="relu",
kernel_shape=(3, 3),
filter_basis_type="morlet",
transform_skip=False,
drop_conv_rate=0.1,
drop_path_rate=0.1,
drop_dense_rate=0.5,
downsampling_mode="bilinear",
upsampling_mode="bilinear",
):
super().__init__()
self.img_size = img_size
self.grid = grid
self.grid_internal = grid_internal
self.in_chans = in_chans
self.out_chans = out_chans
self.embed_dims = embed_dims
self.num_blocks = len(self.embed_dims)
self.depths = depths
self.kernel_shape = kernel_shape
assert len(self.depths) == self.num_blocks
# activation function
if activation_function == "relu":
self.activation_function = nn.ReLU
elif activation_function == "gelu":
self.activation_function = nn.GELU
# for debugging purposes
elif activation_function == "identity":
self.activation_function = nn.Identity
else:
raise ValueError(f"Unknown activation function {activation_function}")
# set up drop path rates
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_blocks)]
self.dblocks = nn.ModuleList([])
out_shape = img_size
grid_in = grid
grid_out = grid_internal
in_channels = in_chans
for i in range(self.num_blocks):
out_shape_new = (out_shape[0] // scale_factor, out_shape[1] // scale_factor)
out_channels = self.embed_dims[i]
self.dblocks.append(
DownsamplingBlock(
in_shape=out_shape,
out_shape=out_shape_new,
in_channels=in_channels,
out_channels=out_channels,
grid_in=grid_in,
grid_out=grid_out,
nrep=self.depths[i],
kernel_shape=kernel_shape,
basis_type=filter_basis_type,
activation=self.activation_function,
drop_conv_rate=drop_conv_rate,
drop_path_rate=dpr[i],
drop_dense_rate=drop_dense_rate,
transform_skip=transform_skip,
downsampling_mode=downsampling_mode,
)
)
out_shape = out_shape_new
grid_in = grid_internal
in_channels = out_channels
self.ublocks = nn.ModuleList([])
for i in range(self.num_blocks - 1, -1, -1):
in_shape = self.dblocks[i].out_shape
out_shape = self.dblocks[i].in_shape
in_channels = self.dblocks[i].out_channels
if i != self.num_blocks - 1:
in_channels = 2 * in_channels
out_channels = self.dblocks[i].in_channels
if i == 0:
out_channels = self.embed_dims[0]
grid_in = self.dblocks[i].grid_out
grid_out = self.dblocks[i].grid_in
self.ublocks.append(
UpsamplingBlock(
in_shape=in_shape,
out_shape=out_shape,
in_channels=in_channels,
out_channels=out_channels,
grid_in=grid_in,
grid_out=grid_out,
kernel_shape=kernel_shape,
basis_type=filter_basis_type,
activation=self.activation_function,
drop_conv_rate=drop_conv_rate,
drop_path_rate=0.0,
drop_dense_rate=drop_dense_rate,
transform_skip=transform_skip,
upsampling_mode=upsampling_mode,
)
)
self.head = nn.Conv2d(self.embed_dims[0], self.out_chans, kernel_size=1, bias=True)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x):
# encoder:
features = []
feat = x
for dblock in self.dblocks:
feat = dblock(feat)
features.append(feat)
# reverse list
features = features[::-1]
# perform upsample
ufeat = self.ublocks[0](features[0])
for feat, ublock in zip(features[1:], self.ublocks[1:]):
ufeat = ublock(torch.cat([feat, ufeat], dim=1))
# last layer
out = self.head(ufeat)
return out
...@@ -28,12 +28,14 @@ ...@@ -28,12 +28,14 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# #
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch_harmonics import RealSHT, InverseRealSHT from torch_harmonics import RealSHT, InverseRealSHT
from ._layers import * from torch_harmonics.examples.models._layers import MLP, SpectralConvS2, SequencePositionEmbedding, SpectralPositionEmbedding, LearnablePositionEmbedding
from functools import partial from functools import partial
...@@ -53,10 +55,11 @@ class SphericalFourierNeuralOperatorBlock(nn.Module): ...@@ -53,10 +55,11 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
drop_rate=0.0, drop_rate=0.0,
drop_path=0.0, drop_path=0.0,
act_layer=nn.GELU, act_layer=nn.GELU,
norm_layer=nn.Identity, norm_layer="none",
inner_skip="none", inner_skip="none",
outer_skip="identity", outer_skip="identity",
use_mlp=True, use_mlp=True,
bias=False,
): ):
super().__init__() super().__init__()
...@@ -68,7 +71,7 @@ class SphericalFourierNeuralOperatorBlock(nn.Module): ...@@ -68,7 +71,7 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
if inner_skip == "linear" or inner_skip == "identity": if inner_skip == "linear" or inner_skip == "identity":
gain_factor /= 2.0 gain_factor /= 2.0
self.global_conv = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain_factor, bias=False) self.global_conv = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain_factor, bias=bias)
if inner_skip == "linear": if inner_skip == "linear":
self.inner_skip = nn.Conv2d(input_dim, output_dim, 1, 1) self.inner_skip = nn.Conv2d(input_dim, output_dim, 1, 1)
...@@ -81,8 +84,15 @@ class SphericalFourierNeuralOperatorBlock(nn.Module): ...@@ -81,8 +84,15 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
else: else:
raise ValueError(f"Unknown skip connection type {inner_skip}") raise ValueError(f"Unknown skip connection type {inner_skip}")
# first normalisation layer # normalisation layer
self.norm0 = norm_layer() if norm_layer == "layer_norm":
self.norm = nn.LayerNorm(normalized_shape=(inverse_transform.nlat, inverse_transform.nlon), eps=1e-6)
elif norm_layer == "instance_norm":
self.norm = nn.InstanceNorm2d(num_features=output_dim, eps=1e-6, affine=True, track_running_stats=False)
elif norm_layer == "none":
self.norm = nn.Identity()
else:
raise NotImplementedError(f"Error, normalization {self.norm_layer} not implemented.")
# dropout # dropout
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
...@@ -108,14 +118,12 @@ class SphericalFourierNeuralOperatorBlock(nn.Module): ...@@ -108,14 +118,12 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
else: else:
raise ValueError(f"Unknown skip connection type {outer_skip}") raise ValueError(f"Unknown skip connection type {outer_skip}")
# second normalisation layer
self.norm1 = norm_layer()
def forward(self, x): def forward(self, x):
x, residual = self.global_conv(x) x, residual = self.global_conv(x)
x = self.norm0(x) x = self.norm(x)
if hasattr(self, "inner_skip"): if hasattr(self, "inner_skip"):
x = x + self.inner_skip(residual) x = x + self.inner_skip(residual)
...@@ -123,8 +131,6 @@ class SphericalFourierNeuralOperatorBlock(nn.Module): ...@@ -123,8 +131,6 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
if hasattr(self, "mlp"): if hasattr(self, "mlp"):
x = self.mlp(x) x = self.mlp(x)
x = self.norm1(x)
x = self.drop_path(x) x = self.drop_path(x)
if hasattr(self, "outer_skip"): if hasattr(self, "outer_skip"):
...@@ -133,7 +139,7 @@ class SphericalFourierNeuralOperatorBlock(nn.Module): ...@@ -133,7 +139,7 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
return x return x
class SphericalFourierNeuralOperatorNet(nn.Module): class SphericalFourierNeuralOperator(nn.Module):
""" """
SphericalFourierNeuralOperator module. Implements the 'linear' variant of the Spherical Fourier Neural Operator SphericalFourierNeuralOperator module. Implements the 'linear' variant of the Spherical Fourier Neural Operator
as presented in [1]. Spherical convolutions are applied via spectral transforms to apply a geometrically consistent as presented in [1]. Spherical convolutions are applied via spectral transforms to apply a geometrically consistent
...@@ -169,14 +175,16 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -169,14 +175,16 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm" Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
hard_thresholding_fraction : float, optional hard_thresholding_fraction : float, optional
Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0 Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0
big_skip : bool, optional residual_prediction : bool, optional
Whether to add a single large skip connection, by default True Whether to add a single large skip connection, by default True
pos_embed : bool, optional pos_embed : bool, optional
Whether to use positional embedding, by default True Whether to use positional embedding, by default True
bias : bool, optional
Whether to use a bias, by default False
Example: Example:
-------- --------
>>> model = SphericalFourierNeuralOperatorNet( >>> model = SphericalFourierNeuralOperator(
... img_shape=(128, 256), ... img_shape=(128, 256),
... scale_factor=4, ... scale_factor=4,
... in_chans=2, ... in_chans=2,
...@@ -212,9 +220,9 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -212,9 +220,9 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
drop_path_rate=0.0, drop_path_rate=0.0,
normalization_layer="none", normalization_layer="none",
hard_thresholding_fraction=1.0, hard_thresholding_fraction=1.0,
use_complex_kernels=True, residual_prediction=False,
big_skip=False, pos_embed="none",
pos_embed=False, bias=False,
): ):
super().__init__() super().__init__()
...@@ -231,7 +239,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -231,7 +239,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
self.normalization_layer = normalization_layer self.normalization_layer = normalization_layer
self.use_mlp = use_mlp self.use_mlp = use_mlp
self.encoder_layers = encoder_layers self.encoder_layers = encoder_layers
self.big_skip = big_skip self.residual_prediction = residual_prediction
# activation function # activation function
if activation_function == "relu": if activation_function == "relu":
...@@ -252,30 +260,18 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -252,30 +260,18 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0.0 else nn.Identity() self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0.0 else nn.Identity()
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_layers)] dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_layers)]
# pick norm layer if pos_embed == "sequence":
if self.normalization_layer == "layer_norm": self.pos_embed = SequencePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal)
norm_layer0 = partial(nn.LayerNorm, normalized_shape=(self.img_size[0], self.img_size[1]), eps=1e-6) elif pos_embed == "spectral":
norm_layer1 = partial(nn.LayerNorm, normalized_shape=(self.h, self.w), eps=1e-6) self.pos_embed = SpectralPositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal)
elif self.normalization_layer == "instance_norm": elif pos_embed == "learnable lat":
norm_layer0 = partial(nn.InstanceNorm2d, num_features=self.embed_dim, eps=1e-6, affine=True, track_running_stats=False) self.pos_embed = LearnablePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal, embed_type="lat")
norm_layer1 = partial(nn.InstanceNorm2d, num_features=self.embed_dim, eps=1e-6, affine=True, track_running_stats=False) elif pos_embed == "learnable latlon":
elif self.normalization_layer == "none": self.pos_embed = LearnablePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal, embed_type="latlon")
norm_layer0 = nn.Identity elif pos_embed == "none":
norm_layer1 = norm_layer0 self.pos_embed = nn.Identity()
else:
raise NotImplementedError(f"Error, normalization {self.normalization_layer} not implemented.")
if pos_embed == "latlon" or pos_embed == True:
self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, self.img_size[0], self.img_size[1]))
nn.init.constant_(self.pos_embed, 0.0)
elif pos_embed == "lat":
self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, self.img_size[0], 1))
nn.init.constant_(self.pos_embed, 0.0)
elif pos_embed == "const":
self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, 1, 1))
nn.init.constant_(self.pos_embed, 0.0)
else: else:
self.pos_embed = None raise ValueError(f"Unknown position embedding type {pos_embed}")
# construct an encoder with num_encoder_layers # construct an encoder with num_encoder_layers
num_encoder_layers = 1 num_encoder_layers = 1
...@@ -292,7 +288,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -292,7 +288,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
encoder_layers.append(fc) encoder_layers.append(fc)
encoder_layers.append(self.activation_function()) encoder_layers.append(self.activation_function())
current_dim = encoder_hidden_dim current_dim = encoder_hidden_dim
fc = nn.Conv2d(current_dim, self.embed_dim, 1, bias=False) fc = nn.Conv2d(current_dim, self.embed_dim, 1, bias=bias)
scale = math.sqrt(1.0 / current_dim) scale = math.sqrt(1.0 / current_dim)
nn.init.normal_(fc.weight, mean=0.0, std=scale) nn.init.normal_(fc.weight, mean=0.0, std=scale)
if fc.bias is not None: if fc.bias is not None:
...@@ -318,13 +314,6 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -318,13 +314,6 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
first_layer = i == 0 first_layer = i == 0
last_layer = i == self.num_layers - 1 last_layer = i == self.num_layers - 1
if first_layer:
norm_layer = norm_layer1
elif last_layer:
norm_layer = norm_layer0
else:
norm_layer = norm_layer1
block = SphericalFourierNeuralOperatorBlock( block = SphericalFourierNeuralOperatorBlock(
self.trans_down if first_layer else self.trans, self.trans_down if first_layer else self.trans,
self.itrans_up if last_layer else self.itrans, self.itrans_up if last_layer else self.itrans,
...@@ -334,8 +323,9 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -334,8 +323,9 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
drop_rate=drop_rate, drop_rate=drop_rate,
drop_path=dpr[i], drop_path=dpr[i],
act_layer=self.activation_function, act_layer=self.activation_function,
norm_layer=norm_layer, norm_layer=self.normalization_layer,
use_mlp=use_mlp, use_mlp=use_mlp,
bias=bias,
) )
self.blocks.append(block) self.blocks.append(block)
...@@ -343,7 +333,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -343,7 +333,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
# construct an decoder with num_decoder_layers # construct an decoder with num_decoder_layers
num_decoder_layers = 1 num_decoder_layers = 1
decoder_hidden_dim = int(self.embed_dim * mlp_ratio) decoder_hidden_dim = int(self.embed_dim * mlp_ratio)
current_dim = self.embed_dim + self.big_skip * self.in_chans current_dim = self.embed_dim
decoder_layers = [] decoder_layers = []
for l in range(num_decoder_layers - 1): for l in range(num_decoder_layers - 1):
fc = nn.Conv2d(current_dim, decoder_hidden_dim, 1, bias=True) fc = nn.Conv2d(current_dim, decoder_hidden_dim, 1, bias=True)
...@@ -355,7 +345,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -355,7 +345,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
decoder_layers.append(fc) decoder_layers.append(fc)
decoder_layers.append(self.activation_function()) decoder_layers.append(self.activation_function())
current_dim = decoder_hidden_dim current_dim = decoder_hidden_dim
fc = nn.Conv2d(current_dim, self.out_chans, 1, bias=False) fc = nn.Conv2d(current_dim, self.out_chans, 1, bias=bias)
scale = math.sqrt(1.0 / current_dim) scale = math.sqrt(1.0 / current_dim)
nn.init.normal_(fc.weight, mean=0.0, std=scale) nn.init.normal_(fc.weight, mean=0.0, std=scale)
if fc.bias is not None: if fc.bias is not None:
...@@ -378,19 +368,19 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -378,19 +368,19 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
def forward(self, x): def forward(self, x):
if self.big_skip: if self.residual_prediction:
residual = x residual = x
x = self.encoder(x) x = self.encoder(x)
if self.pos_embed is not None: if self.pos_embed is not None:
x = x + self.pos_embed x = self.pos_embed(x)
x = self.forward_features(x) x = self.forward_features(x)
if self.big_skip:
x = torch.cat((x, residual), dim=1)
x = self.decoder(x) x = self.decoder(x)
if self.residual_prediction:
x = x + residual
return x return x
...@@ -32,7 +32,7 @@ ...@@ -32,7 +32,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch_harmonics as harmonics import torch_harmonics as th
from torch_harmonics.quadrature import _precompute_longitudes from torch_harmonics.quadrature import _precompute_longitudes
import math import math
...@@ -61,19 +61,19 @@ class SphereSolver(nn.Module): ...@@ -61,19 +61,19 @@ class SphereSolver(nn.Module):
self.register_buffer('coeff', torch.as_tensor(coeff, dtype=torch.float64)) self.register_buffer('coeff', torch.as_tensor(coeff, dtype=torch.float64))
# SHT # SHT
self.sht = harmonics.RealSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False) self.sht = th.RealSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False)
self.isht = harmonics.InverseRealSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False) self.isht = th.InverseRealSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False)
self.lmax = lmax or self.sht.lmax self.lmax = lmax or self.sht.lmax
self.mmax = lmax or self.sht.mmax self.mmax = lmax or self.sht.mmax
# compute gridpoints # compute gridpoints
if self.grid == "legendre-gauss": if self.grid == "legendre-gauss":
cost, _ = harmonics.quadrature.legendre_gauss_weights(self.nlat, -1, 1) cost, _ = th.quadrature.legendre_gauss_weights(self.nlat, -1, 1)
elif self.grid == "lobatto": elif self.grid == "lobatto":
cost, _ = harmonics.quadrature.lobatto_weights(self.nlat, -1, 1) cost, _ = th.quadrature.lobatto_weights(self.nlat, -1, 1)
elif self.grid == "equiangular": elif self.grid == "equiangular":
cost, _ = harmonics.quadrature.clenshaw_curtiss_weights(self.nlat, -1, 1) cost, _ = th.quadrature.clenshaw_curtiss_weights(self.nlat, -1, 1)
# apply cosine transform and flip them # apply cosine transform and flip them
lats = -torch.arcsin(cost) lats = -torch.arcsin(cost)
......
...@@ -32,7 +32,7 @@ ...@@ -32,7 +32,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch_harmonics as harmonics import torch_harmonics as th
from torch_harmonics.quadrature import _precompute_longitudes from torch_harmonics.quadrature import _precompute_longitudes
import math import math
...@@ -64,21 +64,21 @@ class ShallowWaterSolver(nn.Module): ...@@ -64,21 +64,21 @@ class ShallowWaterSolver(nn.Module):
self.register_buffer('hamp', torch.as_tensor(hamp, dtype=torch.float64)) self.register_buffer('hamp', torch.as_tensor(hamp, dtype=torch.float64))
# SHT # SHT
self.sht = harmonics.RealSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False) self.sht = th.RealSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False)
self.isht = harmonics.InverseRealSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False) self.isht = th.InverseRealSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False)
self.vsht = harmonics.RealVectorSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False) self.vsht = th.RealVectorSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False)
self.ivsht = harmonics.InverseRealVectorSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False) self.ivsht = th.InverseRealVectorSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False)
self.lmax = lmax or self.sht.lmax self.lmax = lmax or self.sht.lmax
self.mmax = lmax or self.sht.mmax self.mmax = lmax or self.sht.mmax
# compute gridpoints # compute gridpoints
if self.grid == "legendre-gauss": if self.grid == "legendre-gauss":
cost, quad_weights = harmonics.quadrature.legendre_gauss_weights(self.nlat, -1, 1) cost, quad_weights = th.quadrature.legendre_gauss_weights(self.nlat, -1, 1)
elif self.grid == "lobatto": elif self.grid == "lobatto":
cost, quad_weights = harmonics.quadrature.lobatto_weights(self.nlat, -1, 1) cost, quad_weights = th.quadrature.lobatto_weights(self.nlat, -1, 1)
elif self.grid == "equiangular": elif self.grid == "equiangular":
cost, quad_weights = harmonics.quadrature.clenshaw_curtiss_weights(self.nlat, -1, 1) cost, quad_weights = th.quadrature.clenshaw_curtiss_weights(self.nlat, -1, 1)
quad_weights = quad_weights.reshape(-1, 1) quad_weights = quad_weights.reshape(-1, 1)
......
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import os
import math
import torch
from torch.utils.data import Dataset, DataLoader, Subset
import numpy as np
from torch_harmonics.quadrature import _precompute_latitudes
from torch_harmonics.examples.losses import get_quadrature_weights
# some specifiers where to find the dataset
DEFAULT_BASE_URL = "https://cvg-data.inf.ethz.ch/2d3ds/no_xyz/"
DEFAULT_TAR_FILE_PAIRS = [
("area_1_no_xyz.tar", "area_1"),
("area_2_no_xyz.tar", "area_2"),
("area_3_no_xyz.tar", "area_3"),
("area_4_no_xyz.tar", "area_4"),
("area_5a_no_xyz.tar", "area_5a"),
("area_5b_no_xyz.tar", "area_5b"),
("area_6_no_xyz.tar", "area_6"),
]
DEFAULT_LABELS_URL = "https://raw.githubusercontent.com/alexsax/2D-3D-Semantics/refs/heads/master/assets/semantic_labels.json"
class Stanford2D3DSDownloader:
"""
Convenience class for downloading the 2d3ds dataset [1].
References
-----------
.. [1] Armeni, I., Sax, S., Zamir, A. R., Savarese, S.;
"Joint 2D-3D-Semantic Data for Indoor Scene Understanding" (2017).
https://arxiv.org/abs/1702.01105.
"""
def __init__(self, base_url: str = DEFAULT_BASE_URL, local_dir: str = "data"):
self.base_url = base_url
self.local_dir = local_dir
os.makedirs(self.local_dir, exist_ok=True)
def _download_file(self, filename):
import requests
from tqdm import tqdm
url = f"{self.base_url}/{filename}"
local_path = os.path.join(self.local_dir, filename)
if os.path.exists(local_path):
print(f"Note: Skipping download for {filename}, because it already exists")
return local_path
print(f"Downloading {filename}...")
temp_path = local_path.split(".")[0] + ".part"
# Resume logic
headers = {}
if os.path.exists(temp_path):
headers = {"Range": f"bytes={os.stat(temp_path).st_size}-"}
response = requests.get(url, headers=headers, stream=True, timeout=30)
if os.path.exists(temp_path):
total_size = int(response.headers.get("content-length", 0)) + os.stat(temp_path).st_size
else:
total_size = int(response.headers.get("content-length", 0))
with open(temp_path, "ab") as f, tqdm(desc=filename, total=total_size, unit="B", unit_scale=True, unit_divisor=1024, initial=os.stat(temp_path).st_size) as pbar:
for chunk in response.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
pbar.update(len(chunk))
os.rename(temp_path, local_path)
return local_path
def _extract_tar(self, tar_path):
import tarfile
with tarfile.open(tar_path) as tar:
tar.extractall(path=self.local_dir)
tar_filenames = tar.getnames()
extracted_dir = tar_filenames[0]
os.remove(tar_path)
return extracted_dir
def download_dataset(self, file_extracted_directory_pairs=DEFAULT_TAR_FILE_PAIRS):
import requests
data_folders = []
for file, extracted_folder_name in file_extracted_directory_pairs:
if not os.path.exists(os.path.join(self.local_dir, extracted_folder_name)):
downloaded_file = self._download_file(file)
data_folders.append(self._extract_tar(downloaded_file))
else:
print(f"Warning: Skipping D/L for '{file}' because folder '{extracted_folder_name}' already exists")
data_folders.append(extracted_folder_name)
labels_json_url = DEFAULT_LABELS_URL
class_labels = requests.get(labels_json_url).json()
return data_folders, class_labels
def _rgb_to_id(self, img, class_labels_map, class_labels_indices):
# Convert to int32 first to avoid overflow
r = img[..., 0].astype(np.int32)
g = img[..., 1].astype(np.int32)
b = img[..., 2].astype(np.int32)
lookup_indices = r * 256 * 256 + g * 256 + b
def _convert(lookup: int) -> int:
# the dataset has a bad label for clutter, so we need to fix it
# clutter is 855309, but the labels file has it as 3341
# The original conversion used uint8, which overflowed the clutter label to 3341
# this is a fix to handle that accidental usage of undefined overflow behavior
if lookup == 855309:
label = class_labels_map[3341] # clutter
else:
label = class_labels_map[lookup]
class_index = class_labels_indices.index(label)
return class_index
lookup_fn = np.vectorize(_convert)
return lookup_fn(lookup_indices)
def convert_dataset(
self,
data_folders,
class_labels,
rgb_path: str = "pano/rgb",
semantic_path: str = "pano/semantic",
depth_path: str = "pano/depth",
output_filename="semantic",
dataset_file: str = "stanford_2d3ds_dataset.h5",
downsampling_factor: int = 16,
remove_alpha_channel: bool = True,
):
converted_dataset_path = os.path.join(self.local_dir, dataset_file)
from PIL import Image
from tqdm import tqdm
import h5py as h5
file_paths = []
min_vals = None
max_vals = None
# condition class labels first:
class_labels_map = [label.split("_")[0] for label in class_labels]
class_labels_indices = sorted(list(set(class_labels_map)))
# get all the file path input, output pairs
for base_path in data_folders:
rgb_dir = os.path.join(self.local_dir, base_path, rgb_path)
semantic_dir = os.path.join(self.local_dir, base_path, semantic_path)
depth_dir = os.path.join(self.local_dir, base_path, depth_path)
if os.path.exists(rgb_dir) and os.path.exists(semantic_dir) and os.path.exists(depth_dir):
for file_input in os.listdir(rgb_dir):
if not file_input.endswith(".png"):
continue
rgb_filepath = os.path.join(rgb_dir, file_input)
semantic_filepath = "_".join(os.path.splitext(os.path.basename(rgb_filepath))[0].split("_")[:-1]) + f"_{output_filename}.png"
semantic_filepath = os.path.join(semantic_dir, semantic_filepath)
depth_filepath = "_".join(os.path.splitext(os.path.basename(rgb_filepath))[0].split("_")[:-1]) + f"_depth.png"
depth_filepath = os.path.join(depth_dir, depth_filepath)
if not os.path.exists(semantic_filepath):
print(f"Warning: Couldn't find output file in pair: ({rgb_filepath},{semantic_filepath})")
continue
if not os.path.exists(depth_filepath):
print(f"Warning: Couldn't find depth file in pair: ({rgb_filepath},{depth_filepath})")
continue
file_paths.append((rgb_filepath, semantic_filepath, depth_filepath))
elif not os.path.exists(rgb_dir):
print("Warning: RGB dir doesn't exist: ", rgb_dir)
continue
elif not os.path.exists(semantic_dir):
print("Warning: Semantic dir doesn't exist: ", semantic_dir)
continue
elif not os.path.exists(depth_dir):
print("Warning: Depth dir doesn't exist: ", depth_dir)
continue
num_samples = len(file_paths)
if num_samples > 0:
first_rgb, first_semantic, first_depth = file_paths[0]
first_rgb = np.array(Image.open(first_rgb))
# first_semantic = np.array(Image.open(first_semantic))
# first_depth = np.array(Image.open(first_depth))
rgb_shape = first_rgb.shape
img_shape = (rgb_shape[0] // downsampling_factor, rgb_shape[1] // downsampling_factor)
rgb_channels = rgb_shape[2]
if remove_alpha_channel:
rgb_channels = 3
else:
raise ValueError(f"No samples found")
# create the dataset file
with h5.File(converted_dataset_path, "w") as h5file:
rgb_data = h5file.create_dataset("rgb", (num_samples, rgb_channels, *img_shape), "f4")
semantic_data = h5file.create_dataset("semantic", (num_samples, *img_shape), "i8")
depth_data = h5file.create_dataset("depth", (num_samples, *img_shape), "f4")
classes = h5file.create_dataset("class_labels", data=class_labels_indices)
num_classes = len(set(class_labels_indices))
data_source_path = h5file.create_dataset("data_source_path", (num_samples,), dtype=h5.string_dtype(encoding="utf-8"))
data_target_path = h5file.create_dataset("data_target_path", (num_samples,), dtype=h5.string_dtype(encoding="utf-8"))
# prepare computation of the class histogram
class_histogram = np.zeros(num_classes)
_, quad_weights = _precompute_latitudes(nlat=img_shape[0], grid="equiangular")
quad_weights = quad_weights.reshape(-1, 1) * 2 * torch.pi / float(img_shape[1])
quad_weights = quad_weights.tile(1, img_shape[1])
quad_weights /= torch.sum(quad_weights)
quad_weights = quad_weights.numpy()
for count in tqdm(range(num_samples), desc="preparing dataset"):
# open image
img = Image.open(file_paths[count][0])
# downsampling
if downsampling_factor != 1:
# first width, then weight, weird
img = img.resize(size=(img_shape[1], img_shape[0]), resample=Image.BILINEAR)
# remove alpha channel if requested
if remove_alpha_channel:
img = img.convert("RGBA")
background = Image.new("RGBA", img.size, (255, 255, 255))
# compoe foreground and background and remove alpha channel
img = np.array(Image.alpha_composite(background, img))
r_data = img[:, :, :3]
else:
r_data = np.array(img)
# transpose to channels first
r_data = np.transpose(r_data / 255.0, axes=(2, 0, 1))
# write to disk
rgb_data[count, ...] = r_data[...]
data_source_path[count] = file_paths[count][0]
# compute stats -> segmentation
# min/max
tmp_min = np.min(r_data, axis=(1, 2))
tmp_max = np.max(r_data, axis=(1, 2))
# mean/var
tmp_mean = np.sum(r_data * quad_weights[np.newaxis, :, :], axis=(1, 2))
tmp_m2 = np.sum(np.square(r_data - tmp_mean[:, np.newaxis, np.newaxis]) * quad_weights[np.newaxis, :, :])
if count == 0:
# min/max
min_vals = tmp_min
max_vals = tmp_max
# mean/var
mean_vals = tmp_mean
m2_vals = tmp_m2
else:
# min/max
min_vals = np.minimum(min_vals, tmp_min)
max_vals = np.minimum(max_vals, tmp_max)
# mean/var
delta = tmp_mean - mean_vals
mean_vals += delta / float(count + 1)
m2_vals += tmp_m2 + delta * delta * float(count / (count + 1))
# get the target
sem = Image.open(file_paths[count][1])
# downsampling
if downsampling_factor != 1:
sem = sem.resize(size=(img_shape[1], img_shape[0]), resample=Image.NEAREST)
sem_data = np.array(sem, dtype=np.uint32)
# map to classes
sem_data = self._rgb_to_id(sem_data, class_labels_map, class_labels_indices)
# write to file
semantic_data[count, ...] = sem_data[...]
data_target_path[count] = file_paths[count][1]
# Here we want depth
dep = Image.open(file_paths[count][2])
if downsampling_factor != 1:
dep = dep.resize(size=(img_shape[1], img_shape[0]), resample=Image.NEAREST)
dep_data = np.array(dep)
depth_data[count, ...] = dep_data[...] / 65536.0
# compute stats -> depth
# min/max
tmp_min_depth = np.min(dep_data, axis=(0, 1))
tmp_max_depth = np.max(dep_data, axis=(0, 1))
# mean/var
tmp_mean_depth = np.sum(dep_data * quad_weights[:, :])
tmp_m2_depth = np.sum(np.square(dep_data - tmp_mean_depth) * quad_weights[:, :])
if count == 0:
min_vals_depth = tmp_min_depth
max_vals_depth = tmp_max_depth
mean_vals_depth = tmp_mean_depth
m2_vals_depth = tmp_m2_depth
else:
min_vals_depth = np.minimum(min_vals_depth, tmp_min_depth)
max_vals_depth = np.minimum(max_vals_depth, tmp_max_depth)
delta = tmp_mean_depth - mean_vals_depth
mean_vals_depth += delta / float(count + 1)
m2_vals_depth += tmp_m2_depth + delta * delta * float(count / (count + 1))
# update the class histogram
for c in range(num_classes):
class_histogram[c] += quad_weights[sem_data == c].sum()
# record min/max
h5file.create_dataset("min_rgb", data=min_vals.astype(np.float32))
h5file.create_dataset("max_rgb", data=max_vals.astype(np.float32))
h5file.create_dataset("mean_rgb", data=mean_vals.astype(np.float32))
std_vals = np.sqrt(m2_vals / float(num_samples - 1))
h5file.create_dataset("std_rgb", data=std_vals.astype(np.float32))
# record min/max
h5file.create_dataset("min_depth", data=min_vals_depth.astype(np.float32))
h5file.create_dataset("max_depth", data=max_vals_depth.astype(np.float32))
h5file.create_dataset("mean_depth", data=mean_vals_depth.astype(np.float32))
std_vals_depth = np.sqrt(m2_vals_depth / float(num_samples - 1))
h5file.create_dataset("std_depth", data=std_vals_depth.astype(np.float32))
# record class histogram
class_histogram = class_histogram / num_samples
h5file.create_dataset("class_histogram", data=class_histogram.astype(np.float32))
return converted_dataset_path
def prepare_dataset(self, file_extracted_directory_pairs=DEFAULT_TAR_FILE_PAIRS, dataset_file: str = "stanford_2d3ds_dataset.h5", downsampling_factor: int = 16):
converted_dataset_path = os.path.join(self.local_dir, dataset_file)
if os.path.exists(converted_dataset_path):
print(
f"Dataset file at {converted_dataset_path} already exists. Skipping download and conversion. If you want to create a new dataset file, delete or rename the existing file."
)
return converted_dataset_path
data_folders, class_labels = self.download_dataset(file_extracted_directory_pairs=file_extracted_directory_pairs)
converted_dataset_path = self.convert_dataset(data_folders=data_folders, class_labels=class_labels, dataset_file=dataset_file, downsampling_factor=downsampling_factor)
self.converted_dataset_path = converted_dataset_path
return self.converted_dataset_path
class StanfordSegmentationDataset(Dataset):
"""
Spherical segmentation dataset from [1].
References
-----------
.. [1] Armeni, I., Sax, S., Zamir, A. R., Savarese, S.;
"Joint 2D-3D-Semantic Data for Indoor Scene Understanding" (2017).
https://arxiv.org/abs/1702.01105.
"""
def __init__(
self,
dataset_file,
ignore_alpha_channel=True,
exclude_polar_fraction=0,
):
import h5py as h5
self.dataset_file = dataset_file
self.exclude_polar_fraction = exclude_polar_fraction
with h5.File(self.dataset_file, "r") as h5file:
self.img_rgb = h5file["rgb"][0].shape
self.img_seg = h5file["semantic"][0].shape
self.num_samples = h5file["rgb"].shape[0]
self.num_classes = h5file["class_labels"].shape[0]
self.class_labels = [class_name.decode("utf-8") for class_name in h5file["class_labels"][...].tolist()]
self.class_histogram = np.array(h5file["class_histogram"][...])
self.class_histogram = self.class_histogram / self.class_histogram.sum()
self.mean = h5file["mean_rgb"][...]
self.std = h5file["std_rgb"][...]
self.min = h5file["min_rgb"][...]
self.max = h5file["max_rgb"][...]
self.img_filepath = h5file["data_source_path"][...]
self.tar_filepath = h5file["data_target_path"][...]
if ignore_alpha_channel:
self.img_rgb = (3, self.img_rgb[1], self.img_rgb[2])
# open file and check for
self.h5file = None
self.rgb = None
self.semantic = None
# return index set to false by default
# when true, the __getitem__ method will return the index of the input,target pair
self.return_index = False
@property
def target_shape(self):
return self.img_seg
@property
def input_shape(self):
return self.img_rgb
def set_return_index(self, return_index: bool):
self.return_index = return_index
def get_img_filepath(self, idx: int):
return self.img_filepath[idx]
def get_tar_filepath(self, idx: int):
return self.tar_filepath[idx]
def _id_to_class(self, class_id):
if class_id > self.num_classes:
print("WARNING: ID > number of classes!")
return None
return self.segmentation_classes[class_id]
def _mask_invalid(self, tar):
return np.where(tar >= self.num_classes, -100, tar)
def __len__(self):
return self.num_samples
def _init_files(self):
import h5py as h5
self.h5file = h5.File(self.dataset_file, "r")
self.rgb = self.h5file["rgb"]
self.semantic = self.h5file["semantic"]
def reset(self):
self.rgb = None
self.semantic = None
if self.h5file is not None:
self.h5file.close()
del self.h5file
self.h5file = None
def __getitem__(self, idx, mask_invalid=True):
if self.h5file is None:
# init files
self._init_files()
rgb = self.rgb[idx, 0 : self.img_rgb[0], 0 : self.img_rgb[1], 0 : self.img_rgb[2]]
sem = self.semantic[idx, 0 : self.img_seg[0], 0 : self.img_seg[1]]
if mask_invalid:
sem = self._mask_invalid(sem)
if self.exclude_polar_fraction > 0:
hcut = int(self.exclude_polar_fraction * sem.shape[0])
if hcut > 0:
sem[0:hcut, :] = -100
sem[-hcut:, :] = -100
return rgb, sem
class StanfordDatasetSubset(Subset):
def __init__(self, dataset, indices, return_index=False):
super().__init__(dataset, indices)
self.return_index = return_index
self.dataset = dataset
def set_return_index(self, value):
self.return_index = value
def __getitem__(self, index):
real_index = self.indices[index]
data = self.dataset[real_index]
if self.return_index:
return data[0], data[1], real_index
else:
# Otherwise, return only (data, target)
return data[0], data[1]
class StanfordDepthDataset(Dataset):
"""
Spherical segmentation dataset from [1].
References
-----------
.. [1] Armeni, I., Sax, S., Zamir, A. R., Savarese, S.;
"Joint 2D-3D-Semantic Data for Indoor Scene Understanding" (2017).
https://arxiv.org/abs/1702.01105.
"""
def __init__(self, dataset_file, ignore_alpha_channel=True, log_depth=False, exclude_polar_fraction=0.0):
import h5py as h5
self.dataset_file = dataset_file
self.log_depth = log_depth
self.exclude_polar_fraction = exclude_polar_fraction
with h5.File(self.dataset_file, "r") as h5file:
self.img_rgb = h5file["rgb"][0].shape
self.img_depth = h5file["depth"][0].shape
self.num_samples = h5file["rgb"].shape[0]
self.mean_in = h5file["mean_rgb"][...]
self.std_in = h5file["std_rgb"][...]
self.min_in = h5file["min_rgb"][...]
self.max_in = h5file["max_rgb"][...]
self.mean_out = h5file["mean_depth"][...]
self.std_out = h5file["std_depth"][...]
self.min_out = h5file["min_depth"][...]
self.max_out = h5file["max_depth"][...]
if ignore_alpha_channel:
self.img_rgb = (3, self.img_rgb[1], self.img_rgb[2])
# open file and check for
self.h5file = None
self.rgb = None
self.depth = None
@property
def target_shape(self):
return self.img_depth
@property
def input_shape(self):
return self.img_rgb
def __len__(self):
return self.num_samples
def _init_files(self):
import h5py as h5
self.h5file = h5.File(self.dataset_file, "r")
self.rgb = self.h5file["rgb"]
self.depth = self.h5file["depth"]
def reset(self):
self.rgb = None
self.depth = None
if self.h5file is not None:
self.h5file.close()
del self.h5file
self.h5file = None
def _mask_invalid(self, tar):
return tar * np.where(tar == tar.max(), 0, 1)
def __getitem__(self, idx, mask_invalid=True):
if self.h5file is None:
# init files
self._init_files()
rgb = self.rgb[idx, 0 : self.img_rgb[0], 0 : self.img_rgb[1], 0 : self.img_rgb[2]]
depth = self.depth[idx, 0 : self.img_depth[0], 0 : self.img_depth[1]]
if mask_invalid:
depth = self._mask_invalid(depth)
if self.exclude_polar_fraction > 0:
hcut = int(self.exclude_polar_fraction * depth.shape[0])
if hcut > 0:
depth[0:hcut, :] = 0
depth[-hcut:, :] = 0
if self.log_depth:
depth = np.log(1 + depth)
return rgb, depth
def compute_stats_s2(dataset: Dataset, normalize_target: bool = False):
"""
Compute stats using parallel welford reduction and quadrature on the sphere. The parallel welford reduction follows this article (parallel algorithm): https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
"""
nexamples = len(dataset)
count = 0
for isample in range(nexamples):
token = dataset[isample]
# dimension of inp and tar are (3, nlat, nlon)
inp, tar = token
nlat = tar.shape[-2]
nlon = tar.shape[-1]
# pre-compute quadrature weights
if isample == 0:
quad_weights = get_quadrature_weights(nlat=inp.shape[1], nlon=inp.shape[2], grid="equiangular", tile=True).numpy().astype(np.float64)
# this is a special case for the depth dataset
# TODO: maybe make this an argument
if normalize_target:
mask = np.where(tar == 0, 0, 1)
masked_area = np.sum(mask * quad_weights[np.newaxis, :, :], axis=(-2, -1))
# get initial welford values
if isample == 0:
# input
inp_means = np.sum(inp * quad_weights[np.newaxis, :, :], axis=(-2, -1))
inp_m2s = np.sum(np.square(inp - inp_means[:, np.newaxis, np.newaxis]) * quad_weights[np.newaxis, :, :], axis=(-2, -1))
# target
if normalize_target:
tar_means = np.sum(mask * tar * quad_weights[np.newaxis, :, :], axis=(-2, -1)) / masked_area
tar_m2s = np.sum(mask * np.square(tar - tar_means[:, np.newaxis, np.newaxis]) * quad_weights[np.newaxis, :, :], axis=(-2, -1)) / masked_area
# update count
count = 1
# do welford update
else:
# input
# get new mean and m2
inp_mean = np.sum(inp * quad_weights[np.newaxis, :, :], axis=(-2, -1))
inp_m2 = np.sum(np.square(inp - inp_mean[:, np.newaxis, np.newaxis]) * quad_weights[np.newaxis, :, :], axis=(-2, -1))
# update welford values
inp_delta = inp_mean - inp_means
inp_m2s = inp_m2s + inp_m2 + inp_delta**2 * count / float(count + 1)
inp_means = inp_means + inp_delta / float(count + 1)
# target
if normalize_target:
# get new mean and m2
tar_mean = np.sum(mask * tar * quad_weights[np.newaxis, :, :], axis=(-2, -1)) / masked_area
tar_m2 = np.sum(mask * np.square(tar - tar_mean[:, np.newaxis, np.newaxis]) * quad_weights[np.newaxis, :, :], axis=(-2, -1)) / masked_area
# update welford values
tar_delta = tar_mean - tar_means
tar_m2s = tar_m2s + tar_m2 + tar_delta**2 * count / float(count + 1)
tar_means = tar_means + tar_delta / float(count + 1)
# update count
count += 1
# finalize
inp_stds = np.sqrt(inp_m2s / float(count))
result = (inp_means.astype(np.float32), inp_stds.astype(np.float32))
if normalize_target:
tar_stds = np.sqrt(tar_m2s / float(count))
result += (tar_means.astype(np.float32), tar_stds.astype(np.float32))
return result
...@@ -30,70 +30,158 @@ ...@@ -30,70 +30,158 @@
# #
import numpy as np import numpy as np
import matplotlib.pyplot as plt import os
import cartopy
import cartopy.crs as ccrs # guarded imports
try:
import matplotlib.pyplot as plt
except ImportError as err:
plt = None
try:
import cartopy
import cartopy.crs as ccrs
except ImportError as err:
cartopy = None
ccrs = None
def check_plotting_dependencies():
if plt is None:
raise ImportError("matplotlib is required for plotting functions. Install it with 'pip install matplotlib'")
if cartopy is None:
raise ImportError("cartopy is required for map plotting. Install it with 'pip install cartopy'")
def get_projection(
projection,
central_latitude=0,
central_longitude=0,
):
if projection == "orthographic":
proj = ccrs.Orthographic(central_latitude=central_latitude, central_longitude=central_longitude)
elif projection == "robinson":
proj = ccrs.Robinson(central_longitude=central_longitude)
elif projection == "platecarree":
proj = ccrs.PlateCarree(central_longitude=central_longitude)
elif projection == "mollweide":
proj = ccrs.Mollweide(central_longitude=central_longitude)
else:
raise ValueError(f"Unknown projection mode {projection}")
return proj
def plot_sphere(
data, fig=None, projection="robinson", cmap="RdBu", title=None, colorbar=False, coastlines=False, gridlines=False, central_latitude=0, central_longitude=0, lon=None, lat=None, **kwargs
):
"""
Plots a function defined on the sphere using pcolormesh
"""
# make sure cartopy exist
check_plotting_dependencies()
def plot_sphere(data, fig=None, cmap="RdBu", title=None, colorbar=False, coastlines=False, gridlines=False, central_latitude=0, central_longitude=0, lon=None, lat=None, **kwargs):
if fig == None: if fig == None:
fig = plt.figure() fig = plt.figure()
nlat = data.shape[-2] nlat = data.shape[-2]
nlon = data.shape[-1] nlon = data.shape[-1]
if lon is None: if lon is None:
lon = np.linspace(0, 2 * np.pi, nlon) lon = np.linspace(0, 2 * np.pi, nlon + 1)[:-1]
if lat is None: if lat is None:
lat = np.linspace(np.pi / 2.0, -np.pi / 2.0, nlat) lat = np.linspace(np.pi / 2.0, -np.pi / 2.0, nlat)
Lon, Lat = np.meshgrid(lon, lat) Lon, Lat = np.meshgrid(lon, lat)
proj = ccrs.Orthographic(central_longitude=central_longitude, central_latitude=central_latitude) # convert radians to degrees
# proj = ccrs.Mollweide(central_longitude=central_longitude)
ax = fig.add_subplot(projection=proj)
Lon = Lon * 180 / np.pi Lon = Lon * 180 / np.pi
Lat = Lat * 180 / np.pi Lat = Lat * 180 / np.pi
# get the projection. Latitude is shifted to match plot_sphere
proj = get_projection(projection, central_latitude=central_latitude, central_longitude=central_longitude)
ax = fig.add_subplot(projection=proj)
# contour data over the map. # contour data over the map.
im = ax.pcolormesh(Lon, Lat, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=False, **kwargs) im = ax.pcolormesh(Lon, Lat, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=False, **kwargs)
# add features if requested
if coastlines: if coastlines:
ax.add_feature(cartopy.feature.COASTLINE, edgecolor="white", facecolor="none", linewidth=1.5) ax.add_feature(cartopy.feature.COASTLINE, edgecolor="white", facecolor="none", linewidth=1.5)
if gridlines:
gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=False, linewidth=1.5, color="gray", alpha=0.6, linestyle="--") # add colorbar if requested
if colorbar: if colorbar:
plt.colorbar(im, extend="both") plt.colorbar(im)
plt.title(title, y=1.05)
# add gridlines
if gridlines:
gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=False, linewidth=1, color="gray", alpha=0.6, linestyle="--")
# add title with smaller font
plt.title(title, y=1.05, fontsize=8)
return im return im
def plot_data(data, fig=None, cmap="RdBu", title=None, colorbar=False, coastlines=False, gridlines=False, central_longitude=0, lon=None, lat=None, **kwargs): def imshow_sphere(data, fig=None, projection="robinson", title=None, central_latitude=0, central_longitude=0, **kwargs):
"""
Displays an image on the sphere
"""
# make sure cartopy exist
check_plotting_dependencies()
if fig == None: if fig == None:
fig = plt.figure() fig = plt.figure()
nlat = data.shape[-2] # get the projection. Latitude is shifted to match plot_sphere
nlon = data.shape[-1] proj = get_projection(projection, central_latitude=central_latitude, central_longitude=central_longitude + 180)
if lon is None:
lon = np.linspace(0, 2 * np.pi, nlon)
if lat is None:
lat = np.linspace(np.pi / 2.0, -np.pi / 2.0, nlat)
Lon, Lat = np.meshgrid(lon, lat)
proj = ccrs.PlateCarree(central_longitude=central_longitude)
# proj = ccrs.Mollweide(central_longitude=central_longitude)
ax = fig.add_subplot(projection=proj) ax = fig.add_subplot(projection=proj)
Lon = Lon * 180 / np.pi
Lat = Lat * 180 / np.pi
# contour data over the map. # contour data over the map.
im = ax.pcolormesh(Lon, Lat, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=False, **kwargs) im = ax.imshow(data, transform=ccrs.PlateCarree(), **kwargs)
if coastlines:
ax.add_feature(cartopy.feature.COASTLINE, edgecolor="white", facecolor="none", linewidth=1.5) # add title
if gridlines:
gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=False, linewidth=1.5, color="gray", alpha=0.6, linestyle="--")
if colorbar:
plt.colorbar(im, extend="both")
plt.title(title, y=1.05) plt.title(title, y=1.05)
return im return im
# def plot_data(data,
# fig=None,
# cmap="RdBu",
# title=None,
# colorbar=False,
# coastlines=False,
# central_longitude=0,
# lon=None,
# lat=None,
# **kwargs):
# if fig == None:
# fig = plt.figure()
# nlat = data.shape[-2]
# nlon = data.shape[-1]
# if lon is None:
# lon = np.linspace(0, 2*np.pi, nlon+1)[:-1]
# if lat is None:
# lat = np.linspace(np.pi/2., -np.pi/2., nlat)
# Lon, Lat = np.meshgrid(lon, lat)
# proj = ccrs.Robinson(central_longitude=central_longitude)
# # proj = ccrs.Mollweide(central_longitude=central_longitude)
# ax = fig.add_subplot(projection=proj)
# Lon = Lon*180/np.pi
# Lat = Lat*180/np.pi
# # contour data over the map.
# im = ax.pcolormesh(Lon, Lat, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=False, **kwargs)
# if coastlines:
# ax.add_feature(cartopy.feature.COASTLINE, edgecolor='white', facecolor='none', linewidth=1.5)
# if colorbar:
# plt.colorbar(im)
# plt.title(title, y=1.05)
# return im
...@@ -121,14 +121,15 @@ class ResampleS2(nn.Module): ...@@ -121,14 +121,15 @@ class ResampleS2(nn.Module):
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}" return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}"
def _upscale_longitudes(self, x: torch.Tensor): def _upscale_longitudes(self, x: torch.Tensor):
# do the interpolation # do the interpolation in precision of x
lwgt = self.lon_weights.to(x.dtype)
if self.mode == "bilinear": if self.mode == "bilinear":
x = torch.lerp(x[..., self.lon_idx_left], x[..., self.lon_idx_right], self.lon_weights) x = torch.lerp(x[..., self.lon_idx_left], x[..., self.lon_idx_right], lwgt)
else: else:
omega = x[..., self.lon_idx_right] - x[..., self.lon_idx_left] omega = x[..., self.lon_idx_right] - x[..., self.lon_idx_left]
somega = torch.sin(omega) somega = torch.sin(omega)
start_prefac = torch.where(somega > 1e-4, torch.sin((1.0 - self.lon_weights) * omega) / somega, (1.0 - self.lon_weights)) start_prefac = torch.where(somega > 1e-4, torch.sin((1.0 - lwgt) * omega) / somega, (1.0 - lwgt))
end_prefac = torch.where(somega > 1e-4, torch.sin(self.lon_weights * omega) / somega, self.lon_weights) end_prefac = torch.where(somega > 1e-4, torch.sin(lwgt * omega) / somega, lwgt)
x = start_prefac * x[..., self.lon_idx_left] + end_prefac * x[..., self.lon_idx_right] x = start_prefac * x[..., self.lon_idx_left] + end_prefac * x[..., self.lon_idx_right]
return x return x
...@@ -142,14 +143,15 @@ class ResampleS2(nn.Module): ...@@ -142,14 +143,15 @@ class ResampleS2(nn.Module):
return x return x
def _upscale_latitudes(self, x: torch.Tensor): def _upscale_latitudes(self, x: torch.Tensor):
# do the interpolation # do the interpolation in precision of x
lwgt = self.lat_weights.to(x.dtype)
if self.mode == "bilinear": if self.mode == "bilinear":
x = torch.lerp(x[..., self.lat_idx, :], x[..., self.lat_idx + 1, :], self.lat_weights) x = torch.lerp(x[..., self.lat_idx, :], x[..., self.lat_idx + 1, :], lwgt)
else: else:
omega = x[..., self.lat_idx + 1, :] - x[..., self.lat_idx, :] omega = x[..., self.lat_idx + 1, :] - x[..., self.lat_idx, :]
somega = torch.sin(omega) somega = torch.sin(omega)
start_prefac = torch.where(somega > 1e-4, torch.sin((1.0 - self.lat_weights) * omega) / somega, (1.0 - self.lat_weights)) start_prefac = torch.where(somega > 1e-4, torch.sin((1.0 - lwgt) * omega) / somega, (1.0 - lwgt))
end_prefac = torch.where(somega > 1e-4, torch.sin(self.lat_weights * omega) / somega, self.lat_weights) end_prefac = torch.where(somega > 1e-4, torch.sin(lwgt * omega) / somega, lwgt)
x = start_prefac * x[..., self.lat_idx, :] + end_prefac * x[..., self.lat_idx + 1, :] x = start_prefac * x[..., self.lat_idx, :] + end_prefac * x[..., self.lat_idx + 1, :]
return x return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment