Commit 6a845fd3 authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

adding spherical attention

parent b3816ebc
*.DS_Store *.DS_Store
__pycache__
*.so
checkpoints
\ No newline at end of file
...@@ -2,6 +2,9 @@ The code was authored by the following people: ...@@ -2,6 +2,9 @@ The code was authored by the following people:
Boris Bonev - NVIDIA Corporation Boris Bonev - NVIDIA Corporation
Thorsten Kurth - NVIDIA Corporation Thorsten Kurth - NVIDIA Corporation
Max Rietmann - NVIDIA Corporation
Andrea Paris - NVIDIA Corporation
Alberto Carpentieri - NVIDIA Corporation
Mauro Bisson - NVIDIA Corporation Mauro Bisson - NVIDIA Corporation
Massimiliano Fatica - NVIDIA Corporation Massimiliano Fatica - NVIDIA Corporation
Christian Hundt - NVIDIA Corporation Christian Hundt - NVIDIA Corporation
......
...@@ -2,6 +2,19 @@ ...@@ -2,6 +2,19 @@
## Versioning ## Versioning
### v0.8.0
* Adding spherical attention and spherical neighborhood attention
* Custom CUDA kerneles for spherical neighborhood attention
* New datasets for segmentation and depth estimation on the sphere based on the 2D3DS dataset
* added new spherical architectures and corresponding baselines
* S2 Transformer
* S2 Segformer
* S2 U-Net
* Reworked spherical examples for better reproducibility
* Added spherical loss functions to examples
* Added plotting module
### v0.7.6 ### v0.7.6
* Adding cache for precomoputed tensors such as weight tensors for DISCO and SHT * Adding cache for precomoputed tensors such as weight tensors for DISCO and SHT
......
...@@ -30,15 +30,14 @@ ...@@ -30,15 +30,14 @@
# build after cloning in directoy torch_harmonics via # build after cloning in directoy torch_harmonics via
# docker build . -t torch_harmonics # docker build . -t torch_harmonics
FROM nvcr.io/nvidia/pytorch:24.08-py3 FROM nvcr.io/nvidia/pytorch:24.12-py3
COPY . /workspace/torch_harmonics
# we need this for tests # we need this for tests
RUN pip install parameterized RUN pip install parameterized
# The custom CUDA extension does not suppport architerctures < 7.0 # The custom CUDA extension does not suppport architerctures < 7.0
ENV FORCE_CUDA_EXTENSION=1 ENV FORCE_CUDA_EXTENSION=1
ENV TORCH_CUDA_ARCH_LIST "7.0 7.2 7.5 8.0 8.6 8.7 9.0+PTX" ENV TORCH_CUDA_ARCH_LIST="7.0 7.2 7.5 8.0 8.6 8.7 9.0+PTX"
COPY . /workspace/torch_harmonics
RUN cd /workspace/torch_harmonics && pip install --no-build-isolation . RUN cd /workspace/torch_harmonics && pip install --no-build-isolation .
# coding=utf-8 # SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
# #
# Redistribution and use in source and binary forms, with or without # Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met: # modification, are permitted provided that the following conditions are met:
# #
...@@ -29,68 +27,34 @@ ...@@ -29,68 +27,34 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# #
import torch # build after cloning in directoy torch_harmonics via
import torch.nn as nn # docker build . -t torch_harmonics
# complex activation functions
class ComplexCardioid(nn.Module):
"""
Complex Cardioid activation function
"""
def __init__(self):
super(ComplexCardioid, self).__init__()
def forward(self, z: torch.Tensor) -> torch.Tensor:
out = 0.5 * (1. + torch.cos(z.angle())) * z
return out
class ComplexReLU(nn.Module): FROM nvcr.io/nvidia/pytorch:24.12-py3
"""
Complex-valued variants of the ReLU activation function
"""
def __init__(self, negative_slope=0., mode="real", bias_shape=None, scale=1.):
super(ComplexReLU, self).__init__()
# store parameters
self.mode = mode
if self.mode in ["modulus", "halfplane"]:
if bias_shape is not None:
self.bias = nn.Parameter(scale * torch.ones(bias_shape, dtype=torch.float32))
else:
self.bias = nn.Parameter(scale * torch.ones((1), dtype=torch.float32))
else:
self.bias = 0
self.negative_slope = negative_slope # we need this for tests
self.act = nn.LeakyReLU(negative_slope = negative_slope) RUN pip install parameterized
def forward(self, z: torch.Tensor) -> torch.Tensor: # we install this for the examples
RUN pip install wandb
if self.mode == "cartesian": # cartopy
zr = torch.view_as_real(z) RUN pip install cartopy
za = self.act(zr)
out = torch.view_as_complex(za)
elif self.mode == "modulus": # h5py
zabs = torch.sqrt(torch.square(z.real) + torch.square(z.imag)) RUN pip install h5py
out = torch.where(zabs + self.bias > 0, (zabs + self.bias) * z / zabs, 0.0)
elif self.mode == "cardioid": # natten
out = 0.5 * (1. + torch.cos(z.angle())) * z RUN cd /opt && git clone https://github.com/SHI-Labs/NATTEN natten && \
cd natten && \
make WITH_CUDA=1 \
CUDA_ARCH="7.0;7.2;7.5;8.0;8.6;8.7;9.0" \
WORKERS=4
# elif self.mode == "halfplane": # install torch harmonics
# # bias is an angle parameter in this case COPY . /workspace/torch_harmonics
# modified_angle = torch.angle(z) - self.bias
# condition = torch.logical_and( (0. <= modified_angle), (modified_angle < torch.pi/2.) )
# out = torch.where(condition, z, self.negative_slope * z)
elif self.mode == "real": # The custom CUDA extension does not suppport architerctures < 7.0
zr = torch.view_as_real(z) ENV FORCE_CUDA_EXTENSION=1
outr = zr.clone() ENV TORCH_CUDA_ARCH_LIST="7.0 7.2 7.5 8.0 8.6 8.7 9.0+PTX"
outr[..., 0] = self.act(zr[..., 0]) RUN cd /workspace/torch_harmonics && pip install --no-build-isolation .
out = torch.view_as_complex(outr)
else:
raise NotImplementedError
return out
\ No newline at end of file
...@@ -56,7 +56,7 @@ The SHT algorithm uses quadrature rules to compute the projection onto the assoc ...@@ -56,7 +56,7 @@ The SHT algorithm uses quadrature rules to compute the projection onto the assoc
torch-harmonics uses PyTorch primitives to implement these operations, making it fully differentiable. Moreover, the quadrature can be distributed onto multiple ranks making it spatially distributed. torch-harmonics uses PyTorch primitives to implement these operations, making it fully differentiable. Moreover, the quadrature can be distributed onto multiple ranks making it spatially distributed.
torch-harmonics has been used to implement a variety of differentiable PDE solvers which generated the animations below. Moreover, it has enabled the development of Spherical Fourier Neural Operators (SFNOs) [1]. torch-harmonics has been used to implement a variety of differentiable PDE solvers which generated the animations below. Moreover, it has enabled the development of Spherical Fourier Neural Operators [1].
<div align="center"> <div align="center">
<table border="0" cellspacing="0" cellpadding="0"> <table border="0" cellspacing="0" cellpadding="0">
...@@ -169,9 +169,13 @@ $$ ...@@ -169,9 +169,13 @@ $$
Here, $x_j \in [-1,1]$ are the quadrature nodes with the respective quadrature weights $w_j$. Here, $x_j \in [-1,1]$ are the quadrature nodes with the respective quadrature weights $w_j$.
### Discrete-continuous convolutions ### Discrete-continuous convolutions on the sphere
torch-harmonics now provides local discrete-continuous (DISCO) convolutions as outlined in [4] on the sphere. torch-harmonics now provides local discrete-continuous (DISCO) convolutions as outlined in [4] on the sphere. These are use in local neural operators to generalize convolutions to structured and unstructured meshes on the sphere.
### Spherical (neighborhood) attention
torch-harmonics introducers spherical attention mechanisms which correctly generalize the attention mechanism to the sphere. The use of quadrature rules makes the resulting operations approximately equivariant and equivariant in the continuous limit. Moreover, neighborhood attention is correctly generalized onto the sphere by using the geodesic distance to determine the size of the neighborhood.
## Getting started ## Getting started
...@@ -208,6 +212,16 @@ Detailed usage of torch-harmonics, alongside helpful analysis provided in a seri ...@@ -208,6 +212,16 @@ Detailed usage of torch-harmonics, alongside helpful analysis provided in a seri
8. [Training Spherical Fourier Neural Operators (SFNO)](./notebooks/train_sfno.ipynb) 8. [Training Spherical Fourier Neural Operators (SFNO)](./notebooks/train_sfno.ipynb)
9. [Resampling signals on the sphere](./notebooks/resample_sphere.ipynb) 9. [Resampling signals on the sphere](./notebooks/resample_sphere.ipynb)
## Examples and reproducibility
The `examples` folder contains training scripts for three distinct tasks:
* [solution of the shallow water equations on the rotating sphere](./examples/shallow_water_equations/train.py)
* [depth estimation on the sphere](./examples/depth/train.py)
* [semantic segmentation on the sphere](./examples/segmentation/train.py)
Results from the papers can generally be reproduced by running `python train.py`. In the case of some older results the number of epochs and learning-rate may need to be adjusted by passing the corresponding command line argument.
## Remarks on automatic mixed precision (AMP) support ## Remarks on automatic mixed precision (AMP) support
Note that torch-harmonics uses Fourier transforms from `torch.fft` which in turn uses kernels from the optimized `cuFFT` library. This library supports fourier transforms of `float32` and `float64` (i.e. `single` and `double` precision) tensors for all input sizes. For `float16` (i.e. `half` precision) and `bfloat16` inputs however, the dimensions which are transformed are restricted to powers of two. Since data is converted to one of these reduced precision floating point formats when `torch.autocast` is used, torch-harmonics will issue an error when the input shapes are not powers of two. For these cases, we recommend disabling autocast for the harmonics transform specifically: Note that torch-harmonics uses Fourier transforms from `torch.fft` which in turn uses kernels from the optimized `cuFFT` library. This library supports fourier transforms of `float32` and `float64` (i.e. `single` and `double` precision) tensors for all input sizes. For `float16` (i.e. `half` precision) and `bfloat16` inputs however, the dimensions which are transformed are restricted to powers of two. Since data is converted to one of these reduced precision floating point formats when `torch.autocast` is used, torch-harmonics will issue an error when the input shapes are not powers of two. For these cases, we recommend disabling autocast for the harmonics transform specifically:
...@@ -237,7 +251,7 @@ Depending on the problem, it might be beneficial to upcast data to `float64` ins ...@@ -237,7 +251,7 @@ Depending on the problem, it might be beneficial to upcast data to `float64` ins
## Contributors ## Contributors
[Boris Bonev](https://bonevbs.github.io) (bbonev@nvidia.com), [Thorsten Kurth](https://github.com/azrael417) (tkurth@nvidia.com), [Mauro Bisson](https://scholar.google.com/citations?hl=en&user=f0JE-0gAAAAJ) , [Massimiliano Fatica](https://scholar.google.com/citations?user=Deaq4uUAAAAJ&hl=en), [Nikola Kovachki](https://kovachki.github.io), [Jean Kossaifi](http://jeankossaifi.com), [Christian Hundt](https://github.com/gravitino) [Boris Bonev](https://bonevbs.github.io) (bbonev@nvidia.com), [Thorsten Kurth](https://github.com/azrael417) (tkurth@nvidia.com), [Max Rietmann](https://github.com/rietmann-nv), [Mauro Bisson](https://scholar.google.com/citations?hl=en&user=f0JE-0gAAAAJ), [Andrea Paris](https://github.com/apaaris), [Alberto Carpentieri](https://github.com/albertocarpentieri), [Massimiliano Fatica](https://scholar.google.com/citations?user=Deaq4uUAAAAJ&hl=en), [Nikola Kovachki](https://kovachki.github.io), [Jean Kossaifi](http://jeankossaifi.com), [Christian Hundt](https://github.com/gravitino)
## Cite us ## Cite us
......
# coding=utf-8 # coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
# #
# Redistribution and use in source and binary forms, with or without # Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met: # modification, are permitted provided that the following conditions are met:
# #
...@@ -29,48 +29,6 @@ ...@@ -29,48 +29,6 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# #
# ignore this (just for development without installation) from .transformer import Transformer
import sys from .segformer import Segformer
sys.path.append("..") from .unet import UNet
sys.path.append(".")
import torch
import torch_harmonics as harmonics
try:
from tqdm import tqdm
except:
tqdm = lambda x : x
# everything is awesome on GPUs
device = torch.device("cuda")
# create a batch with one sample and 21 channels
b, c, n_theta, n_lambda = 1, 21, 360, 720
# your layers to play with
forward_transform = harmonics.RealSHT(n_theta, n_lambda).to(device)
inverse_transform = harmonics.InverseRealSHT(n_theta, n_lambda).to(device)
forward_transform_equi = harmonics.RealSHT(n_theta, n_lambda, grid="equiangular").to(device)
inverse_transform_equi = harmonics.InverseRealSHT(n_theta, n_lambda, grid="equiangular").to(device)
signal_leggauss = inverse_transform(torch.randn(b, c, n_theta, n_theta+1, device=device, dtype=torch.complex128))
signal_equi = inverse_transform(torch.randn(b, c, n_theta, n_theta+1, device=device, dtype=torch.complex128))
# let's check the layers
for num_iters in [1, 8, 64, 512]:
base = signal_leggauss
for iteration in tqdm(range(num_iters)):
base = inverse_transform(forward_transform(base))
print("relative l2 error accumulation on the legendre-gauss grid: ",
torch.mean(torch.norm(base-signal_leggauss, p='fro', dim=(-1,-2)) / torch.norm(signal_leggauss, p='fro', dim=(-1,-2)) ).item(),
"after", num_iters, "iterations")
# let's check the equiangular layers
for num_iters in [1, 8, 64, 512]:
base = signal_equi
for iteration in tqdm(range(num_iters)):
base = inverse_transform_equi(forward_transform_equi(base))
print("relative l2 error accumulation with interpolation onto equiangular grid: ",
torch.mean(torch.norm(base-signal_equi, p='fro', dim=(-1,-2)) / torch.norm(signal_equi, p='fro', dim=(-1,-2)) ).item(),
"after", num_iters, "iterations")
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import math
import torch
import torch.nn as nn
from natten import NeighborhoodAttention2D as NeighborhoodAttention
from torch_harmonics.examples.models._layers import MLP, LayerNorm, DropPath
from functools import partial
class OverlapPatchMerging(nn.Module):
def __init__(
self,
in_shape=(721, 1440),
out_shape=(481, 960),
in_channels=3,
out_channels=64,
kernel_shape=(3, 3),
bias=False,
):
super().__init__()
# conv
stride_h = in_shape[0] // out_shape[0]
stride_w = in_shape[1] // out_shape[1]
pad_h = math.ceil(((out_shape[0] - 1) * stride_h - in_shape[0] + kernel_shape[0]) / 2)
pad_w = math.ceil(((out_shape[1] - 1) * stride_w - in_shape[1] + kernel_shape[1]) / 2)
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_shape,
bias=bias,
stride=(stride_h, stride_w),
padding=(pad_h, pad_w),
)
# layer norm
self.norm = nn.LayerNorm((out_channels), eps=1e-05, elementwise_affine=True, bias=True)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x):
x = self.conv(x)
# permute
x = x.permute(0, 2, 3, 1)
x = self.norm(x)
out = x.permute(0, 3, 1, 2)
return out
class MixFFN(nn.Module):
def __init__(
self,
shape,
inout_channels,
hidden_channels,
mlp_bias=True,
kernel_shape=(3, 3),
conv_bias=False,
activation=nn.GELU,
use_mlp=False,
drop_path=0.0,
):
super().__init__()
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm = nn.LayerNorm((inout_channels), eps=1e-05, elementwise_affine=True, bias=True)
if use_mlp:
# although the paper says MLP, it uses a single linear layer
self.mlp_in = MLP(inout_channels, hidden_features=hidden_channels, out_features=inout_channels, act_layer=activation, output_bias=False, drop_rate=0.0)
else:
self.mlp_in = nn.Conv2d(in_channels=inout_channels, out_channels=inout_channels, kernel_size=1, bias=True)
self.conv = nn.Conv2d(inout_channels, inout_channels, kernel_size=kernel_shape, groups=inout_channels, bias=conv_bias, padding="same")
if use_mlp:
self.mlp_out = MLP(inout_channels, hidden_features=hidden_channels, out_features=inout_channels, act_layer=activation, output_bias=False, drop_rate=0.0)
else:
self.mlp_out = nn.Conv2d(in_channels=inout_channels, out_channels=inout_channels, kernel_size=1, bias=True)
self.act = activation()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
# norm
x = x.permute(0, 2, 3, 1)
x = self.norm(x)
x = x.permute(0, 3, 1, 2)
# NOTE: we add another activation here
# because in the paper they only use depthwise conv,
# but without this activation it would just be a fused MM
# with the disco conv
x = self.mlp_in(x)
# conv parth
x = self.act(self.conv(x))
# second linear
x = self.mlp_out(x)
return residual + self.drop_path(x)
class GlobalAttention(nn.Module):
"""
Global self-attention block over 2D inputs using MultiheadAttention.
Input shape: (B, C, H, W)
Output shape: (B, C, H, W) with residual skip.
"""
def __init__(self, chans, num_heads=8, dropout=0.0, bias=True):
super().__init__()
self.attn = nn.MultiheadAttention(embed_dim=chans, num_heads=num_heads, dropout=dropout, batch_first=True, bias=bias)
def forward(self, x):
# x: B, C, H, W
B, H, W, C = x.shape
# flatten spatial dims
x_flat = x.view(B, H * W, C) # B, N, C
# self-attention
out, _ = self.attn(x_flat, x_flat, x_flat)
# reshape back
out = out.view(B, H, W, C)
return out
class AttentionWrapper(nn.Module):
def __init__(self, channels, shape, heads, pre_norm=False, attention_drop_rate=0.0, drop_path=0.0, attention_mode="neighborhood", kernel_shape=(7, 7), bias=True):
super().__init__()
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.attention_mode = attention_mode
if attention_mode == "neighborhood":
self.att = NeighborhoodAttention(
channels, kernel_size=kernel_shape, dilation=1, num_heads=heads, qk_scale=None, attn_drop=attention_drop_rate, proj_drop=0.0, qkv_bias=bias
)
elif attention_mode == "global":
self.att = GlobalAttention(channels, num_heads=heads, dropout=attention_drop_rate, bias=bias)
else:
raise ValueError(f"Unknown attention mode function {attention_mode}")
self.norm = None
if pre_norm:
self.norm = nn.LayerNorm((channels), eps=1e-05, elementwise_affine=True, bias=True)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = x.permute(0, 2, 3, 1)
if self.norm is not None:
x = self.norm(x)
x = self.att(x)
x = x.permute(0, 3, 1, 2)
return residual + self.drop_path(x)
class TransformerBlock(nn.Module):
def __init__(
self,
in_shape,
out_shape,
in_channels,
out_channels,
mlp_hidden_channels,
nrep=1,
heads=1,
kernel_shape=(3, 3),
activation=nn.GELU,
att_drop_rate=0.0,
drop_path_rates=0.0,
attention_mode="neighborhood",
attn_kernel_shape=(7, 7),
bias=True
):
super().__init__()
# ensure odd
if attn_kernel_shape[0] % 2 == 0:
raise ValueError(f"Attn Kernel shape {kernel_shape} is even, use odd kernel shape")
if attn_kernel_shape[1] % 2 == 0:
raise ValueError(f"Kernel shape {kernel_shape} is even, use odd kernel shape")
attn_kernel_shape = list(attn_kernel_shape)
orig_attn_kernel_shape = attn_kernel_shape.copy()
# ensure that attn kernel shape is smaller than in_shape in both dimensions
# if necessary fix kernel_shape to be 1 less (and odd) than in_shape
if attn_kernel_shape[0] >= out_shape[0]:
attn_kernel_shape[0] = out_shape[0] - 1
# ensure odd
if attn_kernel_shape[0] % 2 == 0:
attn_kernel_shape[0] -= 1
# make square if original was square
if orig_attn_kernel_shape[0] == orig_attn_kernel_shape[1]:
attn_kernel_shape[1] = attn_kernel_shape[0]
if attn_kernel_shape[1] >= out_shape[1]:
attn_kernel_shape[1] = out_shape[1] - 1
# ensure odd
if attn_kernel_shape[1] % 2 == 0:
attn_kernel_shape[1] -= 1
attn_kernel_shape = tuple(attn_kernel_shape)
self.in_shape = in_shape
self.out_shape = out_shape
self.in_channels = in_channels
self.out_channels = out_channels
if isinstance(drop_path_rates, float):
drop_path_rates = [x.item() for x in torch.linspace(0, drop_path_rates, nrep)]
assert len(drop_path_rates) == nrep
self.fwd = [
OverlapPatchMerging(
in_shape=in_shape,
out_shape=out_shape,
in_channels=in_channels,
out_channels=out_channels,
kernel_shape=kernel_shape,
bias=False,
)
]
for i in range(nrep):
self.fwd.append(
AttentionWrapper(
channels=out_channels,
shape=out_shape,
heads=heads,
pre_norm=True,
attention_drop_rate=att_drop_rate,
drop_path=drop_path_rates[i],
attention_mode=attention_mode,
kernel_shape=attn_kernel_shape,
bias=bias
)
)
self.fwd.append(
MixFFN(
out_shape,
inout_channels=out_channels,
hidden_channels=mlp_hidden_channels,
mlp_bias=True,
kernel_shape=kernel_shape,
conv_bias=False,
activation=activation,
use_mlp=False,
drop_path=drop_path_rates[i],
)
)
# make sequential
self.fwd = nn.Sequential(*self.fwd)
# final norm
self.norm = nn.LayerNorm((out_channels), eps=1e-05, elementwise_affine=True, bias=True)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fwd(x)
# apply norm
x = x.permute(0, 2, 3, 1)
x = self.norm(x)
x = x.permute(0, 3, 1, 2)
return x
class Upsampling(nn.Module):
def __init__(
self,
in_shape,
out_shape,
in_channels,
out_channels,
hidden_channels,
mlp_bias=True,
kernel_shape=(3, 3),
conv_bias=False,
activation=nn.GELU,
use_mlp=False,
):
super().__init__()
self.out_shape = out_shape
if use_mlp:
self.mlp = MLP(in_channels, hidden_features=hidden_channels, out_features=out_channels, act_layer=activation, output_bias=False, drop_rate=0.0)
else:
self.mlp = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=True)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = nn.functional.interpolate(self.mlp(x), size=self.out_shape, mode="bilinear")
return x
class Segformer(nn.Module):
"""
Spherical segformer model designed to approximate mappings from spherical signals to spherical segmentation masks
Parameters
-----------
img_shape : tuple, optional
Shape of the input channels, by default (128, 256)
kernel_shape: tuple, int
scale_factor: int, optional
Scale factor to use, by default 2
in_chans : int, optional
Number of input channels, by default 3
out_chans : int, optional
Number of classes, by default 3
embed_dims : List[int], optional
Dimension of the embeddings for each block, has to be the same length as heads
heads : List[int], optional
Number of heads for each block in the network, has to be the same length as embed_dims
depths: List[in], optional
Number of repetitions of attentions blocks and ffn mixers per layer. Has to be the same length as embed_dims and heads
activation_function : str, optional
Activation function to use, by default "gelu"
embedder_kernel_shape : int, optional
size of the encoder kernel
use_mlp : int, optional
Whether to use MLPs in the SFNO blocks, by default True
mlp_ratio : int, optional
Ratio of MLP to use, by default 2.0
drop_rate : float, optional
Dropout rate, by default 0.0
drop_path_rate : float, optional
Dropout path rate, by default 0.0
normalization_layer : str, optional
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
Example
-----------
>>> model = Segformer(
... img_size=(128, 256),
... in_chans=3,
... out_chans=3,
... embed_dims=[64, 128, 256, 512],
... heads=[1, 2, 4, 8],
... depths=[3, 4, 6, 3],
... scale_factor=2,
... activation_function="gelu",
... kernel_shape=(3, 3),
... mlp_ratio=2.0,
... att_drop_rate=0.0,
... drop_path_rate=0.1,
... attention_mode="global",
))
>>> model(torch.randn(1, 2, 128, 256)).shape
torch.Size([1, 2, 128, 256])
"""
def __init__(
self,
img_size=(128, 256),
in_chans=3,
out_chans=3,
embed_dims=[64, 128, 256, 512],
heads=[1, 2, 4, 8],
depths=[3, 4, 6, 3],
scale_factor=2,
activation_function="gelu",
kernel_shape=(3, 3),
mlp_ratio=2.0,
att_drop_rate=0.0,
drop_path_rate=0.1,
attention_mode="neighborhood",
attn_kernel_shape=(7, 7),
bias=True
):
super().__init__()
self.img_size = img_size
self.in_chans = in_chans
self.out_chans = out_chans
self.embed_dims = embed_dims
self.heads = heads
self.num_blocks = len(self.embed_dims)
self.depths = depths
self.kernel_shape = kernel_shape
assert len(self.heads) == self.num_blocks
assert len(self.depths) == self.num_blocks
# activation function
if activation_function == "relu":
self.activation_function = nn.ReLU
elif activation_function == "gelu":
self.activation_function = nn.GELU
# for debugging purposes
elif activation_function == "identity":
self.activation_function = nn.Identity
else:
raise ValueError(f"Unknown activation function {activation_function}")
# set up drop path rates
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
self.blocks = nn.ModuleList([])
out_shape = img_size
in_channels = in_chans
cur = 0
for i in range(self.num_blocks):
out_shape_new = (out_shape[0] // scale_factor, out_shape[1] // scale_factor)
out_channels = self.embed_dims[i]
self.blocks.append(
TransformerBlock(
in_shape=out_shape,
out_shape=out_shape_new,
in_channels=in_channels,
out_channels=out_channels,
mlp_hidden_channels=int(mlp_ratio * out_channels),
nrep=self.depths[i],
heads=self.heads[i],
kernel_shape=kernel_shape,
activation=self.activation_function,
att_drop_rate=att_drop_rate,
drop_path_rates=dpr[cur : cur + self.depths[i]],
attention_mode=attention_mode,
attn_kernel_shape=attn_kernel_shape,
bias=bias
)
)
cur += self.depths[i]
out_shape = out_shape_new
in_channels = out_channels
self.upsamplers = nn.ModuleList([])
out_shape = img_size
for i in range(self.num_blocks):
in_shape = self.blocks[i].out_shape
self.upsamplers.append(
Upsampling(
in_shape=in_shape,
out_shape=out_shape,
in_channels=self.embed_dims[i],
out_channels=self.embed_dims[i],
hidden_channels=int(mlp_ratio * self.embed_dims[i]),
mlp_bias=True,
kernel_shape=kernel_shape,
conv_bias=False,
activation=nn.GELU,
)
)
segmentation_head_dim = sum(self.embed_dims)
self.segmentation_head = nn.Conv2d(in_channels=segmentation_head_dim, out_channels=out_chans, kernel_size=1, bias=True)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x):
# encoder:
features = []
feat = x
for block in self.blocks:
feat = block(feat)
features.append(feat)
# perform upsample
upfeats = []
for feat, upsampler in zip(features, self.upsamplers):
upfeat = upsampler(feat)
upfeats.append(upfeat)
# perform concatenation
upfeats = torch.cat(upfeats, dim=1)
# final upsampling and prediction
out = self.segmentation_head(upfeats)
return out
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import math
import torch
import torch.nn as nn
from torch_harmonics.examples.models._layers import MLP, LayerNorm, DropPath, SequencePositionEmbedding, SpectralPositionEmbedding, LearnablePositionEmbedding
from natten import NeighborhoodAttention2D as NeighborhoodAttention
from functools import partial
class Encoder(nn.Module):
def __init__(
self,
in_shape=(721, 1440),
out_shape=(480, 960),
in_chans=2,
out_chans=2,
kernel_shape=(3, 3),
groups=1,
bias=False,
):
super().__init__()
stride_h = in_shape[0] // out_shape[0]
stride_w = in_shape[1] // out_shape[1]
pad_h = math.ceil(((out_shape[0] - 1) * stride_h - in_shape[0] + kernel_shape[0]) / 2)
pad_w = math.ceil(((out_shape[1] - 1) * stride_w - in_shape[1] + kernel_shape[1]) / 2)
self.conv = nn.Conv2d(in_chans, out_chans, kernel_size=kernel_shape, bias=bias, stride=(stride_h, stride_w), padding=(pad_h, pad_w), groups=groups)
def forward(self, x):
x = self.conv(x)
return x
class Decoder(nn.Module):
def __init__(self, in_shape=(480, 960), out_shape=(721, 1440), in_chans=2, out_chans=2, kernel_shape=(3, 3), groups=1, bias=False, upsampling_method="conv"):
super().__init__()
self.out_shape = out_shape
self.upsampling_method = upsampling_method
if upsampling_method == "conv":
self.upsample = nn.Sequential(
nn.Upsample(
size=out_shape,
mode="bilinear",
),
nn.Conv2d(in_chans, out_chans, kernel_size=kernel_shape, bias=bias, padding="same", groups=groups),
)
elif upsampling_method == "pixel_shuffle":
# check if it is possible to use PixelShuffle
if out_shape[0] // in_shape[0] != out_shape[1] // in_shape[1]:
raise Exception(f"out_shape {out_shape} and in_shape {in_shape} are incompatible for shuffle decoding")
upsampling_factor = out_shape[0] // in_shape[0]
self.upsample = nn.Sequential(
nn.Conv2d(in_chans, out_chans * (upsampling_factor**2), kernel_size=1, bias=bias, padding=0, groups=groups), nn.PixelShuffle(upsampling_factor)
)
else:
raise ValueError(f"Unknown upsampling method {upsampling_method}")
def forward(self, x):
x = self.upsample(x)
return x
class GlobalAttention(nn.Module):
"""
Global self-attention block over 2D inputs using MultiheadAttention.
Input shape: (B, C, H, W)
Output shape: (B, C, H, W) with residual skip.
"""
def __init__(self, chans, num_heads=8, dropout=0.0, bias=True):
super().__init__()
self.attn = nn.MultiheadAttention(embed_dim=chans, num_heads=num_heads, dropout=dropout, batch_first=True, bias=bias)
def forward(self, x):
# x: B, C, H, W
B, H, W, C = x.shape
# flatten spatial dims
x_flat = x.reshape(B, H * W, C) # B, N, C
# self-attention
out, _ = self.attn(x_flat, x_flat, x_flat)
# reshape back
out = out.view(B, H, W, C)
return out
class AttentionBlock(nn.Module):
"""
Neighborhood attention block based on Natten.
"""
def __init__(
self,
in_shape=(480, 960),
out_shape=(480, 960),
chans=2,
num_heads=1,
mlp_ratio=2.0,
drop_rate=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer="none",
use_mlp=True,
bias=True,
attention_mode="neighborhood",
attn_kernel_shape=(7, 7),
):
super().__init__()
# normalisation layer
if norm_layer == "layer_norm":
self.norm0 = LayerNorm(in_channels=chans, eps=1e-6)
self.norm1 = LayerNorm(in_channels=chans, eps=1e-6)
elif norm_layer == "instance_norm":
self.norm0 = nn.InstanceNorm2d(num_features=chans, eps=1e-6, affine=True, track_running_stats=False)
self.norm1 = nn.InstanceNorm2d(num_features=chans, eps=1e-6, affine=True, track_running_stats=False)
elif norm_layer == "none":
self.norm0 = nn.Identity()
self.norm1 = nn.Identity()
else:
raise NotImplementedError(f"Error, normalization {norm_layer} not implemented.")
# determine shape for neighborhood attention
if attention_mode == "neighborhood":
self.self_attn = NeighborhoodAttention(
chans,
kernel_size=attn_kernel_shape,
dilation=1,
num_heads=num_heads,
qkv_bias=bias,
qk_scale=None,
attn_drop=drop_rate,
proj_drop=drop_rate,
)
else:
self.self_attn = GlobalAttention(chans, num_heads=num_heads, dropout=drop_rate, bias=bias)
self.skip0 = nn.Identity()
# dropout
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
if use_mlp == True:
mlp_hidden_dim = int(chans * mlp_ratio)
self.mlp = MLP(
in_features=chans,
out_features=chans,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop_rate=drop_rate,
checkpointing=False,
gain=0.5,
)
self.skip1 = nn.Identity()
def forward(self, x):
residual = x
x = self.norm0(x)
x = x.permute(0, 2, 3, 1)
x = self.self_attn(x).permute(0, 3, 1, 2)
if hasattr(self, "skip0"):
x = x + self.skip0(residual)
residual = x
x = self.norm1(x)
if hasattr(self, "mlp"):
x = self.mlp(x)
x = self.drop_path(x)
if hasattr(self, "skip1"):
x = x + self.skip1(residual)
return x
class Transformer(nn.Module):
"""
Parameters
----------
img_size : tuple of int
(latitude, longitude) size of the input tensor.
scale_factor : int
Ratio for down- and up-sampling between input and internal resolution.
in_chans : int
Number of channels in the input tensor.
out_chans : int
Number of channels in the output tensor.
embed_dim : int
Embedding dimension inside attention blocks.
num_layers : int
Number of attention blocks.
activation_function : str
"relu", "gelu", or "identity" specifying the activation.
encoder_kernel_shape : tuple of int
Kernel size for the encoder convolution.
num_heads : int
Number of heads in NeighborhoodAttention.
use_mlp : bool
If True, an MLP follows attention in each block.
mlp_ratio : float
Ratio of MLP hidden dim to input dim.
drop_rate : float
Dropout rate before positional embedding.
drop_path_rate : float
Stochastic depth rate across transformer blocks.
normalization_layer : str
"layer_norm", "instance_norm", or "none".
residual_prediction : bool
If True, add the input as a global skip connection.
pos_embed : str
"sequence", "spectral", "learnable lat", "learnable latlon", or "none".
bias : bool
Whether convolution and attention projections include bias.
attention_mode: str
"neighborhood" or "global"
upsampling_method: str
"conv" or "pixel_shuffle"
attn_kernel_shape: tuple
Example
-------
>>> model = Transformer(
... img_size=(128, 256),
... scale_factor=2,
... in_chans=3,
... out_chans=3,
... embed_dim=256,
... num_layers=4,
... activation_function="gelu",
... encoder_kernel_shape=(3, 3),
... num_heads=1,
... use_mlp=True,
... mlp_ratio=2.0,
... drop_rate=0.0,
... drop_path_rate=0.0,
... normalization_layer="none",
... residual_prediction=False,
... pos_embed="spectral",
... bias=True,
... attention_mode="neighborhood",
... attn_kernel_shape=(7,7),
... upsampling_method="conv"
... )
>>> x = torch.randn(1, 3, 128, 256)
>>> print(model(x).shape)
torch.Size([1, 3, 128, 256])
"""
def __init__(
self,
img_size=(128, 256),
grid_internal="legendre-gauss",
scale_factor=3,
in_chans=3,
out_chans=3,
embed_dim=256,
num_layers=4,
activation_function="gelu",
encoder_kernel_shape=(3, 3),
num_heads=1,
use_mlp=True,
mlp_ratio=2.0,
drop_rate=0.0,
drop_path_rate=0.0,
normalization_layer="none",
residual_prediction=False,
pos_embed="spectral",
bias=True,
attention_mode="neighborhood",
attn_kernel_shape=(7, 7),
upsampling_method="conv",
):
super().__init__()
self.img_size = img_size
self.scale_factor = scale_factor
self.in_chans = in_chans
self.out_chans = out_chans
self.embed_dim = embed_dim
self.num_layers = num_layers
self.encoder_kernel_shape = encoder_kernel_shape
self.normalization_layer = normalization_layer
self.use_mlp = use_mlp
self.residual_prediction = residual_prediction
# activation function
if activation_function == "relu":
self.activation_function = nn.ReLU
elif activation_function == "gelu":
self.activation_function = nn.GELU
# for debugging purposes
elif activation_function == "identity":
self.activation_function = nn.Identity
else:
raise ValueError(f"Unknown activation function {activation_function}")
# compute downsampled image size. We assume that the latitude-grid includes both poles
self.h = (self.img_size[0] - 1) // scale_factor + 1
self.w = self.img_size[1] // scale_factor
# dropout
self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0.0 else nn.Identity()
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_layers)]
if pos_embed == "sequence":
self.pos_embed = SequencePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal)
elif pos_embed == "spectral":
self.pos_embed = SpectralPositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal)
elif pos_embed == "learnable lat":
self.pos_embed = LearnablePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal, embed_type="lat")
elif pos_embed == "learnable latlon":
self.pos_embed = LearnablePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal, embed_type="latlon")
elif pos_embed == "none":
self.pos_embed = nn.Identity()
else:
raise ValueError(f"Unknown position embedding type {pos_embed}")
# maybe keep for now becuase tr
# encoder
self.encoder = Encoder(
in_shape=self.img_size,
out_shape=(self.h, self.w),
in_chans=self.in_chans,
out_chans=self.embed_dim,
kernel_shape=self.encoder_kernel_shape,
groups=1,
bias=False,
)
self.blocks = nn.ModuleList([])
for i in range(self.num_layers):
block = AttentionBlock(
in_shape=(self.h, self.w),
out_shape=(self.h, self.w),
chans=self.embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
drop_rate=drop_rate,
drop_path=dpr[i],
act_layer=self.activation_function,
norm_layer=self.normalization_layer,
use_mlp=use_mlp,
bias=bias,
attention_mode=attention_mode,
attn_kernel_shape=attn_kernel_shape,
)
self.blocks.append(block)
# decoder
self.decoder = Decoder(
in_shape=(self.h, self.w),
out_shape=self.img_size,
in_chans=self.embed_dim,
out_chans=self.out_chans,
kernel_shape=self.encoder_kernel_shape,
groups=1,
bias=False,
upsampling_method=upsampling_method,
)
@torch.jit.ignore
def no_weight_decay(self):
return {"pos_embed", "cls_token"}
def forward_features(self, x):
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
return x
def forward(self, x):
if self.residual_prediction:
residual = x
x = self.encoder(x)
if self.pos_embed is not None:
# x = x + self.pos_embed
x = self.pos_embed(x)
x = self.forward_features(x)
x = self.decoder(x)
if self.residual_prediction:
x = x + residual
return x
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.amp as amp
from torch_harmonics.examples.models._layers import MLP, DropPath
from functools import partial
class DownsamplingBlock(nn.Module):
def __init__(
self,
in_shape,
out_shape,
in_channels,
out_channels,
nrep=1,
kernel_shape=(3, 3),
activation=nn.ReLU,
transform_skip=False,
drop_conv_rate=0.,
drop_path_rate=0.,
drop_dense_rate=0.,
downsampling_mode="bilinear",
):
super().__init__()
self.in_shape = in_shape
self.out_shape = out_shape
self.in_channels = in_channels
self.out_channels = out_channels
self.downsampling_mode = downsampling_mode
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.fwd =[]
for i in range(nrep):
# conv
self.fwd.append(
nn.Conv2d(
in_channels=(in_channels if i==0 else out_channels),
out_channels=out_channels,
kernel_size=kernel_shape,
bias=False,
padding="same"
)
)
if drop_conv_rate > 0.:
self.fwd.append(
nn.Dropout2d(
p=drop_conv_rate
)
)
# batchnorm
self.fwd.append(
nn.BatchNorm2d(out_channels,
eps=1e-05,
momentum=0.1,
affine=True,
track_running_stats=True)
)
# activation
self.fwd.append(
activation(),
)
if downsampling_mode == "conv":
stride_h = in_shape[0] // out_shape[0]
stride_w = in_shape[1] // out_shape[1]
pad_h = math.ceil(((out_shape[0] - 1) * stride_h
- in_shape[0]
+ kernel_shape[0]) / 2)
pad_w = math.ceil(((out_shape[1] - 1) * stride_w
- in_shape[1]
+ kernel_shape[1]) / 2)
self.downsample = nn.Conv2d(
in_channels=(in_channels if i==0 else out_channels),
out_channels=out_channels,
kernel_size=kernel_shape,
bias=False,
stride=(stride_h, stride_w),
padding=(pad_h, pad_w)
)
else:
self.downsample = nn.Identity()
# make sequential
self.fwd = nn.Sequential(*self.fwd)
# final norm
if transform_skip or (in_channels != out_channels):
self.transform_skip = nn.Conv2d(in_channels,
out_channels,
kernel_size=1,
bias=True)
if drop_dense_rate >0.:
self.transform_skip = nn.Sequential(
self.transform_skip,
nn.Dropout2d(p=drop_dense_rate),
)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# skip connection
residual = x
if hasattr(self, "transform_skip"):
residual = self.transform_skip(residual)
# main path
x = self.fwd(x)
# add residual connection
x = residual + self.drop_path(x)
# downsample
x = self.downsample(x)
if self.downsampling_mode == "bilinear":
x = F.interpolate(x, size=self.out_shape, mode="bilinear")
return x
class UpsamplingBlock(nn.Module):
def __init__(
self,
in_shape,
out_shape,
in_channels,
out_channels,
nrep=1,
kernel_shape=(3, 3),
activation=nn.ReLU,
transform_skip=False,
drop_conv_rate=0.,
drop_path_rate=0.,
drop_dense_rate=0.,
upsampling_mode="bilinear",
):
super().__init__()
self.in_shape = in_shape
self.out_shape = out_shape
self.in_channels = in_channels
self.out_channels = out_channels
self.upsampling_mode = upsampling_mode
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
if in_shape != out_shape:
if upsampling_mode == "conv":
stride_h = out_shape[0] // in_shape[0]
stride_w = out_shape[1] // in_shape[1]
pad_h = math.ceil(((in_shape[0] - 1) * stride_h
- in_shape[0]
+ kernel_shape[0]) / 2)
pad_w = math.ceil(((in_shape[1] - 1) * stride_w
- in_shape[1]
+ kernel_shape[1]) / 2)
self.upsample = nn.Sequential(
nn.ConvTranspose2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=kernel_shape,
stride=(stride_h, stride_w),
padding=(pad_h, pad_w)
),
nn.BatchNorm2d(out_channels,
eps=1e-05,
momentum=0.1,
affine=True,
track_running_stats=True),
activation(),
nn.Conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=kernel_shape,
bias=False,
padding="same")
)
self.fwd =[]
for i in range(nrep):
# conv
self.fwd.append(
nn.Conv2d(
in_channels=(in_channels if i == 0 else out_channels),
out_channels=out_channels,
kernel_size=kernel_shape,
bias=False,
padding="same")
)
if drop_conv_rate > 0.:
self.fwd.append(
nn.Dropout2d(
p=drop_conv_rate
)
)
# batchnorm
self.fwd.append(
nn.BatchNorm2d((out_channels if i==nrep-1 else in_channels),
eps=1e-05,
momentum=0.1,
affine=True,
track_running_stats=True)
)
# activation
self.fwd.append(
activation(),
)
# make sequential
self.fwd = nn.Sequential(*self.fwd)
# final norm
if transform_skip or (in_channels != out_channels):
self.transform_skip = nn.Conv2d(in_channels,
out_channels,
kernel_size=1,
bias=True)
if drop_dense_rate >0.:
self.transform_skip = nn.Sequential(
self.transform_skip,
nn.Dropout2d(p=drop_dense_rate),
)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# skip connection
residual = x
if hasattr(self, "transform_skip"):
residual = self.transform_skip(residual)
# main path
x = residual + self.drop_path(self.fwd(x))
# upsampling
if self.upsampling_mode=="bilinear":
x = F.interpolate(x, size=self.out_shape, mode="bilinear")
else:
x = self.upsample(x)
return x
class UNet(nn.Module):
"""
Spherical segformer model designed to approximate mappings from spherical signals to spherical segmentation masks
Parameters
-----------
img_shape : tuple, optional
Shape of the input channels, by default (128, 256)
kernel_shape: tuple, int
scale_factor: int, optional
Scale factor to use, by default 2
in_chans : int, optional
Number of input channels, by default 3
num_classes : int, optional
Number of classes, by default 3
embed_dims : List[int], optional
Dimension of the embeddings for each block, has to be the same length as depths
depths: List[in], optional
Number of repetitions of conv blocks and ffn mixers per layer. Has to be the same length as embed_dims
activation_function : str, optional
Activation function to use, by default "relu"
embedder_kernel_shape : int, optional
size of the encoder kernel
use_mlp : int, optional
Whether to use MLPs in the SFNO blocks, by default True
mlp_ratio : int, optional
Ratio of MLP to use, by default 2.0
drop_rate : float, optional
Dropout rate, by default 0.0
drop_path_rate : float, optional
Dropout path rate, by default 0.0
normalization_layer : str, optional
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
Example
-----------
>>> model = UNet(
... img_shape=(128, 256),
... scale_factor=4,
... in_chans=2,
... num_classes=2,
... embed_dims=[64, 128, 256, 512],)
>>> model(torch.randn(1, 2, 128, 256)).shape
torch.Size([1, 2, 128, 256])
"""
def __init__(
self,
img_shape=(128, 256),
in_chans=3,
num_classes=3,
embed_dims=[64, 128, 256, 512],
depths=[2, 2, 2, 2],
scale_factor=2,
activation_function="relu",
kernel_shape=(3, 3),
transform_skip=False,
drop_conv_rate=0.1,
drop_path_rate=0.1,
drop_dense_rate=0.5,
downsampling_mode="bilinear",
upsampling_mode="bilinear",
):
super().__init__()
self.img_shape = img_shape
self.in_chans = in_chans
self.num_classes = num_classes
self.embed_dims = embed_dims
self.num_blocks = len(self.embed_dims)
self.depths = depths
self.kernel_shape = kernel_shape
assert(len(self.depths) == self.num_blocks)
# activation function
if activation_function == "relu":
self.activation_function = nn.ReLU
elif activation_function == "gelu":
self.activation_function = nn.GELU
# for debugging purposes
elif activation_function == "identity":
self.activation_function = nn.Identity
else:
raise ValueError(f"Unknown activation function {activation_function}")
# set up drop path rates
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_blocks)]
self.dblocks = nn.ModuleList([])
out_shape = img_shape
in_channels = in_chans
for i in range(self.num_blocks):
out_shape_new = (out_shape[0] // scale_factor, out_shape[1] // scale_factor)
out_channels = self.embed_dims[i]
self.dblocks.append(
DownsamplingBlock(
in_shape=out_shape,
out_shape=out_shape_new,
in_channels=in_channels,
out_channels=out_channels,
nrep=self.depths[i],
kernel_shape=kernel_shape,
activation=self.activation_function,
drop_conv_rate=drop_conv_rate,
drop_path_rate=dpr[i],
drop_dense_rate=drop_dense_rate,
transform_skip=transform_skip,
downsampling_mode=downsampling_mode,
)
)
out_shape = out_shape_new
in_channels = out_channels
self.ublocks = nn.ModuleList([])
for i in range(self.num_blocks-1, -1, -1):
in_shape = self.dblocks[i].out_shape
out_shape = self.dblocks[i].in_shape
in_channels = self.dblocks[i].out_channels
if i != self.num_blocks-1:
in_channels = 2 * in_channels
out_channels = self.dblocks[i].in_channels
if i==0:
out_channels = self.embed_dims[0]
self.ublocks.append(
UpsamplingBlock(
in_shape=in_shape,
out_shape=out_shape,
in_channels=in_channels,
out_channels=out_channels,
kernel_shape=kernel_shape,
activation=self.activation_function,
drop_conv_rate=drop_conv_rate,
drop_path_rate=0.,
drop_dense_rate=drop_dense_rate,
transform_skip=transform_skip,
upsampling_mode=upsampling_mode,
)
)
self.head = nn.Conv2d(self.embed_dims[0], self.num_classes, kernel_size=1, bias=True)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x):
# encoder:
features = []
feat = x
for dblock in self.dblocks:
feat = dblock(feat)
features.append(feat)
# reverse list
features = features[::-1]
# perform upsample
ufeat = self.ublocks[0](features[0])
for feat, ublock in zip(features[1:], self.ublocks[1:]):
ufeat = ublock(torch.cat([feat, ufeat], dim=1))
# last layer
out = self.head(ufeat)
return out
if __name__ == "__main__":
model = UNet(
img_shape=(128, 256),
scale_factor=2,
in_chans=2,
embed_dims=[64, 128, 256],
depths=[2, 2, 2])
print(model)
print(model(torch.randn(1, 2, 128, 256)).shape)
\ No newline at end of file
This diff is collapsed.
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import os, sys
from functools import partial
# import baseline models
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from baseline_models import Transformer, UNet, Segformer
from torch_harmonics.examples.models import SphericalFourierNeuralOperator, LocalSphericalNeuralOperator, SphericalTransformer, SphericalUNet, SphericalSegformer
def get_baseline_models(img_size=(128, 256), in_chans=3, out_chans=3, residual_prediction=False, drop_path_rate=0., grid="equiangular"):
# prepare dicts containing models and corresponding metrics
model_registry = dict(
sfno_sc2_layers4_e32 = partial(
SphericalFourierNeuralOperator,
img_size=img_size,
grid=grid,
in_chans=in_chans,
out_chans=out_chans,
num_layers=4,
scale_factor=2,
embed_dim=32,
activation_function="gelu",
residual_prediction=residual_prediction,
use_mlp=True,
normalization_layer="instance_norm",
),
lsno_sc2_layers4_e32 = partial(
LocalSphericalNeuralOperator,
img_size=img_size,
grid=grid,
in_chans=in_chans,
out_chans=out_chans,
num_layers=4,
scale_factor=2,
embed_dim=32,
activation_function="gelu",
residual_prediction=residual_prediction,
use_mlp=True,
normalization_layer="instance_norm",
kernel_shape=(5, 4),
encoder_kernel_shape=(5, 4),
filter_basis_type="piecewise linear",
upsample_sht=False,
),
s2unet_sc2_layers4_e128 = partial(
SphericalUNet,
img_size=img_size,
grid=grid,
grid_internal="equiangular",
in_chans=in_chans,
out_chans=out_chans,
embed_dims=[16, 32, 64, 128],
depths=[2, 2, 2, 2],
scale_factor=2,
activation_function="gelu",
kernel_shape=(5, 4),
filter_basis_type="piecewise linear",
drop_path_rate=0.1,
drop_conv_rate=0.2,
drop_dense_rate=0.5,
transform_skip=False,
upsampling_mode="conv",
downsampling_mode="conv",
),
s2transformer_sc2_layers4_e128 = partial(
SphericalTransformer,
img_size=img_size,
grid=grid,
in_chans=in_chans,
out_chans=out_chans,
num_layers=4,
scale_factor=2,
embed_dim=128,
activation_function="gelu",
residual_prediction=residual_prediction,
pos_embed="spectral",
use_mlp=True,
normalization_layer="instance_norm",
encoder_kernel_shape=(5, 4),
filter_basis_type="piecewise linear",
drop_path_rate=drop_path_rate,
upsample_sht=False,
attention_mode="global",
bias=False
),
s2transformer_sc2_layers4_e256 = partial(
SphericalTransformer,
img_size=img_size,
grid=grid,
in_chans=in_chans,
out_chans=out_chans,
num_layers=4,
scale_factor=2,
embed_dim=256,
activation_function="gelu",
residual_prediction=residual_prediction,
pos_embed="spectral",
use_mlp=True,
normalization_layer="instance_norm",
encoder_kernel_shape=(5, 4),
filter_basis_type="piecewise linear",
drop_path_rate=drop_path_rate,
upsample_sht=False,
attention_mode="global",
bias=False
),
s2ntransformer_sc2_layers4_e128 = partial(
SphericalTransformer,
img_size=img_size,
grid=grid,
in_chans=in_chans,
out_chans=out_chans,
num_layers=4,
scale_factor=2,
embed_dim=128,
activation_function="gelu",
residual_prediction=residual_prediction,
pos_embed="spectral",
use_mlp=True,
normalization_layer="instance_norm",
encoder_kernel_shape=(5, 4),
filter_basis_type="piecewise linear",
drop_path_rate=drop_path_rate,
upsample_sht=False,
attention_mode="neighborhood",
bias=False
),
s2ntransformer_sc2_layers4_e256 = partial(
SphericalTransformer,
img_size=img_size,
grid=grid,
in_chans=in_chans,
out_chans=out_chans,
num_layers=4,
scale_factor=2,
embed_dim=256,
activation_function="gelu",
residual_prediction=residual_prediction,
pos_embed="spectral",
use_mlp=True,
normalization_layer="instance_norm",
encoder_kernel_shape=(5, 4),
filter_basis_type="piecewise linear",
drop_path_rate=drop_path_rate,
upsample_sht=False,
attention_mode="neighborhood",
bias=False
),
transformer_sc2_layers4_e128 = partial(
Transformer,
img_size=img_size,
in_chans=in_chans,
out_chans=out_chans,
num_layers=4,
scale_factor=2,
embed_dim=128,
activation_function="gelu",
residual_prediction=residual_prediction,
pos_embed="spectral",
use_mlp=True,
normalization_layer="instance_norm",
encoder_kernel_shape=(3, 3),
drop_path_rate=drop_path_rate,
attention_mode="global",
upsampling_method="conv",
bias=False
),
transformer_sc2_layers4_e256 = partial(
Transformer,
img_size=img_size,
in_chans=in_chans,
out_chans=out_chans,
num_layers=4,
scale_factor=2,
embed_dim=256,
activation_function="gelu",
residual_prediction=residual_prediction,
pos_embed="spectral",
use_mlp=True,
normalization_layer="instance_norm",
encoder_kernel_shape=(3, 3),
drop_path_rate=drop_path_rate,
attention_mode="global",
upsampling_method="conv",
bias=False
),
ntransformer_sc2_layers4_e128 = partial(
Transformer,
img_size=img_size,
in_chans=in_chans,
out_chans=out_chans,
num_layers=4,
scale_factor=2,
embed_dim=128,
activation_function="gelu",
residual_prediction=residual_prediction,
pos_embed="spectral",
use_mlp=True,
normalization_layer="instance_norm",
encoder_kernel_shape=(3, 3),
drop_path_rate=drop_path_rate,
attention_mode="neighborhood",
attn_kernel_shape=(7, 7),
bias=False
),
ntransformer_sc2_layers4_e256 = partial(
Transformer,
img_size=img_size,
in_chans=in_chans,
out_chans=out_chans,
num_layers=4,
scale_factor=2,
embed_dim=256,
activation_function="gelu",
residual_prediction=residual_prediction,
pos_embed="spectral",
use_mlp=True,
normalization_layer="instance_norm",
encoder_kernel_shape=(3, 3),
drop_path_rate=drop_path_rate,
attention_mode="neighborhood",
attn_kernel_shape=(7, 7),
bias=False
),
s2segformer_sc2_layers4_e128 = partial(
SphericalSegformer,
img_size=img_size,
grid=grid,
grid_internal="equiangular",
in_chans=in_chans,
out_chans=out_chans,
embed_dims=[16, 32, 64, 128],
heads=[1, 2, 4, 8],
depths=[3, 4, 6, 3],
scale_factor=2,
activation_function="gelu",
kernel_shape=(5, 4),
filter_basis_type="piecewise linear",
mlp_ratio=4.0,
att_drop_rate=0.0,
drop_path_rate=0.1,
attention_mode="global",
bias=False
),
s2segformer_sc2_layers4_e256 = partial(
SphericalSegformer,
img_size=img_size,
grid=grid,
grid_internal="equiangular",
in_chans=in_chans,
out_chans=out_chans,
embed_dims=[32, 64, 128, 256],
heads=[1, 2, 4, 8],
depths=[3, 4, 6, 3],
scale_factor=2,
activation_function="gelu",
kernel_shape=(5, 4),
filter_basis_type="piecewise linear",
mlp_ratio=4.0,
att_drop_rate=0.0,
drop_path_rate=0.1,
attention_mode="global",
bias=False
),
s2nsegformer_sc2_layers4_e128 = partial(
SphericalSegformer,
img_size=img_size,
grid=grid,
grid_internal="equiangular",
in_chans=in_chans,
out_chans=out_chans,
embed_dims=[16, 32, 64, 128],
heads=[1, 2, 4, 8],
depths=[3, 4, 6, 3],
scale_factor=2,
activation_function="gelu",
kernel_shape=(5, 4),
filter_basis_type="piecewise linear",
mlp_ratio=4.0,
att_drop_rate=0.0,
drop_path_rate=0.1,
attention_mode="neighborhood",
bias=False
),
s2nsegformer_sc2_layers4_e256 = partial(
SphericalSegformer,
img_size=img_size,
grid=grid,
grid_internal="equiangular",
in_chans=in_chans,
out_chans=out_chans,
embed_dims=[32, 64, 128, 256],
heads=[1, 2, 4, 8],
depths=[3, 4, 6, 3],
scale_factor=2,
activation_function="gelu",
kernel_shape=(5, 4),
filter_basis_type="piecewise linear",
mlp_ratio=4.0,
att_drop_rate=0.0,
drop_path_rate=0.1,
attention_mode="neighborhood",
bias=False
),
segformer_sc2_layers4_e128 = partial(
Segformer,
img_size=img_size,
in_chans=in_chans,
out_chans=out_chans,
embed_dims=[16, 32, 64, 128],
heads=[1, 2, 4, 8],
depths=[3, 4, 6, 3],
scale_factor=2,
activation_function="gelu",
kernel_shape=(4, 4),
mlp_ratio=4.0,
att_drop_rate=0.0,
drop_path_rate=0.1,
attention_mode="global",
bias=False
),
segformer_sc2_layers4_e256 = partial(
Segformer,
img_size=img_size,
in_chans=in_chans,
out_chans=out_chans,
embed_dims=[32, 64, 128, 256],
heads=[1, 2, 4, 8],
depths=[3, 4, 6, 3],
scale_factor=2,
activation_function="gelu",
kernel_shape=(4, 4),
mlp_ratio=4.0,
att_drop_rate=0.0,
drop_path_rate=0.1,
attention_mode="global",
bias=False
),
nsegformer_sc2_layers4_e128 = partial(
Segformer,
img_size=img_size,
in_chans=in_chans,
out_chans=out_chans,
embed_dims=[16, 32, 64, 128],
heads=[1, 2, 4, 8],
depths=[3, 4, 6, 3],
scale_factor=2,
activation_function="gelu",
kernel_shape=(4, 4),
mlp_ratio=4.0,
att_drop_rate=0.0,
drop_path_rate=0.1,
attention_mode="neighborhood",
attn_kernel_shape=(7, 7),
bias=False
),
nsegformer_sc2_layers4_e256 = partial(
Segformer,
img_size=img_size,
in_chans=in_chans,
out_chans=out_chans,
embed_dims=[32, 64, 128, 256],
heads=[1, 2, 4, 8],
depths=[3, 4, 6, 3],
scale_factor=2,
activation_function="gelu",
kernel_shape=(4, 4),
mlp_ratio=4.0,
att_drop_rate=0.0,
drop_path_rate=0.1,
attention_mode="neighborhood",
attn_kernel_shape=(7, 7),
bias=False
),
vit_sc2_layers4_e128 = partial(
Transformer,
img_size=img_size,
in_chans=in_chans,
out_chans=out_chans,
num_layers=4,
scale_factor=2,
embed_dim=128,
activation_function="gelu",
residual_prediction=residual_prediction,
pos_embed="learnable latlon",
use_mlp=True,
normalization_layer="layer_norm",
encoder_kernel_shape=(2, 2),
attention_mode="global",
upsampling_method="pixel_shuffle",
bias=False
),
)
return model_registry
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment