Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
c7afb546
Unverified
Commit
c7afb546
authored
Jul 21, 2025
by
Thorsten Kurth
Committed by
GitHub
Jul 21, 2025
Browse files
Merge pull request #95 from NVIDIA/aparis/docs
Docstrings PR
parents
b5c410c0
644465ba
Pipeline
#2854
canceled with stages
Changes
44
Pipelines
2
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
347 additions
and
83 deletions
+347
-83
Changelog.md
Changelog.md
+1
-0
examples/baseline_models/segformer.py
examples/baseline_models/segformer.py
+145
-3
examples/baseline_models/transformer.py
examples/baseline_models/transformer.py
+65
-1
examples/baseline_models/unet.py
examples/baseline_models/unet.py
+68
-3
examples/depth/train.py
examples/depth/train.py
+0
-3
examples/model_registry.py
examples/model_registry.py
+1
-1
examples/segmentation/train.py
examples/segmentation/train.py
+3
-4
examples/shallow_water_equations/train.py
examples/shallow_water_equations/train.py
+1
-4
setup.py
setup.py
+3
-1
tests/test_attention.py
tests/test_attention.py
+2
-0
tests/test_cache.py
tests/test_cache.py
+0
-1
tests/test_convolution.py
tests/test_convolution.py
+4
-7
tests/test_distributed_convolution.py
tests/test_distributed_convolution.py
+5
-10
tests/test_distributed_resample.py
tests/test_distributed_resample.py
+5
-9
tests/test_distributed_sht.py
tests/test_distributed_sht.py
+4
-15
tests/test_sht.py
tests/test_sht.py
+2
-2
torch_harmonics/_disco_convolution.py
torch_harmonics/_disco_convolution.py
+6
-6
torch_harmonics/_neighborhood_attention.py
torch_harmonics/_neighborhood_attention.py
+6
-7
torch_harmonics/attention.py
torch_harmonics/attention.py
+0
-6
torch_harmonics/cache.py
torch_harmonics/cache.py
+26
-0
No files found.
Changelog.md
View file @
c7afb546
...
...
@@ -14,6 +14,7 @@
*
Reorganized examples folder, including new examples based on the 2d3ds dataset
*
Added spherical loss functions to examples
*
Added plotting module
*
Updated docstrings
### v0.7.6
...
...
examples/baseline_models/segformer.py
View file @
c7afb546
...
...
@@ -41,6 +41,24 @@ from functools import partial
class
OverlapPatchMerging
(
nn
.
Module
):
"""
OverlapPatchMerging layer for merging patches.
Parameters
-----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
kernel_shape : tuple
Kernel shape for convolution
bias : bool, optional
Whether to use bias, by default False
"""
def
__init__
(
self
,
in_shape
=
(
721
,
1440
),
...
...
@@ -88,6 +106,30 @@ class OverlapPatchMerging(nn.Module):
class
MixFFN
(
nn
.
Module
):
"""
MixFFN module combining MLP and depthwise convolution.
Parameters
-----------
shape : tuple
Input shape (height, width)
inout_channels : int
Number of input/output channels
hidden_channels : int
Number of hidden channels in MLP
mlp_bias : bool, optional
Whether to use bias in MLP layers, by default True
kernel_shape : tuple, optional
Kernel shape for depthwise convolution, by default (3, 3)
conv_bias : bool, optional
Whether to use bias in convolution, by default False
activation : callable, optional
Activation function, by default nn.GELU
use_mlp : bool, optional
Whether to use MLP instead of linear layers, by default False
drop_path : float, optional
Drop path rate, by default 0.0
"""
def
__init__
(
self
,
shape
,
...
...
@@ -142,7 +184,7 @@ class MixFFN(nn.Module):
x
=
x
.
permute
(
0
,
3
,
1
,
2
)
# NOTE: we add another activation here
# because in the paper the
y
only use depthwise conv,
# because in the paper the
authors
only use depthwise conv,
# but without this activation it would just be a fused MM
# with the disco conv
x
=
self
.
mlp_in
(
x
)
...
...
@@ -162,6 +204,17 @@ class GlobalAttention(nn.Module):
Input shape: (B, C, H, W)
Output shape: (B, C, H, W) with residual skip.
Parameters
-----------
chans : int
Number of channels
num_heads : int, optional
Number of attention heads, by default 8
dropout : float, optional
Dropout rate, by default 0.0
bias : bool, optional
Whether to use bias, by default True
"""
def
__init__
(
self
,
chans
,
num_heads
=
8
,
dropout
=
0.0
,
bias
=
True
):
...
...
@@ -169,6 +222,7 @@ class GlobalAttention(nn.Module):
self
.
attn
=
nn
.
MultiheadAttention
(
embed_dim
=
chans
,
num_heads
=
num_heads
,
dropout
=
dropout
,
batch_first
=
True
,
bias
=
bias
)
def
forward
(
self
,
x
):
# x: B, C, H, W
B
,
H
,
W
,
C
=
x
.
shape
# flatten spatial dims
...
...
@@ -181,6 +235,30 @@ class GlobalAttention(nn.Module):
class
AttentionWrapper
(
nn
.
Module
):
"""
Wrapper for different attention mechanisms.
Parameters
-----------
channels : int
Number of channels
shape : tuple
Input shape (height, width)
heads : int
Number of attention heads
pre_norm : bool, optional
Whether to apply normalization before attention, by default False
attention_drop_rate : float, optional
Attention dropout rate, by default 0.0
drop_path : float, optional
Drop path rate, by default 0.0
attention_mode : str, optional
Attention mode ("neighborhood", "global"), by default "neighborhood"
kernel_shape : tuple, optional
Kernel shape for neighborhood attention, by default (7, 7)
bias : bool, optional
Whether to use bias, by default True
"""
def
__init__
(
self
,
channels
,
shape
,
heads
,
pre_norm
=
False
,
attention_drop_rate
=
0.0
,
drop_path
=
0.0
,
attention_mode
=
"neighborhood"
,
kernel_shape
=
(
7
,
7
),
bias
=
True
):
super
().
__init__
()
...
...
@@ -203,11 +281,13 @@ class AttentionWrapper(nn.Module):
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
residual
=
x
x
=
x
.
permute
(
0
,
2
,
3
,
1
)
if
self
.
norm
is
not
None
:
...
...
@@ -219,6 +299,41 @@ class AttentionWrapper(nn.Module):
class
TransformerBlock
(
nn
.
Module
):
"""
Transformer block with attention and MLP.
Parameters
----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
mlp_hidden_channels : int
Number of hidden channels in MLP
nrep : int, optional
Number of repetitions of attention and MLP blocks, by default 1
heads : int, optional
Number of attention heads, by default 1
kernel_shape : tuple, optional
Kernel shape for neighborhood attention, by default (3, 3)
activation : torch.nn.Module, optional
Activation function to use, by default nn.GELU
att_drop_rate : float, optional
Attention dropout rate, by default 0.0
drop_path_rates : float or list, optional
Drop path rates for each block, by default 0.0
attention_mode : str, optional
Attention mode ("neighborhood", "global"), by default "neighborhood"
attn_kernel_shape : tuple, optional
Kernel shape for neighborhood attention, by default (7, 7)
bias : bool, optional
Whether to use bias, by default True
"""
def
__init__
(
self
,
in_shape
,
...
...
@@ -341,6 +456,33 @@ class TransformerBlock(nn.Module):
class
Upsampling
(
nn
.
Module
):
"""
Upsampling block for the Segformer model.
Parameters
----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
hidden_channels : int
Number of hidden channels in MLP
mlp_bias : bool, optional
Whether to use bias in MLP, by default True
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
conv_bias : bool, optional
Whether to use bias in convolution, by default False
activation : torch.nn.Module, optional
Activation function to use, by default nn.GELU
use_mlp : bool, optional
Whether to use MLP, by default False
"""
def
__init__
(
self
,
in_shape
,
...
...
@@ -382,7 +524,7 @@ class Segformer(nn.Module):
Spherical segformer model designed to approximate mappings from spherical signals to spherical segmentation masks
Parameters
----------
-
----------
img_shape : tuple, optional
Shape of the input channels, by default (128, 256)
kernel_shape: tuple, int
...
...
@@ -414,7 +556,7 @@ class Segformer(nn.Module):
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
Example
----------
-
----------
>>> model = Segformer(
... img_size=(128, 256),
... in_chans=3,
...
...
examples/baseline_models/transformer.py
View file @
c7afb546
...
...
@@ -57,11 +57,34 @@ class Encoder(nn.Module):
self
.
conv
=
nn
.
Conv2d
(
in_chans
,
out_chans
,
kernel_size
=
kernel_shape
,
bias
=
bias
,
stride
=
(
stride_h
,
stride_w
),
padding
=
(
pad_h
,
pad_w
),
groups
=
groups
)
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
return
x
class
Decoder
(
nn
.
Module
):
"""
Decoder module for upsampling and feature processing.
Parameters
-----------
in_shape : tuple, optional
Input shape (height, width), by default (480, 960)
out_shape : tuple, optional
Output shape (height, width), by default (721, 1440)
in_chans : int, optional
Number of input channels, by default 2
out_chans : int, optional
Number of output channels, by default 2
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
groups : int, optional
Number of groups for convolution, by default 1
bias : bool, optional
Whether to use bias, by default False
upsampling_method : str, optional
Upsampling method ("conv", "pixel_shuffle"), by default "conv"
"""
def
__init__
(
self
,
in_shape
=
(
480
,
960
),
out_shape
=
(
721
,
1440
),
in_chans
=
2
,
out_chans
=
2
,
kernel_shape
=
(
3
,
3
),
groups
=
1
,
bias
=
False
,
upsampling_method
=
"conv"
):
super
().
__init__
()
self
.
out_shape
=
out_shape
...
...
@@ -87,6 +110,7 @@ class Decoder(nn.Module):
raise
ValueError
(
f
"Unknown upsampling method
{
upsampling_method
}
"
)
def
forward
(
self
,
x
):
x
=
self
.
upsample
(
x
)
return
x
...
...
@@ -97,6 +121,17 @@ class GlobalAttention(nn.Module):
Input shape: (B, C, H, W)
Output shape: (B, C, H, W) with residual skip.
Parameters
-----------
chans : int
Number of channels
num_heads : int, optional
Number of attention heads, by default 8
dropout : float, optional
Dropout rate, by default 0.0
bias : bool, optional
Whether to use bias, by default True
"""
def
__init__
(
self
,
chans
,
num_heads
=
8
,
dropout
=
0.0
,
bias
=
True
):
...
...
@@ -104,6 +139,7 @@ class GlobalAttention(nn.Module):
self
.
attn
=
nn
.
MultiheadAttention
(
embed_dim
=
chans
,
num_heads
=
num_heads
,
dropout
=
dropout
,
batch_first
=
True
,
bias
=
bias
)
def
forward
(
self
,
x
):
# x: B, C, H, W
B
,
H
,
W
,
C
=
x
.
shape
# flatten spatial dims
...
...
@@ -118,8 +154,36 @@ class GlobalAttention(nn.Module):
class
AttentionBlock
(
nn
.
Module
):
"""
Neighborhood attention block based on Natten.
Parameters
-----------
in_shape : tuple, optional
Input shape (height, width), by default (480, 960)
out_shape : tuple, optional
Output shape (height, width), by default (480, 960)
chans : int, optional
Number of channels, by default 2
num_heads : int, optional
Number of attention heads, by default 1
mlp_ratio : float, optional
Ratio of MLP hidden dim to input dim, by default 2.0
drop_rate : float, optional
Dropout rate, by default 0.0
drop_path : float, optional
Drop path rate, by default 0.0
act_layer : callable, optional
Activation function, by default nn.GELU
norm_layer : str, optional
Normalization layer type, by default "none"
use_mlp : bool, optional
Whether to use MLP, by default True
bias : bool, optional
Whether to use bias, by default True
attention_mode : str, optional
Attention mode ("neighborhood", "global"), by default "neighborhood"
attn_kernel_shape : tuple, optional
Kernel shape for neighborhood attention, by default (7, 7)
"""
def
__init__
(
self
,
in_shape
=
(
480
,
960
),
...
...
examples/baseline_models/unet.py
View file @
c7afb546
...
...
@@ -43,6 +43,37 @@ from functools import partial
class
DownsamplingBlock
(
nn
.
Module
):
"""
Downsampling block for the UNet model.
Parameters
----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
nrep : int, optional
Number of repetitions of conv blocks, by default 1
kernel_shape : tuple, optional
Kernel shape for convolutions, by default (3, 3)
activation : callable, optional
Activation function, by default nn.ReLU
transform_skip : bool, optional
Whether to transform skip connections, by default False
drop_conv_rate : float, optional
Dropout rate for convolutions, by default 0.
drop_path_rate : float, optional
Drop path rate, by default 0.
drop_dense_rate : float, optional
Dropout rate for dense layers, by default 0.
downsampling_mode : str, optional
Downsampling mode ("bilinear", "conv"), by default "bilinear"
"""
def
__init__
(
self
,
in_shape
,
...
...
@@ -146,6 +177,7 @@ class DownsamplingBlock(nn.Module):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# skip connection
residual
=
x
if
hasattr
(
self
,
"transform_skip"
):
...
...
@@ -166,6 +198,36 @@ class DownsamplingBlock(nn.Module):
class
UpsamplingBlock
(
nn
.
Module
):
"""
Upsampling block for UNet architecture.
Parameters
-----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
nrep : int, optional
Number of repetitions of conv blocks, by default 1
kernel_shape : tuple, optional
Kernel shape for convolutions, by default (3, 3)
activation : callable, optional
Activation function, by default nn.ReLU
transform_skip : bool, optional
Whether to transform skip connections, by default False
drop_conv_rate : float, optional
Dropout rate for convolutions, by default 0.
drop_path_rate : float, optional
Drop path rate, by default 0.
drop_dense_rate : float, optional
Dropout rate for dense layers, by default 0.
upsampling_mode : str, optional
Upsampling mode ("bilinear", "conv"), by default "bilinear"
"""
def
__init__
(
self
,
in_shape
,
...
...
@@ -280,6 +342,7 @@ class UpsamplingBlock(nn.Module):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# skip connection
residual
=
x
if
hasattr
(
self
,
"transform_skip"
):
...
...
@@ -304,6 +367,7 @@ class UNet(nn.Module):
img_shape : tuple, optional
Shape of the input channels, by default (128, 256)
kernel_shape: tuple, int
Kernel shape for convolutions
scale_factor: int, optional
Scale factor to use, by default 2
in_chans : int, optional
...
...
@@ -336,11 +400,12 @@ class UNet(nn.Module):
... scale_factor=4,
... in_chans=2,
... num_classes=2,
... embed_dims=[64, 128, 256, 512],)
... embed_dims=[16, 32, 64, 128],
... depths=[2, 2, 2, 2],
... use_mlp=True,)
>>> model(torch.randn(1, 2, 128, 256)).shape
torch.Size([1, 2, 128, 256])
"""
def
__init__
(
self
,
img_shape
=
(
128
,
256
),
...
...
@@ -450,7 +515,7 @@ class UNet(nn.Module):
def
forward
(
self
,
x
):
# encoder:
features
=
[]
feat
=
x
...
...
examples/depth/train.py
View file @
c7afb546
...
...
@@ -68,9 +68,6 @@ def count_parameters(model):
# convenience function for logging weights and gradients
def
log_weights_and_grads
(
exp_dir
,
model
,
iters
=
1
):
"""
Helper routine intended for debugging purposes
"""
log_path
=
os
.
path
.
join
(
exp_dir
,
"weights_and_grads"
)
if
not
os
.
path
.
isdir
(
log_path
):
os
.
makedirs
(
log_path
,
exist_ok
=
True
)
...
...
examples/model_registry.py
View file @
c7afb546
...
...
@@ -39,7 +39,7 @@ from baseline_models import Transformer, UNet, Segformer
from
torch_harmonics.examples.models
import
SphericalFourierNeuralOperator
,
LocalSphericalNeuralOperator
,
SphericalTransformer
,
SphericalUNet
,
SphericalSegformer
def
get_baseline_models
(
img_size
=
(
128
,
256
),
in_chans
=
3
,
out_chans
=
3
,
residual_prediction
=
False
,
drop_path_rate
=
0.
,
grid
=
"equiangular"
):
# prepare dicts containing models and corresponding metrics
model_registry
=
dict
(
sfno_sc2_layers4_e32
=
partial
(
...
...
examples/segmentation/train.py
View file @
c7afb546
...
...
@@ -68,14 +68,13 @@ import wandb
# helper routine for counting number of paramerters in model
def
count_parameters
(
model
):
return
sum
(
p
.
numel
()
for
p
in
model
.
parameters
()
if
p
.
requires_grad
)
# convenience function for logging weights and gradients
def
log_weights_and_grads
(
exp_dir
,
model
,
iters
=
1
):
"""
Helper routine intended for debugging purposes
"""
log_path
=
os
.
path
.
join
(
exp_dir
,
"weights_and_grads"
)
if
not
os
.
path
.
isdir
(
log_path
):
os
.
makedirs
(
log_path
,
exist_ok
=
True
)
...
...
@@ -178,7 +177,7 @@ def train_model(
logging
=
True
,
device
=
torch
.
device
(
"cpu"
),
):
train_start
=
time
.
time
()
# set AMP type
...
...
examples/shallow_water_equations/train.py
View file @
c7afb546
...
...
@@ -68,9 +68,6 @@ def count_parameters(model):
# convenience function for logging weights and gradients
def
log_weights_and_grads
(
model
,
iters
=
1
):
"""
Helper routine intended for debugging purposes
"""
root_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"weights_and_grads"
)
weights_and_grads_fname
=
os
.
path
.
join
(
root_path
,
f
"weights_and_grads_step
{
iters
:
03
d
}
.tar"
)
...
...
@@ -238,7 +235,7 @@ def train_model(
logging
=
True
,
device
=
torch
.
device
(
"cpu"
),
):
train_start
=
time
.
time
()
# set AMP type
...
...
setup.py
View file @
c7afb546
...
...
@@ -55,6 +55,7 @@ except (ImportError, TypeError, AssertionError, AttributeError) as e:
def
get_compile_args
(
module_name
):
"""If user runs build with TORCH_HARMONICS_DEBUG=1 set, it will use debugging flags to build"""
debug_mode
=
os
.
environ
.
get
(
'TORCH_HARMONICS_DEBUG'
,
'0'
)
==
'1'
profile_mode
=
os
.
environ
.
get
(
'TORCH_HARMONICS_PROFILE'
,
'0'
)
==
'1'
...
...
@@ -77,7 +78,8 @@ def get_compile_args(module_name):
}
def
get_ext_modules
():
"""Get list of extension modules to compile."""
ext_modules
=
[]
cmdclass
=
{}
...
...
tests/test_attention.py
View file @
c7afb546
...
...
@@ -67,6 +67,8 @@ _perf_test_thresholds = {"fwd_ms": 50, "bwd_ms": 150}
@
parameterized_class
((
"device"
),
_devices
)
class
TestNeighborhoodAttentionS2
(
unittest
.
TestCase
):
"""Test the neighborhood attention module (CPU/CUDA if available)."""
def
setUp
(
self
):
torch
.
manual_seed
(
333
)
if
self
.
device
.
type
==
"cuda"
:
...
...
tests/test_cache.py
View file @
c7afb546
...
...
@@ -36,7 +36,6 @@ import torch
class
TestCacheConsistency
(
unittest
.
TestCase
):
def
test_consistency
(
self
,
verbose
=
False
):
if
verbose
:
print
(
"Testing that cache values does not get modified externally"
)
...
...
tests/test_convolution.py
View file @
c7afb546
...
...
@@ -47,9 +47,7 @@ if torch.cuda.is_available():
def
_normalize_convolution_tensor_dense
(
psi
,
quad_weights
,
transpose_normalization
=
False
,
basis_norm_mode
=
"none"
,
merge_quadrature
=
False
,
eps
=
1e-9
):
"""
Discretely normalizes the convolution tensor.
"""
"""Discretely normalizes the convolution tensor."""
kernel_size
,
nlat_out
,
nlon_out
,
nlat_in
,
nlon_in
=
psi
.
shape
correction_factor
=
nlon_out
/
nlon_in
...
...
@@ -98,10 +96,7 @@ def _precompute_convolution_tensor_dense(
basis_norm_mode
=
"none"
,
merge_quadrature
=
False
,
):
"""
Helper routine to compute the convolution Tensor in a dense fashion
"""
"""Helper routine to compute the convolution Tensor in a dense fashion."""
assert
len
(
in_shape
)
==
2
assert
len
(
out_shape
)
==
2
...
...
@@ -168,6 +163,8 @@ def _precompute_convolution_tensor_dense(
@
parameterized_class
((
"device"
),
_devices
)
class
TestDiscreteContinuousConvolution
(
unittest
.
TestCase
):
"""Test the discrete-continuous convolution module (CPU/CUDA if available)."""
def
setUp
(
self
):
torch
.
manual_seed
(
333
)
if
self
.
device
.
type
==
"cuda"
:
...
...
tests/test_distributed_convolution.py
View file @
c7afb546
...
...
@@ -41,10 +41,10 @@ import torch_harmonics.distributed as thd
class
TestDistributedDiscreteContinuousConvolution
(
unittest
.
TestCase
):
"""Test the distributed discrete-continuous convolution module."""
@
classmethod
def
setUpClass
(
cls
):
# set up distributed
cls
.
world_rank
=
int
(
os
.
getenv
(
"WORLD_RANK"
,
0
))
cls
.
grid_size_h
=
int
(
os
.
getenv
(
"GRID_H"
,
1
))
...
...
@@ -118,6 +118,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
dist
.
destroy_process_group
(
None
)
def
_split_helper
(
self
,
tensor
):
with
torch
.
no_grad
():
# split in W
tensor_list_local
=
thd
.
split_tensor_along_dim
(
tensor
,
dim
=-
1
,
num_chunks
=
self
.
grid_size_w
)
...
...
@@ -130,6 +131,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
return
tensor_local
def
_gather_helper_fwd
(
self
,
tensor
,
B
,
C
,
convolution_dist
):
# we need the shapes
lat_shapes
=
convolution_dist
.
lat_out_shapes
lon_shapes
=
convolution_dist
.
lon_out_shapes
...
...
@@ -157,6 +159,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
return
tensor_gather
def
_gather_helper_bwd
(
self
,
tensor
,
B
,
C
,
convolution_dist
):
# we need the shapes
lat_shapes
=
convolution_dist
.
lat_in_shapes
lon_shapes
=
convolution_dist
.
lon_in_shapes
...
...
@@ -204,7 +207,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
def
test_distributed_disco_conv
(
self
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
batch_size
,
num_chan
,
kernel_shape
,
basis_type
,
basis_norm_mode
,
groups
,
grid_in
,
grid_out
,
transpose
,
tol
):
B
,
C
,
H
,
W
=
batch_size
,
num_chan
,
nlat_in
,
nlon_in
disco_args
=
dict
(
...
...
@@ -238,9 +241,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
# create tensors
inp_full
=
torch
.
randn
((
B
,
C
,
H
,
W
),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
#############################################################
# local conv
#############################################################
# FWD pass
inp_full
.
requires_grad
=
True
out_full
=
conv_local
(
inp_full
)
...
...
@@ -254,9 +255,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
out_full
.
backward
(
ograd_full
)
igrad_full
=
inp_full
.
grad
.
clone
()
#############################################################
# distributed conv
#############################################################
# FWD pass
inp_local
=
self
.
_split_helper
(
inp_full
)
inp_local
.
requires_grad
=
True
...
...
@@ -268,9 +267,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
out_local
.
backward
(
ograd_local
)
igrad_local
=
inp_local
.
grad
.
clone
()
#############################################################
# evaluate FWD pass
#############################################################
with
torch
.
no_grad
():
out_gather_full
=
self
.
_gather_helper_fwd
(
out_local
,
B
,
C
,
conv_dist
)
err
=
torch
.
mean
(
torch
.
norm
(
out_full
-
out_gather_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
))
/
torch
.
norm
(
out_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
)))
...
...
@@ -278,9 +275,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
print
(
f
"final relative error of output:
{
err
.
item
()
}
"
)
self
.
assertTrue
(
err
.
item
()
<=
tol
)
#############################################################
# evaluate BWD pass
#############################################################
with
torch
.
no_grad
():
igrad_gather_full
=
self
.
_gather_helper_bwd
(
igrad_local
,
B
,
C
,
conv_dist
)
...
...
tests/test_distributed_resample.py
View file @
c7afb546
...
...
@@ -41,6 +41,7 @@ import torch_harmonics.distributed as thd
class
TestDistributedResampling
(
unittest
.
TestCase
):
"""Test the distributed resampling module (CPU/CUDA if available)."""
@
classmethod
def
setUpClass
(
cls
):
...
...
@@ -118,6 +119,7 @@ class TestDistributedResampling(unittest.TestCase):
dist
.
destroy_process_group
(
None
)
def
_split_helper
(
self
,
tensor
):
with
torch
.
no_grad
():
# split in W
tensor_list_local
=
thd
.
split_tensor_along_dim
(
tensor
,
dim
=-
1
,
num_chunks
=
self
.
grid_size_w
)
...
...
@@ -130,6 +132,7 @@ class TestDistributedResampling(unittest.TestCase):
return
tensor_local
def
_gather_helper_fwd
(
self
,
tensor
,
B
,
C
,
convolution_dist
):
# we need the shapes
lat_shapes
=
convolution_dist
.
lat_out_shapes
lon_shapes
=
convolution_dist
.
lon_out_shapes
...
...
@@ -157,6 +160,7 @@ class TestDistributedResampling(unittest.TestCase):
return
tensor_gather
def
_gather_helper_bwd
(
self
,
tensor
,
B
,
C
,
resampling_dist
):
# we need the shapes
lat_shapes
=
resampling_dist
.
lat_in_shapes
lon_shapes
=
resampling_dist
.
lon_in_shapes
...
...
@@ -196,7 +200,7 @@ class TestDistributedResampling(unittest.TestCase):
def
test_distributed_resampling
(
self
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
batch_size
,
num_chan
,
grid_in
,
grid_out
,
mode
,
tol
,
verbose
):
B
,
C
,
H
,
W
=
batch_size
,
num_chan
,
nlat_in
,
nlon_in
res_args
=
dict
(
...
...
@@ -216,9 +220,7 @@ class TestDistributedResampling(unittest.TestCase):
# create tensors
inp_full
=
torch
.
randn
((
B
,
C
,
H
,
W
),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
#############################################################
# local conv
#############################################################
# FWD pass
inp_full
.
requires_grad
=
True
out_full
=
res_local
(
inp_full
)
...
...
@@ -232,9 +234,7 @@ class TestDistributedResampling(unittest.TestCase):
out_full
.
backward
(
ograd_full
)
igrad_full
=
inp_full
.
grad
.
clone
()
#############################################################
# distributed conv
#############################################################
# FWD pass
inp_local
=
self
.
_split_helper
(
inp_full
)
inp_local
.
requires_grad
=
True
...
...
@@ -246,9 +246,7 @@ class TestDistributedResampling(unittest.TestCase):
out_local
.
backward
(
ograd_local
)
igrad_local
=
inp_local
.
grad
.
clone
()
#############################################################
# evaluate FWD pass
#############################################################
with
torch
.
no_grad
():
out_gather_full
=
self
.
_gather_helper_fwd
(
out_local
,
B
,
C
,
res_dist
)
err
=
torch
.
mean
(
torch
.
norm
(
out_full
-
out_gather_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
))
/
torch
.
norm
(
out_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
)))
...
...
@@ -256,9 +254,7 @@ class TestDistributedResampling(unittest.TestCase):
print
(
f
"final relative error of output:
{
err
.
item
()
}
"
)
self
.
assertTrue
(
err
.
item
()
<=
tol
)
#############################################################
# evaluate BWD pass
#############################################################
with
torch
.
no_grad
():
igrad_gather_full
=
self
.
_gather_helper_bwd
(
igrad_local
,
B
,
C
,
res_dist
)
...
...
tests/test_distributed_sht.py
View file @
c7afb546
...
...
@@ -41,10 +41,10 @@ import torch_harmonics.distributed as thd
class
TestDistributedSphericalHarmonicTransform
(
unittest
.
TestCase
):
"""Test the distributed spherical harmonic transform module (CPU/CUDA if available)."""
@
classmethod
def
setUpClass
(
cls
):
# set up distributed
cls
.
world_rank
=
int
(
os
.
getenv
(
"WORLD_RANK"
,
0
))
cls
.
grid_size_h
=
int
(
os
.
getenv
(
"GRID_H"
,
1
))
...
...
@@ -163,6 +163,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
return
tensor_gather
def
_gather_helper_bwd
(
self
,
tensor
,
B
,
C
,
transform_dist
,
vector
):
# we need the shapes
lat_shapes
=
transform_dist
.
lat_shapes
lon_shapes
=
transform_dist
.
lon_shapes
...
...
@@ -214,6 +215,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
]
)
def
test_distributed_sht
(
self
,
nlat
,
nlon
,
batch_size
,
num_chan
,
grid
,
vector
,
tol
):
B
,
C
,
H
,
W
=
batch_size
,
num_chan
,
nlat
,
nlon
# set up handles
...
...
@@ -230,9 +232,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
else
:
inp_full
=
torch
.
randn
((
B
,
C
,
H
,
W
),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
#############################################################
# local transform
#############################################################
# FWD pass
inp_full
.
requires_grad
=
True
out_full
=
forward_transform_local
(
inp_full
)
...
...
@@ -246,9 +246,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
out_full
.
backward
(
ograd_full
)
igrad_full
=
inp_full
.
grad
.
clone
()
#############################################################
# distributed transform
#############################################################
# FWD pass
inp_local
=
self
.
_split_helper
(
inp_full
)
inp_local
.
requires_grad
=
True
...
...
@@ -260,9 +258,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
out_local
.
backward
(
ograd_local
)
igrad_local
=
inp_local
.
grad
.
clone
()
#############################################################
# evaluate FWD pass
#############################################################
with
torch
.
no_grad
():
out_gather_full
=
self
.
_gather_helper_fwd
(
out_local
,
B
,
C
,
forward_transform_dist
,
vector
)
err
=
torch
.
mean
(
torch
.
norm
(
out_full
-
out_gather_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
))
/
torch
.
norm
(
out_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
)))
...
...
@@ -270,9 +266,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
print
(
f
"final relative error of output:
{
err
.
item
()
}
"
)
self
.
assertTrue
(
err
.
item
()
<=
tol
)
#############################################################
# evaluate BWD pass
#############################################################
with
torch
.
no_grad
():
igrad_gather_full
=
self
.
_gather_helper_bwd
(
igrad_local
,
B
,
C
,
forward_transform_dist
,
vector
)
err
=
torch
.
mean
(
torch
.
norm
(
igrad_full
-
igrad_gather_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
))
/
torch
.
norm
(
igrad_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
)))
...
...
@@ -301,6 +295,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
]
)
def
test_distributed_isht
(
self
,
nlat
,
nlon
,
batch_size
,
num_chan
,
grid
,
vector
,
tol
):
B
,
C
,
H
,
W
=
batch_size
,
num_chan
,
nlat
,
nlon
if
vector
:
...
...
@@ -340,9 +335,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
out_full
.
backward
(
ograd_full
)
igrad_full
=
inp_full
.
grad
.
clone
()
#############################################################
# distributed transform
#############################################################
# FWD pass
inp_local
=
self
.
_split_helper
(
inp_full
)
inp_local
.
requires_grad
=
True
...
...
@@ -354,9 +347,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
out_local
.
backward
(
ograd_local
)
igrad_local
=
inp_local
.
grad
.
clone
()
#############################################################
# evaluate FWD pass
#############################################################
with
torch
.
no_grad
():
out_gather_full
=
self
.
_gather_helper_bwd
(
out_local
,
B
,
C
,
backward_transform_dist
,
vector
)
err
=
torch
.
mean
(
torch
.
norm
(
out_full
-
out_gather_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
))
/
torch
.
norm
(
out_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
)))
...
...
@@ -364,9 +355,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
print
(
f
"final relative error of output:
{
err
.
item
()
}
"
)
self
.
assertTrue
(
err
.
item
()
<=
tol
)
#############################################################
# evaluate BWD pass
#############################################################
with
torch
.
no_grad
():
igrad_gather_full
=
self
.
_gather_helper_fwd
(
igrad_local
,
B
,
C
,
backward_transform_dist
,
vector
)
err
=
torch
.
mean
(
torch
.
norm
(
igrad_full
-
igrad_gather_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
))
/
torch
.
norm
(
igrad_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
)))
...
...
tests/test_sht.py
View file @
c7afb546
...
...
@@ -42,7 +42,7 @@ if torch.cuda.is_available():
class
TestLegendrePolynomials
(
unittest
.
TestCase
):
"""Test the associated Legendre polynomials (CPU/CUDA if available)."""
def
setUp
(
self
):
self
.
cml
=
lambda
m
,
l
:
math
.
sqrt
((
2
*
l
+
1
)
/
4
/
math
.
pi
)
*
math
.
sqrt
(
math
.
factorial
(
l
-
m
)
/
math
.
factorial
(
l
+
m
))
self
.
pml
=
dict
()
...
...
@@ -79,7 +79,7 @@ class TestLegendrePolynomials(unittest.TestCase):
@
parameterized_class
((
"device"
),
_devices
)
class
TestSphericalHarmonicTransform
(
unittest
.
TestCase
):
"""Test the spherical harmonic transform (CPU/CUDA if available)."""
def
setUp
(
self
):
torch
.
manual_seed
(
333
)
if
self
.
device
.
type
==
"cuda"
:
...
...
torch_harmonics/_disco_convolution.py
View file @
c7afb546
...
...
@@ -42,7 +42,7 @@ except ImportError as err:
# some helper functions
def
_get_psi
(
kernel_size
:
int
,
psi_idx
:
torch
.
Tensor
,
psi_vals
:
torch
.
Tensor
,
nlat_in
:
int
,
nlon_in
:
int
,
nlat_out
:
int
,
nlon_out
:
int
,
nlat_in_local
:
Optional
[
int
]
=
None
,
nlat_out_local
:
Optional
[
int
]
=
None
,
semi_transposed
:
Optional
[
bool
]
=
False
):
"""Creates a sparse tensor for spherical harmonic convolution operations."""
nlat_in_local
=
nlat_in_local
if
nlat_in_local
is
not
None
else
nlat_in
nlat_out_local
=
nlat_out_local
if
nlat_out_local
is
not
None
else
nlat_out
...
...
@@ -67,6 +67,7 @@ class _DiscoS2ContractionCuda(torch.autograd.Function):
def
forward
(
ctx
,
x
:
torch
.
Tensor
,
roff_idx
:
torch
.
Tensor
,
ker_idx
:
torch
.
Tensor
,
row_idx
:
torch
.
Tensor
,
col_idx
:
torch
.
Tensor
,
vals
:
torch
.
Tensor
,
kernel_size
:
int
,
nlat_out
:
int
,
nlon_out
:
int
):
ctx
.
save_for_backward
(
roff_idx
,
ker_idx
,
row_idx
,
col_idx
,
vals
)
ctx
.
kernel_size
=
kernel_size
ctx
.
nlat_in
=
x
.
shape
[
-
2
]
...
...
@@ -81,6 +82,7 @@ class _DiscoS2ContractionCuda(torch.autograd.Function):
@
staticmethod
@
custom_bwd
(
device_type
=
"cuda"
)
def
backward
(
ctx
,
grad_output
):
roff_idx
,
ker_idx
,
row_idx
,
col_idx
,
vals
=
ctx
.
saved_tensors
gtype
=
grad_output
.
dtype
grad_output
=
grad_output
.
to
(
torch
.
float32
).
contiguous
()
...
...
@@ -97,6 +99,7 @@ class _DiscoS2TransposeContractionCuda(torch.autograd.Function):
def
forward
(
ctx
,
x
:
torch
.
Tensor
,
roff_idx
:
torch
.
Tensor
,
ker_idx
:
torch
.
Tensor
,
row_idx
:
torch
.
Tensor
,
col_idx
:
torch
.
Tensor
,
vals
:
torch
.
Tensor
,
kernel_size
:
int
,
nlat_out
:
int
,
nlon_out
:
int
):
ctx
.
save_for_backward
(
roff_idx
,
ker_idx
,
row_idx
,
col_idx
,
vals
)
ctx
.
kernel_size
=
kernel_size
ctx
.
nlat_in
=
x
.
shape
[
-
2
]
...
...
@@ -111,6 +114,7 @@ class _DiscoS2TransposeContractionCuda(torch.autograd.Function):
@
staticmethod
@
custom_bwd
(
device_type
=
"cuda"
)
def
backward
(
ctx
,
grad_output
):
roff_idx
,
ker_idx
,
row_idx
,
col_idx
,
vals
=
ctx
.
saved_tensors
gtype
=
grad_output
.
dtype
grad_output
=
grad_output
.
to
(
torch
.
float32
).
contiguous
()
...
...
@@ -140,6 +144,7 @@ def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: in
shifting of the input tensor, which can potentially be costly. For an efficient implementation
on GPU, make sure to use the custom kernel written in CUDA.
"""
assert
len
(
psi
.
shape
)
==
3
assert
len
(
x
.
shape
)
==
4
psi
=
psi
.
to
(
x
.
device
)
...
...
@@ -171,11 +176,6 @@ def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: in
def
_disco_s2_transpose_contraction_torch
(
x
:
torch
.
Tensor
,
psi
:
torch
.
Tensor
,
nlon_out
:
int
):
"""
Reference implementation of the custom contraction as described in [1]. This requires repeated
shifting of the input tensor, which can potentially be costly. For an efficient implementation
on GPU, make sure to use the custom kernel written in CUDA.
"""
assert
len
(
psi
.
shape
)
==
3
assert
len
(
x
.
shape
)
==
5
psi
=
psi
.
to
(
x
.
device
)
...
...
torch_harmonics/_neighborhood_attention.py
View file @
c7afb546
...
...
@@ -50,8 +50,6 @@ except ImportError as err:
def
_neighborhood_attention_s2_fwd_torch
(
kx
:
torch
.
Tensor
,
vx
:
torch
.
Tensor
,
qy
:
torch
.
Tensor
,
quad_weights
:
torch
.
Tensor
,
col_idx
:
torch
.
Tensor
,
row_off
:
torch
.
Tensor
,
nlon_in
:
int
,
nlat_out
:
int
,
nlon_out
:
int
)
->
torch
.
Tensor
:
# prepare result tensor
y
=
torch
.
zeros_like
(
qy
)
...
...
@@ -170,7 +168,6 @@ def _neighborhood_attention_s2_bwd_dv_torch(kx: torch.Tensor, vx: torch.Tensor,
def
_neighborhood_attention_s2_bwd_dk_torch
(
kx
:
torch
.
Tensor
,
vx
:
torch
.
Tensor
,
qy
:
torch
.
Tensor
,
dy
:
torch
.
Tensor
,
quad_weights
:
torch
.
Tensor
,
col_idx
:
torch
.
Tensor
,
row_off
:
torch
.
Tensor
,
nlon_in
:
int
,
nlat_out
:
int
,
nlon_out
:
int
):
# shapes:
# input
# kx: B, C, Hi, Wi
...
...
@@ -252,6 +249,7 @@ def _neighborhood_attention_s2_bwd_dq_torch(kx: torch.Tensor, vx: torch.Tensor,
quad_weights
:
torch
.
Tensor
,
col_idx
:
torch
.
Tensor
,
row_off
:
torch
.
Tensor
,
nlon_in
:
int
,
nlat_out
:
int
,
nlon_out
:
int
):
# shapes:
# input
# kx: B, C, Hi, Wi
...
...
@@ -329,7 +327,7 @@ class _NeighborhoodAttentionS2(torch.autograd.Function):
bk
:
Union
[
torch
.
Tensor
,
None
],
bv
:
Union
[
torch
.
Tensor
,
None
],
bq
:
Union
[
torch
.
Tensor
,
None
],
quad_weights
:
torch
.
Tensor
,
col_idx
:
torch
.
Tensor
,
row_off
:
torch
.
Tensor
,
nh
:
int
,
nlon_in
:
int
,
nlat_out
:
int
,
nlon_out
:
int
):
ctx
.
save_for_backward
(
col_idx
,
row_off
,
quad_weights
,
k
,
v
,
q
,
wk
,
wv
,
wq
,
bk
,
bv
,
bq
)
ctx
.
nh
=
nh
ctx
.
nlon_in
=
nlon_in
...
...
@@ -443,7 +441,7 @@ def _neighborhood_attention_s2_torch(k: torch.Tensor, v: torch.Tensor, q: torch.
bq
:
Union
[
torch
.
Tensor
,
None
],
quad_weights
:
torch
.
Tensor
,
col_idx
:
torch
.
Tensor
,
row_off
:
torch
.
Tensor
,
nh
:
int
,
nlon_in
:
int
,
nlat_out
:
int
,
nlon_out
:
int
)
->
torch
.
Tensor
:
return
_NeighborhoodAttentionS2
.
apply
(
k
,
v
,
q
,
wk
,
wv
,
wq
,
bk
,
bv
,
bq
,
quad_weights
,
col_idx
,
row_off
,
nh
,
nlon_in
,
nlat_out
,
nlon_out
)
...
...
@@ -451,6 +449,7 @@ def _neighborhood_attention_s2_torch(k: torch.Tensor, v: torch.Tensor, q: torch.
class
_NeighborhoodAttentionS2Cuda
(
torch
.
autograd
.
Function
):
@
staticmethod
@
custom_fwd
(
device_type
=
"cuda"
)
def
forward
(
ctx
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
...
...
@@ -458,7 +457,7 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
bk
:
Union
[
torch
.
Tensor
,
None
],
bv
:
Union
[
torch
.
Tensor
,
None
],
bq
:
Union
[
torch
.
Tensor
,
None
],
quad_weights
:
torch
.
Tensor
,
col_idx
:
torch
.
Tensor
,
row_off
:
torch
.
Tensor
,
max_psi_nnz
:
int
,
nh
:
int
,
nlon_in
:
int
,
nlat_out
:
int
,
nlon_out
:
int
):
ctx
.
save_for_backward
(
col_idx
,
row_off
,
quad_weights
,
k
,
v
,
q
,
wk
,
wv
,
wq
,
bk
,
bv
,
bq
)
ctx
.
nh
=
nh
ctx
.
max_psi_nnz
=
max_psi_nnz
...
...
@@ -584,7 +583,7 @@ def _neighborhood_attention_s2_cuda(k: torch.Tensor, v: torch.Tensor, q: torch.T
bq
:
Union
[
torch
.
Tensor
,
None
],
quad_weights
:
torch
.
Tensor
,
col_idx
:
torch
.
Tensor
,
row_off
:
torch
.
Tensor
,
max_psi_nnz
:
int
,
nh
:
int
,
nlon_in
:
int
,
nlat_out
:
int
,
nlon_out
:
int
)
->
torch
.
Tensor
:
return
_NeighborhoodAttentionS2Cuda
.
apply
(
k
,
v
,
q
,
wk
,
wv
,
wq
,
bk
,
bv
,
bq
,
quad_weights
,
col_idx
,
row_off
,
max_psi_nnz
,
nh
,
nlon_in
,
nlat_out
,
nlon_out
)
torch_harmonics/attention.py
View file @
c7afb546
...
...
@@ -142,9 +142,6 @@ class AttentionS2(nn.Module):
def
extra_repr
(
self
):
r
"""
Pretty print module
"""
return
f
"in_shape=
{
(
self
.
nlat_in
,
self
.
nlon_in
)
}
, out_shape=
{
(
self
.
nlat_out
,
self
.
nlon_out
)
}
, in_channels=
{
self
.
in_channels
}
, out_channels=
{
self
.
out_channels
}
, k_channels=
{
self
.
k_channels
}
"
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
value
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
...
...
@@ -317,9 +314,6 @@ class NeighborhoodAttentionS2(nn.Module):
self
.
proj_bias
=
None
def
extra_repr
(
self
):
r
"""
Pretty print module
"""
return
f
"in_shape=
{
(
self
.
nlat_in
,
self
.
nlon_in
)
}
, out_shape=
{
(
self
.
nlat_out
,
self
.
nlon_out
)
}
, in_channels=
{
self
.
in_channels
}
, out_channels=
{
self
.
out_channels
}
, k_channels=
{
self
.
k_channels
}
"
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
value
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
...
...
torch_harmonics/cache.py
View file @
c7afb546
...
...
@@ -35,6 +35,32 @@ from copy import deepcopy
# copying LRU cache decorator a la:
# https://stackoverflow.com/questions/54909357/how-to-get-functools-lru-cache-to-return-new-instances
def
lru_cache
(
maxsize
=
20
,
typed
=
False
,
copy
=
False
):
"""
Least Recently Used (LRU) cache decorator with optional deep copying.
This is a wrapper around functools.lru_cache that adds the ability to return
deep copies of cached results to prevent unintended modifications to cached objects.
Parameters
-----------
maxsize : int, optional
Maximum number of items to cache, by default 20
typed : bool, optional
Whether to cache different types separately, by default False
copy : bool, optional
Whether to return deep copies of cached results, by default False
Returns
-------
function
Decorated function with LRU caching
Example
-------
>>> @lru_cache(maxsize=10, copy=True)
... def expensive_function(x):
... return [x, x*2, x*3]
"""
def
decorator
(
f
):
cached_func
=
functools
.
lru_cache
(
maxsize
=
maxsize
,
typed
=
typed
)(
f
)
def
wrapper
(
*
args
,
**
kwargs
):
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment