Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
313b1b73
Commit
313b1b73
authored
Jun 26, 2025
by
apaaris
Committed by
Boris Bonev
Jul 21, 2025
Browse files
Corrected docstrings in _layers.py
parent
e4879676
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
230 additions
and
222 deletions
+230
-222
torch_harmonics/examples/models/_layers.py
torch_harmonics/examples/models/_layers.py
+230
-222
No files found.
torch_harmonics/examples/models/_layers.py
View file @
313b1b73
...
@@ -42,12 +42,15 @@ from torch_harmonics import InverseRealSHT
...
@@ -42,12 +42,15 @@ from torch_harmonics import InverseRealSHT
def
_no_grad_trunc_normal_
(
tensor
,
mean
,
std
,
a
,
b
):
def
_no_grad_trunc_normal_
(
tensor
,
mean
,
std
,
a
,
b
):
"""
"""
Internal function to fill tensor with truncated normal distribution values.
Initialize tensor with truncated normal distribution without gradients.
This is a helper function for trunc_normal_ that performs the actual initialization
without requiring gradients to be tracked.
Parameters
Parameters
-----------
-----------
tensor : torch.Tensor
tensor : torch.Tensor
Tensor to
fill with values
Tensor to
initialize
mean : float
mean : float
Mean of the normal distribution
Mean of the normal distribution
std : float
std : float
...
@@ -60,11 +63,24 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
...
@@ -60,11 +63,24 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
Returns
Returns
-------
-------
torch.Tensor
torch.Tensor
The fill
ed tensor
Initializ
ed tensor
"""
"""
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def
norm_cdf
(
x
):
def
norm_cdf
(
x
):
"""
Compute standard normal cumulative distribution function.
Parameters
-----------
x : float
Input value
Returns
-------
float
CDF value
"""
# Computes standard normal cumulative distribution function
# Computes standard normal cumulative distribution function
return
(
1.0
+
math
.
erf
(
x
/
math
.
sqrt
(
2.0
)))
/
2.0
return
(
1.0
+
math
.
erf
(
x
/
math
.
sqrt
(
2.0
)))
/
2.0
...
@@ -117,28 +133,12 @@ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
...
@@ -117,28 +133,12 @@ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
@
torch
.
jit
.
script
@
torch
.
jit
.
script
def
drop_path
(
x
:
torch
.
Tensor
,
drop_prob
:
float
=
0.0
,
training
:
bool
=
False
)
->
torch
.
Tensor
:
def
drop_path
(
x
:
torch
.
Tensor
,
drop_prob
:
float
=
0.0
,
training
:
bool
=
False
)
->
torch
.
Tensor
:
"""
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
'survival rate' as the argument.
Parameters
-----------
x : torch.Tensor
Input tensor
drop_prob : float, optional
Dropout probability, by default 0.0
training : bool, optional
Whether in training mode, by default False
Returns
-------
torch.Tensor
Output tensor with potential drop path applied
"""
"""
if
drop_prob
==
0.0
or
not
training
:
if
drop_prob
==
0.0
or
not
training
:
return
x
return
x
...
@@ -151,23 +151,26 @@ def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -
...
@@ -151,23 +151,26 @@ def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -
class
DropPath
(
nn
.
Module
):
class
DropPath
(
nn
.
Module
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This module implements stochastic depth regularization by randomly dropping
entire residual paths during training, which helps with regularization and
training of very deep networks.
Parameters
-----------
drop_prob : float, optional
Probability of dropping a path, by default None
"""
def
__init__
(
self
,
drop_prob
=
None
):
def
__init__
(
self
,
drop_prob
=
None
):
"""
Initialize DropPath module.
Parameters
-----------
drop_prob : float, optional
Dropout probability, by default None
"""
super
(
DropPath
,
self
).
__init__
()
super
(
DropPath
,
self
).
__init__
()
self
.
drop_prob
=
drop_prob
self
.
drop_prob
=
drop_prob
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
"""
"""
Forward pass with drop path.
Forward pass with drop path
regularization
.
Parameters
Parameters
-----------
-----------
...
@@ -177,7 +180,7 @@ class DropPath(nn.Module):
...
@@ -177,7 +180,7 @@ class DropPath(nn.Module):
Returns
Returns
-------
-------
torch.Tensor
torch.Tensor
Output tensor with potential
drop path applied
Output tensor with potential
path dropping
"""
"""
return
drop_path
(
x
,
self
.
drop_prob
,
self
.
training
)
return
drop_path
(
x
,
self
.
drop_prob
,
self
.
training
)
...
@@ -186,6 +189,9 @@ class PatchEmbed(nn.Module):
...
@@ -186,6 +189,9 @@ class PatchEmbed(nn.Module):
"""
"""
Patch embedding layer for vision transformers.
Patch embedding layer for vision transformers.
This module splits input images into patches and projects them to a
higher dimensional embedding space using convolutional layers.
Parameters
Parameters
-----------
-----------
img_size : tuple, optional
img_size : tuple, optional
...
@@ -216,12 +222,12 @@ class PatchEmbed(nn.Module):
...
@@ -216,12 +222,12 @@ class PatchEmbed(nn.Module):
Parameters
Parameters
-----------
-----------
x : torch.Tensor
x : torch.Tensor
Input tensor
with
shape (batch, channels, height, width)
Input tensor
of
shape (batch
_size
, channels, height, width)
Returns
Returns
-------
-------
torch.Tensor
torch.Tensor
Embedded patches with
shape (batch, embed_dim, num_patches)
Patch embeddings of
shape (batch
_size
, embed_dim, num_patches)
"""
"""
# gather input
# gather input
B
,
C
,
H
,
W
=
x
.
shape
B
,
C
,
H
,
W
=
x
.
shape
...
@@ -235,6 +241,9 @@ class MLP(nn.Module):
...
@@ -235,6 +241,9 @@ class MLP(nn.Module):
"""
"""
Multi-layer perceptron with optional checkpointing.
Multi-layer perceptron with optional checkpointing.
This module implements a feed-forward network with two linear layers
and an activation function, with optional dropout and gradient checkpointing.
Parameters
Parameters
-----------
-----------
in_features : int
in_features : int
...
@@ -252,7 +261,7 @@ class MLP(nn.Module):
...
@@ -252,7 +261,7 @@ class MLP(nn.Module):
checkpointing : bool, optional
checkpointing : bool, optional
Whether to use gradient checkpointing, by default False
Whether to use gradient checkpointing, by default False
gain : float, optional
gain : float, optional
Gain factor for
outpu
t initialization, by default 1.0
Gain factor for
weigh
t initialization, by default 1.0
"""
"""
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
ReLU
,
output_bias
=
False
,
drop_rate
=
0.0
,
checkpointing
=
False
,
gain
=
1.0
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
ReLU
,
output_bias
=
False
,
drop_rate
=
0.0
,
checkpointing
=
False
,
gain
=
1.0
):
...
@@ -325,24 +334,24 @@ class MLP(nn.Module):
...
@@ -325,24 +334,24 @@ class MLP(nn.Module):
class
RealFFT2
(
nn
.
Module
):
class
RealFFT2
(
nn
.
Module
):
"""
"""
Helper routine to wrap FFT similarly to the SHT
Helper routine to wrap FFT similarly to the SHT.
This module provides a wrapper around PyTorch's real FFT2D that mimics
the interface of spherical harmonic transforms for consistency.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
lmax : int, optional
Maximum spherical harmonic degree, by default None (same as nlat)
mmax : int, optional
Maximum spherical harmonic order, by default None (nlon//2 + 1)
"""
"""
def
__init__
(
self
,
nlat
,
nlon
,
lmax
=
None
,
mmax
=
None
):
def
__init__
(
self
,
nlat
,
nlon
,
lmax
=
None
,
mmax
=
None
):
"""
Initialize RealFFT2 module.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
lmax : int, optional
Maximum l mode, by default None (same as nlat)
mmax : int, optional
Maximum m mode, by default None (nlon // 2 + 1)
"""
super
(
RealFFT2
,
self
).
__init__
()
super
(
RealFFT2
,
self
).
__init__
()
self
.
nlat
=
nlat
self
.
nlat
=
nlat
...
@@ -352,17 +361,17 @@ class RealFFT2(nn.Module):
...
@@ -352,17 +361,17 @@ class RealFFT2(nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
"""
"""
Forward pass
of R
ealFFT2.
Forward pass
: compute r
eal
FFT2
D
.
Parameters
Parameters
-----------
-----------
x : torch.Tensor
x : torch.Tensor
Input tensor
with shape (batch, channels, nlat, nlon)
Input tensor
Returns
Returns
-------
-------
torch.Tensor
torch.Tensor
Output tensor with shape (batch, channels, nlat, mmax)
FFT coefficients
"""
"""
y
=
torch
.
fft
.
rfft2
(
x
,
dim
=
(
-
2
,
-
1
),
norm
=
"ortho"
)
y
=
torch
.
fft
.
rfft2
(
x
,
dim
=
(
-
2
,
-
1
),
norm
=
"ortho"
)
y
=
torch
.
cat
((
y
[...,
:
math
.
ceil
(
self
.
lmax
/
2
),
:
self
.
mmax
],
y
[...,
-
math
.
floor
(
self
.
lmax
/
2
)
:,
:
self
.
mmax
]),
dim
=-
2
)
y
=
torch
.
cat
((
y
[...,
:
math
.
ceil
(
self
.
lmax
/
2
),
:
self
.
mmax
],
y
[...,
-
math
.
floor
(
self
.
lmax
/
2
)
:,
:
self
.
mmax
]),
dim
=-
2
)
...
@@ -371,24 +380,24 @@ class RealFFT2(nn.Module):
...
@@ -371,24 +380,24 @@ class RealFFT2(nn.Module):
class
InverseRealFFT2
(
nn
.
Module
):
class
InverseRealFFT2
(
nn
.
Module
):
"""
"""
Helper routine to wrap inverse FFT similarly to the SHT
Helper routine to wrap inverse FFT similarly to the SHT.
This module provides a wrapper around PyTorch's inverse real FFT2D that mimics
the interface of inverse spherical harmonic transforms for consistency.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
lmax : int, optional
Maximum spherical harmonic degree, by default None (same as nlat)
mmax : int, optional
Maximum spherical harmonic order, by default None (nlon//2 + 1)
"""
"""
def
__init__
(
self
,
nlat
,
nlon
,
lmax
=
None
,
mmax
=
None
):
def
__init__
(
self
,
nlat
,
nlon
,
lmax
=
None
,
mmax
=
None
):
"""
Initialize InverseRealFFT2 module.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
lmax : int, optional
Maximum l mode, by default None (same as nlat)
mmax : int, optional
Maximum m mode, by default None (nlon // 2 + 1)
"""
super
(
InverseRealFFT2
,
self
).
__init__
()
super
(
InverseRealFFT2
,
self
).
__init__
()
self
.
nlat
=
nlat
self
.
nlat
=
nlat
...
@@ -398,45 +407,46 @@ class InverseRealFFT2(nn.Module):
...
@@ -398,45 +407,46 @@ class InverseRealFFT2(nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
"""
"""
Forward pass
of I
nverse
R
ealFFT2.
Forward pass
: compute i
nverse
r
eal
FFT2
D
.
Parameters
Parameters
-----------
-----------
x : torch.Tensor
x : torch.Tensor
Input
tensor with shape (batch, channels, nlat, mmax)
Input
FFT coefficients
Returns
Returns
-------
-------
torch.Tensor
torch.Tensor
Output tensor with shape (batch, channels, nlat, nlon)
Reconstructed spatial signal
"""
"""
return
torch
.
fft
.
irfft2
(
x
,
dim
=
(
-
2
,
-
1
),
s
=
(
self
.
nlat
,
self
.
nlon
),
norm
=
"ortho"
)
return
torch
.
fft
.
irfft2
(
x
,
dim
=
(
-
2
,
-
1
),
s
=
(
self
.
nlat
,
self
.
nlon
),
norm
=
"ortho"
)
class
LayerNorm
(
nn
.
Module
):
class
LayerNorm
(
nn
.
Module
):
"""
"""
Wrapper class that moves the channel dimension to the end
Wrapper class that moves the channel dimension to the end.
This module provides a layer normalization that works with channel-first
tensors by temporarily transposing the channel dimension to the end,
applying normalization, and then transposing back.
Parameters
-----------
in_channels : int
Number of input channels
eps : float, optional
Epsilon for numerical stability, by default 1e-05
elementwise_affine : bool, optional
Whether to use learnable affine parameters, by default True
bias : bool, optional
Whether to use bias, by default True
device : torch.device, optional
Device to place the module on, by default None
dtype : torch.dtype, optional
Data type for the module, by default None
"""
"""
def
__init__
(
self
,
in_channels
,
eps
=
1e-05
,
elementwise_affine
=
True
,
bias
=
True
,
device
=
None
,
dtype
=
None
):
def
__init__
(
self
,
in_channels
,
eps
=
1e-05
,
elementwise_affine
=
True
,
bias
=
True
,
device
=
None
,
dtype
=
None
):
"""
Initialize LayerNorm module.
Parameters
-----------
in_channels : int
Number of input channels
eps : float, optional
Epsilon for numerical stability, by default 1e-05
elementwise_affine : bool, optional
Whether to use learnable affine parameters, by default True
bias : bool, optional
Whether to use bias, by default True
device : torch.device, optional
Device to place the module on, by default None
dtype : torch.dtype, optional
Data type, by default None
"""
super
().
__init__
()
super
().
__init__
()
self
.
channel_dim
=
-
3
self
.
channel_dim
=
-
3
...
@@ -445,31 +455,33 @@ class LayerNorm(nn.Module):
...
@@ -445,31 +455,33 @@ class LayerNorm(nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
"""
"""
Forward pass
of LayerNorm
.
Forward pass
with channel dimension handling
.
Parameters
Parameters
-----------
-----------
x : torch.Tensor
x : torch.Tensor
Input tensor
Input tensor
with channel dimension at -3
Returns
Returns
-------
-------
torch.Tensor
torch.Tensor
Normalized tensor
Normalized tensor
with same shape as input
"""
"""
return
self
.
norm
(
x
.
transpose
(
self
.
channel_dim
,
-
1
)).
transpose
(
-
1
,
self
.
channel_dim
)
return
self
.
norm
(
x
.
transpose
(
self
.
channel_dim
,
-
1
)).
transpose
(
-
1
,
self
.
channel_dim
)
class
SpectralConvS2
(
nn
.
Module
):
class
SpectralConvS2
(
nn
.
Module
):
"""
"""
Spectral convolution layer for spherical data.
Spectral Convolution according to Driscoll & Healy. Designed for convolutions on the two-sphere S2
using the Spherical Harmonic Transforms in torch-harmonics, but supports convolutions on the periodic
domain via the RealFFT2 and InverseRealFFT2 wrappers.
Parameters
Parameters
-----------
-----------
forward_transform : nn.Module
forward_transform : nn.Module
Forward transform (
e.g., RealSH
T)
Forward transform (
SHT or FF
T)
inverse_transform : nn.Module
inverse_transform : nn.Module
Inverse transform (
e.g., InverseRealSH
T)
Inverse transform (
ISHT or IFF
T)
in_channels : int
in_channels : int
Number of input channels
Number of input channels
out_channels : int
out_channels : int
...
@@ -477,31 +489,49 @@ class SpectralConvS2(nn.Module):
...
@@ -477,31 +489,49 @@ class SpectralConvS2(nn.Module):
gain : float, optional
gain : float, optional
Gain factor for weight initialization, by default 2.0
Gain factor for weight initialization, by default 2.0
operator_type : str, optional
operator_type : str, optional
Type of spectral operator, by default "driscoll-healy"
Type of spectral operator
("driscoll-healy", "diagonal", "block-diagonal")
, by default "driscoll-healy"
lr_scale_exponent : int, optional
lr_scale_exponent : int, optional
Learning rate scal
e
exponent, by default 0
Learning rate scal
ing
exponent, by default 0
bias : bool, optional
bias : bool, optional
Whether to use bias, by default False
Whether to use bias, by default False
"""
"""
def
__init__
(
self
,
forward_transform
,
inverse_transform
,
in_channels
,
out_channels
,
gain
=
2.0
,
operator_type
=
"driscoll-healy"
,
lr_scale_exponent
=
0
,
bias
=
False
):
def
__init__
(
self
,
forward_transform
,
inverse_transform
,
in_channels
,
out_channels
,
gain
=
2.0
,
operator_type
=
"driscoll-healy"
,
lr_scale_exponent
=
0
,
bias
=
False
):
super
(
SpectralConvS2
,
self
).
__init__
()
super
().
__init__
()
self
.
forward_transform
=
forward_transform
self
.
forward_transform
=
forward_transform
self
.
inverse_transform
=
inverse_transform
self
.
inverse_transform
=
inverse_transform
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
modes_lat
=
self
.
inverse_transform
.
lmax
self
.
modes_lon
=
self
.
inverse_transform
.
mmax
self
.
scale_residual
=
(
self
.
forward_transform
.
nlat
!=
self
.
inverse_transform
.
nlat
)
or
(
self
.
forward_transform
.
nlon
!=
self
.
inverse_transform
.
nlon
)
# remember factorization details
self
.
operator_type
=
operator_type
self
.
operator_type
=
operator_type
self
.
lr_scale_exponent
=
lr_scale_exponent
# initialize the weights
assert
self
.
inverse_transform
.
lmax
==
self
.
modes_lat
scale
=
math
.
sqrt
(
gain
/
in_channels
)
assert
self
.
inverse_transform
.
mmax
==
self
.
modes_lon
self
.
weight
=
nn
.
Parameter
(
scale
*
torch
.
randn
(
out_channels
,
in_channels
,
dtype
=
torch
.
cfloat
))
if
bias
:
weight_shape
=
[
out_channels
,
in_channels
]
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
out_channels
,
dtype
=
torch
.
cfloat
))
if
self
.
operator_type
==
"diagonal"
:
weight_shape
+=
[
self
.
modes_lat
,
self
.
modes_lon
]
self
.
contract_func
=
"...ilm,oilm->...olm"
elif
self
.
operator_type
==
"block-diagonal"
:
weight_shape
+=
[
self
.
modes_lat
,
self
.
modes_lon
,
self
.
modes_lon
]
self
.
contract_func
=
"...ilm,oilnm->...oln"
elif
self
.
operator_type
==
"driscoll-healy"
:
weight_shape
+=
[
self
.
modes_lat
]
self
.
contract_func
=
"...ilm,oil->...olm"
else
:
else
:
self
.
bias
=
None
raise
NotImplementedError
(
f
"Unkonw operator type f
{
self
.
operator_type
}
"
)
# form weight tensors
scale
=
math
.
sqrt
(
gain
/
in_channels
)
self
.
weight
=
nn
.
Parameter
(
scale
*
torch
.
randn
(
*
weight_shape
,
dtype
=
torch
.
complex64
))
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
out_channels
,
1
,
1
))
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
"""
"""
...
@@ -514,28 +544,36 @@ class SpectralConvS2(nn.Module):
...
@@ -514,28 +544,36 @@ class SpectralConvS2(nn.Module):
Returns
Returns
-------
-------
t
orch.Tensor
t
uple
Output tensor after spectral convolution
Tuple containing (output, residual) tensors
"""
"""
# apply forward transform
dtype
=
x
.
dtype
x
=
self
.
forward_transform
(
x
)
x
=
x
.
float
()
residual
=
x
# apply spectral convolution
with
torch
.
autocast
(
device_type
=
"cuda"
,
enabled
=
False
):
x
=
torch
.
einsum
(
"bilm,oim->bolm"
,
x
,
self
.
weight
)
x
=
self
.
forward_transform
(
x
)
if
self
.
scale_residual
:
residual
=
self
.
inverse_transform
(
x
)
# apply inverse transform
x
=
torch
.
einsum
(
self
.
contract_func
,
x
,
self
.
weight
)
x
=
self
.
inverse_transform
(
x
)
# add bias if present
with
torch
.
autocast
(
device_type
=
"cuda"
,
enabled
=
False
):
if
self
.
bias
is
not
None
:
x
=
self
.
inverse_transform
(
x
)
x
=
x
+
self
.
bias
.
view
(
1
,
-
1
,
1
,
1
)
return
x
if
hasattr
(
self
,
"bias"
):
x
=
x
+
self
.
bias
x
=
x
.
type
(
dtype
)
return
x
,
residual
class
PositionEmbedding
(
nn
.
Module
,
metaclass
=
abc
.
ABCMeta
):
class
PositionEmbedding
(
nn
.
Module
,
metaclass
=
abc
.
ABCMeta
):
"""
"""
Abstract base class for position embeddings on spherical data.
Abstract base class for position embeddings.
This class defines the interface for position embedding modules
that add positional information to input tensors.
Parameters
Parameters
-----------
-----------
...
@@ -548,30 +586,34 @@ class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta):
...
@@ -548,30 +586,34 @@ class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta):
"""
"""
def
__init__
(
self
,
img_shape
=
(
480
,
960
),
grid
=
"equiangular"
,
num_chans
=
1
):
def
__init__
(
self
,
img_shape
=
(
480
,
960
),
grid
=
"equiangular"
,
num_chans
=
1
):
super
(
PositionEmbedding
,
self
).
__init__
()
super
().
__init__
()
self
.
img_shape
=
img_shape
self
.
img_shape
=
img_shape
self
.
grid
=
grid
self
.
num_chans
=
num_chans
self
.
num_chans
=
num_chans
@
abc
.
abstractmethod
def
forward
(
self
,
x
:
torch
.
Tensor
):
def
forward
(
self
,
x
:
torch
.
Tensor
):
"""
"""
Abstract forward method for
position embedding.
Forward pass: add
position embedding
s to input
.
Parameters
Parameters
-----------
-----------
x : torch.Tensor
x : torch.Tensor
Input tensor
Input tensor
Returns
-------
torch.Tensor
Input tensor with position embeddings added
"""
"""
pas
s
return
x
+
self
.
position_embedding
s
class
SequencePositionEmbedding
(
PositionEmbedding
):
class
SequencePositionEmbedding
(
PositionEmbedding
):
"""
"""
Sequence-based position embedding
for spherical data
.
S
tandard s
equence-based position embedding.
This module
adds position embeddings based on the sequence of latitude and longitud
e
This module
implements sinusoidal position embeddings similar to thos
e
coordinates, providing spatial context to the model
.
used in the original Transformer paper, adapted for 2D spatial data
.
Parameters
Parameters
-----------
-----------
...
@@ -582,38 +624,29 @@ class SequencePositionEmbedding(PositionEmbedding):
...
@@ -582,38 +624,29 @@ class SequencePositionEmbedding(PositionEmbedding):
num_chans : int, optional
num_chans : int, optional
Number of channels, by default 1
Number of channels, by default 1
"""
"""
def
__init__
(
self
,
img_shape
=
(
480
,
960
),
grid
=
"equiangular"
,
num_chans
=
1
):
def
__init__
(
self
,
img_shape
=
(
480
,
960
),
grid
=
"equiangular"
,
num_chans
=
1
):
super
(
SequencePositionEmbedding
,
self
).
__init__
(
img_shape
=
img_shape
,
grid
=
grid
,
num_chans
=
num_chans
)
super
().
__init__
(
img_shape
=
img_shape
,
grid
=
grid
,
num_chans
=
num_chans
)
# create position embeddings
with
torch
.
no_grad
():
pos_embed
=
torch
.
zeros
(
1
,
num_chans
,
img_shape
[
0
],
img_shape
[
1
])
# alternating custom position embeddings
nn
.
init
.
trunc_normal_
(
pos_embed
,
std
=
0.02
)
pos
=
torch
.
arange
(
self
.
img_shape
[
0
]
*
self
.
img_shape
[
1
]).
reshape
(
1
,
1
,
*
self
.
img_shape
).
repeat
(
1
,
self
.
num_chans
,
1
,
1
)
self
.
register_buffer
(
"pos_embed"
,
pos_embed
)
k
=
torch
.
arange
(
self
.
num_chans
).
reshape
(
1
,
self
.
num_chans
,
1
,
1
)
denom
=
torch
.
pow
(
10000
,
2
*
k
/
self
.
num_chans
)
def
forward
(
self
,
x
:
torch
.
Tensor
):
pos_embed
=
torch
.
where
(
k
%
2
==
0
,
torch
.
sin
(
pos
/
denom
),
torch
.
cos
(
pos
/
denom
))
"""
Forward pass of sequence position embedding.
# register tensor
self
.
register_buffer
(
"position_embeddings"
,
pos_embed
.
float
())
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Tensor with position embeddings added
"""
return
x
+
self
.
pos_embed
class
SpectralPositionEmbedding
(
PositionEmbedding
):
class
SpectralPositionEmbedding
(
PositionEmbedding
):
r
"""
"""
Spectral position embedding for spherical
data
.
Spectral position embedding
s
for spherical
transformers
.
This module adds position embeddings in the spectral domain using spherical harmonics,
This module creates position embeddings in the spectral domain using
providing spectral context to the model.
spherical harmonic functions, which are particularly suitable for
spherical data processing.
Parameters
Parameters
-----------
-----------
...
@@ -624,39 +657,43 @@ class SpectralPositionEmbedding(PositionEmbedding):
...
@@ -624,39 +657,43 @@ class SpectralPositionEmbedding(PositionEmbedding):
num_chans : int, optional
num_chans : int, optional
Number of channels, by default 1
Number of channels, by default 1
"""
"""
def
__init__
(
self
,
img_shape
=
(
480
,
960
),
grid
=
"equiangular"
,
num_chans
=
1
):
def
__init__
(
self
,
img_shape
=
(
480
,
960
),
grid
=
"equiangular"
,
num_chans
=
1
):
super
(
SpectralPositionEmbedding
,
self
).
__init__
(
img_shape
=
img_shape
,
grid
=
grid
,
num_chans
=
num_chans
)
super
().
__init__
(
img_shape
=
img_shape
,
grid
=
grid
,
num_chans
=
num_chans
)
# create spectral position embeddings
# compute maximum required frequency and prepare isht
pos_embed
=
torch
.
zeros
(
1
,
num_chans
,
img_shape
[
0
],
img_shape
[
1
]
//
2
+
1
,
dtype
=
torch
.
cfloat
)
lmax
=
math
.
floor
(
math
.
sqrt
(
self
.
num_chans
))
+
1
nn
.
init
.
trunc_normal_
(
pos_embed
.
real
,
std
=
0.02
)
isht
=
InverseRealSHT
(
nlat
=
self
.
img_shape
[
0
],
nlon
=
self
.
img_shape
[
1
],
lmax
=
lmax
,
mmax
=
lmax
,
grid
=
grid
)
nn
.
init
.
trunc_normal_
(
pos_embed
.
imag
,
std
=
0.02
)
self
.
register_buffer
(
"pos_embed"
,
pos_embed
)
def
forward
(
self
,
x
:
torch
.
Tensor
):
# fill position embedding
"""
with
torch
.
no_grad
():
Forward pass of spectral position embedding.
pos_embed_freq
=
torch
.
zeros
(
1
,
self
.
num_chans
,
isht
.
lmax
,
isht
.
mmax
,
dtype
=
torch
.
complex64
)
Parameters
for
i
in
range
(
self
.
num_chans
):
-----------
l
=
math
.
floor
(
math
.
sqrt
(
i
))
x : torch.Tensor
m
=
i
-
l
**
2
-
l
Input tensor
if
m
<
0
:
Returns
pos_embed_freq
[
0
,
i
,
l
,
-
m
]
=
1.0j
-------
else
:
torch.Tensor
pos_embed_freq
[
0
,
i
,
l
,
m
]
=
1.0
Tensor with spectral position embeddings added
"""
# compute spatial position embeddings
return
x
+
self
.
pos_embed
pos_embed
=
isht
(
pos_embed_freq
)
# normalization
pos_embed
=
pos_embed
/
torch
.
amax
(
pos_embed
.
abs
(),
dim
=
(
-
1
,
-
2
),
keepdim
=
True
)
# register tensor
self
.
register_buffer
(
"position_embeddings"
,
pos_embed
)
class
LearnablePositionEmbedding
(
PositionEmbedding
):
class
LearnablePositionEmbedding
(
PositionEmbedding
):
r
"""
"""
Learnable position embedding for spherical
data
.
Learnable position embedding
s
for spherical
transformers
.
This module
add
s learnable position embeddings that
are optimized during training,
This module
provide
s learnable position embeddings that
can be either
allowing the model to learn optimal spatial representation
s.
latitude-only or full latitude-longitude embedding
s.
Parameters
Parameters
-----------
-----------
...
@@ -667,47 +704,18 @@ class LearnablePositionEmbedding(PositionEmbedding):
...
@@ -667,47 +704,18 @@ class LearnablePositionEmbedding(PositionEmbedding):
num_chans : int, optional
num_chans : int, optional
Number of channels, by default 1
Number of channels, by default 1
embed_type : str, optional
embed_type : str, optional
Embedding type ("lat"
, "lon", or "both
"), by default "lat"
Embedding type ("lat"
or "latlon
"), by default "lat"
"""
"""
def
__init__
(
self
,
img_shape
=
(
480
,
960
),
grid
=
"equiangular"
,
num_chans
=
1
,
embed_type
=
"lat"
):
def
__init__
(
self
,
img_shape
=
(
480
,
960
),
grid
=
"equiangular"
,
num_chans
=
1
,
embed_type
=
"lat"
):
super
(
LearnablePositionEmbedding
,
self
).
__init__
(
img_shape
=
img_shape
,
grid
=
grid
,
num_chans
=
num_chans
)
super
().
__init__
(
img_shape
=
img_shape
,
grid
=
grid
,
num_chans
=
num_chans
)
self
.
embed_type
=
embed_type
if
embed_type
==
"lat"
:
# latitude embedding
pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
num_chans
,
img_shape
[
0
],
1
))
nn
.
init
.
trunc_normal_
(
pos_embed
,
std
=
0.02
)
self
.
register_parameter
(
"pos_embed"
,
pos_embed
)
elif
embed_type
==
"lon"
:
# longitude embedding
pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
num_chans
,
1
,
img_shape
[
1
]))
nn
.
init
.
trunc_normal_
(
pos_embed
,
std
=
0.02
)
self
.
register_parameter
(
"pos_embed"
,
pos_embed
)
elif
embed_type
==
"latlon"
:
# full lat-lon embedding
pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
num_chans
,
img_shape
[
0
],
img_shape
[
1
]))
nn
.
init
.
trunc_normal_
(
pos_embed
,
std
=
0.02
)
self
.
register_parameter
(
"pos_embed"
,
pos_embed
)
else
:
raise
ValueError
(
f
"Unknown embedding type
{
embed_type
}
"
)
def
forward
(
self
,
x
:
torch
.
Tensor
):
if
embed_type
==
"latlon"
:
"""
self
.
position_embeddings
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
self
.
num_chans
,
self
.
img_shape
[
0
],
self
.
img_shape
[
1
]))
Forward pass of learnable position embedding.
elif
embed_type
==
"lat"
:
self
.
position_embeddings
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
self
.
num_chans
,
self
.
img_shape
[
0
],
1
))
Parameters
else
:
-----------
raise
ValueError
(
f
"Unknown learnable position embedding type
{
embed_type
}
"
)
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Tensor with learnable position embeddings added
"""
return
x
+
self
.
pos_embed
# class SpiralPositionEmbedding(PositionEmbedding):
# class SpiralPositionEmbedding(PositionEmbedding):
# """
# """
...
@@ -731,4 +739,4 @@ class LearnablePositionEmbedding(PositionEmbedding):
...
@@ -731,4 +739,4 @@ class LearnablePositionEmbedding(PositionEmbedding):
# pos_embed = torch.where(k % 2 == 0, torch.sin(k * (lons + lats)), torch.cos(k * (lons - lats)))
# pos_embed = torch.where(k % 2 == 0, torch.sin(k * (lons + lats)), torch.cos(k * (lons - lats)))
# # register tensor
# # register tensor
# self.register_buffer("position_embeddings", pos_embed.float())
# self.register_buffer("position_embeddings", pos_embed.float())
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment