Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
c7afb546
Unverified
Commit
c7afb546
authored
Jul 21, 2025
by
Thorsten Kurth
Committed by
GitHub
Jul 21, 2025
Browse files
Merge pull request #95 from NVIDIA/aparis/docs
Docstrings PR
parents
b5c410c0
644465ba
Pipeline
#2854
canceled with stages
Changes
44
Pipelines
2
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
347 additions
and
83 deletions
+347
-83
Changelog.md
Changelog.md
+1
-0
examples/baseline_models/segformer.py
examples/baseline_models/segformer.py
+145
-3
examples/baseline_models/transformer.py
examples/baseline_models/transformer.py
+65
-1
examples/baseline_models/unet.py
examples/baseline_models/unet.py
+68
-3
examples/depth/train.py
examples/depth/train.py
+0
-3
examples/model_registry.py
examples/model_registry.py
+1
-1
examples/segmentation/train.py
examples/segmentation/train.py
+3
-4
examples/shallow_water_equations/train.py
examples/shallow_water_equations/train.py
+1
-4
setup.py
setup.py
+3
-1
tests/test_attention.py
tests/test_attention.py
+2
-0
tests/test_cache.py
tests/test_cache.py
+0
-1
tests/test_convolution.py
tests/test_convolution.py
+4
-7
tests/test_distributed_convolution.py
tests/test_distributed_convolution.py
+5
-10
tests/test_distributed_resample.py
tests/test_distributed_resample.py
+5
-9
tests/test_distributed_sht.py
tests/test_distributed_sht.py
+4
-15
tests/test_sht.py
tests/test_sht.py
+2
-2
torch_harmonics/_disco_convolution.py
torch_harmonics/_disco_convolution.py
+6
-6
torch_harmonics/_neighborhood_attention.py
torch_harmonics/_neighborhood_attention.py
+6
-7
torch_harmonics/attention.py
torch_harmonics/attention.py
+0
-6
torch_harmonics/cache.py
torch_harmonics/cache.py
+26
-0
No files found.
Changelog.md
View file @
c7afb546
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
*
Reorganized examples folder, including new examples based on the 2d3ds dataset
*
Reorganized examples folder, including new examples based on the 2d3ds dataset
*
Added spherical loss functions to examples
*
Added spherical loss functions to examples
*
Added plotting module
*
Added plotting module
*
Updated docstrings
### v0.7.6
### v0.7.6
...
...
examples/baseline_models/segformer.py
View file @
c7afb546
...
@@ -41,6 +41,24 @@ from functools import partial
...
@@ -41,6 +41,24 @@ from functools import partial
class
OverlapPatchMerging
(
nn
.
Module
):
class
OverlapPatchMerging
(
nn
.
Module
):
"""
OverlapPatchMerging layer for merging patches.
Parameters
-----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
kernel_shape : tuple
Kernel shape for convolution
bias : bool, optional
Whether to use bias, by default False
"""
def
__init__
(
def
__init__
(
self
,
self
,
in_shape
=
(
721
,
1440
),
in_shape
=
(
721
,
1440
),
...
@@ -88,6 +106,30 @@ class OverlapPatchMerging(nn.Module):
...
@@ -88,6 +106,30 @@ class OverlapPatchMerging(nn.Module):
class
MixFFN
(
nn
.
Module
):
class
MixFFN
(
nn
.
Module
):
"""
MixFFN module combining MLP and depthwise convolution.
Parameters
-----------
shape : tuple
Input shape (height, width)
inout_channels : int
Number of input/output channels
hidden_channels : int
Number of hidden channels in MLP
mlp_bias : bool, optional
Whether to use bias in MLP layers, by default True
kernel_shape : tuple, optional
Kernel shape for depthwise convolution, by default (3, 3)
conv_bias : bool, optional
Whether to use bias in convolution, by default False
activation : callable, optional
Activation function, by default nn.GELU
use_mlp : bool, optional
Whether to use MLP instead of linear layers, by default False
drop_path : float, optional
Drop path rate, by default 0.0
"""
def
__init__
(
def
__init__
(
self
,
self
,
shape
,
shape
,
...
@@ -142,7 +184,7 @@ class MixFFN(nn.Module):
...
@@ -142,7 +184,7 @@ class MixFFN(nn.Module):
x
=
x
.
permute
(
0
,
3
,
1
,
2
)
x
=
x
.
permute
(
0
,
3
,
1
,
2
)
# NOTE: we add another activation here
# NOTE: we add another activation here
# because in the paper the
y
only use depthwise conv,
# because in the paper the
authors
only use depthwise conv,
# but without this activation it would just be a fused MM
# but without this activation it would just be a fused MM
# with the disco conv
# with the disco conv
x
=
self
.
mlp_in
(
x
)
x
=
self
.
mlp_in
(
x
)
...
@@ -162,6 +204,17 @@ class GlobalAttention(nn.Module):
...
@@ -162,6 +204,17 @@ class GlobalAttention(nn.Module):
Input shape: (B, C, H, W)
Input shape: (B, C, H, W)
Output shape: (B, C, H, W) with residual skip.
Output shape: (B, C, H, W) with residual skip.
Parameters
-----------
chans : int
Number of channels
num_heads : int, optional
Number of attention heads, by default 8
dropout : float, optional
Dropout rate, by default 0.0
bias : bool, optional
Whether to use bias, by default True
"""
"""
def
__init__
(
self
,
chans
,
num_heads
=
8
,
dropout
=
0.0
,
bias
=
True
):
def
__init__
(
self
,
chans
,
num_heads
=
8
,
dropout
=
0.0
,
bias
=
True
):
...
@@ -169,6 +222,7 @@ class GlobalAttention(nn.Module):
...
@@ -169,6 +222,7 @@ class GlobalAttention(nn.Module):
self
.
attn
=
nn
.
MultiheadAttention
(
embed_dim
=
chans
,
num_heads
=
num_heads
,
dropout
=
dropout
,
batch_first
=
True
,
bias
=
bias
)
self
.
attn
=
nn
.
MultiheadAttention
(
embed_dim
=
chans
,
num_heads
=
num_heads
,
dropout
=
dropout
,
batch_first
=
True
,
bias
=
bias
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
# x: B, C, H, W
# x: B, C, H, W
B
,
H
,
W
,
C
=
x
.
shape
B
,
H
,
W
,
C
=
x
.
shape
# flatten spatial dims
# flatten spatial dims
...
@@ -181,6 +235,30 @@ class GlobalAttention(nn.Module):
...
@@ -181,6 +235,30 @@ class GlobalAttention(nn.Module):
class
AttentionWrapper
(
nn
.
Module
):
class
AttentionWrapper
(
nn
.
Module
):
"""
Wrapper for different attention mechanisms.
Parameters
-----------
channels : int
Number of channels
shape : tuple
Input shape (height, width)
heads : int
Number of attention heads
pre_norm : bool, optional
Whether to apply normalization before attention, by default False
attention_drop_rate : float, optional
Attention dropout rate, by default 0.0
drop_path : float, optional
Drop path rate, by default 0.0
attention_mode : str, optional
Attention mode ("neighborhood", "global"), by default "neighborhood"
kernel_shape : tuple, optional
Kernel shape for neighborhood attention, by default (7, 7)
bias : bool, optional
Whether to use bias, by default True
"""
def
__init__
(
self
,
channels
,
shape
,
heads
,
pre_norm
=
False
,
attention_drop_rate
=
0.0
,
drop_path
=
0.0
,
attention_mode
=
"neighborhood"
,
kernel_shape
=
(
7
,
7
),
bias
=
True
):
def
__init__
(
self
,
channels
,
shape
,
heads
,
pre_norm
=
False
,
attention_drop_rate
=
0.0
,
drop_path
=
0.0
,
attention_mode
=
"neighborhood"
,
kernel_shape
=
(
7
,
7
),
bias
=
True
):
super
().
__init__
()
super
().
__init__
()
...
@@ -203,11 +281,13 @@ class AttentionWrapper(nn.Module):
...
@@ -203,11 +281,13 @@ class AttentionWrapper(nn.Module):
self
.
apply
(
self
.
_init_weights
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
LayerNorm
):
if
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
residual
=
x
residual
=
x
x
=
x
.
permute
(
0
,
2
,
3
,
1
)
x
=
x
.
permute
(
0
,
2
,
3
,
1
)
if
self
.
norm
is
not
None
:
if
self
.
norm
is
not
None
:
...
@@ -219,6 +299,41 @@ class AttentionWrapper(nn.Module):
...
@@ -219,6 +299,41 @@ class AttentionWrapper(nn.Module):
class
TransformerBlock
(
nn
.
Module
):
class
TransformerBlock
(
nn
.
Module
):
"""
Transformer block with attention and MLP.
Parameters
----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
mlp_hidden_channels : int
Number of hidden channels in MLP
nrep : int, optional
Number of repetitions of attention and MLP blocks, by default 1
heads : int, optional
Number of attention heads, by default 1
kernel_shape : tuple, optional
Kernel shape for neighborhood attention, by default (3, 3)
activation : torch.nn.Module, optional
Activation function to use, by default nn.GELU
att_drop_rate : float, optional
Attention dropout rate, by default 0.0
drop_path_rates : float or list, optional
Drop path rates for each block, by default 0.0
attention_mode : str, optional
Attention mode ("neighborhood", "global"), by default "neighborhood"
attn_kernel_shape : tuple, optional
Kernel shape for neighborhood attention, by default (7, 7)
bias : bool, optional
Whether to use bias, by default True
"""
def
__init__
(
def
__init__
(
self
,
self
,
in_shape
,
in_shape
,
...
@@ -341,6 +456,33 @@ class TransformerBlock(nn.Module):
...
@@ -341,6 +456,33 @@ class TransformerBlock(nn.Module):
class
Upsampling
(
nn
.
Module
):
class
Upsampling
(
nn
.
Module
):
"""
Upsampling block for the Segformer model.
Parameters
----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
hidden_channels : int
Number of hidden channels in MLP
mlp_bias : bool, optional
Whether to use bias in MLP, by default True
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
conv_bias : bool, optional
Whether to use bias in convolution, by default False
activation : torch.nn.Module, optional
Activation function to use, by default nn.GELU
use_mlp : bool, optional
Whether to use MLP, by default False
"""
def
__init__
(
def
__init__
(
self
,
self
,
in_shape
,
in_shape
,
...
@@ -382,7 +524,7 @@ class Segformer(nn.Module):
...
@@ -382,7 +524,7 @@ class Segformer(nn.Module):
Spherical segformer model designed to approximate mappings from spherical signals to spherical segmentation masks
Spherical segformer model designed to approximate mappings from spherical signals to spherical segmentation masks
Parameters
Parameters
----------
-
----------
img_shape : tuple, optional
img_shape : tuple, optional
Shape of the input channels, by default (128, 256)
Shape of the input channels, by default (128, 256)
kernel_shape: tuple, int
kernel_shape: tuple, int
...
@@ -414,7 +556,7 @@ class Segformer(nn.Module):
...
@@ -414,7 +556,7 @@ class Segformer(nn.Module):
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
Example
Example
----------
-
----------
>>> model = Segformer(
>>> model = Segformer(
... img_size=(128, 256),
... img_size=(128, 256),
... in_chans=3,
... in_chans=3,
...
...
examples/baseline_models/transformer.py
View file @
c7afb546
...
@@ -57,11 +57,34 @@ class Encoder(nn.Module):
...
@@ -57,11 +57,34 @@ class Encoder(nn.Module):
self
.
conv
=
nn
.
Conv2d
(
in_chans
,
out_chans
,
kernel_size
=
kernel_shape
,
bias
=
bias
,
stride
=
(
stride_h
,
stride_w
),
padding
=
(
pad_h
,
pad_w
),
groups
=
groups
)
self
.
conv
=
nn
.
Conv2d
(
in_chans
,
out_chans
,
kernel_size
=
kernel_shape
,
bias
=
bias
,
stride
=
(
stride_h
,
stride_w
),
padding
=
(
pad_h
,
pad_w
),
groups
=
groups
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
self
.
conv
(
x
)
return
x
return
x
class
Decoder
(
nn
.
Module
):
class
Decoder
(
nn
.
Module
):
"""
Decoder module for upsampling and feature processing.
Parameters
-----------
in_shape : tuple, optional
Input shape (height, width), by default (480, 960)
out_shape : tuple, optional
Output shape (height, width), by default (721, 1440)
in_chans : int, optional
Number of input channels, by default 2
out_chans : int, optional
Number of output channels, by default 2
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
groups : int, optional
Number of groups for convolution, by default 1
bias : bool, optional
Whether to use bias, by default False
upsampling_method : str, optional
Upsampling method ("conv", "pixel_shuffle"), by default "conv"
"""
def
__init__
(
self
,
in_shape
=
(
480
,
960
),
out_shape
=
(
721
,
1440
),
in_chans
=
2
,
out_chans
=
2
,
kernel_shape
=
(
3
,
3
),
groups
=
1
,
bias
=
False
,
upsampling_method
=
"conv"
):
def
__init__
(
self
,
in_shape
=
(
480
,
960
),
out_shape
=
(
721
,
1440
),
in_chans
=
2
,
out_chans
=
2
,
kernel_shape
=
(
3
,
3
),
groups
=
1
,
bias
=
False
,
upsampling_method
=
"conv"
):
super
().
__init__
()
super
().
__init__
()
self
.
out_shape
=
out_shape
self
.
out_shape
=
out_shape
...
@@ -87,6 +110,7 @@ class Decoder(nn.Module):
...
@@ -87,6 +110,7 @@ class Decoder(nn.Module):
raise
ValueError
(
f
"Unknown upsampling method
{
upsampling_method
}
"
)
raise
ValueError
(
f
"Unknown upsampling method
{
upsampling_method
}
"
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
=
self
.
upsample
(
x
)
x
=
self
.
upsample
(
x
)
return
x
return
x
...
@@ -97,6 +121,17 @@ class GlobalAttention(nn.Module):
...
@@ -97,6 +121,17 @@ class GlobalAttention(nn.Module):
Input shape: (B, C, H, W)
Input shape: (B, C, H, W)
Output shape: (B, C, H, W) with residual skip.
Output shape: (B, C, H, W) with residual skip.
Parameters
-----------
chans : int
Number of channels
num_heads : int, optional
Number of attention heads, by default 8
dropout : float, optional
Dropout rate, by default 0.0
bias : bool, optional
Whether to use bias, by default True
"""
"""
def
__init__
(
self
,
chans
,
num_heads
=
8
,
dropout
=
0.0
,
bias
=
True
):
def
__init__
(
self
,
chans
,
num_heads
=
8
,
dropout
=
0.0
,
bias
=
True
):
...
@@ -104,6 +139,7 @@ class GlobalAttention(nn.Module):
...
@@ -104,6 +139,7 @@ class GlobalAttention(nn.Module):
self
.
attn
=
nn
.
MultiheadAttention
(
embed_dim
=
chans
,
num_heads
=
num_heads
,
dropout
=
dropout
,
batch_first
=
True
,
bias
=
bias
)
self
.
attn
=
nn
.
MultiheadAttention
(
embed_dim
=
chans
,
num_heads
=
num_heads
,
dropout
=
dropout
,
batch_first
=
True
,
bias
=
bias
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
# x: B, C, H, W
# x: B, C, H, W
B
,
H
,
W
,
C
=
x
.
shape
B
,
H
,
W
,
C
=
x
.
shape
# flatten spatial dims
# flatten spatial dims
...
@@ -118,8 +154,36 @@ class GlobalAttention(nn.Module):
...
@@ -118,8 +154,36 @@ class GlobalAttention(nn.Module):
class
AttentionBlock
(
nn
.
Module
):
class
AttentionBlock
(
nn
.
Module
):
"""
"""
Neighborhood attention block based on Natten.
Neighborhood attention block based on Natten.
"""
Parameters
-----------
in_shape : tuple, optional
Input shape (height, width), by default (480, 960)
out_shape : tuple, optional
Output shape (height, width), by default (480, 960)
chans : int, optional
Number of channels, by default 2
num_heads : int, optional
Number of attention heads, by default 1
mlp_ratio : float, optional
Ratio of MLP hidden dim to input dim, by default 2.0
drop_rate : float, optional
Dropout rate, by default 0.0
drop_path : float, optional
Drop path rate, by default 0.0
act_layer : callable, optional
Activation function, by default nn.GELU
norm_layer : str, optional
Normalization layer type, by default "none"
use_mlp : bool, optional
Whether to use MLP, by default True
bias : bool, optional
Whether to use bias, by default True
attention_mode : str, optional
Attention mode ("neighborhood", "global"), by default "neighborhood"
attn_kernel_shape : tuple, optional
Kernel shape for neighborhood attention, by default (7, 7)
"""
def
__init__
(
def
__init__
(
self
,
self
,
in_shape
=
(
480
,
960
),
in_shape
=
(
480
,
960
),
...
...
examples/baseline_models/unet.py
View file @
c7afb546
...
@@ -43,6 +43,37 @@ from functools import partial
...
@@ -43,6 +43,37 @@ from functools import partial
class
DownsamplingBlock
(
nn
.
Module
):
class
DownsamplingBlock
(
nn
.
Module
):
"""
Downsampling block for the UNet model.
Parameters
----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
nrep : int, optional
Number of repetitions of conv blocks, by default 1
kernel_shape : tuple, optional
Kernel shape for convolutions, by default (3, 3)
activation : callable, optional
Activation function, by default nn.ReLU
transform_skip : bool, optional
Whether to transform skip connections, by default False
drop_conv_rate : float, optional
Dropout rate for convolutions, by default 0.
drop_path_rate : float, optional
Drop path rate, by default 0.
drop_dense_rate : float, optional
Dropout rate for dense layers, by default 0.
downsampling_mode : str, optional
Downsampling mode ("bilinear", "conv"), by default "bilinear"
"""
def
__init__
(
def
__init__
(
self
,
self
,
in_shape
,
in_shape
,
...
@@ -146,6 +177,7 @@ class DownsamplingBlock(nn.Module):
...
@@ -146,6 +177,7 @@ class DownsamplingBlock(nn.Module):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
bias
,
0
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# skip connection
# skip connection
residual
=
x
residual
=
x
if
hasattr
(
self
,
"transform_skip"
):
if
hasattr
(
self
,
"transform_skip"
):
...
@@ -166,6 +198,36 @@ class DownsamplingBlock(nn.Module):
...
@@ -166,6 +198,36 @@ class DownsamplingBlock(nn.Module):
class
UpsamplingBlock
(
nn
.
Module
):
class
UpsamplingBlock
(
nn
.
Module
):
"""
Upsampling block for UNet architecture.
Parameters
-----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
nrep : int, optional
Number of repetitions of conv blocks, by default 1
kernel_shape : tuple, optional
Kernel shape for convolutions, by default (3, 3)
activation : callable, optional
Activation function, by default nn.ReLU
transform_skip : bool, optional
Whether to transform skip connections, by default False
drop_conv_rate : float, optional
Dropout rate for convolutions, by default 0.
drop_path_rate : float, optional
Drop path rate, by default 0.
drop_dense_rate : float, optional
Dropout rate for dense layers, by default 0.
upsampling_mode : str, optional
Upsampling mode ("bilinear", "conv"), by default "bilinear"
"""
def
__init__
(
def
__init__
(
self
,
self
,
in_shape
,
in_shape
,
...
@@ -280,6 +342,7 @@ class UpsamplingBlock(nn.Module):
...
@@ -280,6 +342,7 @@ class UpsamplingBlock(nn.Module):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
bias
,
0
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# skip connection
# skip connection
residual
=
x
residual
=
x
if
hasattr
(
self
,
"transform_skip"
):
if
hasattr
(
self
,
"transform_skip"
):
...
@@ -304,6 +367,7 @@ class UNet(nn.Module):
...
@@ -304,6 +367,7 @@ class UNet(nn.Module):
img_shape : tuple, optional
img_shape : tuple, optional
Shape of the input channels, by default (128, 256)
Shape of the input channels, by default (128, 256)
kernel_shape: tuple, int
kernel_shape: tuple, int
Kernel shape for convolutions
scale_factor: int, optional
scale_factor: int, optional
Scale factor to use, by default 2
Scale factor to use, by default 2
in_chans : int, optional
in_chans : int, optional
...
@@ -336,11 +400,12 @@ class UNet(nn.Module):
...
@@ -336,11 +400,12 @@ class UNet(nn.Module):
... scale_factor=4,
... scale_factor=4,
... in_chans=2,
... in_chans=2,
... num_classes=2,
... num_classes=2,
... embed_dims=[64, 128, 256, 512],)
... embed_dims=[16, 32, 64, 128],
... depths=[2, 2, 2, 2],
... use_mlp=True,)
>>> model(torch.randn(1, 2, 128, 256)).shape
>>> model(torch.randn(1, 2, 128, 256)).shape
torch.Size([1, 2, 128, 256])
torch.Size([1, 2, 128, 256])
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
img_shape
=
(
128
,
256
),
img_shape
=
(
128
,
256
),
...
...
examples/depth/train.py
View file @
c7afb546
...
@@ -68,9 +68,6 @@ def count_parameters(model):
...
@@ -68,9 +68,6 @@ def count_parameters(model):
# convenience function for logging weights and gradients
# convenience function for logging weights and gradients
def
log_weights_and_grads
(
exp_dir
,
model
,
iters
=
1
):
def
log_weights_and_grads
(
exp_dir
,
model
,
iters
=
1
):
"""
Helper routine intended for debugging purposes
"""
log_path
=
os
.
path
.
join
(
exp_dir
,
"weights_and_grads"
)
log_path
=
os
.
path
.
join
(
exp_dir
,
"weights_and_grads"
)
if
not
os
.
path
.
isdir
(
log_path
):
if
not
os
.
path
.
isdir
(
log_path
):
os
.
makedirs
(
log_path
,
exist_ok
=
True
)
os
.
makedirs
(
log_path
,
exist_ok
=
True
)
...
...
examples/model_registry.py
View file @
c7afb546
examples/segmentation/train.py
View file @
c7afb546
...
@@ -68,14 +68,13 @@ import wandb
...
@@ -68,14 +68,13 @@ import wandb
# helper routine for counting number of paramerters in model
# helper routine for counting number of paramerters in model
def
count_parameters
(
model
):
def
count_parameters
(
model
):
return
sum
(
p
.
numel
()
for
p
in
model
.
parameters
()
if
p
.
requires_grad
)
return
sum
(
p
.
numel
()
for
p
in
model
.
parameters
()
if
p
.
requires_grad
)
# convenience function for logging weights and gradients
# convenience function for logging weights and gradients
def
log_weights_and_grads
(
exp_dir
,
model
,
iters
=
1
):
def
log_weights_and_grads
(
exp_dir
,
model
,
iters
=
1
):
"""
Helper routine intended for debugging purposes
"""
log_path
=
os
.
path
.
join
(
exp_dir
,
"weights_and_grads"
)
log_path
=
os
.
path
.
join
(
exp_dir
,
"weights_and_grads"
)
if
not
os
.
path
.
isdir
(
log_path
):
if
not
os
.
path
.
isdir
(
log_path
):
os
.
makedirs
(
log_path
,
exist_ok
=
True
)
os
.
makedirs
(
log_path
,
exist_ok
=
True
)
...
...
examples/shallow_water_equations/train.py
View file @
c7afb546
...
@@ -68,9 +68,6 @@ def count_parameters(model):
...
@@ -68,9 +68,6 @@ def count_parameters(model):
# convenience function for logging weights and gradients
# convenience function for logging weights and gradients
def
log_weights_and_grads
(
model
,
iters
=
1
):
def
log_weights_and_grads
(
model
,
iters
=
1
):
"""
Helper routine intended for debugging purposes
"""
root_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"weights_and_grads"
)
root_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"weights_and_grads"
)
weights_and_grads_fname
=
os
.
path
.
join
(
root_path
,
f
"weights_and_grads_step
{
iters
:
03
d
}
.tar"
)
weights_and_grads_fname
=
os
.
path
.
join
(
root_path
,
f
"weights_and_grads_step
{
iters
:
03
d
}
.tar"
)
...
...
setup.py
View file @
c7afb546
...
@@ -55,6 +55,7 @@ except (ImportError, TypeError, AssertionError, AttributeError) as e:
...
@@ -55,6 +55,7 @@ except (ImportError, TypeError, AssertionError, AttributeError) as e:
def
get_compile_args
(
module_name
):
def
get_compile_args
(
module_name
):
"""If user runs build with TORCH_HARMONICS_DEBUG=1 set, it will use debugging flags to build"""
"""If user runs build with TORCH_HARMONICS_DEBUG=1 set, it will use debugging flags to build"""
debug_mode
=
os
.
environ
.
get
(
'TORCH_HARMONICS_DEBUG'
,
'0'
)
==
'1'
debug_mode
=
os
.
environ
.
get
(
'TORCH_HARMONICS_DEBUG'
,
'0'
)
==
'1'
profile_mode
=
os
.
environ
.
get
(
'TORCH_HARMONICS_PROFILE'
,
'0'
)
==
'1'
profile_mode
=
os
.
environ
.
get
(
'TORCH_HARMONICS_PROFILE'
,
'0'
)
==
'1'
...
@@ -77,6 +78,7 @@ def get_compile_args(module_name):
...
@@ -77,6 +78,7 @@ def get_compile_args(module_name):
}
}
def
get_ext_modules
():
def
get_ext_modules
():
"""Get list of extension modules to compile."""
ext_modules
=
[]
ext_modules
=
[]
cmdclass
=
{}
cmdclass
=
{}
...
...
tests/test_attention.py
View file @
c7afb546
...
@@ -67,6 +67,8 @@ _perf_test_thresholds = {"fwd_ms": 50, "bwd_ms": 150}
...
@@ -67,6 +67,8 @@ _perf_test_thresholds = {"fwd_ms": 50, "bwd_ms": 150}
@
parameterized_class
((
"device"
),
_devices
)
@
parameterized_class
((
"device"
),
_devices
)
class
TestNeighborhoodAttentionS2
(
unittest
.
TestCase
):
class
TestNeighborhoodAttentionS2
(
unittest
.
TestCase
):
"""Test the neighborhood attention module (CPU/CUDA if available)."""
def
setUp
(
self
):
def
setUp
(
self
):
torch
.
manual_seed
(
333
)
torch
.
manual_seed
(
333
)
if
self
.
device
.
type
==
"cuda"
:
if
self
.
device
.
type
==
"cuda"
:
...
...
tests/test_cache.py
View file @
c7afb546
...
@@ -36,7 +36,6 @@ import torch
...
@@ -36,7 +36,6 @@ import torch
class
TestCacheConsistency
(
unittest
.
TestCase
):
class
TestCacheConsistency
(
unittest
.
TestCase
):
def
test_consistency
(
self
,
verbose
=
False
):
def
test_consistency
(
self
,
verbose
=
False
):
if
verbose
:
if
verbose
:
print
(
"Testing that cache values does not get modified externally"
)
print
(
"Testing that cache values does not get modified externally"
)
...
...
tests/test_convolution.py
View file @
c7afb546
...
@@ -47,9 +47,7 @@ if torch.cuda.is_available():
...
@@ -47,9 +47,7 @@ if torch.cuda.is_available():
def
_normalize_convolution_tensor_dense
(
psi
,
quad_weights
,
transpose_normalization
=
False
,
basis_norm_mode
=
"none"
,
merge_quadrature
=
False
,
eps
=
1e-9
):
def
_normalize_convolution_tensor_dense
(
psi
,
quad_weights
,
transpose_normalization
=
False
,
basis_norm_mode
=
"none"
,
merge_quadrature
=
False
,
eps
=
1e-9
):
"""
"""Discretely normalizes the convolution tensor."""
Discretely normalizes the convolution tensor.
"""
kernel_size
,
nlat_out
,
nlon_out
,
nlat_in
,
nlon_in
=
psi
.
shape
kernel_size
,
nlat_out
,
nlon_out
,
nlat_in
,
nlon_in
=
psi
.
shape
correction_factor
=
nlon_out
/
nlon_in
correction_factor
=
nlon_out
/
nlon_in
...
@@ -98,10 +96,7 @@ def _precompute_convolution_tensor_dense(
...
@@ -98,10 +96,7 @@ def _precompute_convolution_tensor_dense(
basis_norm_mode
=
"none"
,
basis_norm_mode
=
"none"
,
merge_quadrature
=
False
,
merge_quadrature
=
False
,
):
):
"""
"""Helper routine to compute the convolution Tensor in a dense fashion."""
Helper routine to compute the convolution Tensor in a dense fashion
"""
assert
len
(
in_shape
)
==
2
assert
len
(
in_shape
)
==
2
assert
len
(
out_shape
)
==
2
assert
len
(
out_shape
)
==
2
...
@@ -168,6 +163,8 @@ def _precompute_convolution_tensor_dense(
...
@@ -168,6 +163,8 @@ def _precompute_convolution_tensor_dense(
@
parameterized_class
((
"device"
),
_devices
)
@
parameterized_class
((
"device"
),
_devices
)
class
TestDiscreteContinuousConvolution
(
unittest
.
TestCase
):
class
TestDiscreteContinuousConvolution
(
unittest
.
TestCase
):
"""Test the discrete-continuous convolution module (CPU/CUDA if available)."""
def
setUp
(
self
):
def
setUp
(
self
):
torch
.
manual_seed
(
333
)
torch
.
manual_seed
(
333
)
if
self
.
device
.
type
==
"cuda"
:
if
self
.
device
.
type
==
"cuda"
:
...
...
tests/test_distributed_convolution.py
View file @
c7afb546
...
@@ -41,10 +41,10 @@ import torch_harmonics.distributed as thd
...
@@ -41,10 +41,10 @@ import torch_harmonics.distributed as thd
class
TestDistributedDiscreteContinuousConvolution
(
unittest
.
TestCase
):
class
TestDistributedDiscreteContinuousConvolution
(
unittest
.
TestCase
):
"""Test the distributed discrete-continuous convolution module."""
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
# set up distributed
# set up distributed
cls
.
world_rank
=
int
(
os
.
getenv
(
"WORLD_RANK"
,
0
))
cls
.
world_rank
=
int
(
os
.
getenv
(
"WORLD_RANK"
,
0
))
cls
.
grid_size_h
=
int
(
os
.
getenv
(
"GRID_H"
,
1
))
cls
.
grid_size_h
=
int
(
os
.
getenv
(
"GRID_H"
,
1
))
...
@@ -118,6 +118,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
...
@@ -118,6 +118,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
dist
.
destroy_process_group
(
None
)
dist
.
destroy_process_group
(
None
)
def
_split_helper
(
self
,
tensor
):
def
_split_helper
(
self
,
tensor
):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
# split in W
# split in W
tensor_list_local
=
thd
.
split_tensor_along_dim
(
tensor
,
dim
=-
1
,
num_chunks
=
self
.
grid_size_w
)
tensor_list_local
=
thd
.
split_tensor_along_dim
(
tensor
,
dim
=-
1
,
num_chunks
=
self
.
grid_size_w
)
...
@@ -130,6 +131,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
...
@@ -130,6 +131,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
return
tensor_local
return
tensor_local
def
_gather_helper_fwd
(
self
,
tensor
,
B
,
C
,
convolution_dist
):
def
_gather_helper_fwd
(
self
,
tensor
,
B
,
C
,
convolution_dist
):
# we need the shapes
# we need the shapes
lat_shapes
=
convolution_dist
.
lat_out_shapes
lat_shapes
=
convolution_dist
.
lat_out_shapes
lon_shapes
=
convolution_dist
.
lon_out_shapes
lon_shapes
=
convolution_dist
.
lon_out_shapes
...
@@ -157,6 +159,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
...
@@ -157,6 +159,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
return
tensor_gather
return
tensor_gather
def
_gather_helper_bwd
(
self
,
tensor
,
B
,
C
,
convolution_dist
):
def
_gather_helper_bwd
(
self
,
tensor
,
B
,
C
,
convolution_dist
):
# we need the shapes
# we need the shapes
lat_shapes
=
convolution_dist
.
lat_in_shapes
lat_shapes
=
convolution_dist
.
lat_in_shapes
lon_shapes
=
convolution_dist
.
lon_in_shapes
lon_shapes
=
convolution_dist
.
lon_in_shapes
...
@@ -238,9 +241,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
...
@@ -238,9 +241,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
# create tensors
# create tensors
inp_full
=
torch
.
randn
((
B
,
C
,
H
,
W
),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
inp_full
=
torch
.
randn
((
B
,
C
,
H
,
W
),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
#############################################################
# local conv
# local conv
#############################################################
# FWD pass
# FWD pass
inp_full
.
requires_grad
=
True
inp_full
.
requires_grad
=
True
out_full
=
conv_local
(
inp_full
)
out_full
=
conv_local
(
inp_full
)
...
@@ -254,9 +255,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
...
@@ -254,9 +255,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
out_full
.
backward
(
ograd_full
)
out_full
.
backward
(
ograd_full
)
igrad_full
=
inp_full
.
grad
.
clone
()
igrad_full
=
inp_full
.
grad
.
clone
()
#############################################################
# distributed conv
# distributed conv
#############################################################
# FWD pass
# FWD pass
inp_local
=
self
.
_split_helper
(
inp_full
)
inp_local
=
self
.
_split_helper
(
inp_full
)
inp_local
.
requires_grad
=
True
inp_local
.
requires_grad
=
True
...
@@ -268,9 +267,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
...
@@ -268,9 +267,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
out_local
.
backward
(
ograd_local
)
out_local
.
backward
(
ograd_local
)
igrad_local
=
inp_local
.
grad
.
clone
()
igrad_local
=
inp_local
.
grad
.
clone
()
#############################################################
# evaluate FWD pass
# evaluate FWD pass
#############################################################
with
torch
.
no_grad
():
with
torch
.
no_grad
():
out_gather_full
=
self
.
_gather_helper_fwd
(
out_local
,
B
,
C
,
conv_dist
)
out_gather_full
=
self
.
_gather_helper_fwd
(
out_local
,
B
,
C
,
conv_dist
)
err
=
torch
.
mean
(
torch
.
norm
(
out_full
-
out_gather_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
))
/
torch
.
norm
(
out_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
)))
err
=
torch
.
mean
(
torch
.
norm
(
out_full
-
out_gather_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
))
/
torch
.
norm
(
out_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
)))
...
@@ -278,9 +275,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
...
@@ -278,9 +275,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
print
(
f
"final relative error of output:
{
err
.
item
()
}
"
)
print
(
f
"final relative error of output:
{
err
.
item
()
}
"
)
self
.
assertTrue
(
err
.
item
()
<=
tol
)
self
.
assertTrue
(
err
.
item
()
<=
tol
)
#############################################################
# evaluate BWD pass
# evaluate BWD pass
#############################################################
with
torch
.
no_grad
():
with
torch
.
no_grad
():
igrad_gather_full
=
self
.
_gather_helper_bwd
(
igrad_local
,
B
,
C
,
conv_dist
)
igrad_gather_full
=
self
.
_gather_helper_bwd
(
igrad_local
,
B
,
C
,
conv_dist
)
...
...
tests/test_distributed_resample.py
View file @
c7afb546
...
@@ -41,6 +41,7 @@ import torch_harmonics.distributed as thd
...
@@ -41,6 +41,7 @@ import torch_harmonics.distributed as thd
class
TestDistributedResampling
(
unittest
.
TestCase
):
class
TestDistributedResampling
(
unittest
.
TestCase
):
"""Test the distributed resampling module (CPU/CUDA if available)."""
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
...
@@ -118,6 +119,7 @@ class TestDistributedResampling(unittest.TestCase):
...
@@ -118,6 +119,7 @@ class TestDistributedResampling(unittest.TestCase):
dist
.
destroy_process_group
(
None
)
dist
.
destroy_process_group
(
None
)
def
_split_helper
(
self
,
tensor
):
def
_split_helper
(
self
,
tensor
):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
# split in W
# split in W
tensor_list_local
=
thd
.
split_tensor_along_dim
(
tensor
,
dim
=-
1
,
num_chunks
=
self
.
grid_size_w
)
tensor_list_local
=
thd
.
split_tensor_along_dim
(
tensor
,
dim
=-
1
,
num_chunks
=
self
.
grid_size_w
)
...
@@ -130,6 +132,7 @@ class TestDistributedResampling(unittest.TestCase):
...
@@ -130,6 +132,7 @@ class TestDistributedResampling(unittest.TestCase):
return
tensor_local
return
tensor_local
def
_gather_helper_fwd
(
self
,
tensor
,
B
,
C
,
convolution_dist
):
def
_gather_helper_fwd
(
self
,
tensor
,
B
,
C
,
convolution_dist
):
# we need the shapes
# we need the shapes
lat_shapes
=
convolution_dist
.
lat_out_shapes
lat_shapes
=
convolution_dist
.
lat_out_shapes
lon_shapes
=
convolution_dist
.
lon_out_shapes
lon_shapes
=
convolution_dist
.
lon_out_shapes
...
@@ -157,6 +160,7 @@ class TestDistributedResampling(unittest.TestCase):
...
@@ -157,6 +160,7 @@ class TestDistributedResampling(unittest.TestCase):
return
tensor_gather
return
tensor_gather
def
_gather_helper_bwd
(
self
,
tensor
,
B
,
C
,
resampling_dist
):
def
_gather_helper_bwd
(
self
,
tensor
,
B
,
C
,
resampling_dist
):
# we need the shapes
# we need the shapes
lat_shapes
=
resampling_dist
.
lat_in_shapes
lat_shapes
=
resampling_dist
.
lat_in_shapes
lon_shapes
=
resampling_dist
.
lon_in_shapes
lon_shapes
=
resampling_dist
.
lon_in_shapes
...
@@ -216,9 +220,7 @@ class TestDistributedResampling(unittest.TestCase):
...
@@ -216,9 +220,7 @@ class TestDistributedResampling(unittest.TestCase):
# create tensors
# create tensors
inp_full
=
torch
.
randn
((
B
,
C
,
H
,
W
),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
inp_full
=
torch
.
randn
((
B
,
C
,
H
,
W
),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
#############################################################
# local conv
# local conv
#############################################################
# FWD pass
# FWD pass
inp_full
.
requires_grad
=
True
inp_full
.
requires_grad
=
True
out_full
=
res_local
(
inp_full
)
out_full
=
res_local
(
inp_full
)
...
@@ -232,9 +234,7 @@ class TestDistributedResampling(unittest.TestCase):
...
@@ -232,9 +234,7 @@ class TestDistributedResampling(unittest.TestCase):
out_full
.
backward
(
ograd_full
)
out_full
.
backward
(
ograd_full
)
igrad_full
=
inp_full
.
grad
.
clone
()
igrad_full
=
inp_full
.
grad
.
clone
()
#############################################################
# distributed conv
# distributed conv
#############################################################
# FWD pass
# FWD pass
inp_local
=
self
.
_split_helper
(
inp_full
)
inp_local
=
self
.
_split_helper
(
inp_full
)
inp_local
.
requires_grad
=
True
inp_local
.
requires_grad
=
True
...
@@ -246,9 +246,7 @@ class TestDistributedResampling(unittest.TestCase):
...
@@ -246,9 +246,7 @@ class TestDistributedResampling(unittest.TestCase):
out_local
.
backward
(
ograd_local
)
out_local
.
backward
(
ograd_local
)
igrad_local
=
inp_local
.
grad
.
clone
()
igrad_local
=
inp_local
.
grad
.
clone
()
#############################################################
# evaluate FWD pass
# evaluate FWD pass
#############################################################
with
torch
.
no_grad
():
with
torch
.
no_grad
():
out_gather_full
=
self
.
_gather_helper_fwd
(
out_local
,
B
,
C
,
res_dist
)
out_gather_full
=
self
.
_gather_helper_fwd
(
out_local
,
B
,
C
,
res_dist
)
err
=
torch
.
mean
(
torch
.
norm
(
out_full
-
out_gather_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
))
/
torch
.
norm
(
out_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
)))
err
=
torch
.
mean
(
torch
.
norm
(
out_full
-
out_gather_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
))
/
torch
.
norm
(
out_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
)))
...
@@ -256,9 +254,7 @@ class TestDistributedResampling(unittest.TestCase):
...
@@ -256,9 +254,7 @@ class TestDistributedResampling(unittest.TestCase):
print
(
f
"final relative error of output:
{
err
.
item
()
}
"
)
print
(
f
"final relative error of output:
{
err
.
item
()
}
"
)
self
.
assertTrue
(
err
.
item
()
<=
tol
)
self
.
assertTrue
(
err
.
item
()
<=
tol
)
#############################################################
# evaluate BWD pass
# evaluate BWD pass
#############################################################
with
torch
.
no_grad
():
with
torch
.
no_grad
():
igrad_gather_full
=
self
.
_gather_helper_bwd
(
igrad_local
,
B
,
C
,
res_dist
)
igrad_gather_full
=
self
.
_gather_helper_bwd
(
igrad_local
,
B
,
C
,
res_dist
)
...
...
tests/test_distributed_sht.py
View file @
c7afb546
...
@@ -41,10 +41,10 @@ import torch_harmonics.distributed as thd
...
@@ -41,10 +41,10 @@ import torch_harmonics.distributed as thd
class
TestDistributedSphericalHarmonicTransform
(
unittest
.
TestCase
):
class
TestDistributedSphericalHarmonicTransform
(
unittest
.
TestCase
):
"""Test the distributed spherical harmonic transform module (CPU/CUDA if available)."""
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
# set up distributed
# set up distributed
cls
.
world_rank
=
int
(
os
.
getenv
(
"WORLD_RANK"
,
0
))
cls
.
world_rank
=
int
(
os
.
getenv
(
"WORLD_RANK"
,
0
))
cls
.
grid_size_h
=
int
(
os
.
getenv
(
"GRID_H"
,
1
))
cls
.
grid_size_h
=
int
(
os
.
getenv
(
"GRID_H"
,
1
))
...
@@ -163,6 +163,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
...
@@ -163,6 +163,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
return
tensor_gather
return
tensor_gather
def
_gather_helper_bwd
(
self
,
tensor
,
B
,
C
,
transform_dist
,
vector
):
def
_gather_helper_bwd
(
self
,
tensor
,
B
,
C
,
transform_dist
,
vector
):
# we need the shapes
# we need the shapes
lat_shapes
=
transform_dist
.
lat_shapes
lat_shapes
=
transform_dist
.
lat_shapes
lon_shapes
=
transform_dist
.
lon_shapes
lon_shapes
=
transform_dist
.
lon_shapes
...
@@ -214,6 +215,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
...
@@ -214,6 +215,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
]
]
)
)
def
test_distributed_sht
(
self
,
nlat
,
nlon
,
batch_size
,
num_chan
,
grid
,
vector
,
tol
):
def
test_distributed_sht
(
self
,
nlat
,
nlon
,
batch_size
,
num_chan
,
grid
,
vector
,
tol
):
B
,
C
,
H
,
W
=
batch_size
,
num_chan
,
nlat
,
nlon
B
,
C
,
H
,
W
=
batch_size
,
num_chan
,
nlat
,
nlon
# set up handles
# set up handles
...
@@ -230,9 +232,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
...
@@ -230,9 +232,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
else
:
else
:
inp_full
=
torch
.
randn
((
B
,
C
,
H
,
W
),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
inp_full
=
torch
.
randn
((
B
,
C
,
H
,
W
),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
#############################################################
# local transform
# local transform
#############################################################
# FWD pass
# FWD pass
inp_full
.
requires_grad
=
True
inp_full
.
requires_grad
=
True
out_full
=
forward_transform_local
(
inp_full
)
out_full
=
forward_transform_local
(
inp_full
)
...
@@ -246,9 +246,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
...
@@ -246,9 +246,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
out_full
.
backward
(
ograd_full
)
out_full
.
backward
(
ograd_full
)
igrad_full
=
inp_full
.
grad
.
clone
()
igrad_full
=
inp_full
.
grad
.
clone
()
#############################################################
# distributed transform
# distributed transform
#############################################################
# FWD pass
# FWD pass
inp_local
=
self
.
_split_helper
(
inp_full
)
inp_local
=
self
.
_split_helper
(
inp_full
)
inp_local
.
requires_grad
=
True
inp_local
.
requires_grad
=
True
...
@@ -260,9 +258,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
...
@@ -260,9 +258,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
out_local
.
backward
(
ograd_local
)
out_local
.
backward
(
ograd_local
)
igrad_local
=
inp_local
.
grad
.
clone
()
igrad_local
=
inp_local
.
grad
.
clone
()
#############################################################
# evaluate FWD pass
# evaluate FWD pass
#############################################################
with
torch
.
no_grad
():
with
torch
.
no_grad
():
out_gather_full
=
self
.
_gather_helper_fwd
(
out_local
,
B
,
C
,
forward_transform_dist
,
vector
)
out_gather_full
=
self
.
_gather_helper_fwd
(
out_local
,
B
,
C
,
forward_transform_dist
,
vector
)
err
=
torch
.
mean
(
torch
.
norm
(
out_full
-
out_gather_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
))
/
torch
.
norm
(
out_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
)))
err
=
torch
.
mean
(
torch
.
norm
(
out_full
-
out_gather_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
))
/
torch
.
norm
(
out_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
)))
...
@@ -270,9 +266,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
...
@@ -270,9 +266,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
print
(
f
"final relative error of output:
{
err
.
item
()
}
"
)
print
(
f
"final relative error of output:
{
err
.
item
()
}
"
)
self
.
assertTrue
(
err
.
item
()
<=
tol
)
self
.
assertTrue
(
err
.
item
()
<=
tol
)
#############################################################
# evaluate BWD pass
# evaluate BWD pass
#############################################################
with
torch
.
no_grad
():
with
torch
.
no_grad
():
igrad_gather_full
=
self
.
_gather_helper_bwd
(
igrad_local
,
B
,
C
,
forward_transform_dist
,
vector
)
igrad_gather_full
=
self
.
_gather_helper_bwd
(
igrad_local
,
B
,
C
,
forward_transform_dist
,
vector
)
err
=
torch
.
mean
(
torch
.
norm
(
igrad_full
-
igrad_gather_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
))
/
torch
.
norm
(
igrad_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
)))
err
=
torch
.
mean
(
torch
.
norm
(
igrad_full
-
igrad_gather_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
))
/
torch
.
norm
(
igrad_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
)))
...
@@ -301,6 +295,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
...
@@ -301,6 +295,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
]
]
)
)
def
test_distributed_isht
(
self
,
nlat
,
nlon
,
batch_size
,
num_chan
,
grid
,
vector
,
tol
):
def
test_distributed_isht
(
self
,
nlat
,
nlon
,
batch_size
,
num_chan
,
grid
,
vector
,
tol
):
B
,
C
,
H
,
W
=
batch_size
,
num_chan
,
nlat
,
nlon
B
,
C
,
H
,
W
=
batch_size
,
num_chan
,
nlat
,
nlon
if
vector
:
if
vector
:
...
@@ -340,9 +335,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
...
@@ -340,9 +335,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
out_full
.
backward
(
ograd_full
)
out_full
.
backward
(
ograd_full
)
igrad_full
=
inp_full
.
grad
.
clone
()
igrad_full
=
inp_full
.
grad
.
clone
()
#############################################################
# distributed transform
# distributed transform
#############################################################
# FWD pass
# FWD pass
inp_local
=
self
.
_split_helper
(
inp_full
)
inp_local
=
self
.
_split_helper
(
inp_full
)
inp_local
.
requires_grad
=
True
inp_local
.
requires_grad
=
True
...
@@ -354,9 +347,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
...
@@ -354,9 +347,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
out_local
.
backward
(
ograd_local
)
out_local
.
backward
(
ograd_local
)
igrad_local
=
inp_local
.
grad
.
clone
()
igrad_local
=
inp_local
.
grad
.
clone
()
#############################################################
# evaluate FWD pass
# evaluate FWD pass
#############################################################
with
torch
.
no_grad
():
with
torch
.
no_grad
():
out_gather_full
=
self
.
_gather_helper_bwd
(
out_local
,
B
,
C
,
backward_transform_dist
,
vector
)
out_gather_full
=
self
.
_gather_helper_bwd
(
out_local
,
B
,
C
,
backward_transform_dist
,
vector
)
err
=
torch
.
mean
(
torch
.
norm
(
out_full
-
out_gather_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
))
/
torch
.
norm
(
out_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
)))
err
=
torch
.
mean
(
torch
.
norm
(
out_full
-
out_gather_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
))
/
torch
.
norm
(
out_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
)))
...
@@ -364,9 +355,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
...
@@ -364,9 +355,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
print
(
f
"final relative error of output:
{
err
.
item
()
}
"
)
print
(
f
"final relative error of output:
{
err
.
item
()
}
"
)
self
.
assertTrue
(
err
.
item
()
<=
tol
)
self
.
assertTrue
(
err
.
item
()
<=
tol
)
#############################################################
# evaluate BWD pass
# evaluate BWD pass
#############################################################
with
torch
.
no_grad
():
with
torch
.
no_grad
():
igrad_gather_full
=
self
.
_gather_helper_fwd
(
igrad_local
,
B
,
C
,
backward_transform_dist
,
vector
)
igrad_gather_full
=
self
.
_gather_helper_fwd
(
igrad_local
,
B
,
C
,
backward_transform_dist
,
vector
)
err
=
torch
.
mean
(
torch
.
norm
(
igrad_full
-
igrad_gather_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
))
/
torch
.
norm
(
igrad_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
)))
err
=
torch
.
mean
(
torch
.
norm
(
igrad_full
-
igrad_gather_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
))
/
torch
.
norm
(
igrad_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
)))
...
...
tests/test_sht.py
View file @
c7afb546
...
@@ -42,7 +42,7 @@ if torch.cuda.is_available():
...
@@ -42,7 +42,7 @@ if torch.cuda.is_available():
class
TestLegendrePolynomials
(
unittest
.
TestCase
):
class
TestLegendrePolynomials
(
unittest
.
TestCase
):
"""Test the associated Legendre polynomials (CPU/CUDA if available)."""
def
setUp
(
self
):
def
setUp
(
self
):
self
.
cml
=
lambda
m
,
l
:
math
.
sqrt
((
2
*
l
+
1
)
/
4
/
math
.
pi
)
*
math
.
sqrt
(
math
.
factorial
(
l
-
m
)
/
math
.
factorial
(
l
+
m
))
self
.
cml
=
lambda
m
,
l
:
math
.
sqrt
((
2
*
l
+
1
)
/
4
/
math
.
pi
)
*
math
.
sqrt
(
math
.
factorial
(
l
-
m
)
/
math
.
factorial
(
l
+
m
))
self
.
pml
=
dict
()
self
.
pml
=
dict
()
...
@@ -79,7 +79,7 @@ class TestLegendrePolynomials(unittest.TestCase):
...
@@ -79,7 +79,7 @@ class TestLegendrePolynomials(unittest.TestCase):
@
parameterized_class
((
"device"
),
_devices
)
@
parameterized_class
((
"device"
),
_devices
)
class
TestSphericalHarmonicTransform
(
unittest
.
TestCase
):
class
TestSphericalHarmonicTransform
(
unittest
.
TestCase
):
"""Test the spherical harmonic transform (CPU/CUDA if available)."""
def
setUp
(
self
):
def
setUp
(
self
):
torch
.
manual_seed
(
333
)
torch
.
manual_seed
(
333
)
if
self
.
device
.
type
==
"cuda"
:
if
self
.
device
.
type
==
"cuda"
:
...
...
torch_harmonics/_disco_convolution.py
View file @
c7afb546
...
@@ -42,7 +42,7 @@ except ImportError as err:
...
@@ -42,7 +42,7 @@ except ImportError as err:
# some helper functions
# some helper functions
def
_get_psi
(
kernel_size
:
int
,
psi_idx
:
torch
.
Tensor
,
psi_vals
:
torch
.
Tensor
,
nlat_in
:
int
,
nlon_in
:
int
,
nlat_out
:
int
,
nlon_out
:
int
,
nlat_in_local
:
Optional
[
int
]
=
None
,
nlat_out_local
:
Optional
[
int
]
=
None
,
semi_transposed
:
Optional
[
bool
]
=
False
):
def
_get_psi
(
kernel_size
:
int
,
psi_idx
:
torch
.
Tensor
,
psi_vals
:
torch
.
Tensor
,
nlat_in
:
int
,
nlon_in
:
int
,
nlat_out
:
int
,
nlon_out
:
int
,
nlat_in_local
:
Optional
[
int
]
=
None
,
nlat_out_local
:
Optional
[
int
]
=
None
,
semi_transposed
:
Optional
[
bool
]
=
False
):
"""Creates a sparse tensor for spherical harmonic convolution operations."""
nlat_in_local
=
nlat_in_local
if
nlat_in_local
is
not
None
else
nlat_in
nlat_in_local
=
nlat_in_local
if
nlat_in_local
is
not
None
else
nlat_in
nlat_out_local
=
nlat_out_local
if
nlat_out_local
is
not
None
else
nlat_out
nlat_out_local
=
nlat_out_local
if
nlat_out_local
is
not
None
else
nlat_out
...
@@ -67,6 +67,7 @@ class _DiscoS2ContractionCuda(torch.autograd.Function):
...
@@ -67,6 +67,7 @@ class _DiscoS2ContractionCuda(torch.autograd.Function):
def
forward
(
ctx
,
x
:
torch
.
Tensor
,
roff_idx
:
torch
.
Tensor
,
ker_idx
:
torch
.
Tensor
,
def
forward
(
ctx
,
x
:
torch
.
Tensor
,
roff_idx
:
torch
.
Tensor
,
ker_idx
:
torch
.
Tensor
,
row_idx
:
torch
.
Tensor
,
col_idx
:
torch
.
Tensor
,
vals
:
torch
.
Tensor
,
row_idx
:
torch
.
Tensor
,
col_idx
:
torch
.
Tensor
,
vals
:
torch
.
Tensor
,
kernel_size
:
int
,
nlat_out
:
int
,
nlon_out
:
int
):
kernel_size
:
int
,
nlat_out
:
int
,
nlon_out
:
int
):
ctx
.
save_for_backward
(
roff_idx
,
ker_idx
,
row_idx
,
col_idx
,
vals
)
ctx
.
save_for_backward
(
roff_idx
,
ker_idx
,
row_idx
,
col_idx
,
vals
)
ctx
.
kernel_size
=
kernel_size
ctx
.
kernel_size
=
kernel_size
ctx
.
nlat_in
=
x
.
shape
[
-
2
]
ctx
.
nlat_in
=
x
.
shape
[
-
2
]
...
@@ -81,6 +82,7 @@ class _DiscoS2ContractionCuda(torch.autograd.Function):
...
@@ -81,6 +82,7 @@ class _DiscoS2ContractionCuda(torch.autograd.Function):
@
staticmethod
@
staticmethod
@
custom_bwd
(
device_type
=
"cuda"
)
@
custom_bwd
(
device_type
=
"cuda"
)
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
roff_idx
,
ker_idx
,
row_idx
,
col_idx
,
vals
=
ctx
.
saved_tensors
roff_idx
,
ker_idx
,
row_idx
,
col_idx
,
vals
=
ctx
.
saved_tensors
gtype
=
grad_output
.
dtype
gtype
=
grad_output
.
dtype
grad_output
=
grad_output
.
to
(
torch
.
float32
).
contiguous
()
grad_output
=
grad_output
.
to
(
torch
.
float32
).
contiguous
()
...
@@ -97,6 +99,7 @@ class _DiscoS2TransposeContractionCuda(torch.autograd.Function):
...
@@ -97,6 +99,7 @@ class _DiscoS2TransposeContractionCuda(torch.autograd.Function):
def
forward
(
ctx
,
x
:
torch
.
Tensor
,
roff_idx
:
torch
.
Tensor
,
ker_idx
:
torch
.
Tensor
,
def
forward
(
ctx
,
x
:
torch
.
Tensor
,
roff_idx
:
torch
.
Tensor
,
ker_idx
:
torch
.
Tensor
,
row_idx
:
torch
.
Tensor
,
col_idx
:
torch
.
Tensor
,
vals
:
torch
.
Tensor
,
row_idx
:
torch
.
Tensor
,
col_idx
:
torch
.
Tensor
,
vals
:
torch
.
Tensor
,
kernel_size
:
int
,
nlat_out
:
int
,
nlon_out
:
int
):
kernel_size
:
int
,
nlat_out
:
int
,
nlon_out
:
int
):
ctx
.
save_for_backward
(
roff_idx
,
ker_idx
,
row_idx
,
col_idx
,
vals
)
ctx
.
save_for_backward
(
roff_idx
,
ker_idx
,
row_idx
,
col_idx
,
vals
)
ctx
.
kernel_size
=
kernel_size
ctx
.
kernel_size
=
kernel_size
ctx
.
nlat_in
=
x
.
shape
[
-
2
]
ctx
.
nlat_in
=
x
.
shape
[
-
2
]
...
@@ -111,6 +114,7 @@ class _DiscoS2TransposeContractionCuda(torch.autograd.Function):
...
@@ -111,6 +114,7 @@ class _DiscoS2TransposeContractionCuda(torch.autograd.Function):
@
staticmethod
@
staticmethod
@
custom_bwd
(
device_type
=
"cuda"
)
@
custom_bwd
(
device_type
=
"cuda"
)
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
roff_idx
,
ker_idx
,
row_idx
,
col_idx
,
vals
=
ctx
.
saved_tensors
roff_idx
,
ker_idx
,
row_idx
,
col_idx
,
vals
=
ctx
.
saved_tensors
gtype
=
grad_output
.
dtype
gtype
=
grad_output
.
dtype
grad_output
=
grad_output
.
to
(
torch
.
float32
).
contiguous
()
grad_output
=
grad_output
.
to
(
torch
.
float32
).
contiguous
()
...
@@ -140,6 +144,7 @@ def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: in
...
@@ -140,6 +144,7 @@ def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: in
shifting of the input tensor, which can potentially be costly. For an efficient implementation
shifting of the input tensor, which can potentially be costly. For an efficient implementation
on GPU, make sure to use the custom kernel written in CUDA.
on GPU, make sure to use the custom kernel written in CUDA.
"""
"""
assert
len
(
psi
.
shape
)
==
3
assert
len
(
psi
.
shape
)
==
3
assert
len
(
x
.
shape
)
==
4
assert
len
(
x
.
shape
)
==
4
psi
=
psi
.
to
(
x
.
device
)
psi
=
psi
.
to
(
x
.
device
)
...
@@ -171,11 +176,6 @@ def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: in
...
@@ -171,11 +176,6 @@ def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: in
def
_disco_s2_transpose_contraction_torch
(
x
:
torch
.
Tensor
,
psi
:
torch
.
Tensor
,
nlon_out
:
int
):
def
_disco_s2_transpose_contraction_torch
(
x
:
torch
.
Tensor
,
psi
:
torch
.
Tensor
,
nlon_out
:
int
):
"""
Reference implementation of the custom contraction as described in [1]. This requires repeated
shifting of the input tensor, which can potentially be costly. For an efficient implementation
on GPU, make sure to use the custom kernel written in CUDA.
"""
assert
len
(
psi
.
shape
)
==
3
assert
len
(
psi
.
shape
)
==
3
assert
len
(
x
.
shape
)
==
5
assert
len
(
x
.
shape
)
==
5
psi
=
psi
.
to
(
x
.
device
)
psi
=
psi
.
to
(
x
.
device
)
...
...
torch_harmonics/_neighborhood_attention.py
View file @
c7afb546
...
@@ -50,8 +50,6 @@ except ImportError as err:
...
@@ -50,8 +50,6 @@ except ImportError as err:
def
_neighborhood_attention_s2_fwd_torch
(
kx
:
torch
.
Tensor
,
vx
:
torch
.
Tensor
,
qy
:
torch
.
Tensor
,
def
_neighborhood_attention_s2_fwd_torch
(
kx
:
torch
.
Tensor
,
vx
:
torch
.
Tensor
,
qy
:
torch
.
Tensor
,
quad_weights
:
torch
.
Tensor
,
col_idx
:
torch
.
Tensor
,
row_off
:
torch
.
Tensor
,
quad_weights
:
torch
.
Tensor
,
col_idx
:
torch
.
Tensor
,
row_off
:
torch
.
Tensor
,
nlon_in
:
int
,
nlat_out
:
int
,
nlon_out
:
int
)
->
torch
.
Tensor
:
nlon_in
:
int
,
nlat_out
:
int
,
nlon_out
:
int
)
->
torch
.
Tensor
:
# prepare result tensor
# prepare result tensor
y
=
torch
.
zeros_like
(
qy
)
y
=
torch
.
zeros_like
(
qy
)
...
@@ -170,7 +168,6 @@ def _neighborhood_attention_s2_bwd_dv_torch(kx: torch.Tensor, vx: torch.Tensor,
...
@@ -170,7 +168,6 @@ def _neighborhood_attention_s2_bwd_dv_torch(kx: torch.Tensor, vx: torch.Tensor,
def
_neighborhood_attention_s2_bwd_dk_torch
(
kx
:
torch
.
Tensor
,
vx
:
torch
.
Tensor
,
qy
:
torch
.
Tensor
,
dy
:
torch
.
Tensor
,
def
_neighborhood_attention_s2_bwd_dk_torch
(
kx
:
torch
.
Tensor
,
vx
:
torch
.
Tensor
,
qy
:
torch
.
Tensor
,
dy
:
torch
.
Tensor
,
quad_weights
:
torch
.
Tensor
,
col_idx
:
torch
.
Tensor
,
row_off
:
torch
.
Tensor
,
quad_weights
:
torch
.
Tensor
,
col_idx
:
torch
.
Tensor
,
row_off
:
torch
.
Tensor
,
nlon_in
:
int
,
nlat_out
:
int
,
nlon_out
:
int
):
nlon_in
:
int
,
nlat_out
:
int
,
nlon_out
:
int
):
# shapes:
# shapes:
# input
# input
# kx: B, C, Hi, Wi
# kx: B, C, Hi, Wi
...
@@ -252,6 +249,7 @@ def _neighborhood_attention_s2_bwd_dq_torch(kx: torch.Tensor, vx: torch.Tensor,
...
@@ -252,6 +249,7 @@ def _neighborhood_attention_s2_bwd_dq_torch(kx: torch.Tensor, vx: torch.Tensor,
quad_weights
:
torch
.
Tensor
,
col_idx
:
torch
.
Tensor
,
row_off
:
torch
.
Tensor
,
quad_weights
:
torch
.
Tensor
,
col_idx
:
torch
.
Tensor
,
row_off
:
torch
.
Tensor
,
nlon_in
:
int
,
nlat_out
:
int
,
nlon_out
:
int
):
nlon_in
:
int
,
nlat_out
:
int
,
nlon_out
:
int
):
# shapes:
# shapes:
# input
# input
# kx: B, C, Hi, Wi
# kx: B, C, Hi, Wi
...
@@ -451,6 +449,7 @@ def _neighborhood_attention_s2_torch(k: torch.Tensor, v: torch.Tensor, q: torch.
...
@@ -451,6 +449,7 @@ def _neighborhood_attention_s2_torch(k: torch.Tensor, v: torch.Tensor, q: torch.
class
_NeighborhoodAttentionS2Cuda
(
torch
.
autograd
.
Function
):
class
_NeighborhoodAttentionS2Cuda
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
@
custom_fwd
(
device_type
=
"cuda"
)
@
custom_fwd
(
device_type
=
"cuda"
)
def
forward
(
ctx
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
def
forward
(
ctx
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
...
...
torch_harmonics/attention.py
View file @
c7afb546
...
@@ -142,9 +142,6 @@ class AttentionS2(nn.Module):
...
@@ -142,9 +142,6 @@ class AttentionS2(nn.Module):
def
extra_repr
(
self
):
def
extra_repr
(
self
):
r
"""
Pretty print module
"""
return
f
"in_shape=
{
(
self
.
nlat_in
,
self
.
nlon_in
)
}
, out_shape=
{
(
self
.
nlat_out
,
self
.
nlon_out
)
}
, in_channels=
{
self
.
in_channels
}
, out_channels=
{
self
.
out_channels
}
, k_channels=
{
self
.
k_channels
}
"
return
f
"in_shape=
{
(
self
.
nlat_in
,
self
.
nlon_in
)
}
, out_shape=
{
(
self
.
nlat_out
,
self
.
nlon_out
)
}
, in_channels=
{
self
.
in_channels
}
, out_channels=
{
self
.
out_channels
}
, k_channels=
{
self
.
k_channels
}
"
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
value
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
value
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
...
@@ -317,9 +314,6 @@ class NeighborhoodAttentionS2(nn.Module):
...
@@ -317,9 +314,6 @@ class NeighborhoodAttentionS2(nn.Module):
self
.
proj_bias
=
None
self
.
proj_bias
=
None
def
extra_repr
(
self
):
def
extra_repr
(
self
):
r
"""
Pretty print module
"""
return
f
"in_shape=
{
(
self
.
nlat_in
,
self
.
nlon_in
)
}
, out_shape=
{
(
self
.
nlat_out
,
self
.
nlon_out
)
}
, in_channels=
{
self
.
in_channels
}
, out_channels=
{
self
.
out_channels
}
, k_channels=
{
self
.
k_channels
}
"
return
f
"in_shape=
{
(
self
.
nlat_in
,
self
.
nlon_in
)
}
, out_shape=
{
(
self
.
nlat_out
,
self
.
nlon_out
)
}
, in_channels=
{
self
.
in_channels
}
, out_channels=
{
self
.
out_channels
}
, k_channels=
{
self
.
k_channels
}
"
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
value
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
value
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
...
...
torch_harmonics/cache.py
View file @
c7afb546
...
@@ -35,6 +35,32 @@ from copy import deepcopy
...
@@ -35,6 +35,32 @@ from copy import deepcopy
# copying LRU cache decorator a la:
# copying LRU cache decorator a la:
# https://stackoverflow.com/questions/54909357/how-to-get-functools-lru-cache-to-return-new-instances
# https://stackoverflow.com/questions/54909357/how-to-get-functools-lru-cache-to-return-new-instances
def
lru_cache
(
maxsize
=
20
,
typed
=
False
,
copy
=
False
):
def
lru_cache
(
maxsize
=
20
,
typed
=
False
,
copy
=
False
):
"""
Least Recently Used (LRU) cache decorator with optional deep copying.
This is a wrapper around functools.lru_cache that adds the ability to return
deep copies of cached results to prevent unintended modifications to cached objects.
Parameters
-----------
maxsize : int, optional
Maximum number of items to cache, by default 20
typed : bool, optional
Whether to cache different types separately, by default False
copy : bool, optional
Whether to return deep copies of cached results, by default False
Returns
-------
function
Decorated function with LRU caching
Example
-------
>>> @lru_cache(maxsize=10, copy=True)
... def expensive_function(x):
... return [x, x*2, x*3]
"""
def
decorator
(
f
):
def
decorator
(
f
):
cached_func
=
functools
.
lru_cache
(
maxsize
=
maxsize
,
typed
=
typed
)(
f
)
cached_func
=
functools
.
lru_cache
(
maxsize
=
maxsize
,
typed
=
typed
)(
f
)
def
wrapper
(
*
args
,
**
kwargs
):
def
wrapper
(
*
args
,
**
kwargs
):
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment