Commit 6a845fd3 authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

adding spherical attention

parent b3816ebc
*.DS_Store
__pycache__
*.so
checkpoints
\ No newline at end of file
......@@ -2,6 +2,9 @@ The code was authored by the following people:
Boris Bonev - NVIDIA Corporation
Thorsten Kurth - NVIDIA Corporation
Max Rietmann - NVIDIA Corporation
Andrea Paris - NVIDIA Corporation
Alberto Carpentieri - NVIDIA Corporation
Mauro Bisson - NVIDIA Corporation
Massimiliano Fatica - NVIDIA Corporation
Christian Hundt - NVIDIA Corporation
......
......@@ -2,6 +2,19 @@
## Versioning
### v0.8.0
* Adding spherical attention and spherical neighborhood attention
* Custom CUDA kerneles for spherical neighborhood attention
* New datasets for segmentation and depth estimation on the sphere based on the 2D3DS dataset
* added new spherical architectures and corresponding baselines
* S2 Transformer
* S2 Segformer
* S2 U-Net
* Reworked spherical examples for better reproducibility
* Added spherical loss functions to examples
* Added plotting module
### v0.7.6
* Adding cache for precomoputed tensors such as weight tensors for DISCO and SHT
......
......@@ -30,15 +30,14 @@
# build after cloning in directoy torch_harmonics via
# docker build . -t torch_harmonics
FROM nvcr.io/nvidia/pytorch:24.08-py3
COPY . /workspace/torch_harmonics
FROM nvcr.io/nvidia/pytorch:24.12-py3
# we need this for tests
RUN pip install parameterized
# The custom CUDA extension does not suppport architerctures < 7.0
ENV FORCE_CUDA_EXTENSION=1
ENV TORCH_CUDA_ARCH_LIST "7.0 7.2 7.5 8.0 8.6 8.7 9.0+PTX"
ENV TORCH_CUDA_ARCH_LIST="7.0 7.2 7.5 8.0 8.6 8.7 9.0+PTX"
COPY . /workspace/torch_harmonics
RUN cd /workspace/torch_harmonics && pip install --no-build-isolation .
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
......@@ -29,68 +27,34 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import torch
import torch.nn as nn
# complex activation functions
class ComplexCardioid(nn.Module):
"""
Complex Cardioid activation function
"""
def __init__(self):
super(ComplexCardioid, self).__init__()
def forward(self, z: torch.Tensor) -> torch.Tensor:
out = 0.5 * (1. + torch.cos(z.angle())) * z
return out
# build after cloning in directoy torch_harmonics via
# docker build . -t torch_harmonics
class ComplexReLU(nn.Module):
"""
Complex-valued variants of the ReLU activation function
"""
def __init__(self, negative_slope=0., mode="real", bias_shape=None, scale=1.):
super(ComplexReLU, self).__init__()
# store parameters
self.mode = mode
if self.mode in ["modulus", "halfplane"]:
if bias_shape is not None:
self.bias = nn.Parameter(scale * torch.ones(bias_shape, dtype=torch.float32))
else:
self.bias = nn.Parameter(scale * torch.ones((1), dtype=torch.float32))
else:
self.bias = 0
FROM nvcr.io/nvidia/pytorch:24.12-py3
self.negative_slope = negative_slope
self.act = nn.LeakyReLU(negative_slope = negative_slope)
# we need this for tests
RUN pip install parameterized
def forward(self, z: torch.Tensor) -> torch.Tensor:
# we install this for the examples
RUN pip install wandb
if self.mode == "cartesian":
zr = torch.view_as_real(z)
za = self.act(zr)
out = torch.view_as_complex(za)
# cartopy
RUN pip install cartopy
elif self.mode == "modulus":
zabs = torch.sqrt(torch.square(z.real) + torch.square(z.imag))
out = torch.where(zabs + self.bias > 0, (zabs + self.bias) * z / zabs, 0.0)
# h5py
RUN pip install h5py
elif self.mode == "cardioid":
out = 0.5 * (1. + torch.cos(z.angle())) * z
# natten
RUN cd /opt && git clone https://github.com/SHI-Labs/NATTEN natten && \
cd natten && \
make WITH_CUDA=1 \
CUDA_ARCH="7.0;7.2;7.5;8.0;8.6;8.7;9.0" \
WORKERS=4
# elif self.mode == "halfplane":
# # bias is an angle parameter in this case
# modified_angle = torch.angle(z) - self.bias
# condition = torch.logical_and( (0. <= modified_angle), (modified_angle < torch.pi/2.) )
# out = torch.where(condition, z, self.negative_slope * z)
# install torch harmonics
COPY . /workspace/torch_harmonics
elif self.mode == "real":
zr = torch.view_as_real(z)
outr = zr.clone()
outr[..., 0] = self.act(zr[..., 0])
out = torch.view_as_complex(outr)
else:
raise NotImplementedError
return out
\ No newline at end of file
# The custom CUDA extension does not suppport architerctures < 7.0
ENV FORCE_CUDA_EXTENSION=1
ENV TORCH_CUDA_ARCH_LIST="7.0 7.2 7.5 8.0 8.6 8.7 9.0+PTX"
RUN cd /workspace/torch_harmonics && pip install --no-build-isolation .
......@@ -56,7 +56,7 @@ The SHT algorithm uses quadrature rules to compute the projection onto the assoc
torch-harmonics uses PyTorch primitives to implement these operations, making it fully differentiable. Moreover, the quadrature can be distributed onto multiple ranks making it spatially distributed.
torch-harmonics has been used to implement a variety of differentiable PDE solvers which generated the animations below. Moreover, it has enabled the development of Spherical Fourier Neural Operators (SFNOs) [1].
torch-harmonics has been used to implement a variety of differentiable PDE solvers which generated the animations below. Moreover, it has enabled the development of Spherical Fourier Neural Operators [1].
<div align="center">
<table border="0" cellspacing="0" cellpadding="0">
......@@ -169,9 +169,13 @@ $$
Here, $x_j \in [-1,1]$ are the quadrature nodes with the respective quadrature weights $w_j$.
### Discrete-continuous convolutions
### Discrete-continuous convolutions on the sphere
torch-harmonics now provides local discrete-continuous (DISCO) convolutions as outlined in [4] on the sphere.
torch-harmonics now provides local discrete-continuous (DISCO) convolutions as outlined in [4] on the sphere. These are use in local neural operators to generalize convolutions to structured and unstructured meshes on the sphere.
### Spherical (neighborhood) attention
torch-harmonics introducers spherical attention mechanisms which correctly generalize the attention mechanism to the sphere. The use of quadrature rules makes the resulting operations approximately equivariant and equivariant in the continuous limit. Moreover, neighborhood attention is correctly generalized onto the sphere by using the geodesic distance to determine the size of the neighborhood.
## Getting started
......@@ -208,6 +212,16 @@ Detailed usage of torch-harmonics, alongside helpful analysis provided in a seri
8. [Training Spherical Fourier Neural Operators (SFNO)](./notebooks/train_sfno.ipynb)
9. [Resampling signals on the sphere](./notebooks/resample_sphere.ipynb)
## Examples and reproducibility
The `examples` folder contains training scripts for three distinct tasks:
* [solution of the shallow water equations on the rotating sphere](./examples/shallow_water_equations/train.py)
* [depth estimation on the sphere](./examples/depth/train.py)
* [semantic segmentation on the sphere](./examples/segmentation/train.py)
Results from the papers can generally be reproduced by running `python train.py`. In the case of some older results the number of epochs and learning-rate may need to be adjusted by passing the corresponding command line argument.
## Remarks on automatic mixed precision (AMP) support
Note that torch-harmonics uses Fourier transforms from `torch.fft` which in turn uses kernels from the optimized `cuFFT` library. This library supports fourier transforms of `float32` and `float64` (i.e. `single` and `double` precision) tensors for all input sizes. For `float16` (i.e. `half` precision) and `bfloat16` inputs however, the dimensions which are transformed are restricted to powers of two. Since data is converted to one of these reduced precision floating point formats when `torch.autocast` is used, torch-harmonics will issue an error when the input shapes are not powers of two. For these cases, we recommend disabling autocast for the harmonics transform specifically:
......@@ -237,7 +251,7 @@ Depending on the problem, it might be beneficial to upcast data to `float64` ins
## Contributors
[Boris Bonev](https://bonevbs.github.io) (bbonev@nvidia.com), [Thorsten Kurth](https://github.com/azrael417) (tkurth@nvidia.com), [Mauro Bisson](https://scholar.google.com/citations?hl=en&user=f0JE-0gAAAAJ) , [Massimiliano Fatica](https://scholar.google.com/citations?user=Deaq4uUAAAAJ&hl=en), [Nikola Kovachki](https://kovachki.github.io), [Jean Kossaifi](http://jeankossaifi.com), [Christian Hundt](https://github.com/gravitino)
[Boris Bonev](https://bonevbs.github.io) (bbonev@nvidia.com), [Thorsten Kurth](https://github.com/azrael417) (tkurth@nvidia.com), [Max Rietmann](https://github.com/rietmann-nv), [Mauro Bisson](https://scholar.google.com/citations?hl=en&user=f0JE-0gAAAAJ), [Andrea Paris](https://github.com/apaaris), [Alberto Carpentieri](https://github.com/albertocarpentieri), [Massimiliano Fatica](https://scholar.google.com/citations?user=Deaq4uUAAAAJ&hl=en), [Nikola Kovachki](https://kovachki.github.io), [Jean Kossaifi](http://jeankossaifi.com), [Christian Hundt](https://github.com/gravitino)
## Cite us
......
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
......@@ -29,48 +29,6 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# ignore this (just for development without installation)
import sys
sys.path.append("..")
sys.path.append(".")
import torch
import torch_harmonics as harmonics
try:
from tqdm import tqdm
except:
tqdm = lambda x : x
# everything is awesome on GPUs
device = torch.device("cuda")
# create a batch with one sample and 21 channels
b, c, n_theta, n_lambda = 1, 21, 360, 720
# your layers to play with
forward_transform = harmonics.RealSHT(n_theta, n_lambda).to(device)
inverse_transform = harmonics.InverseRealSHT(n_theta, n_lambda).to(device)
forward_transform_equi = harmonics.RealSHT(n_theta, n_lambda, grid="equiangular").to(device)
inverse_transform_equi = harmonics.InverseRealSHT(n_theta, n_lambda, grid="equiangular").to(device)
signal_leggauss = inverse_transform(torch.randn(b, c, n_theta, n_theta+1, device=device, dtype=torch.complex128))
signal_equi = inverse_transform(torch.randn(b, c, n_theta, n_theta+1, device=device, dtype=torch.complex128))
# let's check the layers
for num_iters in [1, 8, 64, 512]:
base = signal_leggauss
for iteration in tqdm(range(num_iters)):
base = inverse_transform(forward_transform(base))
print("relative l2 error accumulation on the legendre-gauss grid: ",
torch.mean(torch.norm(base-signal_leggauss, p='fro', dim=(-1,-2)) / torch.norm(signal_leggauss, p='fro', dim=(-1,-2)) ).item(),
"after", num_iters, "iterations")
# let's check the equiangular layers
for num_iters in [1, 8, 64, 512]:
base = signal_equi
for iteration in tqdm(range(num_iters)):
base = inverse_transform_equi(forward_transform_equi(base))
print("relative l2 error accumulation with interpolation onto equiangular grid: ",
torch.mean(torch.norm(base-signal_equi, p='fro', dim=(-1,-2)) / torch.norm(signal_equi, p='fro', dim=(-1,-2)) ).item(),
"after", num_iters, "iterations")
from .transformer import Transformer
from .segformer import Segformer
from .unet import UNet
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import math
import torch
import torch.nn as nn
from natten import NeighborhoodAttention2D as NeighborhoodAttention
from torch_harmonics.examples.models._layers import MLP, LayerNorm, DropPath
from functools import partial
class OverlapPatchMerging(nn.Module):
def __init__(
self,
in_shape=(721, 1440),
out_shape=(481, 960),
in_channels=3,
out_channels=64,
kernel_shape=(3, 3),
bias=False,
):
super().__init__()
# conv
stride_h = in_shape[0] // out_shape[0]
stride_w = in_shape[1] // out_shape[1]
pad_h = math.ceil(((out_shape[0] - 1) * stride_h - in_shape[0] + kernel_shape[0]) / 2)
pad_w = math.ceil(((out_shape[1] - 1) * stride_w - in_shape[1] + kernel_shape[1]) / 2)
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_shape,
bias=bias,
stride=(stride_h, stride_w),
padding=(pad_h, pad_w),
)
# layer norm
self.norm = nn.LayerNorm((out_channels), eps=1e-05, elementwise_affine=True, bias=True)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x):
x = self.conv(x)
# permute
x = x.permute(0, 2, 3, 1)
x = self.norm(x)
out = x.permute(0, 3, 1, 2)
return out
class MixFFN(nn.Module):
def __init__(
self,
shape,
inout_channels,
hidden_channels,
mlp_bias=True,
kernel_shape=(3, 3),
conv_bias=False,
activation=nn.GELU,
use_mlp=False,
drop_path=0.0,
):
super().__init__()
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm = nn.LayerNorm((inout_channels), eps=1e-05, elementwise_affine=True, bias=True)
if use_mlp:
# although the paper says MLP, it uses a single linear layer
self.mlp_in = MLP(inout_channels, hidden_features=hidden_channels, out_features=inout_channels, act_layer=activation, output_bias=False, drop_rate=0.0)
else:
self.mlp_in = nn.Conv2d(in_channels=inout_channels, out_channels=inout_channels, kernel_size=1, bias=True)
self.conv = nn.Conv2d(inout_channels, inout_channels, kernel_size=kernel_shape, groups=inout_channels, bias=conv_bias, padding="same")
if use_mlp:
self.mlp_out = MLP(inout_channels, hidden_features=hidden_channels, out_features=inout_channels, act_layer=activation, output_bias=False, drop_rate=0.0)
else:
self.mlp_out = nn.Conv2d(in_channels=inout_channels, out_channels=inout_channels, kernel_size=1, bias=True)
self.act = activation()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
# norm
x = x.permute(0, 2, 3, 1)
x = self.norm(x)
x = x.permute(0, 3, 1, 2)
# NOTE: we add another activation here
# because in the paper they only use depthwise conv,
# but without this activation it would just be a fused MM
# with the disco conv
x = self.mlp_in(x)
# conv parth
x = self.act(self.conv(x))
# second linear
x = self.mlp_out(x)
return residual + self.drop_path(x)
class GlobalAttention(nn.Module):
"""
Global self-attention block over 2D inputs using MultiheadAttention.
Input shape: (B, C, H, W)
Output shape: (B, C, H, W) with residual skip.
"""
def __init__(self, chans, num_heads=8, dropout=0.0, bias=True):
super().__init__()
self.attn = nn.MultiheadAttention(embed_dim=chans, num_heads=num_heads, dropout=dropout, batch_first=True, bias=bias)
def forward(self, x):
# x: B, C, H, W
B, H, W, C = x.shape
# flatten spatial dims
x_flat = x.view(B, H * W, C) # B, N, C
# self-attention
out, _ = self.attn(x_flat, x_flat, x_flat)
# reshape back
out = out.view(B, H, W, C)
return out
class AttentionWrapper(nn.Module):
def __init__(self, channels, shape, heads, pre_norm=False, attention_drop_rate=0.0, drop_path=0.0, attention_mode="neighborhood", kernel_shape=(7, 7), bias=True):
super().__init__()
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.attention_mode = attention_mode
if attention_mode == "neighborhood":
self.att = NeighborhoodAttention(
channels, kernel_size=kernel_shape, dilation=1, num_heads=heads, qk_scale=None, attn_drop=attention_drop_rate, proj_drop=0.0, qkv_bias=bias
)
elif attention_mode == "global":
self.att = GlobalAttention(channels, num_heads=heads, dropout=attention_drop_rate, bias=bias)
else:
raise ValueError(f"Unknown attention mode function {attention_mode}")
self.norm = None
if pre_norm:
self.norm = nn.LayerNorm((channels), eps=1e-05, elementwise_affine=True, bias=True)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = x.permute(0, 2, 3, 1)
if self.norm is not None:
x = self.norm(x)
x = self.att(x)
x = x.permute(0, 3, 1, 2)
return residual + self.drop_path(x)
class TransformerBlock(nn.Module):
def __init__(
self,
in_shape,
out_shape,
in_channels,
out_channels,
mlp_hidden_channels,
nrep=1,
heads=1,
kernel_shape=(3, 3),
activation=nn.GELU,
att_drop_rate=0.0,
drop_path_rates=0.0,
attention_mode="neighborhood",
attn_kernel_shape=(7, 7),
bias=True
):
super().__init__()
# ensure odd
if attn_kernel_shape[0] % 2 == 0:
raise ValueError(f"Attn Kernel shape {kernel_shape} is even, use odd kernel shape")
if attn_kernel_shape[1] % 2 == 0:
raise ValueError(f"Kernel shape {kernel_shape} is even, use odd kernel shape")
attn_kernel_shape = list(attn_kernel_shape)
orig_attn_kernel_shape = attn_kernel_shape.copy()
# ensure that attn kernel shape is smaller than in_shape in both dimensions
# if necessary fix kernel_shape to be 1 less (and odd) than in_shape
if attn_kernel_shape[0] >= out_shape[0]:
attn_kernel_shape[0] = out_shape[0] - 1
# ensure odd
if attn_kernel_shape[0] % 2 == 0:
attn_kernel_shape[0] -= 1
# make square if original was square
if orig_attn_kernel_shape[0] == orig_attn_kernel_shape[1]:
attn_kernel_shape[1] = attn_kernel_shape[0]
if attn_kernel_shape[1] >= out_shape[1]:
attn_kernel_shape[1] = out_shape[1] - 1
# ensure odd
if attn_kernel_shape[1] % 2 == 0:
attn_kernel_shape[1] -= 1
attn_kernel_shape = tuple(attn_kernel_shape)
self.in_shape = in_shape
self.out_shape = out_shape
self.in_channels = in_channels
self.out_channels = out_channels
if isinstance(drop_path_rates, float):
drop_path_rates = [x.item() for x in torch.linspace(0, drop_path_rates, nrep)]
assert len(drop_path_rates) == nrep
self.fwd = [
OverlapPatchMerging(
in_shape=in_shape,
out_shape=out_shape,
in_channels=in_channels,
out_channels=out_channels,
kernel_shape=kernel_shape,
bias=False,
)
]
for i in range(nrep):
self.fwd.append(
AttentionWrapper(
channels=out_channels,
shape=out_shape,
heads=heads,
pre_norm=True,
attention_drop_rate=att_drop_rate,
drop_path=drop_path_rates[i],
attention_mode=attention_mode,
kernel_shape=attn_kernel_shape,
bias=bias
)
)
self.fwd.append(
MixFFN(
out_shape,
inout_channels=out_channels,
hidden_channels=mlp_hidden_channels,
mlp_bias=True,
kernel_shape=kernel_shape,
conv_bias=False,
activation=activation,
use_mlp=False,
drop_path=drop_path_rates[i],
)
)
# make sequential
self.fwd = nn.Sequential(*self.fwd)
# final norm
self.norm = nn.LayerNorm((out_channels), eps=1e-05, elementwise_affine=True, bias=True)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fwd(x)
# apply norm
x = x.permute(0, 2, 3, 1)
x = self.norm(x)
x = x.permute(0, 3, 1, 2)
return x
class Upsampling(nn.Module):
def __init__(
self,
in_shape,
out_shape,
in_channels,
out_channels,
hidden_channels,
mlp_bias=True,
kernel_shape=(3, 3),
conv_bias=False,
activation=nn.GELU,
use_mlp=False,
):
super().__init__()
self.out_shape = out_shape
if use_mlp:
self.mlp = MLP(in_channels, hidden_features=hidden_channels, out_features=out_channels, act_layer=activation, output_bias=False, drop_rate=0.0)
else:
self.mlp = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=True)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = nn.functional.interpolate(self.mlp(x), size=self.out_shape, mode="bilinear")
return x
class Segformer(nn.Module):
"""
Spherical segformer model designed to approximate mappings from spherical signals to spherical segmentation masks
Parameters
-----------
img_shape : tuple, optional
Shape of the input channels, by default (128, 256)
kernel_shape: tuple, int
scale_factor: int, optional
Scale factor to use, by default 2
in_chans : int, optional
Number of input channels, by default 3
out_chans : int, optional
Number of classes, by default 3
embed_dims : List[int], optional
Dimension of the embeddings for each block, has to be the same length as heads
heads : List[int], optional
Number of heads for each block in the network, has to be the same length as embed_dims
depths: List[in], optional
Number of repetitions of attentions blocks and ffn mixers per layer. Has to be the same length as embed_dims and heads
activation_function : str, optional
Activation function to use, by default "gelu"
embedder_kernel_shape : int, optional
size of the encoder kernel
use_mlp : int, optional
Whether to use MLPs in the SFNO blocks, by default True
mlp_ratio : int, optional
Ratio of MLP to use, by default 2.0
drop_rate : float, optional
Dropout rate, by default 0.0
drop_path_rate : float, optional
Dropout path rate, by default 0.0
normalization_layer : str, optional
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
Example
-----------
>>> model = Segformer(
... img_size=(128, 256),
... in_chans=3,
... out_chans=3,
... embed_dims=[64, 128, 256, 512],
... heads=[1, 2, 4, 8],
... depths=[3, 4, 6, 3],
... scale_factor=2,
... activation_function="gelu",
... kernel_shape=(3, 3),
... mlp_ratio=2.0,
... att_drop_rate=0.0,
... drop_path_rate=0.1,
... attention_mode="global",
))
>>> model(torch.randn(1, 2, 128, 256)).shape
torch.Size([1, 2, 128, 256])
"""
def __init__(
self,
img_size=(128, 256),
in_chans=3,
out_chans=3,
embed_dims=[64, 128, 256, 512],
heads=[1, 2, 4, 8],
depths=[3, 4, 6, 3],
scale_factor=2,
activation_function="gelu",
kernel_shape=(3, 3),
mlp_ratio=2.0,
att_drop_rate=0.0,
drop_path_rate=0.1,
attention_mode="neighborhood",
attn_kernel_shape=(7, 7),
bias=True
):
super().__init__()
self.img_size = img_size
self.in_chans = in_chans
self.out_chans = out_chans
self.embed_dims = embed_dims
self.heads = heads
self.num_blocks = len(self.embed_dims)
self.depths = depths
self.kernel_shape = kernel_shape
assert len(self.heads) == self.num_blocks
assert len(self.depths) == self.num_blocks
# activation function
if activation_function == "relu":
self.activation_function = nn.ReLU
elif activation_function == "gelu":
self.activation_function = nn.GELU
# for debugging purposes
elif activation_function == "identity":
self.activation_function = nn.Identity
else:
raise ValueError(f"Unknown activation function {activation_function}")
# set up drop path rates
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
self.blocks = nn.ModuleList([])
out_shape = img_size
in_channels = in_chans
cur = 0
for i in range(self.num_blocks):
out_shape_new = (out_shape[0] // scale_factor, out_shape[1] // scale_factor)
out_channels = self.embed_dims[i]
self.blocks.append(
TransformerBlock(
in_shape=out_shape,
out_shape=out_shape_new,
in_channels=in_channels,
out_channels=out_channels,
mlp_hidden_channels=int(mlp_ratio * out_channels),
nrep=self.depths[i],
heads=self.heads[i],
kernel_shape=kernel_shape,
activation=self.activation_function,
att_drop_rate=att_drop_rate,
drop_path_rates=dpr[cur : cur + self.depths[i]],
attention_mode=attention_mode,
attn_kernel_shape=attn_kernel_shape,
bias=bias
)
)
cur += self.depths[i]
out_shape = out_shape_new
in_channels = out_channels
self.upsamplers = nn.ModuleList([])
out_shape = img_size
for i in range(self.num_blocks):
in_shape = self.blocks[i].out_shape
self.upsamplers.append(
Upsampling(
in_shape=in_shape,
out_shape=out_shape,
in_channels=self.embed_dims[i],
out_channels=self.embed_dims[i],
hidden_channels=int(mlp_ratio * self.embed_dims[i]),
mlp_bias=True,
kernel_shape=kernel_shape,
conv_bias=False,
activation=nn.GELU,
)
)
segmentation_head_dim = sum(self.embed_dims)
self.segmentation_head = nn.Conv2d(in_channels=segmentation_head_dim, out_channels=out_chans, kernel_size=1, bias=True)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x):
# encoder:
features = []
feat = x
for block in self.blocks:
feat = block(feat)
features.append(feat)
# perform upsample
upfeats = []
for feat, upsampler in zip(features, self.upsamplers):
upfeat = upsampler(feat)
upfeats.append(upfeat)
# perform concatenation
upfeats = torch.cat(upfeats, dim=1)
# final upsampling and prediction
out = self.segmentation_head(upfeats)
return out
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import math
import torch
import torch.nn as nn
from torch_harmonics.examples.models._layers import MLP, LayerNorm, DropPath, SequencePositionEmbedding, SpectralPositionEmbedding, LearnablePositionEmbedding
from natten import NeighborhoodAttention2D as NeighborhoodAttention
from functools import partial
class Encoder(nn.Module):
def __init__(
self,
in_shape=(721, 1440),
out_shape=(480, 960),
in_chans=2,
out_chans=2,
kernel_shape=(3, 3),
groups=1,
bias=False,
):
super().__init__()
stride_h = in_shape[0] // out_shape[0]
stride_w = in_shape[1] // out_shape[1]
pad_h = math.ceil(((out_shape[0] - 1) * stride_h - in_shape[0] + kernel_shape[0]) / 2)
pad_w = math.ceil(((out_shape[1] - 1) * stride_w - in_shape[1] + kernel_shape[1]) / 2)
self.conv = nn.Conv2d(in_chans, out_chans, kernel_size=kernel_shape, bias=bias, stride=(stride_h, stride_w), padding=(pad_h, pad_w), groups=groups)
def forward(self, x):
x = self.conv(x)
return x
class Decoder(nn.Module):
def __init__(self, in_shape=(480, 960), out_shape=(721, 1440), in_chans=2, out_chans=2, kernel_shape=(3, 3), groups=1, bias=False, upsampling_method="conv"):
super().__init__()
self.out_shape = out_shape
self.upsampling_method = upsampling_method
if upsampling_method == "conv":
self.upsample = nn.Sequential(
nn.Upsample(
size=out_shape,
mode="bilinear",
),
nn.Conv2d(in_chans, out_chans, kernel_size=kernel_shape, bias=bias, padding="same", groups=groups),
)
elif upsampling_method == "pixel_shuffle":
# check if it is possible to use PixelShuffle
if out_shape[0] // in_shape[0] != out_shape[1] // in_shape[1]:
raise Exception(f"out_shape {out_shape} and in_shape {in_shape} are incompatible for shuffle decoding")
upsampling_factor = out_shape[0] // in_shape[0]
self.upsample = nn.Sequential(
nn.Conv2d(in_chans, out_chans * (upsampling_factor**2), kernel_size=1, bias=bias, padding=0, groups=groups), nn.PixelShuffle(upsampling_factor)
)
else:
raise ValueError(f"Unknown upsampling method {upsampling_method}")
def forward(self, x):
x = self.upsample(x)
return x
class GlobalAttention(nn.Module):
"""
Global self-attention block over 2D inputs using MultiheadAttention.
Input shape: (B, C, H, W)
Output shape: (B, C, H, W) with residual skip.
"""
def __init__(self, chans, num_heads=8, dropout=0.0, bias=True):
super().__init__()
self.attn = nn.MultiheadAttention(embed_dim=chans, num_heads=num_heads, dropout=dropout, batch_first=True, bias=bias)
def forward(self, x):
# x: B, C, H, W
B, H, W, C = x.shape
# flatten spatial dims
x_flat = x.reshape(B, H * W, C) # B, N, C
# self-attention
out, _ = self.attn(x_flat, x_flat, x_flat)
# reshape back
out = out.view(B, H, W, C)
return out
class AttentionBlock(nn.Module):
"""
Neighborhood attention block based on Natten.
"""
def __init__(
self,
in_shape=(480, 960),
out_shape=(480, 960),
chans=2,
num_heads=1,
mlp_ratio=2.0,
drop_rate=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer="none",
use_mlp=True,
bias=True,
attention_mode="neighborhood",
attn_kernel_shape=(7, 7),
):
super().__init__()
# normalisation layer
if norm_layer == "layer_norm":
self.norm0 = LayerNorm(in_channels=chans, eps=1e-6)
self.norm1 = LayerNorm(in_channels=chans, eps=1e-6)
elif norm_layer == "instance_norm":
self.norm0 = nn.InstanceNorm2d(num_features=chans, eps=1e-6, affine=True, track_running_stats=False)
self.norm1 = nn.InstanceNorm2d(num_features=chans, eps=1e-6, affine=True, track_running_stats=False)
elif norm_layer == "none":
self.norm0 = nn.Identity()
self.norm1 = nn.Identity()
else:
raise NotImplementedError(f"Error, normalization {norm_layer} not implemented.")
# determine shape for neighborhood attention
if attention_mode == "neighborhood":
self.self_attn = NeighborhoodAttention(
chans,
kernel_size=attn_kernel_shape,
dilation=1,
num_heads=num_heads,
qkv_bias=bias,
qk_scale=None,
attn_drop=drop_rate,
proj_drop=drop_rate,
)
else:
self.self_attn = GlobalAttention(chans, num_heads=num_heads, dropout=drop_rate, bias=bias)
self.skip0 = nn.Identity()
# dropout
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
if use_mlp == True:
mlp_hidden_dim = int(chans * mlp_ratio)
self.mlp = MLP(
in_features=chans,
out_features=chans,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop_rate=drop_rate,
checkpointing=False,
gain=0.5,
)
self.skip1 = nn.Identity()
def forward(self, x):
residual = x
x = self.norm0(x)
x = x.permute(0, 2, 3, 1)
x = self.self_attn(x).permute(0, 3, 1, 2)
if hasattr(self, "skip0"):
x = x + self.skip0(residual)
residual = x
x = self.norm1(x)
if hasattr(self, "mlp"):
x = self.mlp(x)
x = self.drop_path(x)
if hasattr(self, "skip1"):
x = x + self.skip1(residual)
return x
class Transformer(nn.Module):
"""
Parameters
----------
img_size : tuple of int
(latitude, longitude) size of the input tensor.
scale_factor : int
Ratio for down- and up-sampling between input and internal resolution.
in_chans : int
Number of channels in the input tensor.
out_chans : int
Number of channels in the output tensor.
embed_dim : int
Embedding dimension inside attention blocks.
num_layers : int
Number of attention blocks.
activation_function : str
"relu", "gelu", or "identity" specifying the activation.
encoder_kernel_shape : tuple of int
Kernel size for the encoder convolution.
num_heads : int
Number of heads in NeighborhoodAttention.
use_mlp : bool
If True, an MLP follows attention in each block.
mlp_ratio : float
Ratio of MLP hidden dim to input dim.
drop_rate : float
Dropout rate before positional embedding.
drop_path_rate : float
Stochastic depth rate across transformer blocks.
normalization_layer : str
"layer_norm", "instance_norm", or "none".
residual_prediction : bool
If True, add the input as a global skip connection.
pos_embed : str
"sequence", "spectral", "learnable lat", "learnable latlon", or "none".
bias : bool
Whether convolution and attention projections include bias.
attention_mode: str
"neighborhood" or "global"
upsampling_method: str
"conv" or "pixel_shuffle"
attn_kernel_shape: tuple
Example
-------
>>> model = Transformer(
... img_size=(128, 256),
... scale_factor=2,
... in_chans=3,
... out_chans=3,
... embed_dim=256,
... num_layers=4,
... activation_function="gelu",
... encoder_kernel_shape=(3, 3),
... num_heads=1,
... use_mlp=True,
... mlp_ratio=2.0,
... drop_rate=0.0,
... drop_path_rate=0.0,
... normalization_layer="none",
... residual_prediction=False,
... pos_embed="spectral",
... bias=True,
... attention_mode="neighborhood",
... attn_kernel_shape=(7,7),
... upsampling_method="conv"
... )
>>> x = torch.randn(1, 3, 128, 256)
>>> print(model(x).shape)
torch.Size([1, 3, 128, 256])
"""
def __init__(
self,
img_size=(128, 256),
grid_internal="legendre-gauss",
scale_factor=3,
in_chans=3,
out_chans=3,
embed_dim=256,
num_layers=4,
activation_function="gelu",
encoder_kernel_shape=(3, 3),
num_heads=1,
use_mlp=True,
mlp_ratio=2.0,
drop_rate=0.0,
drop_path_rate=0.0,
normalization_layer="none",
residual_prediction=False,
pos_embed="spectral",
bias=True,
attention_mode="neighborhood",
attn_kernel_shape=(7, 7),
upsampling_method="conv",
):
super().__init__()
self.img_size = img_size
self.scale_factor = scale_factor
self.in_chans = in_chans
self.out_chans = out_chans
self.embed_dim = embed_dim
self.num_layers = num_layers
self.encoder_kernel_shape = encoder_kernel_shape
self.normalization_layer = normalization_layer
self.use_mlp = use_mlp
self.residual_prediction = residual_prediction
# activation function
if activation_function == "relu":
self.activation_function = nn.ReLU
elif activation_function == "gelu":
self.activation_function = nn.GELU
# for debugging purposes
elif activation_function == "identity":
self.activation_function = nn.Identity
else:
raise ValueError(f"Unknown activation function {activation_function}")
# compute downsampled image size. We assume that the latitude-grid includes both poles
self.h = (self.img_size[0] - 1) // scale_factor + 1
self.w = self.img_size[1] // scale_factor
# dropout
self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0.0 else nn.Identity()
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_layers)]
if pos_embed == "sequence":
self.pos_embed = SequencePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal)
elif pos_embed == "spectral":
self.pos_embed = SpectralPositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal)
elif pos_embed == "learnable lat":
self.pos_embed = LearnablePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal, embed_type="lat")
elif pos_embed == "learnable latlon":
self.pos_embed = LearnablePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal, embed_type="latlon")
elif pos_embed == "none":
self.pos_embed = nn.Identity()
else:
raise ValueError(f"Unknown position embedding type {pos_embed}")
# maybe keep for now becuase tr
# encoder
self.encoder = Encoder(
in_shape=self.img_size,
out_shape=(self.h, self.w),
in_chans=self.in_chans,
out_chans=self.embed_dim,
kernel_shape=self.encoder_kernel_shape,
groups=1,
bias=False,
)
self.blocks = nn.ModuleList([])
for i in range(self.num_layers):
block = AttentionBlock(
in_shape=(self.h, self.w),
out_shape=(self.h, self.w),
chans=self.embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
drop_rate=drop_rate,
drop_path=dpr[i],
act_layer=self.activation_function,
norm_layer=self.normalization_layer,
use_mlp=use_mlp,
bias=bias,
attention_mode=attention_mode,
attn_kernel_shape=attn_kernel_shape,
)
self.blocks.append(block)
# decoder
self.decoder = Decoder(
in_shape=(self.h, self.w),
out_shape=self.img_size,
in_chans=self.embed_dim,
out_chans=self.out_chans,
kernel_shape=self.encoder_kernel_shape,
groups=1,
bias=False,
upsampling_method=upsampling_method,
)
@torch.jit.ignore
def no_weight_decay(self):
return {"pos_embed", "cls_token"}
def forward_features(self, x):
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
return x
def forward(self, x):
if self.residual_prediction:
residual = x
x = self.encoder(x)
if self.pos_embed is not None:
# x = x + self.pos_embed
x = self.pos_embed(x)
x = self.forward_features(x)
x = self.decoder(x)
if self.residual_prediction:
x = x + residual
return x
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.amp as amp
from torch_harmonics.examples.models._layers import MLP, DropPath
from functools import partial
class DownsamplingBlock(nn.Module):
def __init__(
self,
in_shape,
out_shape,
in_channels,
out_channels,
nrep=1,
kernel_shape=(3, 3),
activation=nn.ReLU,
transform_skip=False,
drop_conv_rate=0.,
drop_path_rate=0.,
drop_dense_rate=0.,
downsampling_mode="bilinear",
):
super().__init__()
self.in_shape = in_shape
self.out_shape = out_shape
self.in_channels = in_channels
self.out_channels = out_channels
self.downsampling_mode = downsampling_mode
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.fwd =[]
for i in range(nrep):
# conv
self.fwd.append(
nn.Conv2d(
in_channels=(in_channels if i==0 else out_channels),
out_channels=out_channels,
kernel_size=kernel_shape,
bias=False,
padding="same"
)
)
if drop_conv_rate > 0.:
self.fwd.append(
nn.Dropout2d(
p=drop_conv_rate
)
)
# batchnorm
self.fwd.append(
nn.BatchNorm2d(out_channels,
eps=1e-05,
momentum=0.1,
affine=True,
track_running_stats=True)
)
# activation
self.fwd.append(
activation(),
)
if downsampling_mode == "conv":
stride_h = in_shape[0] // out_shape[0]
stride_w = in_shape[1] // out_shape[1]
pad_h = math.ceil(((out_shape[0] - 1) * stride_h
- in_shape[0]
+ kernel_shape[0]) / 2)
pad_w = math.ceil(((out_shape[1] - 1) * stride_w
- in_shape[1]
+ kernel_shape[1]) / 2)
self.downsample = nn.Conv2d(
in_channels=(in_channels if i==0 else out_channels),
out_channels=out_channels,
kernel_size=kernel_shape,
bias=False,
stride=(stride_h, stride_w),
padding=(pad_h, pad_w)
)
else:
self.downsample = nn.Identity()
# make sequential
self.fwd = nn.Sequential(*self.fwd)
# final norm
if transform_skip or (in_channels != out_channels):
self.transform_skip = nn.Conv2d(in_channels,
out_channels,
kernel_size=1,
bias=True)
if drop_dense_rate >0.:
self.transform_skip = nn.Sequential(
self.transform_skip,
nn.Dropout2d(p=drop_dense_rate),
)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# skip connection
residual = x
if hasattr(self, "transform_skip"):
residual = self.transform_skip(residual)
# main path
x = self.fwd(x)
# add residual connection
x = residual + self.drop_path(x)
# downsample
x = self.downsample(x)
if self.downsampling_mode == "bilinear":
x = F.interpolate(x, size=self.out_shape, mode="bilinear")
return x
class UpsamplingBlock(nn.Module):
def __init__(
self,
in_shape,
out_shape,
in_channels,
out_channels,
nrep=1,
kernel_shape=(3, 3),
activation=nn.ReLU,
transform_skip=False,
drop_conv_rate=0.,
drop_path_rate=0.,
drop_dense_rate=0.,
upsampling_mode="bilinear",
):
super().__init__()
self.in_shape = in_shape
self.out_shape = out_shape
self.in_channels = in_channels
self.out_channels = out_channels
self.upsampling_mode = upsampling_mode
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
if in_shape != out_shape:
if upsampling_mode == "conv":
stride_h = out_shape[0] // in_shape[0]
stride_w = out_shape[1] // in_shape[1]
pad_h = math.ceil(((in_shape[0] - 1) * stride_h
- in_shape[0]
+ kernel_shape[0]) / 2)
pad_w = math.ceil(((in_shape[1] - 1) * stride_w
- in_shape[1]
+ kernel_shape[1]) / 2)
self.upsample = nn.Sequential(
nn.ConvTranspose2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=kernel_shape,
stride=(stride_h, stride_w),
padding=(pad_h, pad_w)
),
nn.BatchNorm2d(out_channels,
eps=1e-05,
momentum=0.1,
affine=True,
track_running_stats=True),
activation(),
nn.Conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=kernel_shape,
bias=False,
padding="same")
)
self.fwd =[]
for i in range(nrep):
# conv
self.fwd.append(
nn.Conv2d(
in_channels=(in_channels if i == 0 else out_channels),
out_channels=out_channels,
kernel_size=kernel_shape,
bias=False,
padding="same")
)
if drop_conv_rate > 0.:
self.fwd.append(
nn.Dropout2d(
p=drop_conv_rate
)
)
# batchnorm
self.fwd.append(
nn.BatchNorm2d((out_channels if i==nrep-1 else in_channels),
eps=1e-05,
momentum=0.1,
affine=True,
track_running_stats=True)
)
# activation
self.fwd.append(
activation(),
)
# make sequential
self.fwd = nn.Sequential(*self.fwd)
# final norm
if transform_skip or (in_channels != out_channels):
self.transform_skip = nn.Conv2d(in_channels,
out_channels,
kernel_size=1,
bias=True)
if drop_dense_rate >0.:
self.transform_skip = nn.Sequential(
self.transform_skip,
nn.Dropout2d(p=drop_dense_rate),
)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# skip connection
residual = x
if hasattr(self, "transform_skip"):
residual = self.transform_skip(residual)
# main path
x = residual + self.drop_path(self.fwd(x))
# upsampling
if self.upsampling_mode=="bilinear":
x = F.interpolate(x, size=self.out_shape, mode="bilinear")
else:
x = self.upsample(x)
return x
class UNet(nn.Module):
"""
Spherical segformer model designed to approximate mappings from spherical signals to spherical segmentation masks
Parameters
-----------
img_shape : tuple, optional
Shape of the input channels, by default (128, 256)
kernel_shape: tuple, int
scale_factor: int, optional
Scale factor to use, by default 2
in_chans : int, optional
Number of input channels, by default 3
num_classes : int, optional
Number of classes, by default 3
embed_dims : List[int], optional
Dimension of the embeddings for each block, has to be the same length as depths
depths: List[in], optional
Number of repetitions of conv blocks and ffn mixers per layer. Has to be the same length as embed_dims
activation_function : str, optional
Activation function to use, by default "relu"
embedder_kernel_shape : int, optional
size of the encoder kernel
use_mlp : int, optional
Whether to use MLPs in the SFNO blocks, by default True
mlp_ratio : int, optional
Ratio of MLP to use, by default 2.0
drop_rate : float, optional
Dropout rate, by default 0.0
drop_path_rate : float, optional
Dropout path rate, by default 0.0
normalization_layer : str, optional
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
Example
-----------
>>> model = UNet(
... img_shape=(128, 256),
... scale_factor=4,
... in_chans=2,
... num_classes=2,
... embed_dims=[64, 128, 256, 512],)
>>> model(torch.randn(1, 2, 128, 256)).shape
torch.Size([1, 2, 128, 256])
"""
def __init__(
self,
img_shape=(128, 256),
in_chans=3,
num_classes=3,
embed_dims=[64, 128, 256, 512],
depths=[2, 2, 2, 2],
scale_factor=2,
activation_function="relu",
kernel_shape=(3, 3),
transform_skip=False,
drop_conv_rate=0.1,
drop_path_rate=0.1,
drop_dense_rate=0.5,
downsampling_mode="bilinear",
upsampling_mode="bilinear",
):
super().__init__()
self.img_shape = img_shape
self.in_chans = in_chans
self.num_classes = num_classes
self.embed_dims = embed_dims
self.num_blocks = len(self.embed_dims)
self.depths = depths
self.kernel_shape = kernel_shape
assert(len(self.depths) == self.num_blocks)
# activation function
if activation_function == "relu":
self.activation_function = nn.ReLU
elif activation_function == "gelu":
self.activation_function = nn.GELU
# for debugging purposes
elif activation_function == "identity":
self.activation_function = nn.Identity
else:
raise ValueError(f"Unknown activation function {activation_function}")
# set up drop path rates
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_blocks)]
self.dblocks = nn.ModuleList([])
out_shape = img_shape
in_channels = in_chans
for i in range(self.num_blocks):
out_shape_new = (out_shape[0] // scale_factor, out_shape[1] // scale_factor)
out_channels = self.embed_dims[i]
self.dblocks.append(
DownsamplingBlock(
in_shape=out_shape,
out_shape=out_shape_new,
in_channels=in_channels,
out_channels=out_channels,
nrep=self.depths[i],
kernel_shape=kernel_shape,
activation=self.activation_function,
drop_conv_rate=drop_conv_rate,
drop_path_rate=dpr[i],
drop_dense_rate=drop_dense_rate,
transform_skip=transform_skip,
downsampling_mode=downsampling_mode,
)
)
out_shape = out_shape_new
in_channels = out_channels
self.ublocks = nn.ModuleList([])
for i in range(self.num_blocks-1, -1, -1):
in_shape = self.dblocks[i].out_shape
out_shape = self.dblocks[i].in_shape
in_channels = self.dblocks[i].out_channels
if i != self.num_blocks-1:
in_channels = 2 * in_channels
out_channels = self.dblocks[i].in_channels
if i==0:
out_channels = self.embed_dims[0]
self.ublocks.append(
UpsamplingBlock(
in_shape=in_shape,
out_shape=out_shape,
in_channels=in_channels,
out_channels=out_channels,
kernel_shape=kernel_shape,
activation=self.activation_function,
drop_conv_rate=drop_conv_rate,
drop_path_rate=0.,
drop_dense_rate=drop_dense_rate,
transform_skip=transform_skip,
upsampling_mode=upsampling_mode,
)
)
self.head = nn.Conv2d(self.embed_dims[0], self.num_classes, kernel_size=1, bias=True)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x):
# encoder:
features = []
feat = x
for dblock in self.dblocks:
feat = dblock(feat)
features.append(feat)
# reverse list
features = features[::-1]
# perform upsample
ufeat = self.ublocks[0](features[0])
for feat, ublock in zip(features[1:], self.ublocks[1:]):
ufeat = ublock(torch.cat([feat, ufeat], dim=1))
# last layer
out = self.head(ufeat)
return out
if __name__ == "__main__":
model = UNet(
img_shape=(128, 256),
scale_factor=2,
in_chans=2,
embed_dims=[64, 128, 256],
depths=[2, 2, 2])
print(model)
print(model(torch.randn(1, 2, 128, 256)).shape)
\ No newline at end of file
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import os, sys
import time
import argparse
from functools import partial
import torch
from torch.utils.data import DataLoader
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from torchvision.transforms import v2
import pandas as pd
import matplotlib.pyplot as plt
from torch_harmonics.examples import (
StanfordDepthDataset,
Stanford2D3DSDownloader,
compute_stats_s2,
)
from torch_harmonics.examples.losses import W11LossS2, L1LossS2, L2LossS2, NormalLossS2
from torch_harmonics.plotting import plot_sphere, imshow_sphere
# import baseline models
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from model_registry import get_baseline_models
# wandb logging
import wandb
# helper routine for counting number of paramerters in model
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# convenience function for logging weights and gradients
def log_weights_and_grads(exp_dir, model, iters=1):
"""
Helper routine intended for debugging purposes
"""
log_path = os.path.join(exp_dir, "weights_and_grads")
if not os.path.isdir(log_path):
os.makedirs(log_path, exist_ok=True)
weights_and_grads_fname = os.path.join(log_path, f"weights_and_grads_step{iters:03d}.tar")
print(weights_and_grads_fname)
weights_dict = {k: v for k, v in model.named_parameters()}
grad_dict = {k: v.grad for k, v in model.named_parameters()}
store_dict = {"iteration": iters, "grads": grad_dict, "weights": weights_dict}
torch.save(store_dict, weights_and_grads_fname)
# rolls out the FNO and compares to the classical solver
def validate_model(
model,
dataloader,
loss_fn,
metrics_fns,
path_root,
normalization_in=None,
normalization_out=None,
logging=True,
device=torch.device("cpu"),
):
model.eval()
num_examples = len(dataloader)
# make output
if logging and not os.path.isdir(path_root):
os.makedirs(path_root, exist_ok=True)
if dist.is_initialized():
dist.barrier(device_ids=[device.index])
losses = torch.zeros(num_examples, dtype=torch.float32, device=device)
metrics = {}
for metric in metrics_fns:
metrics[metric] = torch.zeros(num_examples, dtype=torch.float32, device=device)
glob_off = 0
if dist.is_initialized():
glob_off = num_examples * dist.get_rank()
with torch.no_grad():
for idx, (inp, tar) in enumerate(dataloader):
inpd = inp.to(device)
tar = tar.to(device)
mask = torch.where(tar == 0, 0.0, 1.0)
if normalization_in is not None:
inpd = normalization_in(inpd)
if normalization_out is not None:
tar = normalization_out(tar)
prd = model(inpd)
losses[idx] = loss_fn(prd, tar.unsqueeze(-3), mask)
for metric in metrics_fns:
metric_buff = metrics[metric]
metric_fn = metrics_fns[metric]
metric_buff[idx] = metric_fn(prd, tar.unsqueeze(-3), mask)
tar = (tar * mask).squeeze()
prd = (prd * mask).squeeze()
# get the minimum
vmin = min(tar.min(), prd.min())
vmax = min(tar.max(), prd.max())
# do plotting
glob_idx = idx + glob_off
fig = plt.figure(figsize=(7.5, 6))
plot_sphere(prd.cpu(), fig=fig, vmax=vmax, vmin=vmin, cmap="plasma")
plt.savefig(os.path.join(path_root, "pred_" + str(glob_idx) + ".png"))
plt.close()
fig = plt.figure(figsize=(7.5, 6))
plot_sphere(tar.cpu(), fig=fig, vmax=vmax, vmin=vmin, cmap="plasma")
plt.savefig(os.path.join(path_root, "truth_" + str(glob_idx) + ".png"))
plt.close()
fig = plt.figure(figsize=(7.5, 6))
imshow_sphere(inp.cpu().squeeze(0).permute(1, 2, 0), fig=fig)
plt.savefig(os.path.join(path_root, "input_" + str(glob_idx) + ".png"))
plt.close()
return losses, metrics
# training function
def train_model(
model,
train_dataloader,
train_sampler,
test_dataloader,
test_sampler,
loss_fn,
metrics_fns,
optimizer,
gscaler,
scheduler=None,
normalization_in=None,
normalization_out=None,
augmentation=False,
nepochs=20,
amp_mode="none",
log_grads=0,
exp_dir=None,
logging=True,
device=torch.device("cpu"),
):
train_start = time.time()
# set AMP type
amp_dtype = torch.float32
if amp_mode == "fp16":
amp_dtype = torch.float16
elif amp_mode == "bf16":
amp_dtype = torch.bfloat16
# count iterations
iters = 0
for epoch in range(nepochs):
# time each epoch
epoch_start = time.time()
# do the training
accumulated_loss = torch.zeros(2, dtype=torch.float32, device=device)
model.train()
if dist.is_initialized():
train_sampler.set_epoch(epoch)
for inp, tar in train_dataloader:
inp = inp.to(device)
tar = tar.to(device)
mask = torch.where(tar == 0, 0.0, 1.0)
if normalization_in is not None:
inp = normalization_in(inp)
if normalization_out is not None:
tar = normalization_out(tar)
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=(amp_mode != "none")):
prd = model(inp)
loss = loss_fn(prd, tar.unsqueeze(-3), mask)
optimizer.zero_grad(set_to_none=True)
gscaler.scale(loss).backward()
if log_grads and (iters % log_grads == 0) and (exp_dir is not None):
log_weights_and_grads(exp_dir, model, iters=iters)
gscaler.step(optimizer)
gscaler.update()
# accumulate loss
accumulated_loss[0] += loss.detach() * inp.size(0)
accumulated_loss[1] += inp.size(0)
iters += 1
if dist.is_initialized():
dist.all_reduce(accumulated_loss)
accumulated_loss = (accumulated_loss[0] / accumulated_loss[1]).item()
# perform validation
valid_loss = torch.zeros(2, dtype=torch.float32, device=device)
# prepare metrics buffer for accumulation of validation metrics
valid_metrics = {}
for metric in metrics_fns:
valid_metrics[metric] = torch.zeros(1, dtype=torch.float32, device=device)
model.eval()
if dist.is_initialized():
test_sampler.set_epoch(epoch)
with torch.no_grad():
for inp, tar in test_dataloader:
inp = inp.to(device)
tar = tar.to(device)
mask = torch.where(tar == 0, 0.0, 1.0)
if normalization_in is not None:
inp = normalization_in(inp)
if normalization_out is not None:
tar = normalization_out(tar)
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=(amp_mode != "none")):
prd = model(inp)
loss = loss_fn(prd, tar.unsqueeze(-3), mask)
valid_loss[0] += loss * inp.size(0)
valid_loss[1] += inp.size(0)
for metric in metrics_fns:
metric_buff = valid_metrics[metric]
metric_fn = metrics_fns[metric]
metric_buff[0] += metric_fn(prd, tar, mask) * inp.size(0)
if dist.is_initialized():
dist.all_reduce(valid_loss)
for metric in metrics_fns:
dist.all_reduce(valid_metrics[metric])
valid_loss = (valid_loss[0] / valid_loss[1]).item()
for metric in valid_metrics:
valid_metrics[metric] = (valid_metrics[metric][0] / valid_loss[1]).item()
if scheduler is not None:
scheduler.step(valid_loss)
epoch_time = time.time() - epoch_start
if logging:
print(f"--------------------------------------------------------------------------------")
print(f"Epoch {epoch} summary:")
print(f"time taken: {epoch_time:.2f}")
print(f"accumulated training loss: {accumulated_loss}")
print(f"relative validation loss: {valid_loss}")
for metric in valid_metrics:
print(f"{metric}: {valid_metrics[metric]}")
if wandb.run is not None:
current_lr = optimizer.param_groups[0]["lr"]
log_dict = {"loss": accumulated_loss, "validation loss": valid_loss, "learning rate": current_lr}
for metric in valid_metrics:
log_dict[metric] = valid_metrics[metric]
wandb.log(log_dict)
# wrapping up
train_time = time.time() - train_start
if logging:
print(f"--------------------------------------------------------------------------------")
print(f"done. Training took {train_time:.2f}.")
return valid_loss
def main(
root_path,
num_epochs=100,
batch_size=8,
learning_rate=1e-4,
train=True,
load_checkpoint=False,
amp_mode="none",
ddp=False,
enable_data_augmentation=False,
ignore_alpha_channel=True,
log_grads=0,
data_path="data",
data_downsampling_factor=16,
exclude_polar_fraction=0.15,
):
# initialize distributed
local_rank = 0
logging = True
if ddp:
dist.init_process_group(backend="nccl")
world_size = dist.get_world_size()
local_rank = dist.get_rank() % torch.cuda.device_count()
logging = dist.get_rank() == 0
# set seed
torch.manual_seed(333)
torch.cuda.manual_seed(333)
# set device
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
torch.cuda.set_device(device.index)
# create dataset directory if it doesn't exist
if logging:
os.makedirs(data_path, exist_ok=True)
# 2D3DS download & dataset initialization
downloader = Stanford2D3DSDownloader(base_url="https://cvg-data.inf.ethz.ch/2d3ds/no_xyz/", local_dir=str(data_path))
dataset_file = downloader.prepare_dataset(dataset_file=f"stanford_2d3ds_dataset_ds{data_downsampling_factor}.h5", downsampling_factor=data_downsampling_factor)
# intiialize distributed for ddp
if dist.is_initialized():
dist.barrier(device_ids=[device.index])
# create the dataset and split it
if logging:
print(f"Initializing dataset...")
# make sure splitting is consistent across ranks
rng = torch.Generator().manual_seed(333)
split_ratios = [0.95, 0.025, 0.025]
dataset = StanfordDepthDataset(dataset_file=dataset_file, ignore_alpha_channel=ignore_alpha_channel, log_depth=False, exclude_polar_fraction=exclude_polar_fraction)
train_dataset, test_dataset, valid_dataset = torch.utils.data.random_split(dataset, split_ratios, generator=rng)
# stats computation
means_in, stds_in, means_out, stds_out = compute_stats_s2(train_dataset.dataset, normalize_target=True)
train_dataset.dataset.reset()
if logging:
print(f"Computed stats:")
print(f"means_in={means_in}")
print(f"stds_in={stds_in}")
print(f"means_out={means_out}")
print(f"stds_out={stds_out}")
# split dataset if distributed
if dist.is_initialized():
train_sampler = DistributedSampler(train_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True, drop_last=True)
test_sampler = DistributedSampler(test_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False, drop_last=True)
valid_sampler = DistributedSampler(valid_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False, drop_last=True)
else:
train_sampler = None
test_sampler = None
valid_sampler = None
# create the dataloaders
train_dataloader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True if train_sampler is None else False, sampler=train_sampler, num_workers=4, persistent_workers=True, pin_memory=True
)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, sampler=test_sampler, num_workers=4, persistent_workers=True, pin_memory=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=1, shuffle=False, sampler=valid_sampler, num_workers=0, persistent_workers=False, pin_memory=True)
# TODO: move augmentation into extra helper module
normalization_in = v2.Normalize(mean=means_in.tolist(), std=stds_in.tolist())
normalization_out = v2.Normalize(mean=means_out.tolist(), std=stds_out.tolist())
augmentation = enable_data_augmentation
in_channels = 3 if ignore_alpha_channel else 4
out_channels = 1
# print dataset info
img_size = dataset.input_shape[1:]
if logging:
print(f"Train dataset initialized with {len(train_dataset)} samples of resolution {img_size}")
print(f"Test dataset initialized with {len(test_dataset)} samples of resolution {img_size}")
print(f"Validation dataset initialized with {len(valid_dataset)} samples of resolution {img_size}")
# get baseline model registry
baseline_models = get_baseline_models(img_size=img_size, in_chans=in_channels, out_chans=out_channels)
# specify which models to train here
models = [
"transformer_sc2_layers4_e128",
"s2transformer_sc2_layers4_e128",
"ntransformer_sc2_layers4_e128",
"s2ntransformer_sc2_layers4_e128",
"segformer_sc2_layers4_e128",
"s2segformer_sc2_layers4_e128",
"nsegformer_sc2_layers4_e128",
"s2nsegformer_sc2_layers4_e128",
"sfno_sc2_layers4_e32",
"lsno_sc2_layers4_e32",
]
models = {k: baseline_models[k] for k in models}
# initialize Sobolev W11 loss function
loss_w11 = W11LossS2(nlat=img_size[0], nlon=img_size[1], grid="equiangular").to(device=device)
loss_l1 = L1LossS2(nlat=img_size[0], nlon=img_size[1], grid="equiangular").to(device=device)
loss_fn = lambda prd, tar, mask: 0.1 * loss_w11(prd, tar, mask) + loss_l1(prd, tar, mask)
# metrics
metrics = {}
metrics_fns = {
"L2 error": L2LossS2(nlat=img_size[0], nlon=img_size[1], grid="equiangular").to(device=device),
"L1 error": L1LossS2(nlat=img_size[0], nlon=img_size[1], grid="equiangular").to(device=device),
"W11 error": W11LossS2(nlat=img_size[0], nlon=img_size[1], grid="equiangular").to(device=device),
"Normals error": NormalLossS2(nlat=img_size[0], nlon=img_size[1], grid="equiangular").to(device=device),
}
# iterate over models and train each model
for model_name, model_handle in models.items():
model = model_handle().to(device)
if logging:
print(model)
if dist.is_initialized():
model = DDP(model, device_ids=[device.index])
metrics[model_name] = {}
num_params = count_parameters(model)
if logging:
print(f"number of trainable params: {num_params}")
metrics[model_name]["num_params"] = num_params
exp_dir = os.path.join(root_path, model_name)
if not os.path.isdir(exp_dir):
os.makedirs(exp_dir, exist_ok=True)
if load_checkpoint:
model.load_state_dict(torch.load(os.path.join(exp_dir, "checkpoint.pt")))
# run the training
if train:
if logging:
run = wandb.init(project="depth estimation 2d3ds", group=model_name, name=model_name + "_" + str(time.time()), config=model_handle.keywords)
else:
run = None
# optimizer:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01, foreach=torch.cuda.is_available())
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min")
gscaler = torch.GradScaler("cuda", enabled=(amp_mode == "fp16"))
start_time = time.time()
if logging:
print(f"Training {model_name}")
train_model(
model,
train_dataloader,
train_sampler,
test_dataloader,
test_sampler,
loss_fn,
metrics_fns,
optimizer,
gscaler,
scheduler,
normalization_in=normalization_in,
normalization_out=normalization_out,
augmentation=None,
nepochs=num_epochs,
amp_mode=amp_mode,
log_grads=log_grads,
exp_dir=exp_dir,
logging=logging,
device=device,
)
training_time = time.time() - start_time
if logging:
run.finish()
torch.save(model.state_dict(), os.path.join(exp_dir, "checkpoint.pt"))
# set seed
torch.manual_seed(333)
torch.cuda.manual_seed(333)
with torch.inference_mode():
# run the validation
losses, metric_results = validate_model(
model,
valid_dataloader,
loss_fn,
metrics_fns,
os.path.join(exp_dir, "figures"),
normalization_in=normalization_in,
normalization_out=normalization_out,
logging=logging,
device=device,
)
# gather losses and metrics into a single tensor
if dist.is_initialized():
losses_dist = torch.zeros(world_size * losses.shape[0], dtype=losses.dtype, device=device)
dist.all_gather_into_tensor(losses_dist, losses)
losses = losses_dist
for metric_name, metric in metric_results.items():
metric_dist = torch.zeros(world_size * metric.shape[0], dtype=metric.dtype, device=device)
dist.all_gather_into_tensor(metric_dist, metric)
metric_results[metric_name] = metric_dist
# compute statistics
metrics[model_name]["loss mean"] = torch.mean(losses).item()
metrics[model_name]["loss std"] = torch.std(losses).item()
for metric in metric_results:
metrics[model_name][metric + " mean"] = torch.mean(metric_results[metric]).item()
metrics[model_name][metric + " std"] = torch.std(metric_results[metric]).item()
if train:
metrics[model_name]["training_time"] = training_time
if logging:
df = pd.DataFrame(metrics)
if not os.path.isdir(os.path.join(root_path, "output_data")):
os.makedirs(os.path.join(root_path, "output_data"), exist_ok=True)
df.to_pickle(os.path.join(root_path, "output_data", "metrics.pkl"))
if dist.is_initialized():
dist.barrier(device_ids=[device.index])
if __name__ == "__main__":
import torch.multiprocessing as mp
mp.set_start_method("forkserver", force=True)
wandb.login()
parser = argparse.ArgumentParser()
parser.add_argument(
"--output_path", default=os.path.join(os.path.dirname(__file__), "checkpoints"), type=str, help="Override the path where checkpoints and run information are stored"
)
parser.add_argument(
"--data_path",
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), "data", "2D3DS"),
type=str,
help="Directory to where the dataset is stored. If the dataset is not found in that location, it will be downloaded automatically.",
)
parser.add_argument("--num_epochs", default=100, type=int, help="Switch for overriding batch size in the configuration file.")
parser.add_argument("--batch_size", default=8, type=int, help="Switch for overriding batch size in the configuration file.")
parser.add_argument("--data_downsampling_factor", default=16, type=int, help="Switch for overriding the downsampling factor of the data.")
parser.add_argument("--learning_rate", default=1e-3, type=float, help="Switch to override learning rate.")
parser.add_argument("--resume", action="store_true", help="Reload checkpoints.")
parser.add_argument("--amp_mode", default="none", type=str, choices=["none", "bf16", "fp16"], help="Switch to enable AMP.")
parser.add_argument("--enable_ddp", action="store_true", help="Switch to enable distributed data parallel.")
parser.add_argument("--enable_data_augmentation", action="store_true", help="Switch to enable data augmentation.")
args = parser.parse_args()
main(
root_path=args.output_path,
num_epochs=args.num_epochs,
batch_size=args.batch_size,
learning_rate=args.learning_rate,
train=args.num_epochs > 0,
load_checkpoint=args.resume,
amp_mode=args.amp_mode,
ddp=args.enable_ddp,
enable_data_augmentation=args.enable_data_augmentation,
ignore_alpha_channel=True,
log_grads=0,
data_path=args.data_path,
data_downsampling_factor=args.data_downsampling_factor,
)
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import os, sys
from functools import partial
# import baseline models
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from baseline_models import Transformer, UNet, Segformer
from torch_harmonics.examples.models import SphericalFourierNeuralOperator, LocalSphericalNeuralOperator, SphericalTransformer, SphericalUNet, SphericalSegformer
def get_baseline_models(img_size=(128, 256), in_chans=3, out_chans=3, residual_prediction=False, drop_path_rate=0., grid="equiangular"):
# prepare dicts containing models and corresponding metrics
model_registry = dict(
sfno_sc2_layers4_e32 = partial(
SphericalFourierNeuralOperator,
img_size=img_size,
grid=grid,
in_chans=in_chans,
out_chans=out_chans,
num_layers=4,
scale_factor=2,
embed_dim=32,
activation_function="gelu",
residual_prediction=residual_prediction,
use_mlp=True,
normalization_layer="instance_norm",
),
lsno_sc2_layers4_e32 = partial(
LocalSphericalNeuralOperator,
img_size=img_size,
grid=grid,
in_chans=in_chans,
out_chans=out_chans,
num_layers=4,
scale_factor=2,
embed_dim=32,
activation_function="gelu",
residual_prediction=residual_prediction,
use_mlp=True,
normalization_layer="instance_norm",
kernel_shape=(5, 4),
encoder_kernel_shape=(5, 4),
filter_basis_type="piecewise linear",
upsample_sht=False,
),
s2unet_sc2_layers4_e128 = partial(
SphericalUNet,
img_size=img_size,
grid=grid,
grid_internal="equiangular",
in_chans=in_chans,
out_chans=out_chans,
embed_dims=[16, 32, 64, 128],
depths=[2, 2, 2, 2],
scale_factor=2,
activation_function="gelu",
kernel_shape=(5, 4),
filter_basis_type="piecewise linear",
drop_path_rate=0.1,
drop_conv_rate=0.2,
drop_dense_rate=0.5,
transform_skip=False,
upsampling_mode="conv",
downsampling_mode="conv",
),
s2transformer_sc2_layers4_e128 = partial(
SphericalTransformer,
img_size=img_size,
grid=grid,
in_chans=in_chans,
out_chans=out_chans,
num_layers=4,
scale_factor=2,
embed_dim=128,
activation_function="gelu",
residual_prediction=residual_prediction,
pos_embed="spectral",
use_mlp=True,
normalization_layer="instance_norm",
encoder_kernel_shape=(5, 4),
filter_basis_type="piecewise linear",
drop_path_rate=drop_path_rate,
upsample_sht=False,
attention_mode="global",
bias=False
),
s2transformer_sc2_layers4_e256 = partial(
SphericalTransformer,
img_size=img_size,
grid=grid,
in_chans=in_chans,
out_chans=out_chans,
num_layers=4,
scale_factor=2,
embed_dim=256,
activation_function="gelu",
residual_prediction=residual_prediction,
pos_embed="spectral",
use_mlp=True,
normalization_layer="instance_norm",
encoder_kernel_shape=(5, 4),
filter_basis_type="piecewise linear",
drop_path_rate=drop_path_rate,
upsample_sht=False,
attention_mode="global",
bias=False
),
s2ntransformer_sc2_layers4_e128 = partial(
SphericalTransformer,
img_size=img_size,
grid=grid,
in_chans=in_chans,
out_chans=out_chans,
num_layers=4,
scale_factor=2,
embed_dim=128,
activation_function="gelu",
residual_prediction=residual_prediction,
pos_embed="spectral",
use_mlp=True,
normalization_layer="instance_norm",
encoder_kernel_shape=(5, 4),
filter_basis_type="piecewise linear",
drop_path_rate=drop_path_rate,
upsample_sht=False,
attention_mode="neighborhood",
bias=False
),
s2ntransformer_sc2_layers4_e256 = partial(
SphericalTransformer,
img_size=img_size,
grid=grid,
in_chans=in_chans,
out_chans=out_chans,
num_layers=4,
scale_factor=2,
embed_dim=256,
activation_function="gelu",
residual_prediction=residual_prediction,
pos_embed="spectral",
use_mlp=True,
normalization_layer="instance_norm",
encoder_kernel_shape=(5, 4),
filter_basis_type="piecewise linear",
drop_path_rate=drop_path_rate,
upsample_sht=False,
attention_mode="neighborhood",
bias=False
),
transformer_sc2_layers4_e128 = partial(
Transformer,
img_size=img_size,
in_chans=in_chans,
out_chans=out_chans,
num_layers=4,
scale_factor=2,
embed_dim=128,
activation_function="gelu",
residual_prediction=residual_prediction,
pos_embed="spectral",
use_mlp=True,
normalization_layer="instance_norm",
encoder_kernel_shape=(3, 3),
drop_path_rate=drop_path_rate,
attention_mode="global",
upsampling_method="conv",
bias=False
),
transformer_sc2_layers4_e256 = partial(
Transformer,
img_size=img_size,
in_chans=in_chans,
out_chans=out_chans,
num_layers=4,
scale_factor=2,
embed_dim=256,
activation_function="gelu",
residual_prediction=residual_prediction,
pos_embed="spectral",
use_mlp=True,
normalization_layer="instance_norm",
encoder_kernel_shape=(3, 3),
drop_path_rate=drop_path_rate,
attention_mode="global",
upsampling_method="conv",
bias=False
),
ntransformer_sc2_layers4_e128 = partial(
Transformer,
img_size=img_size,
in_chans=in_chans,
out_chans=out_chans,
num_layers=4,
scale_factor=2,
embed_dim=128,
activation_function="gelu",
residual_prediction=residual_prediction,
pos_embed="spectral",
use_mlp=True,
normalization_layer="instance_norm",
encoder_kernel_shape=(3, 3),
drop_path_rate=drop_path_rate,
attention_mode="neighborhood",
attn_kernel_shape=(7, 7),
bias=False
),
ntransformer_sc2_layers4_e256 = partial(
Transformer,
img_size=img_size,
in_chans=in_chans,
out_chans=out_chans,
num_layers=4,
scale_factor=2,
embed_dim=256,
activation_function="gelu",
residual_prediction=residual_prediction,
pos_embed="spectral",
use_mlp=True,
normalization_layer="instance_norm",
encoder_kernel_shape=(3, 3),
drop_path_rate=drop_path_rate,
attention_mode="neighborhood",
attn_kernel_shape=(7, 7),
bias=False
),
s2segformer_sc2_layers4_e128 = partial(
SphericalSegformer,
img_size=img_size,
grid=grid,
grid_internal="equiangular",
in_chans=in_chans,
out_chans=out_chans,
embed_dims=[16, 32, 64, 128],
heads=[1, 2, 4, 8],
depths=[3, 4, 6, 3],
scale_factor=2,
activation_function="gelu",
kernel_shape=(5, 4),
filter_basis_type="piecewise linear",
mlp_ratio=4.0,
att_drop_rate=0.0,
drop_path_rate=0.1,
attention_mode="global",
bias=False
),
s2segformer_sc2_layers4_e256 = partial(
SphericalSegformer,
img_size=img_size,
grid=grid,
grid_internal="equiangular",
in_chans=in_chans,
out_chans=out_chans,
embed_dims=[32, 64, 128, 256],
heads=[1, 2, 4, 8],
depths=[3, 4, 6, 3],
scale_factor=2,
activation_function="gelu",
kernel_shape=(5, 4),
filter_basis_type="piecewise linear",
mlp_ratio=4.0,
att_drop_rate=0.0,
drop_path_rate=0.1,
attention_mode="global",
bias=False
),
s2nsegformer_sc2_layers4_e128 = partial(
SphericalSegformer,
img_size=img_size,
grid=grid,
grid_internal="equiangular",
in_chans=in_chans,
out_chans=out_chans,
embed_dims=[16, 32, 64, 128],
heads=[1, 2, 4, 8],
depths=[3, 4, 6, 3],
scale_factor=2,
activation_function="gelu",
kernel_shape=(5, 4),
filter_basis_type="piecewise linear",
mlp_ratio=4.0,
att_drop_rate=0.0,
drop_path_rate=0.1,
attention_mode="neighborhood",
bias=False
),
s2nsegformer_sc2_layers4_e256 = partial(
SphericalSegformer,
img_size=img_size,
grid=grid,
grid_internal="equiangular",
in_chans=in_chans,
out_chans=out_chans,
embed_dims=[32, 64, 128, 256],
heads=[1, 2, 4, 8],
depths=[3, 4, 6, 3],
scale_factor=2,
activation_function="gelu",
kernel_shape=(5, 4),
filter_basis_type="piecewise linear",
mlp_ratio=4.0,
att_drop_rate=0.0,
drop_path_rate=0.1,
attention_mode="neighborhood",
bias=False
),
segformer_sc2_layers4_e128 = partial(
Segformer,
img_size=img_size,
in_chans=in_chans,
out_chans=out_chans,
embed_dims=[16, 32, 64, 128],
heads=[1, 2, 4, 8],
depths=[3, 4, 6, 3],
scale_factor=2,
activation_function="gelu",
kernel_shape=(4, 4),
mlp_ratio=4.0,
att_drop_rate=0.0,
drop_path_rate=0.1,
attention_mode="global",
bias=False
),
segformer_sc2_layers4_e256 = partial(
Segformer,
img_size=img_size,
in_chans=in_chans,
out_chans=out_chans,
embed_dims=[32, 64, 128, 256],
heads=[1, 2, 4, 8],
depths=[3, 4, 6, 3],
scale_factor=2,
activation_function="gelu",
kernel_shape=(4, 4),
mlp_ratio=4.0,
att_drop_rate=0.0,
drop_path_rate=0.1,
attention_mode="global",
bias=False
),
nsegformer_sc2_layers4_e128 = partial(
Segformer,
img_size=img_size,
in_chans=in_chans,
out_chans=out_chans,
embed_dims=[16, 32, 64, 128],
heads=[1, 2, 4, 8],
depths=[3, 4, 6, 3],
scale_factor=2,
activation_function="gelu",
kernel_shape=(4, 4),
mlp_ratio=4.0,
att_drop_rate=0.0,
drop_path_rate=0.1,
attention_mode="neighborhood",
attn_kernel_shape=(7, 7),
bias=False
),
nsegformer_sc2_layers4_e256 = partial(
Segformer,
img_size=img_size,
in_chans=in_chans,
out_chans=out_chans,
embed_dims=[32, 64, 128, 256],
heads=[1, 2, 4, 8],
depths=[3, 4, 6, 3],
scale_factor=2,
activation_function="gelu",
kernel_shape=(4, 4),
mlp_ratio=4.0,
att_drop_rate=0.0,
drop_path_rate=0.1,
attention_mode="neighborhood",
attn_kernel_shape=(7, 7),
bias=False
),
vit_sc2_layers4_e128 = partial(
Transformer,
img_size=img_size,
in_chans=in_chans,
out_chans=out_chans,
num_layers=4,
scale_factor=2,
embed_dim=128,
activation_function="gelu",
residual_prediction=residual_prediction,
pos_embed="learnable latlon",
use_mlp=True,
normalization_layer="layer_norm",
encoder_kernel_shape=(2, 2),
attention_mode="global",
upsampling_method="pixel_shuffle",
bias=False
),
)
return model_registry
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import os, sys
import random
import time
import argparse
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch_harmonics.examples import StanfordSegmentationDataset, Stanford2D3DSDownloader, StanfordDatasetSubset, compute_stats_s2
from torch_harmonics.quadrature import _precompute_latitudes
from torch_harmonics.examples.losses import DiceLossS2, CrossEntropyLossS2, FocalLossS2
from torch_harmonics.examples.metrics import IntersectionOverUnionS2, AccuracyS2
from torch_harmonics.plotting import plot_sphere, imshow_sphere
from torchvision.transforms import v2
# import baseline models
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from model_registry import get_baseline_models
# wandb logging
import wandb
# helper routine for counting number of paramerters in model
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# convenience function for logging weights and gradients
def log_weights_and_grads(exp_dir, model, iters=1):
"""
Helper routine intended for debugging purposes
"""
log_path = os.path.join(exp_dir, "weights_and_grads")
if not os.path.isdir(log_path):
os.makedirs(log_path, exist_ok=True)
weights_and_grads_fname = os.path.join(log_path, f"weights_and_grads_step{iters:03d}.tar")
print(weights_and_grads_fname)
weights_dict = {k: v for k, v in model.named_parameters()}
grad_dict = {k: v.grad for k, v in model.named_parameters()}
store_dict = {"iteration": iters, "grads": grad_dict, "weights": weights_dict}
torch.save(store_dict, weights_and_grads_fname)
# rolls out the FNO and compares to the classical solver
def validate_model(model, dataloader, loss_fn, metrics_fns, path_root, normalization=None, logging=True, device=torch.device("cpu")):
model.eval()
num_examples = len(dataloader)
# make output
if logging and not os.path.isdir(path_root):
os.makedirs(path_root, exist_ok=True)
if dist.is_initialized():
dist.barrier(device_ids=[device.index])
# accumulation buffers for metrics and losses
losses = torch.zeros(num_examples, dtype=torch.float32, device=device)
metrics = {}
for metric in metrics_fns:
metrics[metric] = torch.zeros(num_examples, dtype=torch.float32, device=device)
glob_off = 0
if dist.is_initialized():
glob_off = num_examples * dist.get_rank()
with torch.no_grad():
for idx, (inp, tar) in enumerate(dataloader):
(_, _, idx_file) = dataloader.dataset[idx]
inpd = inp.to(device)
tar = tar.to(device)
if normalization is not None:
inpd = normalization(inpd)
prd = model(inpd)
num_classes = prd.shape[-3]
losses[idx] = loss_fn(prd, tar)
for metric in metrics_fns:
metric_buff = metrics[metric]
metric_fn = metrics_fns[metric]
metric_buff[idx] = metric_fn(prd, tar)
prd = nn.functional.softmax(prd, dim=-3)
prd = torch.argmax(prd, dim=-3).squeeze(0)
# do plotting
glob_idx = idx + glob_off
fig = plt.figure(figsize=(7.5, 6))
plot_sphere(prd.cpu() / num_classes, fig=fig, vmax=1.0, vmin=0.0, cmap="rainbow")
plt.savefig(os.path.join(path_root, "pred_" + str(glob_idx) + ".png"))
plt.close()
fig = plt.figure(figsize=(7.5, 6))
plot_sphere(tar.cpu().squeeze(0) / num_classes, fig=fig, vmax=1.0, vmin=0.0, cmap="rainbow")
plt.savefig(os.path.join(path_root, "truth_" + str(glob_idx) + ".png"))
plt.close()
fig = plt.figure(figsize=(7.5, 6))
imshow_sphere(inp.cpu().squeeze(0).permute(1, 2, 0), fig=fig)
plt.savefig(os.path.join(path_root, "input_" + str(glob_idx) + ".png"))
plt.close()
return losses, metrics
# training function
def train_model(
model,
train_dataloader,
train_sampler,
test_dataloader,
test_sampler,
loss_fn,
metrics_fns,
optimizer,
gscaler,
scheduler=None,
max_grad_norm=0.0,
normalization=None,
augmentation=None,
nepochs=20,
amp_mode="none",
log_grads=0,
exp_dir=None,
logging=True,
device=torch.device("cpu"),
):
train_start = time.time()
# set AMP type
amp_dtype = torch.float32
if amp_mode == "fp16":
amp_dtype = torch.float16
elif amp_mode == "bf16":
amp_dtype = torch.bfloat16
# count iterations
iters = 0
for epoch in range(nepochs):
# time each epoch
epoch_start = time.time()
# do the training
accumulated_loss = torch.zeros(2, dtype=torch.float32, device=device)
model.train()
if dist.is_initialized():
train_sampler.set_epoch(epoch)
for inp, tar in train_dataloader:
inp = inp.to(device)
tar = tar.to(device)
if normalization is not None:
inp = normalization(inp)
if augmentation is not None:
inp = augmentation(inp)
# flip randomly horizontally
if random.random() < 0.5:
inp = torch.flip(inp, dims=(-1,))
tar = torch.flip(tar, dims=(-1,))
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=(amp_mode != "none")):
prd = model(inp)
loss = loss_fn(prd, tar)
optimizer.zero_grad(set_to_none=True)
gscaler.scale(loss).backward()
if log_grads and (iters % log_grads == 0) and (exp_dir is not None):
log_weights_and_grads(exp_dir, model, iters=iters)
if max_grad_norm > 0.0:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
gscaler.step(optimizer)
gscaler.update()
# accumulate loss
accumulated_loss[0] += loss.detach() * inp.size(0)
accumulated_loss[1] += inp.size(0)
iters += 1
if dist.is_initialized():
dist.all_reduce(accumulated_loss)
accumulated_loss = (accumulated_loss[0] / accumulated_loss[1]).item()
# perform validation
valid_loss = torch.zeros(2, dtype=torch.float32, device=device)
valid_metrics = {}
for metric in metrics_fns:
valid_metrics[metric] = torch.zeros(2, dtype=torch.float32, device=device)
model.eval()
if dist.is_initialized():
test_sampler.set_epoch(epoch)
with torch.no_grad():
for inp, tar in test_dataloader:
inp = inp.to(device)
tar = tar.to(device)
if normalization is not None:
inp = normalization(inp)
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=(amp_mode != "none")):
prd = model(inp)
loss = loss_fn(prd, tar)
valid_loss[0] += loss * inp.size(0)
valid_loss[1] += inp.size(0)
for metric in metrics_fns:
metric_buff = valid_metrics[metric]
metric_fn = metrics_fns[metric]
metric_buff[0] += metric_fn(prd, tar) * inp.size(0)
metric_buff[1] += inp.size(0)
if dist.is_initialized():
dist.all_reduce(valid_loss)
for metric in metrics_fns:
dist.all_reduce(valid_metrics[metric])
valid_loss = (valid_loss[0] / valid_loss[1]).item()
for metric in valid_metrics:
valid_metrics[metric] = (valid_metrics[metric][0] / valid_metrics[metric][1]).item()
if scheduler is not None:
scheduler.step()
epoch_time = time.time() - epoch_start
if logging:
print(f"--------------------------------------------------------------------------------")
print(f"Epoch {epoch} summary:")
print(f"time taken: {epoch_time:.2f}")
print(f"accumulated training loss: {accumulated_loss}")
print(f"relative validation loss: {valid_loss}")
for metric in valid_metrics:
print(f"{metric}: {valid_metrics[metric]}")
if wandb.run is not None:
current_lr = optimizer.param_groups[0]["lr"]
log_dict = {"loss": accumulated_loss, "validation loss": valid_loss, "learning rate": current_lr}
for metric in valid_metrics:
log_dict[metric] = valid_metrics[metric]
wandb.log(log_dict)
# wrapping up
train_time = time.time() - train_start
if logging:
print(f"--------------------------------------------------------------------------------")
print(f"done. Training took {train_time:.2f}.")
return valid_loss
def main(
models,
root_path,
num_epochs=100,
batch_size=8,
learning_rate=1e-4,
label_smoothing=0.0,
max_grad_norm=0.0,
train=True,
load_checkpoint=False,
amp_mode="none",
ddp=False,
enable_data_augmentation=False,
ignore_alpha_channel=True,
log_grads=0,
data_path="data",
data_downsampling_factor=16,
):
# initialize distributed
local_rank = 0
logging = True
if ddp:
dist.init_process_group(backend="nccl")
world_size = dist.get_world_size()
local_rank = dist.get_rank() % torch.cuda.device_count()
logging = dist.get_rank() == 0
# set seed
torch.manual_seed(333)
torch.cuda.manual_seed(333)
# set device
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
torch.cuda.set_device(device.index)
# create dataset directory if it doesn't exist
if logging:
os.makedirs(data_path, exist_ok=True)
# 2D3DS download & dataset initialization
downloader = Stanford2D3DSDownloader(base_url="https://cvg-data.inf.ethz.ch/2d3ds/no_xyz/", local_dir=str(data_path))
dataset_file = downloader.prepare_dataset(dataset_file=f"stanford_2d3ds_dataset_ds{data_downsampling_factor}.h5", downsampling_factor=data_downsampling_factor)
# intiialize distributed for ddp
if dist.is_initialized():
dist.barrier(device_ids=[device.index])
# create the dataset and split it
if logging:
print(f"Initializing dataset...")
# make sure splitting is consistent across ranks
rng = torch.Generator().manual_seed(333)
split_ratios = [0.95, 0.025, 0.025]
dataset = StanfordSegmentationDataset(dataset_file=dataset_file, ignore_alpha_channel=ignore_alpha_channel)
# Create custom subsets
train_indices, test_indices, valid_indices = torch.utils.data.random_split(range(len(dataset)), split_ratios, generator=rng)
train_dataset = StanfordDatasetSubset(dataset, train_indices)
test_dataset = StanfordDatasetSubset(dataset, test_indices)
valid_dataset = StanfordDatasetSubset(dataset, valid_indices, return_index=True)
# compute stats on the train dataset
means, stds = compute_stats_s2(train_dataset)
train_dataset.dataset.reset()
if logging:
print(f"Computed stats: means={means}, stds={stds}")
# split dataset if distributed
if dist.is_initialized():
train_sampler = DistributedSampler(train_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True, drop_last=True)
test_sampler = DistributedSampler(test_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False, drop_last=True)
valid_sampler = DistributedSampler(valid_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False, drop_last=True)
else:
train_sampler = None
test_sampler = None
valid_sampler = None
# create the dataloaders
train_dataloader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True if train_sampler is None else False, sampler=train_sampler, num_workers=4, persistent_workers=True, pin_memory=True
)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, sampler=test_sampler, num_workers=4, persistent_workers=True, pin_memory=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=1, shuffle=False, sampler=valid_sampler, num_workers=0, persistent_workers=False, pin_memory=True)
# TODO: move augmentation into extra helper module
normalization = v2.Normalize(mean=means.tolist(), std=stds.tolist())
if enable_data_augmentation:
if not ignore_alpha_channel:
raise NotImplementedError("You can only use data augmentation with RGB images, RGBA is not supported.")
if logging:
print("Using data augmentation")
# imagenet normalization
augmentation = v2.Compose(
[
v2.RandomAutocontrast(p=0.5),
v2.GaussianNoise(mean=0.0, sigma=0.1, clip=True),
v2.ColorJitter(),
]
)
else:
augmentation = None
in_channels = 3 if ignore_alpha_channel else 4
# print dataset info
img_size = dataset.input_shape[1:]
class_histogram = torch.from_numpy(dataset.class_histogram)
# various class weights where tried such as inverse frequency
# No class weights seem to work best
class_weights = None
# make sure there is no nan
if (class_weights is not None) and torch.isnan(class_weights).any():
raise ValueError("The class weights contain NaN.")
if logging:
print(f"Train dataset initialized with {len(train_dataset)} samples of resolution {img_size}")
print(f"Test dataset initialized with {len(test_dataset)} samples of resolution {img_size}")
print(f"Validation dataset initialized with {len(valid_dataset)} samples of resolution {img_size}")
# get baseline model registry
baseline_models = get_baseline_models(img_size=img_size, in_chans=in_channels, out_chans=dataset.num_classes, drop_path_rate=0.1)
# specify which models to train here
if models is None:
models = [
"s2segformer_sc2_layers4_e128",
"s2segformer_sc2_layers4_e256",
"segformer_sc2_layers4_e128",
"segformer_sc2_layers4_e256",
"s2nsegformer_sc2_layers4_e128",
"s2nsegformer_sc2_layers4_e256",
"nsegformer_sc2_layers4_e128",
"nsegformer_sc2_layers4_e256",
"s2transformer_sc2_layers4_e128",
"s2transformer_sc2_layers4_e256",
"s2ntransformer_sc2_layers4_e128",
"s2ntransformer_sc2_layers4_e256",
"transformer_sc2_layers4_e128",
"transformer_sc2_layers4_e256",
"ntransformer_sc2_layers4_e128",
"ntransformer_sc2_layers4_e256",
"vit_sc2_layers4_e128",
"sfno_sc2_layers4_e32",
"lsno_sc2_layers4_e32",
]
elif isinstance(models, str):
models = [models]
models = {k: baseline_models[k] for k in models}
if len(models) == 0:
raise ValueError("No models selected")
# create the loss object
loss_fn = CrossEntropyLossS2(nlat=img_size[0], nlon=img_size[1], grid="equiangular", weight=class_weights, smooth=label_smoothing).to(device=device)
# loss_fn = DiceLossS2(nlat=img_size[0], nlon=img_size[1], grid="equiangular", weight=class_weights, smooth=label_smoothing).to(device=device)
# loss_fn = FocalLossS2(nlat=img_size[0], nlon=img_size[1], grid="equiangular").to(device=device)
# metrics
metrics = {}
metrics_fns = {
"mean IoU": IntersectionOverUnionS2(
nlat=img_size[0],
nlon=img_size[1],
grid="equiangular",
weight=class_weights,
).to(device=device),
"mean Accuracy": AccuracyS2(
nlat=img_size[0],
nlon=img_size[1],
grid="equiangular",
weight=class_weights,
).to(device=device),
}
# iterate over models and train each model
for model_name, model_handle in models.items():
model = model_handle().to(device)
if logging:
print(model)
if dist.is_initialized():
model = DDP(model, device_ids=[device.index])
metrics[model_name] = {}
num_params = count_parameters(model)
if logging:
print(f"number of trainable params: {num_params}")
metrics[model_name]["num_params"] = num_params
exp_dir = os.path.join(root_path, model_name)
if not os.path.isdir(exp_dir):
os.makedirs(exp_dir, exist_ok=True)
if load_checkpoint:
model.load_state_dict(torch.load(os.path.join(exp_dir, "checkpoint.pt")))
# run the training
if train:
if logging:
run = wandb.init(project="spherical segmentation 2d3ds", group=model_name, name=model_name + "_" + str(time.time()), config=model_handle.keywords)
else:
run = None
# optimizer:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.1, foreach=torch.cuda.is_available())
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)
gscaler = torch.GradScaler("cuda", enabled=(amp_mode == "fp16"))
start_time = time.time()
if logging:
print(f"Training {model_name} with config {model_handle}")
train_model(
model,
train_dataloader,
train_sampler,
test_dataloader,
test_sampler,
loss_fn,
metrics_fns,
optimizer,
gscaler,
scheduler,
max_grad_norm=max_grad_norm,
normalization=normalization,
augmentation=augmentation,
nepochs=num_epochs,
amp_mode=amp_mode,
log_grads=log_grads,
exp_dir=exp_dir,
logging=logging,
device=device,
)
training_time = time.time() - start_time
if logging:
run.finish()
torch.save(model.state_dict(), os.path.join(exp_dir, "checkpoint.pt"))
# set seed
torch.manual_seed(333)
torch.cuda.manual_seed(333)
with torch.inference_mode():
# run the validation
losses, metric_results = validate_model(
model, valid_dataloader, loss_fn, metrics_fns, os.path.join(exp_dir, "figures"), normalization=normalization, logging=logging, device=device
)
# gather losses and metrics into a single tensor
if dist.is_initialized():
losses_dist = torch.zeros(world_size * losses.shape[0], dtype=losses.dtype, device=device)
dist.all_gather_into_tensor(losses_dist, losses)
losses = losses_dist
for metric_name, metric in metric_results.items():
metric_dist = torch.zeros(world_size * metric.shape[0], dtype=metric.dtype, device=device)
dist.all_gather_into_tensor(metric_dist, metric)
metric_results[metric_name] = metric_dist
# compute statistics
metrics[model_name]["loss mean"] = torch.mean(losses).item()
metrics[model_name]["loss std"] = torch.std(losses).item()
for metric in metric_results:
metrics[model_name][metric + " mean"] = torch.mean(metric_results[metric]).item()
metrics[model_name][metric + " std"] = torch.std(metric_results[metric]).item()
if train:
metrics[model_name]["training_time"] = training_time
if logging:
df = pd.DataFrame(metrics)
if not os.path.isdir(os.path.join(exp_dir, "output_data")):
os.makedirs(os.path.join(exp_dir, "output_data"), exist_ok=True)
df.to_pickle(os.path.join(exp_dir, "output_data", "metrics.pkl"))
if dist.is_initialized():
dist.barrier(device_ids=[device.index])
if __name__ == "__main__":
import torch.multiprocessing as mp
mp.set_start_method("forkserver", force=True)
wandb.login()
parser = argparse.ArgumentParser()
parser.add_argument(
"--output_path", default=os.path.join(os.path.dirname(__file__), "checkpoints"), type=str, help="Override the path where checkpoints and run information are stored"
)
parser.add_argument(
"--data_path",
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), "data", "2D3DS"),
type=str,
help="Directory to where the dataset is stored. If the dataset is not found in that location, it will be downloaded automatically.",
)
parser.add_argument("--models", default=None, type=str, nargs='+', help="Provide a list of models to run")
parser.add_argument("--num_epochs", default=200, type=int, help="Switch for overriding batch size in the configuration file.")
parser.add_argument("--batch_size", default=8, type=int, help="Switch for overriding batch size in the configuration file.")
parser.add_argument("--data_downsampling_factor", default=16, type=int, help="Switch for overriding the downsampling factor of the data.")
parser.add_argument("--learning_rate", default=5e-4, type=float, help="Switch to override learning rate.")
parser.add_argument("--max_grad_norm", default=4.0, type=float, help="Switch to override max grad norm. A value > 0 activates gradient clipping.")
parser.add_argument("--label_smoothing_factor", default=0.0, type=float, help="Label smoothing factor [0, 1].")
parser.add_argument("--resume", action="store_true", help="Reload checkpoints.")
parser.add_argument("--amp_mode", default="none", type=str, choices=["none", "bf16", "fp16"], help="Switch to enable AMP.")
parser.add_argument("--enable_ddp", action="store_true", help="Switch to enable distributed data parallel.")
parser.add_argument("--enable_data_augmentation", action="store_true", help="Switch to enable data augmentation.")
args = parser.parse_args()
main(
models=args.models,
root_path=args.output_path,
num_epochs=args.num_epochs,
batch_size=args.batch_size,
learning_rate=args.learning_rate,
label_smoothing=args.label_smoothing_factor,
max_grad_norm=args.max_grad_norm,
train=args.num_epochs > 0,
load_checkpoint=args.resume,
amp_mode=args.amp_mode,
ddp=args.enable_ddp,
enable_data_augmentation=args.enable_data_augmentation,
ignore_alpha_channel=True,
log_grads=0,
data_path=args.data_path,
data_downsampling_factor=args.data_downsampling_factor,
)
......@@ -29,11 +29,12 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import os
import os, sys
import time
import argparse
from functools import partial
from tqdm import tqdm
from functools import partial
import torch
import torch.nn as nn
......@@ -45,117 +46,57 @@ import pandas as pd
import matplotlib.pyplot as plt
from torch_harmonics.examples import PdeDataset
from torch_harmonics.examples.losses import L1LossS2, SquaredL2LossS2, L2LossS2, W11LossS2
from torch_harmonics import RealSHT
from torch_harmonics.plotting import plot_sphere
# wandb logging
import wandb
def l2loss_sphere(solver, prd, tar, relative=False, squared=True):
loss = solver.integrate_grid((prd - tar) ** 2, dimensionless=True).sum(dim=-1)
if relative:
loss = loss / solver.integrate_grid(tar**2, dimensionless=True).sum(dim=-1)
if not squared:
loss = torch.sqrt(loss)
loss = loss.mean()
return loss
def spectral_l2loss_sphere(solver, prd, tar, relative=False, squared=True):
# compute coefficients
coeffs = torch.view_as_real(solver.sht(prd - tar))
coeffs = coeffs[..., 0] ** 2 + coeffs[..., 1] ** 2
norm2 = coeffs[..., :, 0] + 2 * torch.sum(coeffs[..., :, 1:], dim=-1)
loss = torch.sum(norm2, dim=(-1, -2))
if relative:
tar_coeffs = torch.view_as_real(solver.sht(tar))
tar_coeffs = tar_coeffs[..., 0] ** 2 + tar_coeffs[..., 1] ** 2
tar_norm2 = tar_coeffs[..., :, 0] + 2 * torch.sum(tar_coeffs[..., :, 1:], dim=-1)
tar_norm2 = torch.sum(tar_norm2, dim=(-1, -2))
loss = loss / tar_norm2
if not squared:
loss = torch.sqrt(loss)
loss = loss.mean()
return loss
def spectral_loss_sphere(solver, prd, tar, relative=False, squared=True):
# gradient weighting factors
lmax = solver.sht.lmax
ls = torch.arange(lmax).float()
spectral_weights = (ls * (ls + 1)).reshape(1, 1, -1, 1).to(prd.device)
# compute coefficients
coeffs = torch.view_as_real(solver.sht(prd - tar))
coeffs = coeffs[..., 0] ** 2 + coeffs[..., 1] ** 2
coeffs = spectral_weights * coeffs
norm2 = coeffs[..., :, 0] + 2 * torch.sum(coeffs[..., :, 1:], dim=-1)
loss = torch.sum(norm2, dim=(-1, -2))
if relative:
tar_coeffs = torch.view_as_real(solver.sht(tar))
tar_coeffs = tar_coeffs[..., 0] ** 2 + tar_coeffs[..., 1] ** 2
tar_coeffs = spectral_weights * tar_coeffs
tar_norm2 = tar_coeffs[..., :, 0] + 2 * torch.sum(tar_coeffs[..., :, 1:], dim=-1)
tar_norm2 = torch.sum(tar_norm2, dim=(-1, -2))
loss = loss / tar_norm2
if not squared:
loss = torch.sqrt(loss)
loss = loss.mean()
return loss
# import baseline models
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from model_registry import get_baseline_models
def h1loss_sphere(solver, prd, tar, relative=False, squared=True):
# gradient weighting factors
lmax = solver.sht.lmax
ls = torch.arange(lmax).float()
spectral_weights = (ls * (ls + 1)).reshape(1, 1, -1, 1).to(prd.device)
# compute coefficients
coeffs = torch.view_as_real(solver.sht(prd - tar))
coeffs = coeffs[..., 0] ** 2 + coeffs[..., 1] ** 2
h1_coeffs = spectral_weights * coeffs
h1_norm2 = h1_coeffs[..., :, 0] + 2 * torch.sum(h1_coeffs[..., :, 1:], dim=-1)
l2_norm2 = coeffs[..., :, 0] + 2 * torch.sum(coeffs[..., :, 1:], dim=-1)
h1_loss = torch.sum(h1_norm2, dim=(-1, -2))
l2_loss = torch.sum(l2_norm2, dim=(-1, -2))
# wandb logging
try:
import wandb
except:
wandb = None
# strictly speaking this is not exactly h1 loss
if not squared:
loss = torch.sqrt(h1_loss) + torch.sqrt(l2_loss)
else:
loss = h1_loss + l2_loss
if relative:
raise NotImplementedError("Relative H1 loss not implemented")
# helper routine for counting number of paramerters in model
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
loss = loss.mean()
return loss
# convenience function for logging weights and gradients
def log_weights_and_grads(model, iters=1):
"""
Helper routine intended for debugging purposes
"""
root_path = os.path.join(os.path.dirname(__file__), "weights_and_grads")
weights_and_grads_fname = os.path.join(root_path, f"weights_and_grads_step{iters:03d}.tar")
print(weights_and_grads_fname)
def fluct_l2loss_sphere(solver, prd, tar, inp, relative=False, polar_opt=0):
# compute the weighting factor first
fluct = solver.integrate_grid((tar - inp) ** 2, dimensionless=True, polar_opt=polar_opt)
weight = fluct / torch.sum(fluct, dim=-1, keepdim=True)
# weight = weight.reshape(*weight.shape, 1, 1)
weights_dict = {k: v for k, v in model.named_parameters()}
grad_dict = {k: v.grad for k, v in model.named_parameters()}
loss = weight * solver.integrate_grid((prd - tar) ** 2, dimensionless=True, polar_opt=polar_opt)
if relative:
loss = loss / (weight * solver.integrate_grid(tar**2, dimensionless=True, polar_opt=polar_opt))
loss = torch.mean(loss)
return loss
store_dict = {"iteration": iters, "grads": grad_dict, "weights": weights_dict}
torch.save(store_dict, weights_and_grads_fname)
# rolls out the FNO and compares to the classical solver
def autoregressive_inference(model, dataset, path_root, nsteps, autoreg_steps=10, nskip=1, plot_channel=0, nics=50):
def autoregressive_inference(
model,
dataset,
loss_fn,
metrics_fns,
path_root,
nsteps,
autoreg_steps=10,
nskip=1,
plot_channel=0,
nics=50,
device=torch.device("cpu"),
):
model.eval()
......@@ -163,9 +104,13 @@ def autoregressive_inference(model, dataset, path_root, nsteps, autoreg_steps=10
if not os.path.isdir(path_root):
os.makedirs(path_root, exist_ok=True)
losses = np.zeros(nics)
fno_times = np.zeros(nics)
nwp_times = np.zeros(nics)
# accumulation buffers for losses, metrics and runtimes
losses = torch.zeros(nics, dtype=torch.float32, device=device)
metrics = {}
for metric in metrics_fns:
metrics[metric] = torch.zeros(nics, dtype=torch.float32, device=device)
model_times = torch.zeros(nics, dtype=torch.float32, device=device)
solver_times = torch.zeros(nics, dtype=torch.float32, device=device)
# accumulation buffers for the power spectrum
prd_mean_coeffs = []
......@@ -184,6 +129,16 @@ def autoregressive_inference(model, dataset, path_root, nsteps, autoreg_steps=10
prd_coeffs = [dataset.sht(prd[0, plot_channel]).detach().cpu().clone()]
ref_coeffs = [prd_coeffs[0].clone()]
# plot the initial condition
if iic == nics - 1 and nskip > 0 and i % nskip == 0:
# do plotting
fig = plt.figure(figsize=(6, 6))
plot_sphere(prd[0, plot_channel].cpu(), fig, vmax=4, vmin=-4, central_latitude=30, gridlines=True, projection="orthographic")
fig.tight_layout()
plt.savefig(os.path.join(path_root, "truth_" + str(0) + ".png"))
plt.close()
# ML model
start_time = time.time()
for i in range(1, autoreg_steps + 1):
......@@ -195,12 +150,13 @@ def autoregressive_inference(model, dataset, path_root, nsteps, autoreg_steps=10
if iic == nics - 1 and nskip > 0 and i % nskip == 0:
# do plotting
fig = plt.figure(figsize=(7.5, 6))
dataset.solver.plot_griddata(prd[0, plot_channel], fig, vmax=4, vmin=-4, projection="robinson")
plt.savefig(os.path.join(path_root,'pred_'+str(i//nskip)+'.png'))
fig = plt.figure(figsize=(6, 6))
plot_sphere(prd[0, plot_channel].cpu(), fig, vmax=4, vmin=-4, central_latitude=30, gridlines=True, projection="orthographic")
fig.tight_layout()
plt.savefig(os.path.join(path_root, "pred_" + str(i // nskip) + ".png"))
plt.close()
fno_times[iic] = time.time() - start_time
model_times[iic] = time.time() - start_time
# classical model
start_time = time.time()
......@@ -213,21 +169,26 @@ def autoregressive_inference(model, dataset, path_root, nsteps, autoreg_steps=10
if iic == nics - 1 and i % nskip == 0 and nskip > 0:
fig = plt.figure(figsize=(7.5, 6))
dataset.solver.plot_griddata(ref[plot_channel], fig, vmax=4, vmin=-4, projection="robinson")
plt.savefig(os.path.join(path_root,'truth_'+str(i//nskip)+'.png'))
fig = plt.figure(figsize=(6, 6))
plot_sphere(ref[plot_channel].cpu(), fig, vmax=4, vmin=-4, central_latitude=30, gridlines=True, projection="orthographic")
fig.tight_layout()
plt.savefig(os.path.join(path_root, "truth_" + str(i // nskip) + ".png"))
plt.close()
nwp_times[iic] = time.time() - start_time
solver_times[iic] = time.time() - start_time
# compute power spectrum and add it to the buffers
prd_mean_coeffs.append(torch.stack(prd_coeffs, 0))
ref_mean_coeffs.append(torch.stack(ref_coeffs, 0))
# ref = (dataset.solver.spec2grid(uspec) - inp_mean) / torch.sqrt(inp_var)
ref = dataset.solver.spec2grid(uspec)
prd = prd * torch.sqrt(inp_var) + inp_mean
losses[iic] = l2loss_sphere(dataset.solver, prd, ref, relative=True).item()
ref = (dataset.solver.spec2grid(uspec) - inp_mean) / torch.sqrt(inp_var)
# ref = dataset.solver.spec2grid(uspec)
losses[iic] = loss_fn(prd, ref)
# prd = prd * torch.sqrt(inp_var) + inp_mean
for metric in metrics_fns:
metric_buff = metrics[metric]
metric_fn = metrics_fns[metric]
metric_buff[iic] = metric_fn(prd, ref)
# compute the averaged powerspectra of prediction and reference
with torch.no_grad():
......@@ -251,35 +212,42 @@ def autoregressive_inference(model, dataset, path_root, nsteps, autoreg_steps=10
plt.xlabel("$l$")
plt.ylabel("powerspectrum")
plt.legend()
plt.savefig(os.path.join(path_root,f'powerspectrum_{step}.png'))
fig.tight_layout()
plt.savefig(os.path.join(path_root, f"powerspectrum_{step}.png"))
fig.clf()
plt.close()
return losses, fno_times, nwp_times
# convenience function for logging weights and gradients
def log_weights_and_grads(model, iters=1):
"""
Helper routine intended for debugging purposes
"""
root_path = os.path.join(os.getcwd(), "weights_and_grads")
weights_and_grads_fname = os.path.join(root_path, f"weights_and_grads_step{iters:03d}.tar")
print(weights_and_grads_fname)
weights_dict = {k: v for k, v in model.named_parameters()}
grad_dict = {k: v.grad for k, v in model.named_parameters()}
store_dict = {"iteration": iters, "grads": grad_dict, "weights": weights_dict}
torch.save(store_dict, weights_and_grads_fname)
return losses, metrics, model_times, solver_times
# training function
def train_model(model, dataloader, optimizer, gscaler, scheduler=None, nepochs=20, nfuture=0, num_examples=256, num_valid=8, loss_fn="l2", enable_amp=False, log_grads=0):
def train_model(
model,
dataloader,
loss_fn,
metrics_fns,
optimizer,
gscaler,
scheduler=None,
nepochs=20,
nfuture=0,
num_examples=256,
num_valid=8,
amp_mode="none",
log_grads=0,
logging=True,
device=torch.device("cpu"),
):
train_start = time.time()
# set AMP type
amp_dtype = torch.float32
if amp_mode == "fp16":
amp_dtype = torch.float16
elif amp_mode == "bf16":
amp_dtype = torch.bfloat16
# count iterations
iters = 0
......@@ -295,31 +263,20 @@ def train_model(model, dataloader, optimizer, gscaler, scheduler=None, nepochs=2
solver = dataloader.dataset.solver
# do the training
acc_loss = 0
accumulated_loss = 0
model.train()
for inp, tar in dataloader:
with torch.autocast(device_type="cuda", enabled=enable_amp):
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=(amp_mode != "none")):
prd = model(inp)
for _ in range(nfuture):
prd = model(prd)
if loss_fn == "l2":
loss = l2loss_sphere(solver, prd, tar, relative=False)
elif loss_fn == "spectral l2":
loss = spectral_l2loss_sphere(solver, prd, tar, relative=False)
elif loss_fn == "h1":
loss = h1loss_sphere(solver, prd, tar, relative=False)
elif loss_fn == "spectral":
loss = spectral_loss_sphere(solver, prd, tar, relative=False)
elif loss_fn == "fluct":
loss = fluct_l2loss_sphere(solver, prd, tar, inp, relative=True)
else:
raise NotImplementedError(f"Unknown loss function {loss_fn}")
acc_loss += loss.item() * inp.size(0)
loss = loss_fn(prd, tar)
accumulated_loss += loss.item() * inp.size(0)
optimizer.zero_grad(set_to_none=True)
gscaler.scale(loss).backward()
......@@ -332,39 +289,63 @@ def train_model(model, dataloader, optimizer, gscaler, scheduler=None, nepochs=2
iters += 1
acc_loss = acc_loss / len(dataloader.dataset)
accumulated_loss = accumulated_loss / len(dataloader.dataset)
dataloader.dataset.set_initial_condition("random")
dataloader.dataset.set_num_examples(num_valid)
# perform validation
valid_loss = 0
# eval mode
model.eval()
# prepare loss buffer for validation loss
valid_loss = torch.zeros(2, dtype=torch.float32, device=device)
# prepare metrics buffer for accumulation of validation metrics
valid_metrics = {}
for metric in metrics_fns:
valid_metrics[metric] = torch.zeros(2, dtype=torch.float32, device=device)
# perform validation
with torch.no_grad():
for inp, tar in dataloader:
prd = model(inp)
for _ in range(nfuture):
prd = model(prd)
loss = l2loss_sphere(solver, prd, tar, relative=True)
loss = loss_fn(prd, tar).item()
valid_loss[0] += loss * inp.size(0)
valid_loss[1] += inp.size(0)
valid_loss += loss.item() * inp.size(0)
for metric in metrics_fns:
metric_buff = valid_metrics[metric]
metric_fn = metrics_fns[metric]
metric_buff[0] += metric_fn(prd, tar) * inp.size(0)
metric_buff[1] += inp.size(0)
valid_loss = valid_loss / len(dataloader.dataset)
valid_loss = (valid_loss[0] / valid_loss[1]).item()
for metric in valid_metrics:
valid_metrics[metric] = (valid_metrics[metric][0] / valid_metrics[metric][1]).item()
if scheduler is not None:
scheduler.step(valid_loss)
epoch_time = time.time() - epoch_start
print(f"--------------------------------------------------------------------------------")
print(f"Epoch {epoch} summary:")
print(f"time taken: {epoch_time}")
print(f"accumulated training loss: {acc_loss}")
print(f"relative validation loss: {valid_loss}")
if wandb.run is not None:
current_lr = optimizer.param_groups[0]["lr"]
wandb.log({"loss": acc_loss, "validation loss": valid_loss, "learning rate": current_lr})
if logging:
print(f"--------------------------------------------------------------------------------")
print(f"Epoch {epoch} summary:")
print(f"time taken: {epoch_time:.2f}")
print(f"accumulated training loss: {accumulated_loss}")
print(f"validation loss: {valid_loss}")
for metric in valid_metrics:
print(f"{metric}: {valid_metrics[metric]}")
if wandb.run is not None:
current_lr = optimizer.param_groups[0]["lr"]
log_dict = {"loss": accumulated_loss, "validation loss": valid_loss, "learning rate": current_lr}
for metric in valid_metrics:
log_dict[metric] = valid_metrics[metric]
wandb.log(log_dict)
train_time = time.time() - train_start
......@@ -373,18 +354,15 @@ def train_model(model, dataloader, optimizer, gscaler, scheduler=None, nepochs=2
return valid_loss
def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
def main(root_path, pretrain_epochs=100, finetune_epochs=10, batch_size=1, learning_rate=1e-3, train=True, load_checkpoint=False, amp_mode="none", log_grads=0):
# enable logging by default
logging = True
# set seed
torch.manual_seed(333)
torch.cuda.manual_seed(333)
# login
wandb.login()
# set parameters
nfuture=0
# set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
......@@ -395,79 +373,49 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
dt_solver = 150
nsteps = dt // dt_solver
grid = "legendre-gauss"
nlat, nlon = (257, 512)
nlat, nlon = (128, 256)
dataset = PdeDataset(dt=dt, nsteps=nsteps, dims=(nlat, nlon), device=device, grid=grid, normalize=True)
dataset.sht = RealSHT(nlat=nlat, nlon=nlon, grid= grid).to(device=device)
dataset.sht = RealSHT(nlat=nlat, nlon=nlon, grid=grid).to(device=device)
# There is still an issue with parallel dataloading. Do NOT use it at the moment
# dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, persistent_workers=True)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0, persistent_workers=False)
# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, persistent_workers=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, persistent_workers=False)
nlat = dataset.nlat
nlon = dataset.nlon
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# prepare dicts containing models and corresponding metrics
models = {}
metrics = {}
from torch_harmonics.examples.models import SphericalFourierNeuralOperatorNet as SFNO
from torch_harmonics.examples.models import LocalSphericalNeuralOperatorNet as LSNO
models[f"sfno_sc2_layers4_e32"] = partial(
SFNO,
img_size=(nlat, nlon),
grid=grid,
hard_thresholding_fraction=0.8,
num_layers=4,
scale_factor=2,
embed_dim=32,
activation_function="gelu",
big_skip=True,
pos_embed=False,
use_mlp=True,
normalization_layer="none",
)
models[f"lsno_sc2_layers4_e32_morlet"] = partial(
LSNO,
img_size=(nlat, nlon),
grid=grid,
num_layers=4,
scale_factor=2,
embed_dim=32,
activation_function="gelu",
big_skip=True,
pos_embed=False,
use_mlp=True,
normalization_layer="none",
kernel_shape=(2, 2),
encoder_kernel_shape=(2, 2),
filter_basis_type="morlet",
upsample_sht = True,
)
models[f"lsno_sc2_layers4_e32_zernike"] = partial(
LSNO,
img_size=(nlat, nlon),
grid=grid,
num_layers=4,
scale_factor=2,
embed_dim=32,
activation_function="gelu",
big_skip=True,
pos_embed=False,
use_mlp=True,
normalization_layer="none",
kernel_shape=(4),
encoder_kernel_shape=(4),
filter_basis_type="zernike",
upsample_sht = True,
)
# get baseline model registry
baseline_models = get_baseline_models(img_size=(nlat, nlon), in_chans=3, out_chans=3, residual_prediction=True, grid=grid)
# specify which models to train here
models = [
"transformer_sc2_layers4_e128",
"s2transformer_sc2_layers4_e128",
"ntransformer_sc2_layers4_e128",
"s2ntransformer_sc2_layers4_e128",
"segformer_sc2_layers4_e128",
"s2segformer_sc2_layers4_e128",
"nsegformer_sc2_layers4_e128",
"s2nsegformer_sc2_layers4_e128",
# "sfno_sc2_layers4_e32",
# "lsno_sc2_layers4_e32",
]
models = {k: baseline_models[k] for k in models}
# loss function
loss_fn = SquaredL2LossS2(nlat=nlat, nlon=nlon, grid=grid).to(device)
# dictionary for logging the metrics
metrics = {}
metrics_fns = {
"L2 error": L2LossS2(nlat=nlat, nlon=nlon, grid=grid).to(device=device),
"L1 error": L1LossS2(nlat=nlat, nlon=nlon, grid=grid).to(device=device),
"W11 error": W11LossS2(nlat=nlat, nlon=nlon, grid=grid).to(device=device),
}
# iterate over models and train each model
root_path = os.getcwd()
for model_name, model_handle in models.items():
model = model_handle().to(device)
......@@ -480,7 +428,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
print(f"number of trainable params: {num_params}")
metrics[model_name]["num_params"] = num_params
exp_dir = os.path.join(root_path, 'checkpoints', model_name)
exp_dir = os.path.join(root_path, model_name)
if not os.path.isdir(exp_dir):
os.makedirs(exp_dir, exist_ok=True)
......@@ -489,58 +437,130 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
# run the training
if train:
run = wandb.init(project="local sno spherical swe", group=model_name, name=model_name + "_" + str(time.time()), config=model_handle.keywords)
if logging and wandb is not None:
run = wandb.init(project="spherical shallow water equations", group=model_name, name=model_name + "_" + str(time.time()), config=model_handle.keywords)
else:
run = None
# optimizer:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min")
gscaler = torch.GradScaler("cuda", enabled=enable_amp)
gscaler = torch.GradScaler("cuda", enabled=(amp_mode == "fp16"))
start_time = time.time()
print(f"Training {model_name}, single step")
train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=200, loss_fn="l2", enable_amp=enable_amp, log_grads=log_grads)
if nfuture > 0:
print(f'Training {model_name}, {nfuture} step')
optimizer = torch.optim.Adam(model.parameters(), lr=5E-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
gscaler = amp.GradScaler(enabled=enable_amp)
dataloader.dataset.nsteps = 2 * dt//dt_solver
train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=10, loss_fn="l2", nfuture=nfuture, enable_amp=enable_amp, log_grads=log_grads)
dataloader.dataset.nsteps = 1 * dt//dt_solver
if logging:
print(f"Training {model_name}, single step")
train_model(
model,
dataloader,
loss_fn,
metrics_fns,
optimizer,
gscaler,
scheduler,
nepochs=pretrain_epochs,
amp_mode=amp_mode,
log_grads=log_grads,
logging=logging,
device=device,
)
if finetune_epochs > 0:
nfuture = 1
if logging:
print(f"Finetuning {model_name}, {nfuture} step")
optimizer = torch.optim.Adam(model.parameters(), lr=0.1 * learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min")
gscaler = torch.GradScaler(enabled=(amp_mode != "none"))
dataloader.dataset.nsteps = 2 * dt // dt_solver
train_model(
model,
dataloader,
loss_fn,
metrics_fns,
optimizer,
gscaler,
scheduler,
nepochs=finetune_epochs,
nfuture=nfuture,
amp_mode=amp_mode,
log_grads=log_grads,
logging=logging,
device=device,
)
dataloader.dataset.nsteps = 1 * dt // dt_solver
training_time = time.time() - start_time
run.finish()
if logging and run is not None:
run.finish()
torch.save(model.state_dict(), os.path.join(exp_dir, 'checkpoint.pt'))
torch.save(model.state_dict(), os.path.join(exp_dir, "checkpoint.pt"))
# set seed
torch.manual_seed(333)
torch.cuda.manual_seed(333)
# run validation
print(f"Validating {model_name}")
with torch.inference_mode():
losses, fno_times, nwp_times = autoregressive_inference(model, dataset, os.path.join(exp_dir,'figures'), nsteps=nsteps, autoreg_steps=30, nics=50)
metrics[model_name]["loss_mean"] = np.mean(losses)
metrics[model_name]["loss_std"] = np.std(losses)
metrics[model_name]["fno_time_mean"] = np.mean(fno_times)
metrics[model_name]["fno_time_std"] = np.std(fno_times)
metrics[model_name]["nwp_time_mean"] = np.mean(nwp_times)
metrics[model_name]["nwp_time_std"] = np.std(nwp_times)
losses, metric_results, model_times, solver_times = autoregressive_inference(
model, dataset, loss_fn, metrics_fns, os.path.join(exp_dir, "figures"), nsteps=nsteps, autoreg_steps=1, nics=50, device=device
)
# compute statistics
metrics[model_name]["loss mean"] = torch.mean(losses).item()
metrics[model_name]["loss std"] = torch.std(losses).item()
metrics[model_name]["model time mean"] = torch.mean(model_times).item()
metrics[model_name]["model time std"] = torch.std(model_times).item()
metrics[model_name]["solver time mean"] = torch.mean(solver_times).item()
metrics[model_name]["solver time std"] = torch.std(solver_times).item()
for metric in metric_results:
metrics[model_name][metric + " mean"] = torch.mean(metric_results[metric]).item()
metrics[model_name][metric + " std"] = torch.std(metric_results[metric]).item()
if train:
metrics[model_name]["training_time"] = training_time
# output metrics to data frame
df = pd.DataFrame(metrics)
if not os.path.isdir(os.path.join(exp_dir, 'output_data',)):
os.makedirs(os.path.join(exp_dir, 'output_data'), exist_ok=True)
df.to_pickle(os.path.join(exp_dir, 'output_data', 'metrics.pkl'))
if not os.path.isdir(os.path.join(root_path, "output_data")):
os.makedirs(os.path.join(root_path, "output_data"), exist_ok=True)
df.to_pickle(os.path.join(root_path, "output_data", "metrics.pkl"))
if __name__ == "__main__":
import torch.multiprocessing as mp
mp.set_start_method("forkserver", force=True)
if wandb is not None:
wandb.login()
parser = argparse.ArgumentParser()
parser.add_argument(
"--root_path", default=os.path.join(os.path.dirname(__file__), "checkpoints"), type=str, help="Override the path where checkpoints and run information are stored"
)
parser.add_argument("--pretrain_epochs", default=100, type=int, help="Number of pretraining epochs.")
parser.add_argument("--finetune_epochs", default=0, type=int, help="Number of fine-tuning epochs.")
parser.add_argument("--batch_size", default=4, type=int, help="Switch for overriding batch size in the configuration file.")
parser.add_argument("--learning_rate", default=1e-4, type=float, help="Switch to override learning rate.")
parser.add_argument("--resume", action="store_true", help="Reload checkpoints.")
parser.add_argument("--amp_mode", default="none", type=str, choices=["none", "bf16", "fp16"], help="Switch to enable AMP.")
args = parser.parse_args()
# main(train=False, load_checkpoint=True, enable_amp=False, log_grads=0)
main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0)
main(
root_path=args.root_path,
pretrain_epochs=args.pretrain_epochs,
finetune_epochs=args.finetune_epochs,
batch_size=args.batch_size,
learning_rate=args.learning_rate,
train=(args.pretrain_epochs > 0 or args.finetune_epochs > 0),
load_checkpoint=args.resume,
amp_mode=args.amp_mode,
log_grads=0,
)
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment