Unverified Commit c7afb546 authored by Thorsten Kurth's avatar Thorsten Kurth Committed by GitHub
Browse files

Merge pull request #95 from NVIDIA/aparis/docs

Docstrings PR
parents b5c410c0 644465ba
Pipeline #2854 canceled with stages
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
* Reorganized examples folder, including new examples based on the 2d3ds dataset * Reorganized examples folder, including new examples based on the 2d3ds dataset
* Added spherical loss functions to examples * Added spherical loss functions to examples
* Added plotting module * Added plotting module
* Updated docstrings
### v0.7.6 ### v0.7.6
......
...@@ -41,6 +41,24 @@ from functools import partial ...@@ -41,6 +41,24 @@ from functools import partial
class OverlapPatchMerging(nn.Module): class OverlapPatchMerging(nn.Module):
"""
OverlapPatchMerging layer for merging patches.
Parameters
-----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
kernel_shape : tuple
Kernel shape for convolution
bias : bool, optional
Whether to use bias, by default False
"""
def __init__( def __init__(
self, self,
in_shape=(721, 1440), in_shape=(721, 1440),
...@@ -88,6 +106,30 @@ class OverlapPatchMerging(nn.Module): ...@@ -88,6 +106,30 @@ class OverlapPatchMerging(nn.Module):
class MixFFN(nn.Module): class MixFFN(nn.Module):
"""
MixFFN module combining MLP and depthwise convolution.
Parameters
-----------
shape : tuple
Input shape (height, width)
inout_channels : int
Number of input/output channels
hidden_channels : int
Number of hidden channels in MLP
mlp_bias : bool, optional
Whether to use bias in MLP layers, by default True
kernel_shape : tuple, optional
Kernel shape for depthwise convolution, by default (3, 3)
conv_bias : bool, optional
Whether to use bias in convolution, by default False
activation : callable, optional
Activation function, by default nn.GELU
use_mlp : bool, optional
Whether to use MLP instead of linear layers, by default False
drop_path : float, optional
Drop path rate, by default 0.0
"""
def __init__( def __init__(
self, self,
shape, shape,
...@@ -142,7 +184,7 @@ class MixFFN(nn.Module): ...@@ -142,7 +184,7 @@ class MixFFN(nn.Module):
x = x.permute(0, 3, 1, 2) x = x.permute(0, 3, 1, 2)
# NOTE: we add another activation here # NOTE: we add another activation here
# because in the paper they only use depthwise conv, # because in the paper the authors only use depthwise conv,
# but without this activation it would just be a fused MM # but without this activation it would just be a fused MM
# with the disco conv # with the disco conv
x = self.mlp_in(x) x = self.mlp_in(x)
...@@ -162,6 +204,17 @@ class GlobalAttention(nn.Module): ...@@ -162,6 +204,17 @@ class GlobalAttention(nn.Module):
Input shape: (B, C, H, W) Input shape: (B, C, H, W)
Output shape: (B, C, H, W) with residual skip. Output shape: (B, C, H, W) with residual skip.
Parameters
-----------
chans : int
Number of channels
num_heads : int, optional
Number of attention heads, by default 8
dropout : float, optional
Dropout rate, by default 0.0
bias : bool, optional
Whether to use bias, by default True
""" """
def __init__(self, chans, num_heads=8, dropout=0.0, bias=True): def __init__(self, chans, num_heads=8, dropout=0.0, bias=True):
...@@ -169,6 +222,7 @@ class GlobalAttention(nn.Module): ...@@ -169,6 +222,7 @@ class GlobalAttention(nn.Module):
self.attn = nn.MultiheadAttention(embed_dim=chans, num_heads=num_heads, dropout=dropout, batch_first=True, bias=bias) self.attn = nn.MultiheadAttention(embed_dim=chans, num_heads=num_heads, dropout=dropout, batch_first=True, bias=bias)
def forward(self, x): def forward(self, x):
# x: B, C, H, W # x: B, C, H, W
B, H, W, C = x.shape B, H, W, C = x.shape
# flatten spatial dims # flatten spatial dims
...@@ -181,6 +235,30 @@ class GlobalAttention(nn.Module): ...@@ -181,6 +235,30 @@ class GlobalAttention(nn.Module):
class AttentionWrapper(nn.Module): class AttentionWrapper(nn.Module):
"""
Wrapper for different attention mechanisms.
Parameters
-----------
channels : int
Number of channels
shape : tuple
Input shape (height, width)
heads : int
Number of attention heads
pre_norm : bool, optional
Whether to apply normalization before attention, by default False
attention_drop_rate : float, optional
Attention dropout rate, by default 0.0
drop_path : float, optional
Drop path rate, by default 0.0
attention_mode : str, optional
Attention mode ("neighborhood", "global"), by default "neighborhood"
kernel_shape : tuple, optional
Kernel shape for neighborhood attention, by default (7, 7)
bias : bool, optional
Whether to use bias, by default True
"""
def __init__(self, channels, shape, heads, pre_norm=False, attention_drop_rate=0.0, drop_path=0.0, attention_mode="neighborhood", kernel_shape=(7, 7), bias=True): def __init__(self, channels, shape, heads, pre_norm=False, attention_drop_rate=0.0, drop_path=0.0, attention_mode="neighborhood", kernel_shape=(7, 7), bias=True):
super().__init__() super().__init__()
...@@ -203,11 +281,13 @@ class AttentionWrapper(nn.Module): ...@@ -203,11 +281,13 @@ class AttentionWrapper(nn.Module):
self.apply(self._init_weights) self.apply(self._init_weights)
def _init_weights(self, m): def _init_weights(self, m):
if isinstance(m, nn.LayerNorm): if isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0) nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x residual = x
x = x.permute(0, 2, 3, 1) x = x.permute(0, 2, 3, 1)
if self.norm is not None: if self.norm is not None:
...@@ -219,6 +299,41 @@ class AttentionWrapper(nn.Module): ...@@ -219,6 +299,41 @@ class AttentionWrapper(nn.Module):
class TransformerBlock(nn.Module): class TransformerBlock(nn.Module):
"""
Transformer block with attention and MLP.
Parameters
----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
mlp_hidden_channels : int
Number of hidden channels in MLP
nrep : int, optional
Number of repetitions of attention and MLP blocks, by default 1
heads : int, optional
Number of attention heads, by default 1
kernel_shape : tuple, optional
Kernel shape for neighborhood attention, by default (3, 3)
activation : torch.nn.Module, optional
Activation function to use, by default nn.GELU
att_drop_rate : float, optional
Attention dropout rate, by default 0.0
drop_path_rates : float or list, optional
Drop path rates for each block, by default 0.0
attention_mode : str, optional
Attention mode ("neighborhood", "global"), by default "neighborhood"
attn_kernel_shape : tuple, optional
Kernel shape for neighborhood attention, by default (7, 7)
bias : bool, optional
Whether to use bias, by default True
"""
def __init__( def __init__(
self, self,
in_shape, in_shape,
...@@ -341,6 +456,33 @@ class TransformerBlock(nn.Module): ...@@ -341,6 +456,33 @@ class TransformerBlock(nn.Module):
class Upsampling(nn.Module): class Upsampling(nn.Module):
"""
Upsampling block for the Segformer model.
Parameters
----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
hidden_channels : int
Number of hidden channels in MLP
mlp_bias : bool, optional
Whether to use bias in MLP, by default True
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
conv_bias : bool, optional
Whether to use bias in convolution, by default False
activation : torch.nn.Module, optional
Activation function to use, by default nn.GELU
use_mlp : bool, optional
Whether to use MLP, by default False
"""
def __init__( def __init__(
self, self,
in_shape, in_shape,
...@@ -382,7 +524,7 @@ class Segformer(nn.Module): ...@@ -382,7 +524,7 @@ class Segformer(nn.Module):
Spherical segformer model designed to approximate mappings from spherical signals to spherical segmentation masks Spherical segformer model designed to approximate mappings from spherical signals to spherical segmentation masks
Parameters Parameters
----------- ----------
img_shape : tuple, optional img_shape : tuple, optional
Shape of the input channels, by default (128, 256) Shape of the input channels, by default (128, 256)
kernel_shape: tuple, int kernel_shape: tuple, int
...@@ -414,7 +556,7 @@ class Segformer(nn.Module): ...@@ -414,7 +556,7 @@ class Segformer(nn.Module):
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm" Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
Example Example
----------- ----------
>>> model = Segformer( >>> model = Segformer(
... img_size=(128, 256), ... img_size=(128, 256),
... in_chans=3, ... in_chans=3,
......
...@@ -57,11 +57,34 @@ class Encoder(nn.Module): ...@@ -57,11 +57,34 @@ class Encoder(nn.Module):
self.conv = nn.Conv2d(in_chans, out_chans, kernel_size=kernel_shape, bias=bias, stride=(stride_h, stride_w), padding=(pad_h, pad_w), groups=groups) self.conv = nn.Conv2d(in_chans, out_chans, kernel_size=kernel_shape, bias=bias, stride=(stride_h, stride_w), padding=(pad_h, pad_w), groups=groups)
def forward(self, x): def forward(self, x):
x = self.conv(x) x = self.conv(x)
return x return x
class Decoder(nn.Module): class Decoder(nn.Module):
"""
Decoder module for upsampling and feature processing.
Parameters
-----------
in_shape : tuple, optional
Input shape (height, width), by default (480, 960)
out_shape : tuple, optional
Output shape (height, width), by default (721, 1440)
in_chans : int, optional
Number of input channels, by default 2
out_chans : int, optional
Number of output channels, by default 2
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
groups : int, optional
Number of groups for convolution, by default 1
bias : bool, optional
Whether to use bias, by default False
upsampling_method : str, optional
Upsampling method ("conv", "pixel_shuffle"), by default "conv"
"""
def __init__(self, in_shape=(480, 960), out_shape=(721, 1440), in_chans=2, out_chans=2, kernel_shape=(3, 3), groups=1, bias=False, upsampling_method="conv"): def __init__(self, in_shape=(480, 960), out_shape=(721, 1440), in_chans=2, out_chans=2, kernel_shape=(3, 3), groups=1, bias=False, upsampling_method="conv"):
super().__init__() super().__init__()
self.out_shape = out_shape self.out_shape = out_shape
...@@ -87,6 +110,7 @@ class Decoder(nn.Module): ...@@ -87,6 +110,7 @@ class Decoder(nn.Module):
raise ValueError(f"Unknown upsampling method {upsampling_method}") raise ValueError(f"Unknown upsampling method {upsampling_method}")
def forward(self, x): def forward(self, x):
x = self.upsample(x) x = self.upsample(x)
return x return x
...@@ -97,6 +121,17 @@ class GlobalAttention(nn.Module): ...@@ -97,6 +121,17 @@ class GlobalAttention(nn.Module):
Input shape: (B, C, H, W) Input shape: (B, C, H, W)
Output shape: (B, C, H, W) with residual skip. Output shape: (B, C, H, W) with residual skip.
Parameters
-----------
chans : int
Number of channels
num_heads : int, optional
Number of attention heads, by default 8
dropout : float, optional
Dropout rate, by default 0.0
bias : bool, optional
Whether to use bias, by default True
""" """
def __init__(self, chans, num_heads=8, dropout=0.0, bias=True): def __init__(self, chans, num_heads=8, dropout=0.0, bias=True):
...@@ -104,6 +139,7 @@ class GlobalAttention(nn.Module): ...@@ -104,6 +139,7 @@ class GlobalAttention(nn.Module):
self.attn = nn.MultiheadAttention(embed_dim=chans, num_heads=num_heads, dropout=dropout, batch_first=True, bias=bias) self.attn = nn.MultiheadAttention(embed_dim=chans, num_heads=num_heads, dropout=dropout, batch_first=True, bias=bias)
def forward(self, x): def forward(self, x):
# x: B, C, H, W # x: B, C, H, W
B, H, W, C = x.shape B, H, W, C = x.shape
# flatten spatial dims # flatten spatial dims
...@@ -118,8 +154,36 @@ class GlobalAttention(nn.Module): ...@@ -118,8 +154,36 @@ class GlobalAttention(nn.Module):
class AttentionBlock(nn.Module): class AttentionBlock(nn.Module):
""" """
Neighborhood attention block based on Natten. Neighborhood attention block based on Natten.
Parameters
-----------
in_shape : tuple, optional
Input shape (height, width), by default (480, 960)
out_shape : tuple, optional
Output shape (height, width), by default (480, 960)
chans : int, optional
Number of channels, by default 2
num_heads : int, optional
Number of attention heads, by default 1
mlp_ratio : float, optional
Ratio of MLP hidden dim to input dim, by default 2.0
drop_rate : float, optional
Dropout rate, by default 0.0
drop_path : float, optional
Drop path rate, by default 0.0
act_layer : callable, optional
Activation function, by default nn.GELU
norm_layer : str, optional
Normalization layer type, by default "none"
use_mlp : bool, optional
Whether to use MLP, by default True
bias : bool, optional
Whether to use bias, by default True
attention_mode : str, optional
Attention mode ("neighborhood", "global"), by default "neighborhood"
attn_kernel_shape : tuple, optional
Kernel shape for neighborhood attention, by default (7, 7)
""" """
def __init__( def __init__(
self, self,
in_shape=(480, 960), in_shape=(480, 960),
......
...@@ -43,6 +43,37 @@ from functools import partial ...@@ -43,6 +43,37 @@ from functools import partial
class DownsamplingBlock(nn.Module): class DownsamplingBlock(nn.Module):
"""
Downsampling block for the UNet model.
Parameters
----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
nrep : int, optional
Number of repetitions of conv blocks, by default 1
kernel_shape : tuple, optional
Kernel shape for convolutions, by default (3, 3)
activation : callable, optional
Activation function, by default nn.ReLU
transform_skip : bool, optional
Whether to transform skip connections, by default False
drop_conv_rate : float, optional
Dropout rate for convolutions, by default 0.
drop_path_rate : float, optional
Drop path rate, by default 0.
drop_dense_rate : float, optional
Dropout rate for dense layers, by default 0.
downsampling_mode : str, optional
Downsampling mode ("bilinear", "conv"), by default "bilinear"
"""
def __init__( def __init__(
self, self,
in_shape, in_shape,
...@@ -146,6 +177,7 @@ class DownsamplingBlock(nn.Module): ...@@ -146,6 +177,7 @@ class DownsamplingBlock(nn.Module):
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
# skip connection # skip connection
residual = x residual = x
if hasattr(self, "transform_skip"): if hasattr(self, "transform_skip"):
...@@ -166,6 +198,36 @@ class DownsamplingBlock(nn.Module): ...@@ -166,6 +198,36 @@ class DownsamplingBlock(nn.Module):
class UpsamplingBlock(nn.Module): class UpsamplingBlock(nn.Module):
"""
Upsampling block for UNet architecture.
Parameters
-----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
nrep : int, optional
Number of repetitions of conv blocks, by default 1
kernel_shape : tuple, optional
Kernel shape for convolutions, by default (3, 3)
activation : callable, optional
Activation function, by default nn.ReLU
transform_skip : bool, optional
Whether to transform skip connections, by default False
drop_conv_rate : float, optional
Dropout rate for convolutions, by default 0.
drop_path_rate : float, optional
Drop path rate, by default 0.
drop_dense_rate : float, optional
Dropout rate for dense layers, by default 0.
upsampling_mode : str, optional
Upsampling mode ("bilinear", "conv"), by default "bilinear"
"""
def __init__( def __init__(
self, self,
in_shape, in_shape,
...@@ -280,6 +342,7 @@ class UpsamplingBlock(nn.Module): ...@@ -280,6 +342,7 @@ class UpsamplingBlock(nn.Module):
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
# skip connection # skip connection
residual = x residual = x
if hasattr(self, "transform_skip"): if hasattr(self, "transform_skip"):
...@@ -304,6 +367,7 @@ class UNet(nn.Module): ...@@ -304,6 +367,7 @@ class UNet(nn.Module):
img_shape : tuple, optional img_shape : tuple, optional
Shape of the input channels, by default (128, 256) Shape of the input channels, by default (128, 256)
kernel_shape: tuple, int kernel_shape: tuple, int
Kernel shape for convolutions
scale_factor: int, optional scale_factor: int, optional
Scale factor to use, by default 2 Scale factor to use, by default 2
in_chans : int, optional in_chans : int, optional
...@@ -336,11 +400,12 @@ class UNet(nn.Module): ...@@ -336,11 +400,12 @@ class UNet(nn.Module):
... scale_factor=4, ... scale_factor=4,
... in_chans=2, ... in_chans=2,
... num_classes=2, ... num_classes=2,
... embed_dims=[64, 128, 256, 512],) ... embed_dims=[16, 32, 64, 128],
... depths=[2, 2, 2, 2],
... use_mlp=True,)
>>> model(torch.randn(1, 2, 128, 256)).shape >>> model(torch.randn(1, 2, 128, 256)).shape
torch.Size([1, 2, 128, 256]) torch.Size([1, 2, 128, 256])
""" """
def __init__( def __init__(
self, self,
img_shape=(128, 256), img_shape=(128, 256),
...@@ -450,7 +515,7 @@ class UNet(nn.Module): ...@@ -450,7 +515,7 @@ class UNet(nn.Module):
def forward(self, x): def forward(self, x):
# encoder: # encoder:
features = [] features = []
feat = x feat = x
......
...@@ -68,9 +68,6 @@ def count_parameters(model): ...@@ -68,9 +68,6 @@ def count_parameters(model):
# convenience function for logging weights and gradients # convenience function for logging weights and gradients
def log_weights_and_grads(exp_dir, model, iters=1): def log_weights_and_grads(exp_dir, model, iters=1):
"""
Helper routine intended for debugging purposes
"""
log_path = os.path.join(exp_dir, "weights_and_grads") log_path = os.path.join(exp_dir, "weights_and_grads")
if not os.path.isdir(log_path): if not os.path.isdir(log_path):
os.makedirs(log_path, exist_ok=True) os.makedirs(log_path, exist_ok=True)
......
...@@ -39,7 +39,7 @@ from baseline_models import Transformer, UNet, Segformer ...@@ -39,7 +39,7 @@ from baseline_models import Transformer, UNet, Segformer
from torch_harmonics.examples.models import SphericalFourierNeuralOperator, LocalSphericalNeuralOperator, SphericalTransformer, SphericalUNet, SphericalSegformer from torch_harmonics.examples.models import SphericalFourierNeuralOperator, LocalSphericalNeuralOperator, SphericalTransformer, SphericalUNet, SphericalSegformer
def get_baseline_models(img_size=(128, 256), in_chans=3, out_chans=3, residual_prediction=False, drop_path_rate=0., grid="equiangular"): def get_baseline_models(img_size=(128, 256), in_chans=3, out_chans=3, residual_prediction=False, drop_path_rate=0., grid="equiangular"):
# prepare dicts containing models and corresponding metrics # prepare dicts containing models and corresponding metrics
model_registry = dict( model_registry = dict(
sfno_sc2_layers4_e32 = partial( sfno_sc2_layers4_e32 = partial(
......
...@@ -68,14 +68,13 @@ import wandb ...@@ -68,14 +68,13 @@ import wandb
# helper routine for counting number of paramerters in model # helper routine for counting number of paramerters in model
def count_parameters(model): def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad) return sum(p.numel() for p in model.parameters() if p.requires_grad)
# convenience function for logging weights and gradients # convenience function for logging weights and gradients
def log_weights_and_grads(exp_dir, model, iters=1): def log_weights_and_grads(exp_dir, model, iters=1):
"""
Helper routine intended for debugging purposes
"""
log_path = os.path.join(exp_dir, "weights_and_grads") log_path = os.path.join(exp_dir, "weights_and_grads")
if not os.path.isdir(log_path): if not os.path.isdir(log_path):
os.makedirs(log_path, exist_ok=True) os.makedirs(log_path, exist_ok=True)
...@@ -178,7 +177,7 @@ def train_model( ...@@ -178,7 +177,7 @@ def train_model(
logging=True, logging=True,
device=torch.device("cpu"), device=torch.device("cpu"),
): ):
train_start = time.time() train_start = time.time()
# set AMP type # set AMP type
......
...@@ -68,9 +68,6 @@ def count_parameters(model): ...@@ -68,9 +68,6 @@ def count_parameters(model):
# convenience function for logging weights and gradients # convenience function for logging weights and gradients
def log_weights_and_grads(model, iters=1): def log_weights_and_grads(model, iters=1):
"""
Helper routine intended for debugging purposes
"""
root_path = os.path.join(os.path.dirname(__file__), "weights_and_grads") root_path = os.path.join(os.path.dirname(__file__), "weights_and_grads")
weights_and_grads_fname = os.path.join(root_path, f"weights_and_grads_step{iters:03d}.tar") weights_and_grads_fname = os.path.join(root_path, f"weights_and_grads_step{iters:03d}.tar")
...@@ -238,7 +235,7 @@ def train_model( ...@@ -238,7 +235,7 @@ def train_model(
logging=True, logging=True,
device=torch.device("cpu"), device=torch.device("cpu"),
): ):
train_start = time.time() train_start = time.time()
# set AMP type # set AMP type
......
...@@ -55,6 +55,7 @@ except (ImportError, TypeError, AssertionError, AttributeError) as e: ...@@ -55,6 +55,7 @@ except (ImportError, TypeError, AssertionError, AttributeError) as e:
def get_compile_args(module_name): def get_compile_args(module_name):
"""If user runs build with TORCH_HARMONICS_DEBUG=1 set, it will use debugging flags to build""" """If user runs build with TORCH_HARMONICS_DEBUG=1 set, it will use debugging flags to build"""
debug_mode = os.environ.get('TORCH_HARMONICS_DEBUG', '0') == '1' debug_mode = os.environ.get('TORCH_HARMONICS_DEBUG', '0') == '1'
profile_mode = os.environ.get('TORCH_HARMONICS_PROFILE', '0') == '1' profile_mode = os.environ.get('TORCH_HARMONICS_PROFILE', '0') == '1'
...@@ -77,7 +78,8 @@ def get_compile_args(module_name): ...@@ -77,7 +78,8 @@ def get_compile_args(module_name):
} }
def get_ext_modules(): def get_ext_modules():
"""Get list of extension modules to compile."""
ext_modules = [] ext_modules = []
cmdclass = {} cmdclass = {}
......
...@@ -67,6 +67,8 @@ _perf_test_thresholds = {"fwd_ms": 50, "bwd_ms": 150} ...@@ -67,6 +67,8 @@ _perf_test_thresholds = {"fwd_ms": 50, "bwd_ms": 150}
@parameterized_class(("device"), _devices) @parameterized_class(("device"), _devices)
class TestNeighborhoodAttentionS2(unittest.TestCase): class TestNeighborhoodAttentionS2(unittest.TestCase):
"""Test the neighborhood attention module (CPU/CUDA if available)."""
def setUp(self): def setUp(self):
torch.manual_seed(333) torch.manual_seed(333)
if self.device.type == "cuda": if self.device.type == "cuda":
......
...@@ -36,7 +36,6 @@ import torch ...@@ -36,7 +36,6 @@ import torch
class TestCacheConsistency(unittest.TestCase): class TestCacheConsistency(unittest.TestCase):
def test_consistency(self, verbose=False): def test_consistency(self, verbose=False):
if verbose: if verbose:
print("Testing that cache values does not get modified externally") print("Testing that cache values does not get modified externally")
......
...@@ -47,9 +47,7 @@ if torch.cuda.is_available(): ...@@ -47,9 +47,7 @@ if torch.cuda.is_available():
def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalization=False, basis_norm_mode="none", merge_quadrature=False, eps=1e-9): def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalization=False, basis_norm_mode="none", merge_quadrature=False, eps=1e-9):
""" """Discretely normalizes the convolution tensor."""
Discretely normalizes the convolution tensor.
"""
kernel_size, nlat_out, nlon_out, nlat_in, nlon_in = psi.shape kernel_size, nlat_out, nlon_out, nlat_in, nlon_in = psi.shape
correction_factor = nlon_out / nlon_in correction_factor = nlon_out / nlon_in
...@@ -98,10 +96,7 @@ def _precompute_convolution_tensor_dense( ...@@ -98,10 +96,7 @@ def _precompute_convolution_tensor_dense(
basis_norm_mode="none", basis_norm_mode="none",
merge_quadrature=False, merge_quadrature=False,
): ):
""" """Helper routine to compute the convolution Tensor in a dense fashion."""
Helper routine to compute the convolution Tensor in a dense fashion
"""
assert len(in_shape) == 2 assert len(in_shape) == 2
assert len(out_shape) == 2 assert len(out_shape) == 2
...@@ -168,6 +163,8 @@ def _precompute_convolution_tensor_dense( ...@@ -168,6 +163,8 @@ def _precompute_convolution_tensor_dense(
@parameterized_class(("device"), _devices) @parameterized_class(("device"), _devices)
class TestDiscreteContinuousConvolution(unittest.TestCase): class TestDiscreteContinuousConvolution(unittest.TestCase):
"""Test the discrete-continuous convolution module (CPU/CUDA if available)."""
def setUp(self): def setUp(self):
torch.manual_seed(333) torch.manual_seed(333)
if self.device.type == "cuda": if self.device.type == "cuda":
......
...@@ -41,10 +41,10 @@ import torch_harmonics.distributed as thd ...@@ -41,10 +41,10 @@ import torch_harmonics.distributed as thd
class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
"""Test the distributed discrete-continuous convolution module."""
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
# set up distributed # set up distributed
cls.world_rank = int(os.getenv("WORLD_RANK", 0)) cls.world_rank = int(os.getenv("WORLD_RANK", 0))
cls.grid_size_h = int(os.getenv("GRID_H", 1)) cls.grid_size_h = int(os.getenv("GRID_H", 1))
...@@ -118,6 +118,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): ...@@ -118,6 +118,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
dist.destroy_process_group(None) dist.destroy_process_group(None)
def _split_helper(self, tensor): def _split_helper(self, tensor):
with torch.no_grad(): with torch.no_grad():
# split in W # split in W
tensor_list_local = thd.split_tensor_along_dim(tensor, dim=-1, num_chunks=self.grid_size_w) tensor_list_local = thd.split_tensor_along_dim(tensor, dim=-1, num_chunks=self.grid_size_w)
...@@ -130,6 +131,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): ...@@ -130,6 +131,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
return tensor_local return tensor_local
def _gather_helper_fwd(self, tensor, B, C, convolution_dist): def _gather_helper_fwd(self, tensor, B, C, convolution_dist):
# we need the shapes # we need the shapes
lat_shapes = convolution_dist.lat_out_shapes lat_shapes = convolution_dist.lat_out_shapes
lon_shapes = convolution_dist.lon_out_shapes lon_shapes = convolution_dist.lon_out_shapes
...@@ -157,6 +159,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): ...@@ -157,6 +159,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
return tensor_gather return tensor_gather
def _gather_helper_bwd(self, tensor, B, C, convolution_dist): def _gather_helper_bwd(self, tensor, B, C, convolution_dist):
# we need the shapes # we need the shapes
lat_shapes = convolution_dist.lat_in_shapes lat_shapes = convolution_dist.lat_in_shapes
lon_shapes = convolution_dist.lon_in_shapes lon_shapes = convolution_dist.lon_in_shapes
...@@ -204,7 +207,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): ...@@ -204,7 +207,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
def test_distributed_disco_conv( def test_distributed_disco_conv(
self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, kernel_shape, basis_type, basis_norm_mode, groups, grid_in, grid_out, transpose, tol self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, kernel_shape, basis_type, basis_norm_mode, groups, grid_in, grid_out, transpose, tol
): ):
B, C, H, W = batch_size, num_chan, nlat_in, nlon_in B, C, H, W = batch_size, num_chan, nlat_in, nlon_in
disco_args = dict( disco_args = dict(
...@@ -238,9 +241,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): ...@@ -238,9 +241,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
# create tensors # create tensors
inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device) inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device)
#############################################################
# local conv # local conv
#############################################################
# FWD pass # FWD pass
inp_full.requires_grad = True inp_full.requires_grad = True
out_full = conv_local(inp_full) out_full = conv_local(inp_full)
...@@ -254,9 +255,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): ...@@ -254,9 +255,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
out_full.backward(ograd_full) out_full.backward(ograd_full)
igrad_full = inp_full.grad.clone() igrad_full = inp_full.grad.clone()
#############################################################
# distributed conv # distributed conv
#############################################################
# FWD pass # FWD pass
inp_local = self._split_helper(inp_full) inp_local = self._split_helper(inp_full)
inp_local.requires_grad = True inp_local.requires_grad = True
...@@ -268,9 +267,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): ...@@ -268,9 +267,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
out_local.backward(ograd_local) out_local.backward(ograd_local)
igrad_local = inp_local.grad.clone() igrad_local = inp_local.grad.clone()
#############################################################
# evaluate FWD pass # evaluate FWD pass
#############################################################
with torch.no_grad(): with torch.no_grad():
out_gather_full = self._gather_helper_fwd(out_local, B, C, conv_dist) out_gather_full = self._gather_helper_fwd(out_local, B, C, conv_dist)
err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2))) err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2)))
...@@ -278,9 +275,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): ...@@ -278,9 +275,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
print(f"final relative error of output: {err.item()}") print(f"final relative error of output: {err.item()}")
self.assertTrue(err.item() <= tol) self.assertTrue(err.item() <= tol)
#############################################################
# evaluate BWD pass # evaluate BWD pass
#############################################################
with torch.no_grad(): with torch.no_grad():
igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, conv_dist) igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, conv_dist)
......
...@@ -41,6 +41,7 @@ import torch_harmonics.distributed as thd ...@@ -41,6 +41,7 @@ import torch_harmonics.distributed as thd
class TestDistributedResampling(unittest.TestCase): class TestDistributedResampling(unittest.TestCase):
"""Test the distributed resampling module (CPU/CUDA if available)."""
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
...@@ -118,6 +119,7 @@ class TestDistributedResampling(unittest.TestCase): ...@@ -118,6 +119,7 @@ class TestDistributedResampling(unittest.TestCase):
dist.destroy_process_group(None) dist.destroy_process_group(None)
def _split_helper(self, tensor): def _split_helper(self, tensor):
with torch.no_grad(): with torch.no_grad():
# split in W # split in W
tensor_list_local = thd.split_tensor_along_dim(tensor, dim=-1, num_chunks=self.grid_size_w) tensor_list_local = thd.split_tensor_along_dim(tensor, dim=-1, num_chunks=self.grid_size_w)
...@@ -130,6 +132,7 @@ class TestDistributedResampling(unittest.TestCase): ...@@ -130,6 +132,7 @@ class TestDistributedResampling(unittest.TestCase):
return tensor_local return tensor_local
def _gather_helper_fwd(self, tensor, B, C, convolution_dist): def _gather_helper_fwd(self, tensor, B, C, convolution_dist):
# we need the shapes # we need the shapes
lat_shapes = convolution_dist.lat_out_shapes lat_shapes = convolution_dist.lat_out_shapes
lon_shapes = convolution_dist.lon_out_shapes lon_shapes = convolution_dist.lon_out_shapes
...@@ -157,6 +160,7 @@ class TestDistributedResampling(unittest.TestCase): ...@@ -157,6 +160,7 @@ class TestDistributedResampling(unittest.TestCase):
return tensor_gather return tensor_gather
def _gather_helper_bwd(self, tensor, B, C, resampling_dist): def _gather_helper_bwd(self, tensor, B, C, resampling_dist):
# we need the shapes # we need the shapes
lat_shapes = resampling_dist.lat_in_shapes lat_shapes = resampling_dist.lat_in_shapes
lon_shapes = resampling_dist.lon_in_shapes lon_shapes = resampling_dist.lon_in_shapes
...@@ -196,7 +200,7 @@ class TestDistributedResampling(unittest.TestCase): ...@@ -196,7 +200,7 @@ class TestDistributedResampling(unittest.TestCase):
def test_distributed_resampling( def test_distributed_resampling(
self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, grid_in, grid_out, mode, tol, verbose self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, grid_in, grid_out, mode, tol, verbose
): ):
B, C, H, W = batch_size, num_chan, nlat_in, nlon_in B, C, H, W = batch_size, num_chan, nlat_in, nlon_in
res_args = dict( res_args = dict(
...@@ -216,9 +220,7 @@ class TestDistributedResampling(unittest.TestCase): ...@@ -216,9 +220,7 @@ class TestDistributedResampling(unittest.TestCase):
# create tensors # create tensors
inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device) inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device)
#############################################################
# local conv # local conv
#############################################################
# FWD pass # FWD pass
inp_full.requires_grad = True inp_full.requires_grad = True
out_full = res_local(inp_full) out_full = res_local(inp_full)
...@@ -232,9 +234,7 @@ class TestDistributedResampling(unittest.TestCase): ...@@ -232,9 +234,7 @@ class TestDistributedResampling(unittest.TestCase):
out_full.backward(ograd_full) out_full.backward(ograd_full)
igrad_full = inp_full.grad.clone() igrad_full = inp_full.grad.clone()
#############################################################
# distributed conv # distributed conv
#############################################################
# FWD pass # FWD pass
inp_local = self._split_helper(inp_full) inp_local = self._split_helper(inp_full)
inp_local.requires_grad = True inp_local.requires_grad = True
...@@ -246,9 +246,7 @@ class TestDistributedResampling(unittest.TestCase): ...@@ -246,9 +246,7 @@ class TestDistributedResampling(unittest.TestCase):
out_local.backward(ograd_local) out_local.backward(ograd_local)
igrad_local = inp_local.grad.clone() igrad_local = inp_local.grad.clone()
#############################################################
# evaluate FWD pass # evaluate FWD pass
#############################################################
with torch.no_grad(): with torch.no_grad():
out_gather_full = self._gather_helper_fwd(out_local, B, C, res_dist) out_gather_full = self._gather_helper_fwd(out_local, B, C, res_dist)
err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2))) err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2)))
...@@ -256,9 +254,7 @@ class TestDistributedResampling(unittest.TestCase): ...@@ -256,9 +254,7 @@ class TestDistributedResampling(unittest.TestCase):
print(f"final relative error of output: {err.item()}") print(f"final relative error of output: {err.item()}")
self.assertTrue(err.item() <= tol) self.assertTrue(err.item() <= tol)
#############################################################
# evaluate BWD pass # evaluate BWD pass
#############################################################
with torch.no_grad(): with torch.no_grad():
igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, res_dist) igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, res_dist)
......
...@@ -41,10 +41,10 @@ import torch_harmonics.distributed as thd ...@@ -41,10 +41,10 @@ import torch_harmonics.distributed as thd
class TestDistributedSphericalHarmonicTransform(unittest.TestCase): class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
"""Test the distributed spherical harmonic transform module (CPU/CUDA if available)."""
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
# set up distributed # set up distributed
cls.world_rank = int(os.getenv("WORLD_RANK", 0)) cls.world_rank = int(os.getenv("WORLD_RANK", 0))
cls.grid_size_h = int(os.getenv("GRID_H", 1)) cls.grid_size_h = int(os.getenv("GRID_H", 1))
...@@ -163,6 +163,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase): ...@@ -163,6 +163,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
return tensor_gather return tensor_gather
def _gather_helper_bwd(self, tensor, B, C, transform_dist, vector): def _gather_helper_bwd(self, tensor, B, C, transform_dist, vector):
# we need the shapes # we need the shapes
lat_shapes = transform_dist.lat_shapes lat_shapes = transform_dist.lat_shapes
lon_shapes = transform_dist.lon_shapes lon_shapes = transform_dist.lon_shapes
...@@ -214,6 +215,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase): ...@@ -214,6 +215,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
] ]
) )
def test_distributed_sht(self, nlat, nlon, batch_size, num_chan, grid, vector, tol): def test_distributed_sht(self, nlat, nlon, batch_size, num_chan, grid, vector, tol):
B, C, H, W = batch_size, num_chan, nlat, nlon B, C, H, W = batch_size, num_chan, nlat, nlon
# set up handles # set up handles
...@@ -230,9 +232,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase): ...@@ -230,9 +232,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
else: else:
inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device) inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device)
#############################################################
# local transform # local transform
#############################################################
# FWD pass # FWD pass
inp_full.requires_grad = True inp_full.requires_grad = True
out_full = forward_transform_local(inp_full) out_full = forward_transform_local(inp_full)
...@@ -246,9 +246,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase): ...@@ -246,9 +246,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
out_full.backward(ograd_full) out_full.backward(ograd_full)
igrad_full = inp_full.grad.clone() igrad_full = inp_full.grad.clone()
#############################################################
# distributed transform # distributed transform
#############################################################
# FWD pass # FWD pass
inp_local = self._split_helper(inp_full) inp_local = self._split_helper(inp_full)
inp_local.requires_grad = True inp_local.requires_grad = True
...@@ -260,9 +258,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase): ...@@ -260,9 +258,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
out_local.backward(ograd_local) out_local.backward(ograd_local)
igrad_local = inp_local.grad.clone() igrad_local = inp_local.grad.clone()
#############################################################
# evaluate FWD pass # evaluate FWD pass
#############################################################
with torch.no_grad(): with torch.no_grad():
out_gather_full = self._gather_helper_fwd(out_local, B, C, forward_transform_dist, vector) out_gather_full = self._gather_helper_fwd(out_local, B, C, forward_transform_dist, vector)
err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2))) err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2)))
...@@ -270,9 +266,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase): ...@@ -270,9 +266,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
print(f"final relative error of output: {err.item()}") print(f"final relative error of output: {err.item()}")
self.assertTrue(err.item() <= tol) self.assertTrue(err.item() <= tol)
#############################################################
# evaluate BWD pass # evaluate BWD pass
#############################################################
with torch.no_grad(): with torch.no_grad():
igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, forward_transform_dist, vector) igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, forward_transform_dist, vector)
err = torch.mean(torch.norm(igrad_full - igrad_gather_full, p="fro", dim=(-1, -2)) / torch.norm(igrad_full, p="fro", dim=(-1, -2))) err = torch.mean(torch.norm(igrad_full - igrad_gather_full, p="fro", dim=(-1, -2)) / torch.norm(igrad_full, p="fro", dim=(-1, -2)))
...@@ -301,6 +295,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase): ...@@ -301,6 +295,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
] ]
) )
def test_distributed_isht(self, nlat, nlon, batch_size, num_chan, grid, vector, tol): def test_distributed_isht(self, nlat, nlon, batch_size, num_chan, grid, vector, tol):
B, C, H, W = batch_size, num_chan, nlat, nlon B, C, H, W = batch_size, num_chan, nlat, nlon
if vector: if vector:
...@@ -340,9 +335,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase): ...@@ -340,9 +335,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
out_full.backward(ograd_full) out_full.backward(ograd_full)
igrad_full = inp_full.grad.clone() igrad_full = inp_full.grad.clone()
#############################################################
# distributed transform # distributed transform
#############################################################
# FWD pass # FWD pass
inp_local = self._split_helper(inp_full) inp_local = self._split_helper(inp_full)
inp_local.requires_grad = True inp_local.requires_grad = True
...@@ -354,9 +347,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase): ...@@ -354,9 +347,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
out_local.backward(ograd_local) out_local.backward(ograd_local)
igrad_local = inp_local.grad.clone() igrad_local = inp_local.grad.clone()
#############################################################
# evaluate FWD pass # evaluate FWD pass
#############################################################
with torch.no_grad(): with torch.no_grad():
out_gather_full = self._gather_helper_bwd(out_local, B, C, backward_transform_dist, vector) out_gather_full = self._gather_helper_bwd(out_local, B, C, backward_transform_dist, vector)
err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2))) err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2)))
...@@ -364,9 +355,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase): ...@@ -364,9 +355,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
print(f"final relative error of output: {err.item()}") print(f"final relative error of output: {err.item()}")
self.assertTrue(err.item() <= tol) self.assertTrue(err.item() <= tol)
#############################################################
# evaluate BWD pass # evaluate BWD pass
#############################################################
with torch.no_grad(): with torch.no_grad():
igrad_gather_full = self._gather_helper_fwd(igrad_local, B, C, backward_transform_dist, vector) igrad_gather_full = self._gather_helper_fwd(igrad_local, B, C, backward_transform_dist, vector)
err = torch.mean(torch.norm(igrad_full - igrad_gather_full, p="fro", dim=(-1, -2)) / torch.norm(igrad_full, p="fro", dim=(-1, -2))) err = torch.mean(torch.norm(igrad_full - igrad_gather_full, p="fro", dim=(-1, -2)) / torch.norm(igrad_full, p="fro", dim=(-1, -2)))
......
...@@ -42,7 +42,7 @@ if torch.cuda.is_available(): ...@@ -42,7 +42,7 @@ if torch.cuda.is_available():
class TestLegendrePolynomials(unittest.TestCase): class TestLegendrePolynomials(unittest.TestCase):
"""Test the associated Legendre polynomials (CPU/CUDA if available)."""
def setUp(self): def setUp(self):
self.cml = lambda m, l: math.sqrt((2 * l + 1) / 4 / math.pi) * math.sqrt(math.factorial(l - m) / math.factorial(l + m)) self.cml = lambda m, l: math.sqrt((2 * l + 1) / 4 / math.pi) * math.sqrt(math.factorial(l - m) / math.factorial(l + m))
self.pml = dict() self.pml = dict()
...@@ -79,7 +79,7 @@ class TestLegendrePolynomials(unittest.TestCase): ...@@ -79,7 +79,7 @@ class TestLegendrePolynomials(unittest.TestCase):
@parameterized_class(("device"), _devices) @parameterized_class(("device"), _devices)
class TestSphericalHarmonicTransform(unittest.TestCase): class TestSphericalHarmonicTransform(unittest.TestCase):
"""Test the spherical harmonic transform (CPU/CUDA if available)."""
def setUp(self): def setUp(self):
torch.manual_seed(333) torch.manual_seed(333)
if self.device.type == "cuda": if self.device.type == "cuda":
......
...@@ -42,7 +42,7 @@ except ImportError as err: ...@@ -42,7 +42,7 @@ except ImportError as err:
# some helper functions # some helper functions
def _get_psi(kernel_size: int, psi_idx: torch.Tensor, psi_vals: torch.Tensor, nlat_in: int, nlon_in: int, nlat_out: int, nlon_out: int, nlat_in_local: Optional[int] = None, nlat_out_local: Optional[int] = None, semi_transposed: Optional[bool] = False): def _get_psi(kernel_size: int, psi_idx: torch.Tensor, psi_vals: torch.Tensor, nlat_in: int, nlon_in: int, nlat_out: int, nlon_out: int, nlat_in_local: Optional[int] = None, nlat_out_local: Optional[int] = None, semi_transposed: Optional[bool] = False):
"""Creates a sparse tensor for spherical harmonic convolution operations."""
nlat_in_local = nlat_in_local if nlat_in_local is not None else nlat_in nlat_in_local = nlat_in_local if nlat_in_local is not None else nlat_in
nlat_out_local = nlat_out_local if nlat_out_local is not None else nlat_out nlat_out_local = nlat_out_local if nlat_out_local is not None else nlat_out
...@@ -67,6 +67,7 @@ class _DiscoS2ContractionCuda(torch.autograd.Function): ...@@ -67,6 +67,7 @@ class _DiscoS2ContractionCuda(torch.autograd.Function):
def forward(ctx, x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, def forward(ctx, x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor,
row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor,
kernel_size: int, nlat_out: int, nlon_out: int): kernel_size: int, nlat_out: int, nlon_out: int):
ctx.save_for_backward(roff_idx, ker_idx, row_idx, col_idx, vals) ctx.save_for_backward(roff_idx, ker_idx, row_idx, col_idx, vals)
ctx.kernel_size = kernel_size ctx.kernel_size = kernel_size
ctx.nlat_in = x.shape[-2] ctx.nlat_in = x.shape[-2]
...@@ -81,6 +82,7 @@ class _DiscoS2ContractionCuda(torch.autograd.Function): ...@@ -81,6 +82,7 @@ class _DiscoS2ContractionCuda(torch.autograd.Function):
@staticmethod @staticmethod
@custom_bwd(device_type="cuda") @custom_bwd(device_type="cuda")
def backward(ctx, grad_output): def backward(ctx, grad_output):
roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors
gtype = grad_output.dtype gtype = grad_output.dtype
grad_output = grad_output.to(torch.float32).contiguous() grad_output = grad_output.to(torch.float32).contiguous()
...@@ -97,6 +99,7 @@ class _DiscoS2TransposeContractionCuda(torch.autograd.Function): ...@@ -97,6 +99,7 @@ class _DiscoS2TransposeContractionCuda(torch.autograd.Function):
def forward(ctx, x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, def forward(ctx, x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor,
row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor,
kernel_size: int, nlat_out: int, nlon_out: int): kernel_size: int, nlat_out: int, nlon_out: int):
ctx.save_for_backward(roff_idx, ker_idx, row_idx, col_idx, vals) ctx.save_for_backward(roff_idx, ker_idx, row_idx, col_idx, vals)
ctx.kernel_size = kernel_size ctx.kernel_size = kernel_size
ctx.nlat_in = x.shape[-2] ctx.nlat_in = x.shape[-2]
...@@ -111,6 +114,7 @@ class _DiscoS2TransposeContractionCuda(torch.autograd.Function): ...@@ -111,6 +114,7 @@ class _DiscoS2TransposeContractionCuda(torch.autograd.Function):
@staticmethod @staticmethod
@custom_bwd(device_type="cuda") @custom_bwd(device_type="cuda")
def backward(ctx, grad_output): def backward(ctx, grad_output):
roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors
gtype = grad_output.dtype gtype = grad_output.dtype
grad_output = grad_output.to(torch.float32).contiguous() grad_output = grad_output.to(torch.float32).contiguous()
...@@ -140,6 +144,7 @@ def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: in ...@@ -140,6 +144,7 @@ def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: in
shifting of the input tensor, which can potentially be costly. For an efficient implementation shifting of the input tensor, which can potentially be costly. For an efficient implementation
on GPU, make sure to use the custom kernel written in CUDA. on GPU, make sure to use the custom kernel written in CUDA.
""" """
assert len(psi.shape) == 3 assert len(psi.shape) == 3
assert len(x.shape) == 4 assert len(x.shape) == 4
psi = psi.to(x.device) psi = psi.to(x.device)
...@@ -171,11 +176,6 @@ def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: in ...@@ -171,11 +176,6 @@ def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: in
def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: int): def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: int):
"""
Reference implementation of the custom contraction as described in [1]. This requires repeated
shifting of the input tensor, which can potentially be costly. For an efficient implementation
on GPU, make sure to use the custom kernel written in CUDA.
"""
assert len(psi.shape) == 3 assert len(psi.shape) == 3
assert len(x.shape) == 5 assert len(x.shape) == 5
psi = psi.to(x.device) psi = psi.to(x.device)
......
...@@ -50,8 +50,6 @@ except ImportError as err: ...@@ -50,8 +50,6 @@ except ImportError as err:
def _neighborhood_attention_s2_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, def _neighborhood_attention_s2_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor,
quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
# prepare result tensor # prepare result tensor
y = torch.zeros_like(qy) y = torch.zeros_like(qy)
...@@ -170,7 +168,6 @@ def _neighborhood_attention_s2_bwd_dv_torch(kx: torch.Tensor, vx: torch.Tensor, ...@@ -170,7 +168,6 @@ def _neighborhood_attention_s2_bwd_dv_torch(kx: torch.Tensor, vx: torch.Tensor,
def _neighborhood_attention_s2_bwd_dk_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor, def _neighborhood_attention_s2_bwd_dk_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor,
quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
nlon_in: int, nlat_out: int, nlon_out: int): nlon_in: int, nlat_out: int, nlon_out: int):
# shapes: # shapes:
# input # input
# kx: B, C, Hi, Wi # kx: B, C, Hi, Wi
...@@ -252,6 +249,7 @@ def _neighborhood_attention_s2_bwd_dq_torch(kx: torch.Tensor, vx: torch.Tensor, ...@@ -252,6 +249,7 @@ def _neighborhood_attention_s2_bwd_dq_torch(kx: torch.Tensor, vx: torch.Tensor,
quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
nlon_in: int, nlat_out: int, nlon_out: int): nlon_in: int, nlat_out: int, nlon_out: int):
# shapes: # shapes:
# input # input
# kx: B, C, Hi, Wi # kx: B, C, Hi, Wi
...@@ -329,7 +327,7 @@ class _NeighborhoodAttentionS2(torch.autograd.Function): ...@@ -329,7 +327,7 @@ class _NeighborhoodAttentionS2(torch.autograd.Function):
bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None], bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None],
quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
nh: int, nlon_in: int, nlat_out: int, nlon_out: int): nh: int, nlon_in: int, nlat_out: int, nlon_out: int):
ctx.save_for_backward(col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq) ctx.save_for_backward(col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq)
ctx.nh = nh ctx.nh = nh
ctx.nlon_in = nlon_in ctx.nlon_in = nlon_in
...@@ -443,7 +441,7 @@ def _neighborhood_attention_s2_torch(k: torch.Tensor, v: torch.Tensor, q: torch. ...@@ -443,7 +441,7 @@ def _neighborhood_attention_s2_torch(k: torch.Tensor, v: torch.Tensor, q: torch.
bq: Union[torch.Tensor, None], quad_weights: torch.Tensor, bq: Union[torch.Tensor, None], quad_weights: torch.Tensor,
col_idx: torch.Tensor, row_off: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
return _NeighborhoodAttentionS2.apply(k, v, q, wk, wv, wq, bk, bv, bq, return _NeighborhoodAttentionS2.apply(k, v, q, wk, wv, wq, bk, bv, bq,
quad_weights, col_idx, row_off, quad_weights, col_idx, row_off,
nh, nlon_in, nlat_out, nlon_out) nh, nlon_in, nlat_out, nlon_out)
...@@ -451,6 +449,7 @@ def _neighborhood_attention_s2_torch(k: torch.Tensor, v: torch.Tensor, q: torch. ...@@ -451,6 +449,7 @@ def _neighborhood_attention_s2_torch(k: torch.Tensor, v: torch.Tensor, q: torch.
class _NeighborhoodAttentionS2Cuda(torch.autograd.Function): class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(device_type="cuda") @custom_fwd(device_type="cuda")
def forward(ctx, k: torch.Tensor, v: torch.Tensor, q: torch.Tensor, def forward(ctx, k: torch.Tensor, v: torch.Tensor, q: torch.Tensor,
...@@ -458,7 +457,7 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function): ...@@ -458,7 +457,7 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None], bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None],
quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int): max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int):
ctx.save_for_backward(col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq) ctx.save_for_backward(col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq)
ctx.nh = nh ctx.nh = nh
ctx.max_psi_nnz = max_psi_nnz ctx.max_psi_nnz = max_psi_nnz
...@@ -584,7 +583,7 @@ def _neighborhood_attention_s2_cuda(k: torch.Tensor, v: torch.Tensor, q: torch.T ...@@ -584,7 +583,7 @@ def _neighborhood_attention_s2_cuda(k: torch.Tensor, v: torch.Tensor, q: torch.T
bq: Union[torch.Tensor, None], quad_weights: torch.Tensor, bq: Union[torch.Tensor, None], quad_weights: torch.Tensor,
col_idx: torch.Tensor, row_off: torch.Tensor, max_psi_nnz: int, col_idx: torch.Tensor, row_off: torch.Tensor, max_psi_nnz: int,
nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
return _NeighborhoodAttentionS2Cuda.apply(k, v, q, wk, wv, wq, bk, bv, bq, return _NeighborhoodAttentionS2Cuda.apply(k, v, q, wk, wv, wq, bk, bv, bq,
quad_weights, col_idx, row_off, max_psi_nnz, quad_weights, col_idx, row_off, max_psi_nnz,
nh, nlon_in, nlat_out, nlon_out) nh, nlon_in, nlat_out, nlon_out)
...@@ -142,9 +142,6 @@ class AttentionS2(nn.Module): ...@@ -142,9 +142,6 @@ class AttentionS2(nn.Module):
def extra_repr(self): def extra_repr(self):
r"""
Pretty print module
"""
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_channels={self.in_channels}, out_channels={self.out_channels}, k_channels={self.k_channels}" return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_channels={self.in_channels}, out_channels={self.out_channels}, k_channels={self.k_channels}"
def forward(self, query: torch.Tensor, key: Optional[torch.Tensor] = None, value: Optional[torch.Tensor] = None) -> torch.Tensor: def forward(self, query: torch.Tensor, key: Optional[torch.Tensor] = None, value: Optional[torch.Tensor] = None) -> torch.Tensor:
...@@ -317,9 +314,6 @@ class NeighborhoodAttentionS2(nn.Module): ...@@ -317,9 +314,6 @@ class NeighborhoodAttentionS2(nn.Module):
self.proj_bias = None self.proj_bias = None
def extra_repr(self): def extra_repr(self):
r"""
Pretty print module
"""
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_channels={self.in_channels}, out_channels={self.out_channels}, k_channels={self.k_channels}" return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_channels={self.in_channels}, out_channels={self.out_channels}, k_channels={self.k_channels}"
def forward(self, query: torch.Tensor, key: Optional[torch.Tensor] = None, value: Optional[torch.Tensor] = None) -> torch.Tensor: def forward(self, query: torch.Tensor, key: Optional[torch.Tensor] = None, value: Optional[torch.Tensor] = None) -> torch.Tensor:
......
...@@ -35,6 +35,32 @@ from copy import deepcopy ...@@ -35,6 +35,32 @@ from copy import deepcopy
# copying LRU cache decorator a la: # copying LRU cache decorator a la:
# https://stackoverflow.com/questions/54909357/how-to-get-functools-lru-cache-to-return-new-instances # https://stackoverflow.com/questions/54909357/how-to-get-functools-lru-cache-to-return-new-instances
def lru_cache(maxsize=20, typed=False, copy=False): def lru_cache(maxsize=20, typed=False, copy=False):
"""
Least Recently Used (LRU) cache decorator with optional deep copying.
This is a wrapper around functools.lru_cache that adds the ability to return
deep copies of cached results to prevent unintended modifications to cached objects.
Parameters
-----------
maxsize : int, optional
Maximum number of items to cache, by default 20
typed : bool, optional
Whether to cache different types separately, by default False
copy : bool, optional
Whether to return deep copies of cached results, by default False
Returns
-------
function
Decorated function with LRU caching
Example
-------
>>> @lru_cache(maxsize=10, copy=True)
... def expensive_function(x):
... return [x, x*2, x*3]
"""
def decorator(f): def decorator(f):
cached_func = functools.lru_cache(maxsize=maxsize, typed=typed)(f) cached_func = functools.lru_cache(maxsize=maxsize, typed=typed)(f)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment