Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
313b1b73
Commit
313b1b73
authored
Jun 26, 2025
by
apaaris
Committed by
Boris Bonev
Jul 21, 2025
Browse files
Corrected docstrings in _layers.py
parent
e4879676
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
230 additions
and
222 deletions
+230
-222
torch_harmonics/examples/models/_layers.py
torch_harmonics/examples/models/_layers.py
+230
-222
No files found.
torch_harmonics/examples/models/_layers.py
View file @
313b1b73
...
...
@@ -42,12 +42,15 @@ from torch_harmonics import InverseRealSHT
def
_no_grad_trunc_normal_
(
tensor
,
mean
,
std
,
a
,
b
):
"""
Internal function to fill tensor with truncated normal distribution values.
Initialize tensor with truncated normal distribution without gradients.
This is a helper function for trunc_normal_ that performs the actual initialization
without requiring gradients to be tracked.
Parameters
-----------
tensor : torch.Tensor
Tensor to
fill with values
Tensor to
initialize
mean : float
Mean of the normal distribution
std : float
...
...
@@ -60,11 +63,24 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
Returns
-------
torch.Tensor
The fill
ed tensor
Initializ
ed tensor
"""
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def
norm_cdf
(
x
):
"""
Compute standard normal cumulative distribution function.
Parameters
-----------
x : float
Input value
Returns
-------
float
CDF value
"""
# Computes standard normal cumulative distribution function
return
(
1.0
+
math
.
erf
(
x
/
math
.
sqrt
(
2.0
)))
/
2.0
...
...
@@ -117,28 +133,12 @@ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
@
torch
.
jit
.
script
def
drop_path
(
x
:
torch
.
Tensor
,
drop_prob
:
float
=
0.0
,
training
:
bool
=
False
)
->
torch
.
Tensor
:
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
Parameters
-----------
x : torch.Tensor
Input tensor
drop_prob : float, optional
Dropout probability, by default 0.0
training : bool, optional
Whether in training mode, by default False
Returns
-------
torch.Tensor
Output tensor with potential drop path applied
"""
if
drop_prob
==
0.0
or
not
training
:
return
x
...
...
@@ -151,23 +151,26 @@ def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -
class
DropPath
(
nn
.
Module
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This module implements stochastic depth regularization by randomly dropping
entire residual paths during training, which helps with regularization and
training of very deep networks.
Parameters
-----------
drop_prob : float, optional
Probability of dropping a path, by default None
"""
def
__init__
(
self
,
drop_prob
=
None
):
"""
Initialize DropPath module.
Parameters
-----------
drop_prob : float, optional
Dropout probability, by default None
"""
super
(
DropPath
,
self
).
__init__
()
self
.
drop_prob
=
drop_prob
def
forward
(
self
,
x
):
"""
Forward pass with drop path.
Forward pass with drop path
regularization
.
Parameters
-----------
...
...
@@ -177,7 +180,7 @@ class DropPath(nn.Module):
Returns
-------
torch.Tensor
Output tensor with potential
drop path applied
Output tensor with potential
path dropping
"""
return
drop_path
(
x
,
self
.
drop_prob
,
self
.
training
)
...
...
@@ -186,6 +189,9 @@ class PatchEmbed(nn.Module):
"""
Patch embedding layer for vision transformers.
This module splits input images into patches and projects them to a
higher dimensional embedding space using convolutional layers.
Parameters
-----------
img_size : tuple, optional
...
...
@@ -216,12 +222,12 @@ class PatchEmbed(nn.Module):
Parameters
-----------
x : torch.Tensor
Input tensor
with
shape (batch, channels, height, width)
Input tensor
of
shape (batch
_size
, channels, height, width)
Returns
-------
torch.Tensor
Embedded patches with
shape (batch, embed_dim, num_patches)
Patch embeddings of
shape (batch
_size
, embed_dim, num_patches)
"""
# gather input
B
,
C
,
H
,
W
=
x
.
shape
...
...
@@ -235,6 +241,9 @@ class MLP(nn.Module):
"""
Multi-layer perceptron with optional checkpointing.
This module implements a feed-forward network with two linear layers
and an activation function, with optional dropout and gradient checkpointing.
Parameters
-----------
in_features : int
...
...
@@ -252,7 +261,7 @@ class MLP(nn.Module):
checkpointing : bool, optional
Whether to use gradient checkpointing, by default False
gain : float, optional
Gain factor for
outpu
t initialization, by default 1.0
Gain factor for
weigh
t initialization, by default 1.0
"""
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
ReLU
,
output_bias
=
False
,
drop_rate
=
0.0
,
checkpointing
=
False
,
gain
=
1.0
):
...
...
@@ -325,24 +334,24 @@ class MLP(nn.Module):
class
RealFFT2
(
nn
.
Module
):
"""
Helper routine to wrap FFT similarly to the SHT
Helper routine to wrap FFT similarly to the SHT.
This module provides a wrapper around PyTorch's real FFT2D that mimics
the interface of spherical harmonic transforms for consistency.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
lmax : int, optional
Maximum spherical harmonic degree, by default None (same as nlat)
mmax : int, optional
Maximum spherical harmonic order, by default None (nlon//2 + 1)
"""
def
__init__
(
self
,
nlat
,
nlon
,
lmax
=
None
,
mmax
=
None
):
"""
Initialize RealFFT2 module.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
lmax : int, optional
Maximum l mode, by default None (same as nlat)
mmax : int, optional
Maximum m mode, by default None (nlon // 2 + 1)
"""
super
(
RealFFT2
,
self
).
__init__
()
self
.
nlat
=
nlat
...
...
@@ -352,17 +361,17 @@ class RealFFT2(nn.Module):
def
forward
(
self
,
x
):
"""
Forward pass
of R
ealFFT2.
Forward pass
: compute r
eal
FFT2
D
.
Parameters
-----------
x : torch.Tensor
Input tensor
with shape (batch, channels, nlat, nlon)
Input tensor
Returns
-------
torch.Tensor
Output tensor with shape (batch, channels, nlat, mmax)
FFT coefficients
"""
y
=
torch
.
fft
.
rfft2
(
x
,
dim
=
(
-
2
,
-
1
),
norm
=
"ortho"
)
y
=
torch
.
cat
((
y
[...,
:
math
.
ceil
(
self
.
lmax
/
2
),
:
self
.
mmax
],
y
[...,
-
math
.
floor
(
self
.
lmax
/
2
)
:,
:
self
.
mmax
]),
dim
=-
2
)
...
...
@@ -371,24 +380,24 @@ class RealFFT2(nn.Module):
class
InverseRealFFT2
(
nn
.
Module
):
"""
Helper routine to wrap inverse FFT similarly to the SHT
Helper routine to wrap inverse FFT similarly to the SHT.
This module provides a wrapper around PyTorch's inverse real FFT2D that mimics
the interface of inverse spherical harmonic transforms for consistency.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
lmax : int, optional
Maximum spherical harmonic degree, by default None (same as nlat)
mmax : int, optional
Maximum spherical harmonic order, by default None (nlon//2 + 1)
"""
def
__init__
(
self
,
nlat
,
nlon
,
lmax
=
None
,
mmax
=
None
):
"""
Initialize InverseRealFFT2 module.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
lmax : int, optional
Maximum l mode, by default None (same as nlat)
mmax : int, optional
Maximum m mode, by default None (nlon // 2 + 1)
"""
super
(
InverseRealFFT2
,
self
).
__init__
()
self
.
nlat
=
nlat
...
...
@@ -398,45 +407,46 @@ class InverseRealFFT2(nn.Module):
def
forward
(
self
,
x
):
"""
Forward pass
of I
nverse
R
ealFFT2.
Forward pass
: compute i
nverse
r
eal
FFT2
D
.
Parameters
-----------
x : torch.Tensor
Input
tensor with shape (batch, channels, nlat, mmax)
Input
FFT coefficients
Returns
-------
torch.Tensor
Output tensor with shape (batch, channels, nlat, nlon)
Reconstructed spatial signal
"""
return
torch
.
fft
.
irfft2
(
x
,
dim
=
(
-
2
,
-
1
),
s
=
(
self
.
nlat
,
self
.
nlon
),
norm
=
"ortho"
)
class
LayerNorm
(
nn
.
Module
):
"""
Wrapper class that moves the channel dimension to the end
Wrapper class that moves the channel dimension to the end.
This module provides a layer normalization that works with channel-first
tensors by temporarily transposing the channel dimension to the end,
applying normalization, and then transposing back.
Parameters
-----------
in_channels : int
Number of input channels
eps : float, optional
Epsilon for numerical stability, by default 1e-05
elementwise_affine : bool, optional
Whether to use learnable affine parameters, by default True
bias : bool, optional
Whether to use bias, by default True
device : torch.device, optional
Device to place the module on, by default None
dtype : torch.dtype, optional
Data type for the module, by default None
"""
def
__init__
(
self
,
in_channels
,
eps
=
1e-05
,
elementwise_affine
=
True
,
bias
=
True
,
device
=
None
,
dtype
=
None
):
"""
Initialize LayerNorm module.
Parameters
-----------
in_channels : int
Number of input channels
eps : float, optional
Epsilon for numerical stability, by default 1e-05
elementwise_affine : bool, optional
Whether to use learnable affine parameters, by default True
bias : bool, optional
Whether to use bias, by default True
device : torch.device, optional
Device to place the module on, by default None
dtype : torch.dtype, optional
Data type, by default None
"""
super
().
__init__
()
self
.
channel_dim
=
-
3
...
...
@@ -445,31 +455,33 @@ class LayerNorm(nn.Module):
def
forward
(
self
,
x
):
"""
Forward pass
of LayerNorm
.
Forward pass
with channel dimension handling
.
Parameters
-----------
x : torch.Tensor
Input tensor
Input tensor
with channel dimension at -3
Returns
-------
torch.Tensor
Normalized tensor
Normalized tensor
with same shape as input
"""
return
self
.
norm
(
x
.
transpose
(
self
.
channel_dim
,
-
1
)).
transpose
(
-
1
,
self
.
channel_dim
)
class
SpectralConvS2
(
nn
.
Module
):
"""
Spectral convolution layer for spherical data.
Spectral Convolution according to Driscoll & Healy. Designed for convolutions on the two-sphere S2
using the Spherical Harmonic Transforms in torch-harmonics, but supports convolutions on the periodic
domain via the RealFFT2 and InverseRealFFT2 wrappers.
Parameters
-----------
forward_transform : nn.Module
Forward transform (
e.g., RealSH
T)
Forward transform (
SHT or FF
T)
inverse_transform : nn.Module
Inverse transform (
e.g., InverseRealSH
T)
Inverse transform (
ISHT or IFF
T)
in_channels : int
Number of input channels
out_channels : int
...
...
@@ -477,31 +489,49 @@ class SpectralConvS2(nn.Module):
gain : float, optional
Gain factor for weight initialization, by default 2.0
operator_type : str, optional
Type of spectral operator, by default "driscoll-healy"
Type of spectral operator
("driscoll-healy", "diagonal", "block-diagonal")
, by default "driscoll-healy"
lr_scale_exponent : int, optional
Learning rate scal
e
exponent, by default 0
Learning rate scal
ing
exponent, by default 0
bias : bool, optional
Whether to use bias, by default False
"""
def
__init__
(
self
,
forward_transform
,
inverse_transform
,
in_channels
,
out_channels
,
gain
=
2.0
,
operator_type
=
"driscoll-healy"
,
lr_scale_exponent
=
0
,
bias
=
False
):
super
(
SpectralConvS2
,
self
).
__init__
()
super
().
__init__
()
self
.
forward_transform
=
forward_transform
self
.
inverse_transform
=
inverse_transform
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
modes_lat
=
self
.
inverse_transform
.
lmax
self
.
modes_lon
=
self
.
inverse_transform
.
mmax
self
.
scale_residual
=
(
self
.
forward_transform
.
nlat
!=
self
.
inverse_transform
.
nlat
)
or
(
self
.
forward_transform
.
nlon
!=
self
.
inverse_transform
.
nlon
)
# remember factorization details
self
.
operator_type
=
operator_type
self
.
lr_scale_exponent
=
lr_scale_exponent
# initialize the weights
scale
=
math
.
sqrt
(
gain
/
in_channels
)
self
.
weight
=
nn
.
Parameter
(
scale
*
torch
.
randn
(
out_channels
,
in_channels
,
dtype
=
torch
.
cfloat
))
assert
self
.
inverse_transform
.
lmax
==
self
.
modes_lat
assert
self
.
inverse_transform
.
mmax
==
self
.
modes_lon
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
out_channels
,
dtype
=
torch
.
cfloat
))
weight_shape
=
[
out_channels
,
in_channels
]
if
self
.
operator_type
==
"diagonal"
:
weight_shape
+=
[
self
.
modes_lat
,
self
.
modes_lon
]
self
.
contract_func
=
"...ilm,oilm->...olm"
elif
self
.
operator_type
==
"block-diagonal"
:
weight_shape
+=
[
self
.
modes_lat
,
self
.
modes_lon
,
self
.
modes_lon
]
self
.
contract_func
=
"...ilm,oilnm->...oln"
elif
self
.
operator_type
==
"driscoll-healy"
:
weight_shape
+=
[
self
.
modes_lat
]
self
.
contract_func
=
"...ilm,oil->...olm"
else
:
self
.
bias
=
None
raise
NotImplementedError
(
f
"Unkonw operator type f
{
self
.
operator_type
}
"
)
# form weight tensors
scale
=
math
.
sqrt
(
gain
/
in_channels
)
self
.
weight
=
nn
.
Parameter
(
scale
*
torch
.
randn
(
*
weight_shape
,
dtype
=
torch
.
complex64
))
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
out_channels
,
1
,
1
))
def
forward
(
self
,
x
):
"""
...
...
@@ -514,28 +544,36 @@ class SpectralConvS2(nn.Module):
Returns
-------
t
orch.Tensor
Output tensor after spectral convolution
t
uple
Tuple containing (output, residual) tensors
"""
# apply forward transform
x
=
self
.
forward_transform
(
x
)
dtype
=
x
.
dtype
x
=
x
.
float
()
residual
=
x
# apply spectral convolution
x
=
torch
.
einsum
(
"bilm,oim->bolm"
,
x
,
self
.
weight
)
with
torch
.
autocast
(
device_type
=
"cuda"
,
enabled
=
False
):
x
=
self
.
forward_transform
(
x
)
if
self
.
scale_residual
:
residual
=
self
.
inverse_transform
(
x
)
# apply inverse transform
x
=
self
.
inverse_transform
(
x
)
x
=
torch
.
einsum
(
self
.
contract_func
,
x
,
self
.
weight
)
# add bias if present
if
self
.
bias
is
not
None
:
x
=
x
+
self
.
bias
.
view
(
1
,
-
1
,
1
,
1
)
with
torch
.
autocast
(
device_type
=
"cuda"
,
enabled
=
False
):
x
=
self
.
inverse_transform
(
x
)
return
x
if
hasattr
(
self
,
"bias"
):
x
=
x
+
self
.
bias
x
=
x
.
type
(
dtype
)
return
x
,
residual
class
PositionEmbedding
(
nn
.
Module
,
metaclass
=
abc
.
ABCMeta
):
"""
Abstract base class for position embeddings on spherical data.
Abstract base class for position embeddings.
This class defines the interface for position embedding modules
that add positional information to input tensors.
Parameters
-----------
...
...
@@ -548,30 +586,34 @@ class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta):
"""
def
__init__
(
self
,
img_shape
=
(
480
,
960
),
grid
=
"equiangular"
,
num_chans
=
1
):
super
(
PositionEmbedding
,
self
).
__init__
()
super
().
__init__
()
self
.
img_shape
=
img_shape
self
.
grid
=
grid
self
.
num_chans
=
num_chans
@
abc
.
abstractmethod
def
forward
(
self
,
x
:
torch
.
Tensor
):
"""
Abstract forward method for
position embedding.
Forward pass: add
position embedding
s to input
.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Input tensor with position embeddings added
"""
pas
s
return
x
+
self
.
position_embedding
s
class
SequencePositionEmbedding
(
PositionEmbedding
):
"""
Sequence-based position embedding
for spherical data
.
S
tandard s
equence-based position embedding.
This module
adds position embeddings based on the sequence of latitude and longitud
e
coordinates, providing spatial context to the model
.
This module
implements sinusoidal position embeddings similar to thos
e
used in the original Transformer paper, adapted for 2D spatial data
.
Parameters
-----------
...
...
@@ -582,38 +624,29 @@ class SequencePositionEmbedding(PositionEmbedding):
num_chans : int, optional
Number of channels, by default 1
"""
def
__init__
(
self
,
img_shape
=
(
480
,
960
),
grid
=
"equiangular"
,
num_chans
=
1
):
super
(
SequencePositionEmbedding
,
self
).
__init__
(
img_shape
=
img_shape
,
grid
=
grid
,
num_chans
=
num_chans
)
super
().
__init__
(
img_shape
=
img_shape
,
grid
=
grid
,
num_chans
=
num_chans
)
# create position embeddings
pos_embed
=
torch
.
zeros
(
1
,
num_chans
,
img_shape
[
0
],
img_shape
[
1
])
nn
.
init
.
trunc_normal_
(
pos_embed
,
std
=
0.02
)
self
.
register_buffer
(
"pos_embed"
,
pos_embed
)
with
torch
.
no_grad
():
# alternating custom position embeddings
pos
=
torch
.
arange
(
self
.
img_shape
[
0
]
*
self
.
img_shape
[
1
]).
reshape
(
1
,
1
,
*
self
.
img_shape
).
repeat
(
1
,
self
.
num_chans
,
1
,
1
)
k
=
torch
.
arange
(
self
.
num_chans
).
reshape
(
1
,
self
.
num_chans
,
1
,
1
)
denom
=
torch
.
pow
(
10000
,
2
*
k
/
self
.
num_chans
)
def
forward
(
self
,
x
:
torch
.
Tensor
):
"""
Forward pass of sequence position embedding.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Tensor with position embeddings added
"""
return
x
+
self
.
pos_embed
pos_embed
=
torch
.
where
(
k
%
2
==
0
,
torch
.
sin
(
pos
/
denom
),
torch
.
cos
(
pos
/
denom
))
# register tensor
self
.
register_buffer
(
"position_embeddings"
,
pos_embed
.
float
())
class
SpectralPositionEmbedding
(
PositionEmbedding
):
r
"""
Spectral position embedding for spherical
data
.
"""
Spectral position embedding
s
for spherical
transformers
.
This module adds position embeddings in the spectral domain using spherical harmonics,
providing spectral context to the model.
This module creates position embeddings in the spectral domain using
spherical harmonic functions, which are particularly suitable for
spherical data processing.
Parameters
-----------
...
...
@@ -624,39 +657,43 @@ class SpectralPositionEmbedding(PositionEmbedding):
num_chans : int, optional
Number of channels, by default 1
"""
def
__init__
(
self
,
img_shape
=
(
480
,
960
),
grid
=
"equiangular"
,
num_chans
=
1
):
super
(
SpectralPositionEmbedding
,
self
).
__init__
(
img_shape
=
img_shape
,
grid
=
grid
,
num_chans
=
num_chans
)
super
().
__init__
(
img_shape
=
img_shape
,
grid
=
grid
,
num_chans
=
num_chans
)
# create spectral position embeddings
pos_embed
=
torch
.
zeros
(
1
,
num_chans
,
img_shape
[
0
],
img_shape
[
1
]
//
2
+
1
,
dtype
=
torch
.
cfloat
)
nn
.
init
.
trunc_normal_
(
pos_embed
.
real
,
std
=
0.02
)
nn
.
init
.
trunc_normal_
(
pos_embed
.
imag
,
std
=
0.02
)
self
.
register_buffer
(
"pos_embed"
,
pos_embed
)
# compute maximum required frequency and prepare isht
lmax
=
math
.
floor
(
math
.
sqrt
(
self
.
num_chans
))
+
1
isht
=
InverseRealSHT
(
nlat
=
self
.
img_shape
[
0
],
nlon
=
self
.
img_shape
[
1
],
lmax
=
lmax
,
mmax
=
lmax
,
grid
=
grid
)
def
forward
(
self
,
x
:
torch
.
Tensor
):
"""
Forward pass of spectral position embedding.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Tensor with spectral position embeddings added
"""
return
x
+
self
.
pos_embed
# fill position embedding
with
torch
.
no_grad
():
pos_embed_freq
=
torch
.
zeros
(
1
,
self
.
num_chans
,
isht
.
lmax
,
isht
.
mmax
,
dtype
=
torch
.
complex64
)
for
i
in
range
(
self
.
num_chans
):
l
=
math
.
floor
(
math
.
sqrt
(
i
))
m
=
i
-
l
**
2
-
l
if
m
<
0
:
pos_embed_freq
[
0
,
i
,
l
,
-
m
]
=
1.0j
else
:
pos_embed_freq
[
0
,
i
,
l
,
m
]
=
1.0
# compute spatial position embeddings
pos_embed
=
isht
(
pos_embed_freq
)
# normalization
pos_embed
=
pos_embed
/
torch
.
amax
(
pos_embed
.
abs
(),
dim
=
(
-
1
,
-
2
),
keepdim
=
True
)
# register tensor
self
.
register_buffer
(
"position_embeddings"
,
pos_embed
)
class
LearnablePositionEmbedding
(
PositionEmbedding
):
r
"""
Learnable position embedding for spherical
data
.
"""
Learnable position embedding
s
for spherical
transformers
.
This module
add
s learnable position embeddings that
are optimized during training,
allowing the model to learn optimal spatial representation
s.
This module
provide
s learnable position embeddings that
can be either
latitude-only or full latitude-longitude embedding
s.
Parameters
-----------
...
...
@@ -667,47 +704,18 @@ class LearnablePositionEmbedding(PositionEmbedding):
num_chans : int, optional
Number of channels, by default 1
embed_type : str, optional
Embedding type ("lat"
, "lon", or "both
"), by default "lat"
Embedding type ("lat"
or "latlon
"), by default "lat"
"""
def
__init__
(
self
,
img_shape
=
(
480
,
960
),
grid
=
"equiangular"
,
num_chans
=
1
,
embed_type
=
"lat"
):
super
(
LearnablePositionEmbedding
,
self
).
__init__
(
img_shape
=
img_shape
,
grid
=
grid
,
num_chans
=
num_chans
)
self
.
embed_type
=
embed_type
if
embed_type
==
"lat"
:
# latitude embedding
pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
num_chans
,
img_shape
[
0
],
1
))
nn
.
init
.
trunc_normal_
(
pos_embed
,
std
=
0.02
)
self
.
register_parameter
(
"pos_embed"
,
pos_embed
)
elif
embed_type
==
"lon"
:
# longitude embedding
pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
num_chans
,
1
,
img_shape
[
1
]))
nn
.
init
.
trunc_normal_
(
pos_embed
,
std
=
0.02
)
self
.
register_parameter
(
"pos_embed"
,
pos_embed
)
elif
embed_type
==
"latlon"
:
# full lat-lon embedding
pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
num_chans
,
img_shape
[
0
],
img_shape
[
1
]))
nn
.
init
.
trunc_normal_
(
pos_embed
,
std
=
0.02
)
self
.
register_parameter
(
"pos_embed"
,
pos_embed
)
else
:
raise
ValueError
(
f
"Unknown embedding type
{
embed_type
}
"
)
super
().
__init__
(
img_shape
=
img_shape
,
grid
=
grid
,
num_chans
=
num_chans
)
def
forward
(
self
,
x
:
torch
.
Tensor
):
"""
Forward pass of learnable position embedding.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Tensor with learnable position embeddings added
"""
return
x
+
self
.
pos_embed
if
embed_type
==
"latlon"
:
self
.
position_embeddings
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
self
.
num_chans
,
self
.
img_shape
[
0
],
self
.
img_shape
[
1
]))
elif
embed_type
==
"lat"
:
self
.
position_embeddings
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
self
.
num_chans
,
self
.
img_shape
[
0
],
1
))
else
:
raise
ValueError
(
f
"Unknown learnable position embedding type
{
embed_type
}
"
)
# class SpiralPositionEmbedding(PositionEmbedding):
# """
...
...
@@ -731,4 +739,4 @@ class LearnablePositionEmbedding(PositionEmbedding):
# pos_embed = torch.where(k % 2 == 0, torch.sin(k * (lons + lats)), torch.cos(k * (lons - lats)))
# # register tensor
# self.register_buffer("position_embeddings", pos_embed.float())
\ No newline at end of file
# self.register_buffer("position_embeddings", pos_embed.float())
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment