Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
901e8635
Commit
901e8635
authored
Jul 16, 2025
by
Andrea Paris
Committed by
Boris Bonev
Jul 21, 2025
Browse files
removed docstrings from forward passes
parent
6373534a
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
32 additions
and
568 deletions
+32
-568
examples/baseline_models/segformer.py
examples/baseline_models/segformer.py
+1
-26
torch_harmonics/convolution.py
torch_harmonics/convolution.py
+1
-13
torch_harmonics/distributed/distributed_resample.py
torch_harmonics/distributed/distributed_resample.py
+0
-13
torch_harmonics/distributed/primitives.py
torch_harmonics/distributed/primitives.py
+1
-57
torch_harmonics/examples/losses.py
torch_harmonics/examples/losses.py
+1
-58
torch_harmonics/examples/metrics.py
torch_harmonics/examples/metrics.py
+3
-45
torch_harmonics/examples/models/_layers.py
torch_harmonics/examples/models/_layers.py
+9
-117
torch_harmonics/examples/models/lsno.py
torch_harmonics/examples/models/lsno.py
+3
-39
torch_harmonics/examples/models/s2segformer.py
torch_harmonics/examples/models/s2segformer.py
+4
-78
torch_harmonics/examples/models/s2transformer.py
torch_harmonics/examples/models/s2transformer.py
+3
-39
torch_harmonics/examples/models/s2unet.py
torch_harmonics/examples/models/s2unet.py
+2
-26
torch_harmonics/examples/models/sfno.py
torch_harmonics/examples/models/sfno.py
+3
-39
torch_harmonics/random_fields.py
torch_harmonics/random_fields.py
+1
-18
No files found.
examples/baseline_models/segformer.py
View file @
901e8635
...
...
@@ -103,19 +103,6 @@ class OverlapPatchMerging(nn.Module):
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
forward
(
self
,
x
):
"""
Forward pass through the OverlapPatchMerging layer.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after patch merging
"""
x
=
self
.
conv
(
x
)
# permute
...
...
@@ -204,19 +191,7 @@ class MixFFN(nn.Module):
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Forward pass through the MixFFN module.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after processing
"""
residual
=
x
# norm
...
...
torch_harmonics/convolution.py
View file @
901e8635
...
...
@@ -693,19 +693,7 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
return
psi
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Forward pass
Parameters
-----------
x: torch.Tensor
Input tensor
Returns
-------
out: torch.Tensor
Output tensor
"""
# extract shape
B
,
C
,
H
,
W
=
x
.
shape
x
=
x
.
reshape
(
B
,
self
.
groups
,
self
.
groupsize
,
H
,
W
)
...
...
torch_harmonics/distributed/distributed_resample.py
View file @
901e8635
...
...
@@ -248,19 +248,6 @@ class DistributedResampleS2(nn.Module):
return
x
def
forward
(
self
,
x
:
torch
.
Tensor
):
"""
Forward pass for distributed resampling.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (batch, channels, nlat, nlon)
Returns
-------
torch.Tensor
Resampled tensor with shape (batch, channels, nlat_out, nlon_out)
"""
if
self
.
skip_resampling
:
return
x
...
...
torch_harmonics/distributed/primitives.py
View file @
901e8635
...
...
@@ -150,23 +150,7 @@ class distributed_transpose_azimuth(torch.autograd.Function):
@
staticmethod
@
custom_fwd
(
device_type
=
"cuda"
)
def
forward
(
ctx
,
x
,
dims
,
dim1_split_sizes
):
r
"""
Forward pass for distributed azimuthal transpose.
Parameters
----------
x: torch.Tensor
The tensor to transpose
dims: List[int]
The dimensions to transpose
dim1_split_sizes: List[int]
The split sizes for the second dimension
Returns
-------
x: torch.Tensor
The transposed tensor
"""
# WAR for a potential contig check torch bug for channels last contig tensors
xlist
,
dim0_split_sizes
,
_
=
_transpose
(
x
,
dims
[
0
],
dims
[
1
],
dim1_split_sizes
,
group
=
azimuth_group
())
x
=
torch
.
cat
(
xlist
,
dim
=
dims
[
1
])
...
...
@@ -205,23 +189,7 @@ class distributed_transpose_polar(torch.autograd.Function):
@
staticmethod
@
custom_fwd
(
device_type
=
"cuda"
)
def
forward
(
ctx
,
x
,
dim
,
dim1_split_sizes
):
r
"""
Forward pass for distributed polar transpose.
Parameters
----------
x: torch.Tensor
The tensor to transpose
dim: List[int]
The dimensions to transpose
dim1_split_sizes: List[int]
The split sizes for the second dimension
Returns
-------
x: torch.Tensor
The transposed tensor
"""
# WAR for a potential contig check torch bug for channels last contig tensors
xlist
,
dim0_split_sizes
,
_
=
_transpose
(
x
,
dim
[
0
],
dim
[
1
],
dim1_split_sizes
,
group
=
polar_group
())
x
=
torch
.
cat
(
xlist
,
dim
=
dim
[
1
])
...
...
@@ -363,19 +331,7 @@ class _CopyToPolarRegion(torch.autograd.Function):
@
staticmethod
@
custom_fwd
(
device_type
=
"cuda"
)
def
forward
(
ctx
,
input_
):
r
"""
Forward pass for copying to polar region.
Parameters
----------
input_: torch.Tensor
The tensor to copy
Returns
-------
input_: torch.Tensor
The tensor to copy
"""
return
input_
@
staticmethod
...
...
@@ -409,19 +365,7 @@ class _CopyToAzimuthRegion(torch.autograd.Function):
@
staticmethod
@
custom_fwd
(
device_type
=
"cuda"
)
def
forward
(
ctx
,
input_
):
r
"""
Forward pass for copying to azimuth region.
Parameters
----------
input_: torch.Tensor
The tensor to copy
Returns
-------
input_: torch.Tensor
The tensor to copy
"""
return
input_
@
staticmethod
...
...
torch_harmonics/examples/losses.py
View file @
901e8635
...
...
@@ -115,21 +115,7 @@ class DiceLossS2(nn.Module):
self
.
register_buffer
(
"weight"
,
weight
.
unsqueeze
(
0
))
def
forward
(
self
,
prd
:
torch
.
Tensor
,
tar
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Forward pass of the Dice loss.
Parameters
-----------
prd : torch.Tensor
Prediction tensor with shape (batch, classes, nlat, nlon)
tar : torch.Tensor
Target tensor with shape (batch, nlat, nlon)
Returns
-------
torch.Tensor
Dice loss value
"""
prd
=
nn
.
functional
.
softmax
(
prd
,
dim
=
1
)
# mask values
...
...
@@ -205,21 +191,6 @@ class CrossEntropyLossS2(nn.Module):
self
.
register_buffer
(
"quad_weights"
,
q
)
def
forward
(
self
,
prd
:
torch
.
Tensor
,
tar
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Forward pass of the cross-entropy loss.
Parameters
-----------
prd : torch.Tensor
Prediction tensor with shape (batch, classes, nlat, nlon)
tar : torch.Tensor
Target tensor with shape (batch, nlat, nlon)
Returns
-------
torch.Tensor
Cross-entropy loss value
"""
# compute log softmax
logits
=
nn
.
functional
.
log_softmax
(
prd
,
dim
=
1
)
...
...
@@ -266,25 +237,6 @@ class FocalLossS2(nn.Module):
self
.
register_buffer
(
"quad_weights"
,
q
)
def
forward
(
self
,
prd
:
torch
.
Tensor
,
tar
:
torch
.
Tensor
,
alpha
:
float
=
0.25
,
gamma
:
float
=
2
):
"""
Forward pass of the focal loss.
Parameters
-----------
prd : torch.Tensor
Prediction tensor with shape (batch, classes, nlat, nlon)
tar : torch.Tensor
Target tensor with shape (batch, nlat, nlon)
alpha : float, optional
Alpha parameter for focal loss, by default 0.25
gamma : float, optional
Gamma parameter for focal loss, by default 2
Returns
-------
torch.Tensor
Focal loss value
"""
# compute logits
logits
=
nn
.
functional
.
log_softmax
(
prd
,
dim
=
1
)
...
...
@@ -339,16 +291,7 @@ class SphericalLossBase(nn.Module, ABC):
return
loss
def
forward
(
self
,
prd
:
torch
.
Tensor
,
tar
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""Common forward pass that handles masking and reduction.
Args:
prd (torch.Tensor): Prediction tensor
tar (torch.Tensor): Target tensor
mask (Optional[torch.Tensor], optional): Mask tensor. Defaults to None.
Returns:
torch.Tensor: Final loss value
"""
loss_term
=
self
.
_compute_loss_term
(
prd
,
tar
)
# Integrate over the sphere for each item in the batch
loss
=
self
.
_integrate_sphere
(
loss_term
,
mask
)
...
...
torch_harmonics/examples/metrics.py
View file @
901e8635
...
...
@@ -169,21 +169,7 @@ class BaseMetricS2(nn.Module):
self
.
register_buffer
(
"weight"
,
weight
.
unsqueeze
(
0
))
def
_forward
(
self
,
pred
:
torch
.
Tensor
,
truth
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Compute base statistics (TP, FP, FN, TN) for the given predictions and ground truth.
Parameters
-----------
pred : torch.Tensor
Predicted logits
truth : torch.Tensor
Ground truth labels
Returns
-------
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
Tuple containing (tp, fp, fn, tn) statistics
"""
# convert logits to class predictions
pred_class
=
_predict_classes
(
pred
)
...
...
@@ -240,21 +226,7 @@ class IntersectionOverUnionS2(BaseMetricS2):
super
().
__init__
(
nlat
,
nlon
,
grid
,
weight
,
ignore_index
,
mode
)
def
forward
(
self
,
pred
:
torch
.
Tensor
,
truth
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Compute IoU score for the given predictions and ground truth.
Parameters
-----------
pred : torch.Tensor
Predicted logits
truth : torch.Tensor
Ground truth labels
Returns
-------
torch.Tensor
IoU score
"""
tp
,
fp
,
fn
,
tn
=
self
.
_forward
(
pred
,
truth
)
# compute score
...
...
@@ -300,21 +272,7 @@ class AccuracyS2(BaseMetricS2):
super
().
__init__
(
nlat
,
nlon
,
grid
,
weight
,
ignore_index
,
mode
)
def
forward
(
self
,
pred
:
torch
.
Tensor
,
truth
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Compute accuracy score for the given predictions and ground truth.
Parameters
-----------
pred : torch.Tensor
Predicted logits
truth : torch.Tensor
Ground truth labels
Returns
-------
torch.Tensor
Accuracy score
"""
tp
,
fp
,
fn
,
tn
=
self
.
_forward
(
pred
,
truth
)
# compute score
...
...
torch_harmonics/examples/models/_layers.py
View file @
901e8635
...
...
@@ -191,19 +191,7 @@ class DropPath(nn.Module):
self
.
drop_prob
=
drop_prob
def
forward
(
self
,
x
):
"""
Forward pass with drop path regularization.
Parameters
----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor with potential path dropping
"""
return
drop_path
(
x
,
self
.
drop_prob
,
self
.
training
)
...
...
@@ -238,19 +226,7 @@ class PatchEmbed(nn.Module):
self
.
proj
.
bias
.
is_shared_mp
=
[
"spatial"
]
def
forward
(
self
,
x
):
"""
Forward pass of patch embedding.
Parameters
----------
x : torch.Tensor
Input tensor of shape (batch_size, channels, height, width)
Returns
-------
torch.Tensor
Patch embeddings of shape (batch_size, embed_dim, num_patches)
"""
# gather input
B
,
C
,
H
,
W
=
x
.
shape
assert
H
==
self
.
img_size
[
0
]
and
W
==
self
.
img_size
[
1
],
f
"Input image size (
{
H
}
*
{
W
}
) doesn't match model (
{
self
.
img_size
[
0
]
}
*
{
self
.
img_size
[
1
]
}
)."
...
...
@@ -319,35 +295,11 @@ class MLP(nn.Module):
@
torch
.
jit
.
ignore
def
checkpoint_forward
(
self
,
x
):
"""
Forward pass with gradient checkpointing.
Parameters
----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor
"""
return
checkpoint
(
self
.
fwd
,
x
)
def
forward
(
self
,
x
):
"""
Forward pass of the MLP.
Parameters
----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor
"""
if
self
.
checkpointing
:
return
self
.
checkpoint_forward
(
x
)
else
:
...
...
@@ -382,19 +334,7 @@ class RealFFT2(nn.Module):
self
.
mmax
=
mmax
or
self
.
nlon
//
2
+
1
def
forward
(
self
,
x
):
"""
Forward pass: compute real FFT2D.
Parameters
----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
FFT coefficients
"""
y
=
torch
.
fft
.
rfft2
(
x
,
dim
=
(
-
2
,
-
1
),
norm
=
"ortho"
)
y
=
torch
.
cat
((
y
[...,
:
math
.
ceil
(
self
.
lmax
/
2
),
:
self
.
mmax
],
y
[...,
-
math
.
floor
(
self
.
lmax
/
2
)
:,
:
self
.
mmax
]),
dim
=-
2
)
return
y
...
...
@@ -428,19 +368,7 @@ class InverseRealFFT2(nn.Module):
self
.
mmax
=
mmax
or
self
.
nlon
//
2
+
1
def
forward
(
self
,
x
):
"""
Forward pass: compute inverse real FFT2D.
Parameters
----------
x : torch.Tensor
Input FFT coefficients
Returns
-------
torch.Tensor
Reconstructed spatial signal
"""
return
torch
.
fft
.
irfft2
(
x
,
dim
=
(
-
2
,
-
1
),
s
=
(
self
.
nlat
,
self
.
nlon
),
norm
=
"ortho"
)
...
...
@@ -476,19 +404,7 @@ class LayerNorm(nn.Module):
self
.
norm
=
nn
.
LayerNorm
(
normalized_shape
=
in_channels
,
eps
=
1e-6
,
elementwise_affine
=
elementwise_affine
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
)
def
forward
(
self
,
x
):
"""
Forward pass with channel dimension handling.
Parameters
----------
x : torch.Tensor
Input tensor with channel dimension at -3
Returns
-------
torch.Tensor
Normalized tensor with same shape as input
"""
return
self
.
norm
(
x
.
transpose
(
self
.
channel_dim
,
-
1
)).
transpose
(
-
1
,
self
.
channel_dim
)
...
...
@@ -556,19 +472,7 @@ class SpectralConvS2(nn.Module):
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
out_channels
,
1
,
1
))
def
forward
(
self
,
x
):
"""
Forward pass of spectral convolution.
Parameters
----------
x : torch.Tensor
Input tensor
Returns
-------
tuple
Tuple containing (output, residual) tensors
"""
dtype
=
x
.
dtype
x
=
x
.
float
()
residual
=
x
...
...
@@ -614,19 +518,7 @@ class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta):
self
.
num_chans
=
num_chans
def
forward
(
self
,
x
:
torch
.
Tensor
):
"""
Forward pass: add position embeddings to input.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Input tensor with position embeddings added
"""
return
x
+
self
.
position_embeddings
...
...
torch_harmonics/examples/models/lsno.py
View file @
901e8635
...
...
@@ -110,19 +110,7 @@ class DiscreteContinuousEncoder(nn.Module):
)
def
forward
(
self
,
x
):
"""
Forward pass of the discrete-continuous encoder.
Parameters
----------
x : torch.Tensor
Input tensor with shape (batch, channels, nlat, nlon)
Returns
-------
torch.Tensor
Encoded tensor with reduced spatial resolution
"""
dtype
=
x
.
dtype
with
amp
.
autocast
(
device_type
=
"cuda"
,
enabled
=
False
):
...
...
@@ -205,19 +193,7 @@ class DiscreteContinuousDecoder(nn.Module):
)
def
forward
(
self
,
x
):
"""
Forward pass of the discrete-continuous decoder.
Parameters
----------
x : torch.Tensor
Input tensor with shape (batch, channels, nlat, nlon)
Returns
-------
torch.Tensor
Decoded tensor with restored spatial resolution
"""
dtype
=
x
.
dtype
with
amp
.
autocast
(
device_type
=
"cuda"
,
enabled
=
False
):
...
...
@@ -628,19 +604,7 @@ class LocalSphericalNeuralOperator(nn.Module):
return
x
def
forward
(
self
,
x
):
"""
Forward pass through the complete LSNO model.
Parameters
----------
x : torch.Tensor
Input tensor of shape (batch_size, in_chans, height, width)
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, out_chans, height, width)
"""
if
self
.
residual_prediction
:
residual
=
x
...
...
torch_harmonics/examples/models/s2segformer.py
View file @
901e8635
...
...
@@ -130,19 +130,7 @@ class OverlapPatchMerging(nn.Module):
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
forward
(
self
,
x
):
"""
Forward pass of the overlap patch merging module.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Merged patches with layer normalization
"""
dtype
=
x
.
dtype
with
amp
.
autocast
(
device_type
=
"cuda"
,
enabled
=
False
):
...
...
@@ -259,19 +247,6 @@ class MixFFN(nn.Module):
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Forward pass of the Mix FFN module.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after Mix FFN processing
"""
residual
=
x
# norm
...
...
@@ -382,19 +357,7 @@ class AttentionWrapper(nn.Module):
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Forward pass of the attention wrapper.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after attention processing
"""
residual
=
x
if
self
.
norm
is
not
None
:
x
=
x
.
permute
(
0
,
2
,
3
,
1
)
...
...
@@ -548,19 +511,6 @@ class TransformerBlock(nn.Module):
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Forward pass of the transformer block.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after transformer block processing
"""
x
=
self
.
fwd
(
x
)
# apply norm
...
...
@@ -664,19 +614,7 @@ class Upsampling(nn.Module):
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Forward pass of the upsampling module.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Upsampled tensor
"""
x
=
self
.
upsample
(
self
.
mlp
(
x
))
return
x
...
...
@@ -871,19 +809,7 @@ class SphericalSegformer(nn.Module):
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
forward
(
self
,
x
):
"""
Forward pass through the complete spherical segformer model.
Parameters
-----------
x : torch.Tensor
Input tensor of shape (batch_size, in_chans, height, width)
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, out_chans, height, width)
"""
# encoder:
features
=
[]
feat
=
x
...
...
torch_harmonics/examples/models/s2transformer.py
View file @
901e8635
...
...
@@ -113,19 +113,7 @@ class DiscreteContinuousEncoder(nn.Module):
)
def
forward
(
self
,
x
):
"""
Forward pass of the discrete-continuous encoder.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (batch, channels, nlat, nlon)
Returns
-------
torch.Tensor
Encoded tensor with reduced spatial resolution
"""
dtype
=
x
.
dtype
with
amp
.
autocast
(
device_type
=
"cuda"
,
enabled
=
False
):
...
...
@@ -209,19 +197,7 @@ class DiscreteContinuousDecoder(nn.Module):
)
def
forward
(
self
,
x
):
"""
Forward pass of the discrete-continuous decoder.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (batch, channels, nlat, nlon)
Returns
-------
torch.Tensor
Decoded tensor with restored spatial resolution
"""
dtype
=
x
.
dtype
with
amp
.
autocast
(
device_type
=
"cuda"
,
enabled
=
False
):
...
...
@@ -593,19 +569,7 @@ class SphericalTransformer(nn.Module):
return
x
def
forward
(
self
,
x
):
"""
Forward pass through the complete spherical transformer model.
Parameters
-----------
x : torch.Tensor
Input tensor of shape (batch_size, in_chans, height, width)
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, out_chans, height, width)
"""
if
self
.
residual_prediction
:
residual
=
x
...
...
torch_harmonics/examples/models/s2unet.py
View file @
901e8635
...
...
@@ -208,19 +208,7 @@ class DownsamplingBlock(nn.Module):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Forward pass of the downsampling block.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Downsampled tensor
"""
# skip connection
residual
=
x
if
hasattr
(
self
,
"transform_skip"
):
...
...
@@ -614,19 +602,7 @@ class SphericalUNet(nn.Module):
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
forward
(
self
,
x
):
"""
Forward pass through the complete spherical U-Net model.
Parameters
-----------
x : torch.Tensor
Input tensor of shape (batch_size, in_chans, height, width)
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, out_chans, height, width)
"""
# encoder:
features
=
[]
feat
=
x
...
...
torch_harmonics/examples/models/sfno.py
View file @
901e8635
...
...
@@ -153,19 +153,7 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
raise
ValueError
(
f
"Unknown skip connection type
{
outer_skip
}
"
)
def
forward
(
self
,
x
):
"""
Forward pass through the SFNO block.
Parameters
----------
x : torch.Tensor
Input tensor
Returns
----------
torch.Tensor
Output tensor after processing through the block
"""
x
,
residual
=
self
.
global_conv
(
x
)
x
=
self
.
norm
(
x
)
...
...
@@ -415,19 +403,7 @@ class SphericalFourierNeuralOperator(nn.Module):
return
{
"pos_embed.pos_embed"
}
def
forward_features
(
self
,
x
):
"""
Forward pass through the feature extraction layers.
Parameters
----------
x : torch.Tensor
Input tensor
Returns
----------
torch.Tensor
Features after processing through the network
"""
x
=
self
.
pos_drop
(
x
)
for
blk
in
self
.
blocks
:
...
...
@@ -436,19 +412,7 @@ class SphericalFourierNeuralOperator(nn.Module):
return
x
def
forward
(
self
,
x
):
"""
Forward pass through the complete SFNO model.
Parameters
----------
x : torch.Tensor
Input tensor of shape (batch_size, in_chans, height, width)
Returns
----------
torch.Tensor
Output tensor of shape (batch_size, out_chans, height, width)
"""
if
self
.
residual_prediction
:
residual
=
x
...
...
torch_harmonics/random_fields.py
View file @
901e8635
...
...
@@ -94,24 +94,7 @@ class GaussianRandomFieldS2(torch.nn.Module):
self
.
gaussian_noise
=
torch
.
distributions
.
normal
.
Normal
(
self
.
mean
,
self
.
var
)
def
forward
(
self
,
N
,
xi
=
None
):
r
"""
Sample random functions from a spherical GRF.
Parameters
----------
N : int
Number of functions to sample.
xi : torch.Tensor, default is None
Noise is a complex tensor of size (N, nlat, nlat+1).
If None, new Gaussian noise is sampled.
If xi is provided, N is ignored.
Output
-------
u : torch.Tensor
N random samples from the GRF returned as a
tensor of size (N, nlat, 2*nlat) on a equiangular grid.
"""
#Sample Gaussian noise.
if
xi
is
None
:
xi
=
self
.
gaussian_noise
.
sample
(
torch
.
Size
((
N
,
self
.
nlat
,
self
.
nlat
+
1
,
2
))).
squeeze
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment