Unverified Commit c7afb546 authored by Thorsten Kurth's avatar Thorsten Kurth Committed by GitHub
Browse files

Merge pull request #95 from NVIDIA/aparis/docs

Docstrings PR
parents b5c410c0 644465ba
Pipeline #2854 canceled with stages
......@@ -14,6 +14,7 @@
* Reorganized examples folder, including new examples based on the 2d3ds dataset
* Added spherical loss functions to examples
* Added plotting module
* Updated docstrings
### v0.7.6
......
......@@ -41,6 +41,24 @@ from functools import partial
class OverlapPatchMerging(nn.Module):
"""
OverlapPatchMerging layer for merging patches.
Parameters
-----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
kernel_shape : tuple
Kernel shape for convolution
bias : bool, optional
Whether to use bias, by default False
"""
def __init__(
self,
in_shape=(721, 1440),
......@@ -88,6 +106,30 @@ class OverlapPatchMerging(nn.Module):
class MixFFN(nn.Module):
"""
MixFFN module combining MLP and depthwise convolution.
Parameters
-----------
shape : tuple
Input shape (height, width)
inout_channels : int
Number of input/output channels
hidden_channels : int
Number of hidden channels in MLP
mlp_bias : bool, optional
Whether to use bias in MLP layers, by default True
kernel_shape : tuple, optional
Kernel shape for depthwise convolution, by default (3, 3)
conv_bias : bool, optional
Whether to use bias in convolution, by default False
activation : callable, optional
Activation function, by default nn.GELU
use_mlp : bool, optional
Whether to use MLP instead of linear layers, by default False
drop_path : float, optional
Drop path rate, by default 0.0
"""
def __init__(
self,
shape,
......@@ -142,7 +184,7 @@ class MixFFN(nn.Module):
x = x.permute(0, 3, 1, 2)
# NOTE: we add another activation here
# because in the paper they only use depthwise conv,
# because in the paper the authors only use depthwise conv,
# but without this activation it would just be a fused MM
# with the disco conv
x = self.mlp_in(x)
......@@ -162,6 +204,17 @@ class GlobalAttention(nn.Module):
Input shape: (B, C, H, W)
Output shape: (B, C, H, W) with residual skip.
Parameters
-----------
chans : int
Number of channels
num_heads : int, optional
Number of attention heads, by default 8
dropout : float, optional
Dropout rate, by default 0.0
bias : bool, optional
Whether to use bias, by default True
"""
def __init__(self, chans, num_heads=8, dropout=0.0, bias=True):
......@@ -169,6 +222,7 @@ class GlobalAttention(nn.Module):
self.attn = nn.MultiheadAttention(embed_dim=chans, num_heads=num_heads, dropout=dropout, batch_first=True, bias=bias)
def forward(self, x):
# x: B, C, H, W
B, H, W, C = x.shape
# flatten spatial dims
......@@ -181,6 +235,30 @@ class GlobalAttention(nn.Module):
class AttentionWrapper(nn.Module):
"""
Wrapper for different attention mechanisms.
Parameters
-----------
channels : int
Number of channels
shape : tuple
Input shape (height, width)
heads : int
Number of attention heads
pre_norm : bool, optional
Whether to apply normalization before attention, by default False
attention_drop_rate : float, optional
Attention dropout rate, by default 0.0
drop_path : float, optional
Drop path rate, by default 0.0
attention_mode : str, optional
Attention mode ("neighborhood", "global"), by default "neighborhood"
kernel_shape : tuple, optional
Kernel shape for neighborhood attention, by default (7, 7)
bias : bool, optional
Whether to use bias, by default True
"""
def __init__(self, channels, shape, heads, pre_norm=False, attention_drop_rate=0.0, drop_path=0.0, attention_mode="neighborhood", kernel_shape=(7, 7), bias=True):
super().__init__()
......@@ -203,11 +281,13 @@ class AttentionWrapper(nn.Module):
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = x.permute(0, 2, 3, 1)
if self.norm is not None:
......@@ -219,6 +299,41 @@ class AttentionWrapper(nn.Module):
class TransformerBlock(nn.Module):
"""
Transformer block with attention and MLP.
Parameters
----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
mlp_hidden_channels : int
Number of hidden channels in MLP
nrep : int, optional
Number of repetitions of attention and MLP blocks, by default 1
heads : int, optional
Number of attention heads, by default 1
kernel_shape : tuple, optional
Kernel shape for neighborhood attention, by default (3, 3)
activation : torch.nn.Module, optional
Activation function to use, by default nn.GELU
att_drop_rate : float, optional
Attention dropout rate, by default 0.0
drop_path_rates : float or list, optional
Drop path rates for each block, by default 0.0
attention_mode : str, optional
Attention mode ("neighborhood", "global"), by default "neighborhood"
attn_kernel_shape : tuple, optional
Kernel shape for neighborhood attention, by default (7, 7)
bias : bool, optional
Whether to use bias, by default True
"""
def __init__(
self,
in_shape,
......@@ -341,6 +456,33 @@ class TransformerBlock(nn.Module):
class Upsampling(nn.Module):
"""
Upsampling block for the Segformer model.
Parameters
----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
hidden_channels : int
Number of hidden channels in MLP
mlp_bias : bool, optional
Whether to use bias in MLP, by default True
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
conv_bias : bool, optional
Whether to use bias in convolution, by default False
activation : torch.nn.Module, optional
Activation function to use, by default nn.GELU
use_mlp : bool, optional
Whether to use MLP, by default False
"""
def __init__(
self,
in_shape,
......@@ -382,7 +524,7 @@ class Segformer(nn.Module):
Spherical segformer model designed to approximate mappings from spherical signals to spherical segmentation masks
Parameters
-----------
----------
img_shape : tuple, optional
Shape of the input channels, by default (128, 256)
kernel_shape: tuple, int
......@@ -414,7 +556,7 @@ class Segformer(nn.Module):
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
Example
-----------
----------
>>> model = Segformer(
... img_size=(128, 256),
... in_chans=3,
......
......@@ -57,11 +57,34 @@ class Encoder(nn.Module):
self.conv = nn.Conv2d(in_chans, out_chans, kernel_size=kernel_shape, bias=bias, stride=(stride_h, stride_w), padding=(pad_h, pad_w), groups=groups)
def forward(self, x):
x = self.conv(x)
return x
class Decoder(nn.Module):
"""
Decoder module for upsampling and feature processing.
Parameters
-----------
in_shape : tuple, optional
Input shape (height, width), by default (480, 960)
out_shape : tuple, optional
Output shape (height, width), by default (721, 1440)
in_chans : int, optional
Number of input channels, by default 2
out_chans : int, optional
Number of output channels, by default 2
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
groups : int, optional
Number of groups for convolution, by default 1
bias : bool, optional
Whether to use bias, by default False
upsampling_method : str, optional
Upsampling method ("conv", "pixel_shuffle"), by default "conv"
"""
def __init__(self, in_shape=(480, 960), out_shape=(721, 1440), in_chans=2, out_chans=2, kernel_shape=(3, 3), groups=1, bias=False, upsampling_method="conv"):
super().__init__()
self.out_shape = out_shape
......@@ -87,6 +110,7 @@ class Decoder(nn.Module):
raise ValueError(f"Unknown upsampling method {upsampling_method}")
def forward(self, x):
x = self.upsample(x)
return x
......@@ -97,6 +121,17 @@ class GlobalAttention(nn.Module):
Input shape: (B, C, H, W)
Output shape: (B, C, H, W) with residual skip.
Parameters
-----------
chans : int
Number of channels
num_heads : int, optional
Number of attention heads, by default 8
dropout : float, optional
Dropout rate, by default 0.0
bias : bool, optional
Whether to use bias, by default True
"""
def __init__(self, chans, num_heads=8, dropout=0.0, bias=True):
......@@ -104,6 +139,7 @@ class GlobalAttention(nn.Module):
self.attn = nn.MultiheadAttention(embed_dim=chans, num_heads=num_heads, dropout=dropout, batch_first=True, bias=bias)
def forward(self, x):
# x: B, C, H, W
B, H, W, C = x.shape
# flatten spatial dims
......@@ -118,8 +154,36 @@ class GlobalAttention(nn.Module):
class AttentionBlock(nn.Module):
"""
Neighborhood attention block based on Natten.
Parameters
-----------
in_shape : tuple, optional
Input shape (height, width), by default (480, 960)
out_shape : tuple, optional
Output shape (height, width), by default (480, 960)
chans : int, optional
Number of channels, by default 2
num_heads : int, optional
Number of attention heads, by default 1
mlp_ratio : float, optional
Ratio of MLP hidden dim to input dim, by default 2.0
drop_rate : float, optional
Dropout rate, by default 0.0
drop_path : float, optional
Drop path rate, by default 0.0
act_layer : callable, optional
Activation function, by default nn.GELU
norm_layer : str, optional
Normalization layer type, by default "none"
use_mlp : bool, optional
Whether to use MLP, by default True
bias : bool, optional
Whether to use bias, by default True
attention_mode : str, optional
Attention mode ("neighborhood", "global"), by default "neighborhood"
attn_kernel_shape : tuple, optional
Kernel shape for neighborhood attention, by default (7, 7)
"""
def __init__(
self,
in_shape=(480, 960),
......
......@@ -43,6 +43,37 @@ from functools import partial
class DownsamplingBlock(nn.Module):
"""
Downsampling block for the UNet model.
Parameters
----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
nrep : int, optional
Number of repetitions of conv blocks, by default 1
kernel_shape : tuple, optional
Kernel shape for convolutions, by default (3, 3)
activation : callable, optional
Activation function, by default nn.ReLU
transform_skip : bool, optional
Whether to transform skip connections, by default False
drop_conv_rate : float, optional
Dropout rate for convolutions, by default 0.
drop_path_rate : float, optional
Drop path rate, by default 0.
drop_dense_rate : float, optional
Dropout rate for dense layers, by default 0.
downsampling_mode : str, optional
Downsampling mode ("bilinear", "conv"), by default "bilinear"
"""
def __init__(
self,
in_shape,
......@@ -146,6 +177,7 @@ class DownsamplingBlock(nn.Module):
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# skip connection
residual = x
if hasattr(self, "transform_skip"):
......@@ -166,6 +198,36 @@ class DownsamplingBlock(nn.Module):
class UpsamplingBlock(nn.Module):
"""
Upsampling block for UNet architecture.
Parameters
-----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
nrep : int, optional
Number of repetitions of conv blocks, by default 1
kernel_shape : tuple, optional
Kernel shape for convolutions, by default (3, 3)
activation : callable, optional
Activation function, by default nn.ReLU
transform_skip : bool, optional
Whether to transform skip connections, by default False
drop_conv_rate : float, optional
Dropout rate for convolutions, by default 0.
drop_path_rate : float, optional
Drop path rate, by default 0.
drop_dense_rate : float, optional
Dropout rate for dense layers, by default 0.
upsampling_mode : str, optional
Upsampling mode ("bilinear", "conv"), by default "bilinear"
"""
def __init__(
self,
in_shape,
......@@ -280,6 +342,7 @@ class UpsamplingBlock(nn.Module):
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# skip connection
residual = x
if hasattr(self, "transform_skip"):
......@@ -304,6 +367,7 @@ class UNet(nn.Module):
img_shape : tuple, optional
Shape of the input channels, by default (128, 256)
kernel_shape: tuple, int
Kernel shape for convolutions
scale_factor: int, optional
Scale factor to use, by default 2
in_chans : int, optional
......@@ -336,11 +400,12 @@ class UNet(nn.Module):
... scale_factor=4,
... in_chans=2,
... num_classes=2,
... embed_dims=[64, 128, 256, 512],)
... embed_dims=[16, 32, 64, 128],
... depths=[2, 2, 2, 2],
... use_mlp=True,)
>>> model(torch.randn(1, 2, 128, 256)).shape
torch.Size([1, 2, 128, 256])
"""
def __init__(
self,
img_shape=(128, 256),
......@@ -450,7 +515,7 @@ class UNet(nn.Module):
def forward(self, x):
# encoder:
features = []
feat = x
......
......@@ -68,9 +68,6 @@ def count_parameters(model):
# convenience function for logging weights and gradients
def log_weights_and_grads(exp_dir, model, iters=1):
"""
Helper routine intended for debugging purposes
"""
log_path = os.path.join(exp_dir, "weights_and_grads")
if not os.path.isdir(log_path):
os.makedirs(log_path, exist_ok=True)
......
......@@ -39,7 +39,7 @@ from baseline_models import Transformer, UNet, Segformer
from torch_harmonics.examples.models import SphericalFourierNeuralOperator, LocalSphericalNeuralOperator, SphericalTransformer, SphericalUNet, SphericalSegformer
def get_baseline_models(img_size=(128, 256), in_chans=3, out_chans=3, residual_prediction=False, drop_path_rate=0., grid="equiangular"):
# prepare dicts containing models and corresponding metrics
model_registry = dict(
sfno_sc2_layers4_e32 = partial(
......
......@@ -68,14 +68,13 @@ import wandb
# helper routine for counting number of paramerters in model
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# convenience function for logging weights and gradients
def log_weights_and_grads(exp_dir, model, iters=1):
"""
Helper routine intended for debugging purposes
"""
log_path = os.path.join(exp_dir, "weights_and_grads")
if not os.path.isdir(log_path):
os.makedirs(log_path, exist_ok=True)
......@@ -178,7 +177,7 @@ def train_model(
logging=True,
device=torch.device("cpu"),
):
train_start = time.time()
# set AMP type
......
......@@ -68,9 +68,6 @@ def count_parameters(model):
# convenience function for logging weights and gradients
def log_weights_and_grads(model, iters=1):
"""
Helper routine intended for debugging purposes
"""
root_path = os.path.join(os.path.dirname(__file__), "weights_and_grads")
weights_and_grads_fname = os.path.join(root_path, f"weights_and_grads_step{iters:03d}.tar")
......@@ -238,7 +235,7 @@ def train_model(
logging=True,
device=torch.device("cpu"),
):
train_start = time.time()
# set AMP type
......
......@@ -55,6 +55,7 @@ except (ImportError, TypeError, AssertionError, AttributeError) as e:
def get_compile_args(module_name):
"""If user runs build with TORCH_HARMONICS_DEBUG=1 set, it will use debugging flags to build"""
debug_mode = os.environ.get('TORCH_HARMONICS_DEBUG', '0') == '1'
profile_mode = os.environ.get('TORCH_HARMONICS_PROFILE', '0') == '1'
......@@ -77,7 +78,8 @@ def get_compile_args(module_name):
}
def get_ext_modules():
"""Get list of extension modules to compile."""
ext_modules = []
cmdclass = {}
......
......@@ -67,6 +67,8 @@ _perf_test_thresholds = {"fwd_ms": 50, "bwd_ms": 150}
@parameterized_class(("device"), _devices)
class TestNeighborhoodAttentionS2(unittest.TestCase):
"""Test the neighborhood attention module (CPU/CUDA if available)."""
def setUp(self):
torch.manual_seed(333)
if self.device.type == "cuda":
......
......@@ -36,7 +36,6 @@ import torch
class TestCacheConsistency(unittest.TestCase):
def test_consistency(self, verbose=False):
if verbose:
print("Testing that cache values does not get modified externally")
......
......@@ -47,9 +47,7 @@ if torch.cuda.is_available():
def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalization=False, basis_norm_mode="none", merge_quadrature=False, eps=1e-9):
"""
Discretely normalizes the convolution tensor.
"""
"""Discretely normalizes the convolution tensor."""
kernel_size, nlat_out, nlon_out, nlat_in, nlon_in = psi.shape
correction_factor = nlon_out / nlon_in
......@@ -98,10 +96,7 @@ def _precompute_convolution_tensor_dense(
basis_norm_mode="none",
merge_quadrature=False,
):
"""
Helper routine to compute the convolution Tensor in a dense fashion
"""
"""Helper routine to compute the convolution Tensor in a dense fashion."""
assert len(in_shape) == 2
assert len(out_shape) == 2
......@@ -168,6 +163,8 @@ def _precompute_convolution_tensor_dense(
@parameterized_class(("device"), _devices)
class TestDiscreteContinuousConvolution(unittest.TestCase):
"""Test the discrete-continuous convolution module (CPU/CUDA if available)."""
def setUp(self):
torch.manual_seed(333)
if self.device.type == "cuda":
......
......@@ -41,10 +41,10 @@ import torch_harmonics.distributed as thd
class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
"""Test the distributed discrete-continuous convolution module."""
@classmethod
def setUpClass(cls):
# set up distributed
cls.world_rank = int(os.getenv("WORLD_RANK", 0))
cls.grid_size_h = int(os.getenv("GRID_H", 1))
......@@ -118,6 +118,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
dist.destroy_process_group(None)
def _split_helper(self, tensor):
with torch.no_grad():
# split in W
tensor_list_local = thd.split_tensor_along_dim(tensor, dim=-1, num_chunks=self.grid_size_w)
......@@ -130,6 +131,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
return tensor_local
def _gather_helper_fwd(self, tensor, B, C, convolution_dist):
# we need the shapes
lat_shapes = convolution_dist.lat_out_shapes
lon_shapes = convolution_dist.lon_out_shapes
......@@ -157,6 +159,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
return tensor_gather
def _gather_helper_bwd(self, tensor, B, C, convolution_dist):
# we need the shapes
lat_shapes = convolution_dist.lat_in_shapes
lon_shapes = convolution_dist.lon_in_shapes
......@@ -204,7 +207,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
def test_distributed_disco_conv(
self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, kernel_shape, basis_type, basis_norm_mode, groups, grid_in, grid_out, transpose, tol
):
B, C, H, W = batch_size, num_chan, nlat_in, nlon_in
disco_args = dict(
......@@ -238,9 +241,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
# create tensors
inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device)
#############################################################
# local conv
#############################################################
# FWD pass
inp_full.requires_grad = True
out_full = conv_local(inp_full)
......@@ -254,9 +255,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
out_full.backward(ograd_full)
igrad_full = inp_full.grad.clone()
#############################################################
# distributed conv
#############################################################
# FWD pass
inp_local = self._split_helper(inp_full)
inp_local.requires_grad = True
......@@ -268,9 +267,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
out_local.backward(ograd_local)
igrad_local = inp_local.grad.clone()
#############################################################
# evaluate FWD pass
#############################################################
with torch.no_grad():
out_gather_full = self._gather_helper_fwd(out_local, B, C, conv_dist)
err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2)))
......@@ -278,9 +275,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
print(f"final relative error of output: {err.item()}")
self.assertTrue(err.item() <= tol)
#############################################################
# evaluate BWD pass
#############################################################
with torch.no_grad():
igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, conv_dist)
......
......@@ -41,6 +41,7 @@ import torch_harmonics.distributed as thd
class TestDistributedResampling(unittest.TestCase):
"""Test the distributed resampling module (CPU/CUDA if available)."""
@classmethod
def setUpClass(cls):
......@@ -118,6 +119,7 @@ class TestDistributedResampling(unittest.TestCase):
dist.destroy_process_group(None)
def _split_helper(self, tensor):
with torch.no_grad():
# split in W
tensor_list_local = thd.split_tensor_along_dim(tensor, dim=-1, num_chunks=self.grid_size_w)
......@@ -130,6 +132,7 @@ class TestDistributedResampling(unittest.TestCase):
return tensor_local
def _gather_helper_fwd(self, tensor, B, C, convolution_dist):
# we need the shapes
lat_shapes = convolution_dist.lat_out_shapes
lon_shapes = convolution_dist.lon_out_shapes
......@@ -157,6 +160,7 @@ class TestDistributedResampling(unittest.TestCase):
return tensor_gather
def _gather_helper_bwd(self, tensor, B, C, resampling_dist):
# we need the shapes
lat_shapes = resampling_dist.lat_in_shapes
lon_shapes = resampling_dist.lon_in_shapes
......@@ -196,7 +200,7 @@ class TestDistributedResampling(unittest.TestCase):
def test_distributed_resampling(
self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, grid_in, grid_out, mode, tol, verbose
):
B, C, H, W = batch_size, num_chan, nlat_in, nlon_in
res_args = dict(
......@@ -216,9 +220,7 @@ class TestDistributedResampling(unittest.TestCase):
# create tensors
inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device)
#############################################################
# local conv
#############################################################
# FWD pass
inp_full.requires_grad = True
out_full = res_local(inp_full)
......@@ -232,9 +234,7 @@ class TestDistributedResampling(unittest.TestCase):
out_full.backward(ograd_full)
igrad_full = inp_full.grad.clone()
#############################################################
# distributed conv
#############################################################
# FWD pass
inp_local = self._split_helper(inp_full)
inp_local.requires_grad = True
......@@ -246,9 +246,7 @@ class TestDistributedResampling(unittest.TestCase):
out_local.backward(ograd_local)
igrad_local = inp_local.grad.clone()
#############################################################
# evaluate FWD pass
#############################################################
with torch.no_grad():
out_gather_full = self._gather_helper_fwd(out_local, B, C, res_dist)
err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2)))
......@@ -256,9 +254,7 @@ class TestDistributedResampling(unittest.TestCase):
print(f"final relative error of output: {err.item()}")
self.assertTrue(err.item() <= tol)
#############################################################
# evaluate BWD pass
#############################################################
with torch.no_grad():
igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, res_dist)
......
......@@ -41,10 +41,10 @@ import torch_harmonics.distributed as thd
class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
"""Test the distributed spherical harmonic transform module (CPU/CUDA if available)."""
@classmethod
def setUpClass(cls):
# set up distributed
cls.world_rank = int(os.getenv("WORLD_RANK", 0))
cls.grid_size_h = int(os.getenv("GRID_H", 1))
......@@ -163,6 +163,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
return tensor_gather
def _gather_helper_bwd(self, tensor, B, C, transform_dist, vector):
# we need the shapes
lat_shapes = transform_dist.lat_shapes
lon_shapes = transform_dist.lon_shapes
......@@ -214,6 +215,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
]
)
def test_distributed_sht(self, nlat, nlon, batch_size, num_chan, grid, vector, tol):
B, C, H, W = batch_size, num_chan, nlat, nlon
# set up handles
......@@ -230,9 +232,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
else:
inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device)
#############################################################
# local transform
#############################################################
# FWD pass
inp_full.requires_grad = True
out_full = forward_transform_local(inp_full)
......@@ -246,9 +246,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
out_full.backward(ograd_full)
igrad_full = inp_full.grad.clone()
#############################################################
# distributed transform
#############################################################
# FWD pass
inp_local = self._split_helper(inp_full)
inp_local.requires_grad = True
......@@ -260,9 +258,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
out_local.backward(ograd_local)
igrad_local = inp_local.grad.clone()
#############################################################
# evaluate FWD pass
#############################################################
with torch.no_grad():
out_gather_full = self._gather_helper_fwd(out_local, B, C, forward_transform_dist, vector)
err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2)))
......@@ -270,9 +266,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
print(f"final relative error of output: {err.item()}")
self.assertTrue(err.item() <= tol)
#############################################################
# evaluate BWD pass
#############################################################
with torch.no_grad():
igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, forward_transform_dist, vector)
err = torch.mean(torch.norm(igrad_full - igrad_gather_full, p="fro", dim=(-1, -2)) / torch.norm(igrad_full, p="fro", dim=(-1, -2)))
......@@ -301,6 +295,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
]
)
def test_distributed_isht(self, nlat, nlon, batch_size, num_chan, grid, vector, tol):
B, C, H, W = batch_size, num_chan, nlat, nlon
if vector:
......@@ -340,9 +335,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
out_full.backward(ograd_full)
igrad_full = inp_full.grad.clone()
#############################################################
# distributed transform
#############################################################
# FWD pass
inp_local = self._split_helper(inp_full)
inp_local.requires_grad = True
......@@ -354,9 +347,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
out_local.backward(ograd_local)
igrad_local = inp_local.grad.clone()
#############################################################
# evaluate FWD pass
#############################################################
with torch.no_grad():
out_gather_full = self._gather_helper_bwd(out_local, B, C, backward_transform_dist, vector)
err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2)))
......@@ -364,9 +355,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
print(f"final relative error of output: {err.item()}")
self.assertTrue(err.item() <= tol)
#############################################################
# evaluate BWD pass
#############################################################
with torch.no_grad():
igrad_gather_full = self._gather_helper_fwd(igrad_local, B, C, backward_transform_dist, vector)
err = torch.mean(torch.norm(igrad_full - igrad_gather_full, p="fro", dim=(-1, -2)) / torch.norm(igrad_full, p="fro", dim=(-1, -2)))
......
......@@ -42,7 +42,7 @@ if torch.cuda.is_available():
class TestLegendrePolynomials(unittest.TestCase):
"""Test the associated Legendre polynomials (CPU/CUDA if available)."""
def setUp(self):
self.cml = lambda m, l: math.sqrt((2 * l + 1) / 4 / math.pi) * math.sqrt(math.factorial(l - m) / math.factorial(l + m))
self.pml = dict()
......@@ -79,7 +79,7 @@ class TestLegendrePolynomials(unittest.TestCase):
@parameterized_class(("device"), _devices)
class TestSphericalHarmonicTransform(unittest.TestCase):
"""Test the spherical harmonic transform (CPU/CUDA if available)."""
def setUp(self):
torch.manual_seed(333)
if self.device.type == "cuda":
......
......@@ -42,7 +42,7 @@ except ImportError as err:
# some helper functions
def _get_psi(kernel_size: int, psi_idx: torch.Tensor, psi_vals: torch.Tensor, nlat_in: int, nlon_in: int, nlat_out: int, nlon_out: int, nlat_in_local: Optional[int] = None, nlat_out_local: Optional[int] = None, semi_transposed: Optional[bool] = False):
"""Creates a sparse tensor for spherical harmonic convolution operations."""
nlat_in_local = nlat_in_local if nlat_in_local is not None else nlat_in
nlat_out_local = nlat_out_local if nlat_out_local is not None else nlat_out
......@@ -67,6 +67,7 @@ class _DiscoS2ContractionCuda(torch.autograd.Function):
def forward(ctx, x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor,
row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor,
kernel_size: int, nlat_out: int, nlon_out: int):
ctx.save_for_backward(roff_idx, ker_idx, row_idx, col_idx, vals)
ctx.kernel_size = kernel_size
ctx.nlat_in = x.shape[-2]
......@@ -81,6 +82,7 @@ class _DiscoS2ContractionCuda(torch.autograd.Function):
@staticmethod
@custom_bwd(device_type="cuda")
def backward(ctx, grad_output):
roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors
gtype = grad_output.dtype
grad_output = grad_output.to(torch.float32).contiguous()
......@@ -97,6 +99,7 @@ class _DiscoS2TransposeContractionCuda(torch.autograd.Function):
def forward(ctx, x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor,
row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor,
kernel_size: int, nlat_out: int, nlon_out: int):
ctx.save_for_backward(roff_idx, ker_idx, row_idx, col_idx, vals)
ctx.kernel_size = kernel_size
ctx.nlat_in = x.shape[-2]
......@@ -111,6 +114,7 @@ class _DiscoS2TransposeContractionCuda(torch.autograd.Function):
@staticmethod
@custom_bwd(device_type="cuda")
def backward(ctx, grad_output):
roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors
gtype = grad_output.dtype
grad_output = grad_output.to(torch.float32).contiguous()
......@@ -140,6 +144,7 @@ def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: in
shifting of the input tensor, which can potentially be costly. For an efficient implementation
on GPU, make sure to use the custom kernel written in CUDA.
"""
assert len(psi.shape) == 3
assert len(x.shape) == 4
psi = psi.to(x.device)
......@@ -171,11 +176,6 @@ def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: in
def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: int):
"""
Reference implementation of the custom contraction as described in [1]. This requires repeated
shifting of the input tensor, which can potentially be costly. For an efficient implementation
on GPU, make sure to use the custom kernel written in CUDA.
"""
assert len(psi.shape) == 3
assert len(x.shape) == 5
psi = psi.to(x.device)
......
......@@ -50,8 +50,6 @@ except ImportError as err:
def _neighborhood_attention_s2_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor,
quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
# prepare result tensor
y = torch.zeros_like(qy)
......@@ -170,7 +168,6 @@ def _neighborhood_attention_s2_bwd_dv_torch(kx: torch.Tensor, vx: torch.Tensor,
def _neighborhood_attention_s2_bwd_dk_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor,
quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
nlon_in: int, nlat_out: int, nlon_out: int):
# shapes:
# input
# kx: B, C, Hi, Wi
......@@ -252,6 +249,7 @@ def _neighborhood_attention_s2_bwd_dq_torch(kx: torch.Tensor, vx: torch.Tensor,
quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
nlon_in: int, nlat_out: int, nlon_out: int):
# shapes:
# input
# kx: B, C, Hi, Wi
......@@ -329,7 +327,7 @@ class _NeighborhoodAttentionS2(torch.autograd.Function):
bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None],
quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
nh: int, nlon_in: int, nlat_out: int, nlon_out: int):
ctx.save_for_backward(col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq)
ctx.nh = nh
ctx.nlon_in = nlon_in
......@@ -443,7 +441,7 @@ def _neighborhood_attention_s2_torch(k: torch.Tensor, v: torch.Tensor, q: torch.
bq: Union[torch.Tensor, None], quad_weights: torch.Tensor,
col_idx: torch.Tensor, row_off: torch.Tensor,
nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
return _NeighborhoodAttentionS2.apply(k, v, q, wk, wv, wq, bk, bv, bq,
quad_weights, col_idx, row_off,
nh, nlon_in, nlat_out, nlon_out)
......@@ -451,6 +449,7 @@ def _neighborhood_attention_s2_torch(k: torch.Tensor, v: torch.Tensor, q: torch.
class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
@staticmethod
@custom_fwd(device_type="cuda")
def forward(ctx, k: torch.Tensor, v: torch.Tensor, q: torch.Tensor,
......@@ -458,7 +457,7 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None],
quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int):
ctx.save_for_backward(col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq)
ctx.nh = nh
ctx.max_psi_nnz = max_psi_nnz
......@@ -584,7 +583,7 @@ def _neighborhood_attention_s2_cuda(k: torch.Tensor, v: torch.Tensor, q: torch.T
bq: Union[torch.Tensor, None], quad_weights: torch.Tensor,
col_idx: torch.Tensor, row_off: torch.Tensor, max_psi_nnz: int,
nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
return _NeighborhoodAttentionS2Cuda.apply(k, v, q, wk, wv, wq, bk, bv, bq,
quad_weights, col_idx, row_off, max_psi_nnz,
nh, nlon_in, nlat_out, nlon_out)
......@@ -142,9 +142,6 @@ class AttentionS2(nn.Module):
def extra_repr(self):
r"""
Pretty print module
"""
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_channels={self.in_channels}, out_channels={self.out_channels}, k_channels={self.k_channels}"
def forward(self, query: torch.Tensor, key: Optional[torch.Tensor] = None, value: Optional[torch.Tensor] = None) -> torch.Tensor:
......@@ -317,9 +314,6 @@ class NeighborhoodAttentionS2(nn.Module):
self.proj_bias = None
def extra_repr(self):
r"""
Pretty print module
"""
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_channels={self.in_channels}, out_channels={self.out_channels}, k_channels={self.k_channels}"
def forward(self, query: torch.Tensor, key: Optional[torch.Tensor] = None, value: Optional[torch.Tensor] = None) -> torch.Tensor:
......
......@@ -35,6 +35,32 @@ from copy import deepcopy
# copying LRU cache decorator a la:
# https://stackoverflow.com/questions/54909357/how-to-get-functools-lru-cache-to-return-new-instances
def lru_cache(maxsize=20, typed=False, copy=False):
"""
Least Recently Used (LRU) cache decorator with optional deep copying.
This is a wrapper around functools.lru_cache that adds the ability to return
deep copies of cached results to prevent unintended modifications to cached objects.
Parameters
-----------
maxsize : int, optional
Maximum number of items to cache, by default 20
typed : bool, optional
Whether to cache different types separately, by default False
copy : bool, optional
Whether to return deep copies of cached results, by default False
Returns
-------
function
Decorated function with LRU caching
Example
-------
>>> @lru_cache(maxsize=10, copy=True)
... def expensive_function(x):
... return [x, x*2, x*3]
"""
def decorator(f):
cached_func = functools.lru_cache(maxsize=maxsize, typed=typed)(f)
def wrapper(*args, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment