Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
901e8635
Commit
901e8635
authored
Jul 16, 2025
by
Andrea Paris
Committed by
Boris Bonev
Jul 21, 2025
Browse files
removed docstrings from forward passes
parent
6373534a
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
32 additions
and
568 deletions
+32
-568
examples/baseline_models/segformer.py
examples/baseline_models/segformer.py
+1
-26
torch_harmonics/convolution.py
torch_harmonics/convolution.py
+1
-13
torch_harmonics/distributed/distributed_resample.py
torch_harmonics/distributed/distributed_resample.py
+0
-13
torch_harmonics/distributed/primitives.py
torch_harmonics/distributed/primitives.py
+1
-57
torch_harmonics/examples/losses.py
torch_harmonics/examples/losses.py
+1
-58
torch_harmonics/examples/metrics.py
torch_harmonics/examples/metrics.py
+3
-45
torch_harmonics/examples/models/_layers.py
torch_harmonics/examples/models/_layers.py
+9
-117
torch_harmonics/examples/models/lsno.py
torch_harmonics/examples/models/lsno.py
+3
-39
torch_harmonics/examples/models/s2segformer.py
torch_harmonics/examples/models/s2segformer.py
+4
-78
torch_harmonics/examples/models/s2transformer.py
torch_harmonics/examples/models/s2transformer.py
+3
-39
torch_harmonics/examples/models/s2unet.py
torch_harmonics/examples/models/s2unet.py
+2
-26
torch_harmonics/examples/models/sfno.py
torch_harmonics/examples/models/sfno.py
+3
-39
torch_harmonics/random_fields.py
torch_harmonics/random_fields.py
+1
-18
No files found.
examples/baseline_models/segformer.py
View file @
901e8635
...
@@ -103,19 +103,6 @@ class OverlapPatchMerging(nn.Module):
...
@@ -103,19 +103,6 @@ class OverlapPatchMerging(nn.Module):
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
"""
Forward pass through the OverlapPatchMerging layer.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after patch merging
"""
x
=
self
.
conv
(
x
)
x
=
self
.
conv
(
x
)
# permute
# permute
...
@@ -204,19 +191,7 @@ class MixFFN(nn.Module):
...
@@ -204,19 +191,7 @@ class MixFFN(nn.Module):
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Forward pass through the MixFFN module.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after processing
"""
residual
=
x
residual
=
x
# norm
# norm
...
...
torch_harmonics/convolution.py
View file @
901e8635
...
@@ -693,19 +693,7 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
...
@@ -693,19 +693,7 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
return
psi
return
psi
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Forward pass
Parameters
-----------
x: torch.Tensor
Input tensor
Returns
-------
out: torch.Tensor
Output tensor
"""
# extract shape
# extract shape
B
,
C
,
H
,
W
=
x
.
shape
B
,
C
,
H
,
W
=
x
.
shape
x
=
x
.
reshape
(
B
,
self
.
groups
,
self
.
groupsize
,
H
,
W
)
x
=
x
.
reshape
(
B
,
self
.
groups
,
self
.
groupsize
,
H
,
W
)
...
...
torch_harmonics/distributed/distributed_resample.py
View file @
901e8635
...
@@ -248,19 +248,6 @@ class DistributedResampleS2(nn.Module):
...
@@ -248,19 +248,6 @@ class DistributedResampleS2(nn.Module):
return
x
return
x
def
forward
(
self
,
x
:
torch
.
Tensor
):
def
forward
(
self
,
x
:
torch
.
Tensor
):
"""
Forward pass for distributed resampling.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (batch, channels, nlat, nlon)
Returns
-------
torch.Tensor
Resampled tensor with shape (batch, channels, nlat_out, nlon_out)
"""
if
self
.
skip_resampling
:
if
self
.
skip_resampling
:
return
x
return
x
...
...
torch_harmonics/distributed/primitives.py
View file @
901e8635
...
@@ -150,23 +150,7 @@ class distributed_transpose_azimuth(torch.autograd.Function):
...
@@ -150,23 +150,7 @@ class distributed_transpose_azimuth(torch.autograd.Function):
@
staticmethod
@
staticmethod
@
custom_fwd
(
device_type
=
"cuda"
)
@
custom_fwd
(
device_type
=
"cuda"
)
def
forward
(
ctx
,
x
,
dims
,
dim1_split_sizes
):
def
forward
(
ctx
,
x
,
dims
,
dim1_split_sizes
):
r
"""
Forward pass for distributed azimuthal transpose.
Parameters
----------
x: torch.Tensor
The tensor to transpose
dims: List[int]
The dimensions to transpose
dim1_split_sizes: List[int]
The split sizes for the second dimension
Returns
-------
x: torch.Tensor
The transposed tensor
"""
# WAR for a potential contig check torch bug for channels last contig tensors
# WAR for a potential contig check torch bug for channels last contig tensors
xlist
,
dim0_split_sizes
,
_
=
_transpose
(
x
,
dims
[
0
],
dims
[
1
],
dim1_split_sizes
,
group
=
azimuth_group
())
xlist
,
dim0_split_sizes
,
_
=
_transpose
(
x
,
dims
[
0
],
dims
[
1
],
dim1_split_sizes
,
group
=
azimuth_group
())
x
=
torch
.
cat
(
xlist
,
dim
=
dims
[
1
])
x
=
torch
.
cat
(
xlist
,
dim
=
dims
[
1
])
...
@@ -205,23 +189,7 @@ class distributed_transpose_polar(torch.autograd.Function):
...
@@ -205,23 +189,7 @@ class distributed_transpose_polar(torch.autograd.Function):
@
staticmethod
@
staticmethod
@
custom_fwd
(
device_type
=
"cuda"
)
@
custom_fwd
(
device_type
=
"cuda"
)
def
forward
(
ctx
,
x
,
dim
,
dim1_split_sizes
):
def
forward
(
ctx
,
x
,
dim
,
dim1_split_sizes
):
r
"""
Forward pass for distributed polar transpose.
Parameters
----------
x: torch.Tensor
The tensor to transpose
dim: List[int]
The dimensions to transpose
dim1_split_sizes: List[int]
The split sizes for the second dimension
Returns
-------
x: torch.Tensor
The transposed tensor
"""
# WAR for a potential contig check torch bug for channels last contig tensors
# WAR for a potential contig check torch bug for channels last contig tensors
xlist
,
dim0_split_sizes
,
_
=
_transpose
(
x
,
dim
[
0
],
dim
[
1
],
dim1_split_sizes
,
group
=
polar_group
())
xlist
,
dim0_split_sizes
,
_
=
_transpose
(
x
,
dim
[
0
],
dim
[
1
],
dim1_split_sizes
,
group
=
polar_group
())
x
=
torch
.
cat
(
xlist
,
dim
=
dim
[
1
])
x
=
torch
.
cat
(
xlist
,
dim
=
dim
[
1
])
...
@@ -363,19 +331,7 @@ class _CopyToPolarRegion(torch.autograd.Function):
...
@@ -363,19 +331,7 @@ class _CopyToPolarRegion(torch.autograd.Function):
@
staticmethod
@
staticmethod
@
custom_fwd
(
device_type
=
"cuda"
)
@
custom_fwd
(
device_type
=
"cuda"
)
def
forward
(
ctx
,
input_
):
def
forward
(
ctx
,
input_
):
r
"""
Forward pass for copying to polar region.
Parameters
----------
input_: torch.Tensor
The tensor to copy
Returns
-------
input_: torch.Tensor
The tensor to copy
"""
return
input_
return
input_
@
staticmethod
@
staticmethod
...
@@ -409,19 +365,7 @@ class _CopyToAzimuthRegion(torch.autograd.Function):
...
@@ -409,19 +365,7 @@ class _CopyToAzimuthRegion(torch.autograd.Function):
@
staticmethod
@
staticmethod
@
custom_fwd
(
device_type
=
"cuda"
)
@
custom_fwd
(
device_type
=
"cuda"
)
def
forward
(
ctx
,
input_
):
def
forward
(
ctx
,
input_
):
r
"""
Forward pass for copying to azimuth region.
Parameters
----------
input_: torch.Tensor
The tensor to copy
Returns
-------
input_: torch.Tensor
The tensor to copy
"""
return
input_
return
input_
@
staticmethod
@
staticmethod
...
...
torch_harmonics/examples/losses.py
View file @
901e8635
...
@@ -115,21 +115,7 @@ class DiceLossS2(nn.Module):
...
@@ -115,21 +115,7 @@ class DiceLossS2(nn.Module):
self
.
register_buffer
(
"weight"
,
weight
.
unsqueeze
(
0
))
self
.
register_buffer
(
"weight"
,
weight
.
unsqueeze
(
0
))
def
forward
(
self
,
prd
:
torch
.
Tensor
,
tar
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
prd
:
torch
.
Tensor
,
tar
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Forward pass of the Dice loss.
Parameters
-----------
prd : torch.Tensor
Prediction tensor with shape (batch, classes, nlat, nlon)
tar : torch.Tensor
Target tensor with shape (batch, nlat, nlon)
Returns
-------
torch.Tensor
Dice loss value
"""
prd
=
nn
.
functional
.
softmax
(
prd
,
dim
=
1
)
prd
=
nn
.
functional
.
softmax
(
prd
,
dim
=
1
)
# mask values
# mask values
...
@@ -205,21 +191,6 @@ class CrossEntropyLossS2(nn.Module):
...
@@ -205,21 +191,6 @@ class CrossEntropyLossS2(nn.Module):
self
.
register_buffer
(
"quad_weights"
,
q
)
self
.
register_buffer
(
"quad_weights"
,
q
)
def
forward
(
self
,
prd
:
torch
.
Tensor
,
tar
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
prd
:
torch
.
Tensor
,
tar
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Forward pass of the cross-entropy loss.
Parameters
-----------
prd : torch.Tensor
Prediction tensor with shape (batch, classes, nlat, nlon)
tar : torch.Tensor
Target tensor with shape (batch, nlat, nlon)
Returns
-------
torch.Tensor
Cross-entropy loss value
"""
# compute log softmax
# compute log softmax
logits
=
nn
.
functional
.
log_softmax
(
prd
,
dim
=
1
)
logits
=
nn
.
functional
.
log_softmax
(
prd
,
dim
=
1
)
...
@@ -266,25 +237,6 @@ class FocalLossS2(nn.Module):
...
@@ -266,25 +237,6 @@ class FocalLossS2(nn.Module):
self
.
register_buffer
(
"quad_weights"
,
q
)
self
.
register_buffer
(
"quad_weights"
,
q
)
def
forward
(
self
,
prd
:
torch
.
Tensor
,
tar
:
torch
.
Tensor
,
alpha
:
float
=
0.25
,
gamma
:
float
=
2
):
def
forward
(
self
,
prd
:
torch
.
Tensor
,
tar
:
torch
.
Tensor
,
alpha
:
float
=
0.25
,
gamma
:
float
=
2
):
"""
Forward pass of the focal loss.
Parameters
-----------
prd : torch.Tensor
Prediction tensor with shape (batch, classes, nlat, nlon)
tar : torch.Tensor
Target tensor with shape (batch, nlat, nlon)
alpha : float, optional
Alpha parameter for focal loss, by default 0.25
gamma : float, optional
Gamma parameter for focal loss, by default 2
Returns
-------
torch.Tensor
Focal loss value
"""
# compute logits
# compute logits
logits
=
nn
.
functional
.
log_softmax
(
prd
,
dim
=
1
)
logits
=
nn
.
functional
.
log_softmax
(
prd
,
dim
=
1
)
...
@@ -339,16 +291,7 @@ class SphericalLossBase(nn.Module, ABC):
...
@@ -339,16 +291,7 @@ class SphericalLossBase(nn.Module, ABC):
return
loss
return
loss
def
forward
(
self
,
prd
:
torch
.
Tensor
,
tar
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
def
forward
(
self
,
prd
:
torch
.
Tensor
,
tar
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""Common forward pass that handles masking and reduction.
Args:
prd (torch.Tensor): Prediction tensor
tar (torch.Tensor): Target tensor
mask (Optional[torch.Tensor], optional): Mask tensor. Defaults to None.
Returns:
torch.Tensor: Final loss value
"""
loss_term
=
self
.
_compute_loss_term
(
prd
,
tar
)
loss_term
=
self
.
_compute_loss_term
(
prd
,
tar
)
# Integrate over the sphere for each item in the batch
# Integrate over the sphere for each item in the batch
loss
=
self
.
_integrate_sphere
(
loss_term
,
mask
)
loss
=
self
.
_integrate_sphere
(
loss_term
,
mask
)
...
...
torch_harmonics/examples/metrics.py
View file @
901e8635
...
@@ -169,21 +169,7 @@ class BaseMetricS2(nn.Module):
...
@@ -169,21 +169,7 @@ class BaseMetricS2(nn.Module):
self
.
register_buffer
(
"weight"
,
weight
.
unsqueeze
(
0
))
self
.
register_buffer
(
"weight"
,
weight
.
unsqueeze
(
0
))
def
_forward
(
self
,
pred
:
torch
.
Tensor
,
truth
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
def
_forward
(
self
,
pred
:
torch
.
Tensor
,
truth
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Compute base statistics (TP, FP, FN, TN) for the given predictions and ground truth.
Parameters
-----------
pred : torch.Tensor
Predicted logits
truth : torch.Tensor
Ground truth labels
Returns
-------
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
Tuple containing (tp, fp, fn, tn) statistics
"""
# convert logits to class predictions
# convert logits to class predictions
pred_class
=
_predict_classes
(
pred
)
pred_class
=
_predict_classes
(
pred
)
...
@@ -240,21 +226,7 @@ class IntersectionOverUnionS2(BaseMetricS2):
...
@@ -240,21 +226,7 @@ class IntersectionOverUnionS2(BaseMetricS2):
super
().
__init__
(
nlat
,
nlon
,
grid
,
weight
,
ignore_index
,
mode
)
super
().
__init__
(
nlat
,
nlon
,
grid
,
weight
,
ignore_index
,
mode
)
def
forward
(
self
,
pred
:
torch
.
Tensor
,
truth
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
pred
:
torch
.
Tensor
,
truth
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Compute IoU score for the given predictions and ground truth.
Parameters
-----------
pred : torch.Tensor
Predicted logits
truth : torch.Tensor
Ground truth labels
Returns
-------
torch.Tensor
IoU score
"""
tp
,
fp
,
fn
,
tn
=
self
.
_forward
(
pred
,
truth
)
tp
,
fp
,
fn
,
tn
=
self
.
_forward
(
pred
,
truth
)
# compute score
# compute score
...
@@ -300,21 +272,7 @@ class AccuracyS2(BaseMetricS2):
...
@@ -300,21 +272,7 @@ class AccuracyS2(BaseMetricS2):
super
().
__init__
(
nlat
,
nlon
,
grid
,
weight
,
ignore_index
,
mode
)
super
().
__init__
(
nlat
,
nlon
,
grid
,
weight
,
ignore_index
,
mode
)
def
forward
(
self
,
pred
:
torch
.
Tensor
,
truth
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
pred
:
torch
.
Tensor
,
truth
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Compute accuracy score for the given predictions and ground truth.
Parameters
-----------
pred : torch.Tensor
Predicted logits
truth : torch.Tensor
Ground truth labels
Returns
-------
torch.Tensor
Accuracy score
"""
tp
,
fp
,
fn
,
tn
=
self
.
_forward
(
pred
,
truth
)
tp
,
fp
,
fn
,
tn
=
self
.
_forward
(
pred
,
truth
)
# compute score
# compute score
...
...
torch_harmonics/examples/models/_layers.py
View file @
901e8635
...
@@ -191,19 +191,7 @@ class DropPath(nn.Module):
...
@@ -191,19 +191,7 @@ class DropPath(nn.Module):
self
.
drop_prob
=
drop_prob
self
.
drop_prob
=
drop_prob
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
"""
Forward pass with drop path regularization.
Parameters
----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor with potential path dropping
"""
return
drop_path
(
x
,
self
.
drop_prob
,
self
.
training
)
return
drop_path
(
x
,
self
.
drop_prob
,
self
.
training
)
...
@@ -238,19 +226,7 @@ class PatchEmbed(nn.Module):
...
@@ -238,19 +226,7 @@ class PatchEmbed(nn.Module):
self
.
proj
.
bias
.
is_shared_mp
=
[
"spatial"
]
self
.
proj
.
bias
.
is_shared_mp
=
[
"spatial"
]
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
"""
Forward pass of patch embedding.
Parameters
----------
x : torch.Tensor
Input tensor of shape (batch_size, channels, height, width)
Returns
-------
torch.Tensor
Patch embeddings of shape (batch_size, embed_dim, num_patches)
"""
# gather input
# gather input
B
,
C
,
H
,
W
=
x
.
shape
B
,
C
,
H
,
W
=
x
.
shape
assert
H
==
self
.
img_size
[
0
]
and
W
==
self
.
img_size
[
1
],
f
"Input image size (
{
H
}
*
{
W
}
) doesn't match model (
{
self
.
img_size
[
0
]
}
*
{
self
.
img_size
[
1
]
}
)."
assert
H
==
self
.
img_size
[
0
]
and
W
==
self
.
img_size
[
1
],
f
"Input image size (
{
H
}
*
{
W
}
) doesn't match model (
{
self
.
img_size
[
0
]
}
*
{
self
.
img_size
[
1
]
}
)."
...
@@ -319,35 +295,11 @@ class MLP(nn.Module):
...
@@ -319,35 +295,11 @@ class MLP(nn.Module):
@
torch
.
jit
.
ignore
@
torch
.
jit
.
ignore
def
checkpoint_forward
(
self
,
x
):
def
checkpoint_forward
(
self
,
x
):
"""
Forward pass with gradient checkpointing.
Parameters
----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor
"""
return
checkpoint
(
self
.
fwd
,
x
)
return
checkpoint
(
self
.
fwd
,
x
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
"""
Forward pass of the MLP.
Parameters
----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor
"""
if
self
.
checkpointing
:
if
self
.
checkpointing
:
return
self
.
checkpoint_forward
(
x
)
return
self
.
checkpoint_forward
(
x
)
else
:
else
:
...
@@ -382,19 +334,7 @@ class RealFFT2(nn.Module):
...
@@ -382,19 +334,7 @@ class RealFFT2(nn.Module):
self
.
mmax
=
mmax
or
self
.
nlon
//
2
+
1
self
.
mmax
=
mmax
or
self
.
nlon
//
2
+
1
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
"""
Forward pass: compute real FFT2D.
Parameters
----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
FFT coefficients
"""
y
=
torch
.
fft
.
rfft2
(
x
,
dim
=
(
-
2
,
-
1
),
norm
=
"ortho"
)
y
=
torch
.
fft
.
rfft2
(
x
,
dim
=
(
-
2
,
-
1
),
norm
=
"ortho"
)
y
=
torch
.
cat
((
y
[...,
:
math
.
ceil
(
self
.
lmax
/
2
),
:
self
.
mmax
],
y
[...,
-
math
.
floor
(
self
.
lmax
/
2
)
:,
:
self
.
mmax
]),
dim
=-
2
)
y
=
torch
.
cat
((
y
[...,
:
math
.
ceil
(
self
.
lmax
/
2
),
:
self
.
mmax
],
y
[...,
-
math
.
floor
(
self
.
lmax
/
2
)
:,
:
self
.
mmax
]),
dim
=-
2
)
return
y
return
y
...
@@ -428,19 +368,7 @@ class InverseRealFFT2(nn.Module):
...
@@ -428,19 +368,7 @@ class InverseRealFFT2(nn.Module):
self
.
mmax
=
mmax
or
self
.
nlon
//
2
+
1
self
.
mmax
=
mmax
or
self
.
nlon
//
2
+
1
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
"""
Forward pass: compute inverse real FFT2D.
Parameters
----------
x : torch.Tensor
Input FFT coefficients
Returns
-------
torch.Tensor
Reconstructed spatial signal
"""
return
torch
.
fft
.
irfft2
(
x
,
dim
=
(
-
2
,
-
1
),
s
=
(
self
.
nlat
,
self
.
nlon
),
norm
=
"ortho"
)
return
torch
.
fft
.
irfft2
(
x
,
dim
=
(
-
2
,
-
1
),
s
=
(
self
.
nlat
,
self
.
nlon
),
norm
=
"ortho"
)
...
@@ -476,19 +404,7 @@ class LayerNorm(nn.Module):
...
@@ -476,19 +404,7 @@ class LayerNorm(nn.Module):
self
.
norm
=
nn
.
LayerNorm
(
normalized_shape
=
in_channels
,
eps
=
1e-6
,
elementwise_affine
=
elementwise_affine
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
)
self
.
norm
=
nn
.
LayerNorm
(
normalized_shape
=
in_channels
,
eps
=
1e-6
,
elementwise_affine
=
elementwise_affine
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
"""
Forward pass with channel dimension handling.
Parameters
----------
x : torch.Tensor
Input tensor with channel dimension at -3
Returns
-------
torch.Tensor
Normalized tensor with same shape as input
"""
return
self
.
norm
(
x
.
transpose
(
self
.
channel_dim
,
-
1
)).
transpose
(
-
1
,
self
.
channel_dim
)
return
self
.
norm
(
x
.
transpose
(
self
.
channel_dim
,
-
1
)).
transpose
(
-
1
,
self
.
channel_dim
)
...
@@ -556,19 +472,7 @@ class SpectralConvS2(nn.Module):
...
@@ -556,19 +472,7 @@ class SpectralConvS2(nn.Module):
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
out_channels
,
1
,
1
))
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
out_channels
,
1
,
1
))
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
"""
Forward pass of spectral convolution.
Parameters
----------
x : torch.Tensor
Input tensor
Returns
-------
tuple
Tuple containing (output, residual) tensors
"""
dtype
=
x
.
dtype
dtype
=
x
.
dtype
x
=
x
.
float
()
x
=
x
.
float
()
residual
=
x
residual
=
x
...
@@ -614,19 +518,7 @@ class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta):
...
@@ -614,19 +518,7 @@ class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta):
self
.
num_chans
=
num_chans
self
.
num_chans
=
num_chans
def
forward
(
self
,
x
:
torch
.
Tensor
):
def
forward
(
self
,
x
:
torch
.
Tensor
):
"""
Forward pass: add position embeddings to input.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Input tensor with position embeddings added
"""
return
x
+
self
.
position_embeddings
return
x
+
self
.
position_embeddings
...
...
torch_harmonics/examples/models/lsno.py
View file @
901e8635
...
@@ -110,19 +110,7 @@ class DiscreteContinuousEncoder(nn.Module):
...
@@ -110,19 +110,7 @@ class DiscreteContinuousEncoder(nn.Module):
)
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
"""
Forward pass of the discrete-continuous encoder.
Parameters
----------
x : torch.Tensor
Input tensor with shape (batch, channels, nlat, nlon)
Returns
-------
torch.Tensor
Encoded tensor with reduced spatial resolution
"""
dtype
=
x
.
dtype
dtype
=
x
.
dtype
with
amp
.
autocast
(
device_type
=
"cuda"
,
enabled
=
False
):
with
amp
.
autocast
(
device_type
=
"cuda"
,
enabled
=
False
):
...
@@ -205,19 +193,7 @@ class DiscreteContinuousDecoder(nn.Module):
...
@@ -205,19 +193,7 @@ class DiscreteContinuousDecoder(nn.Module):
)
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
"""
Forward pass of the discrete-continuous decoder.
Parameters
----------
x : torch.Tensor
Input tensor with shape (batch, channels, nlat, nlon)
Returns
-------
torch.Tensor
Decoded tensor with restored spatial resolution
"""
dtype
=
x
.
dtype
dtype
=
x
.
dtype
with
amp
.
autocast
(
device_type
=
"cuda"
,
enabled
=
False
):
with
amp
.
autocast
(
device_type
=
"cuda"
,
enabled
=
False
):
...
@@ -628,19 +604,7 @@ class LocalSphericalNeuralOperator(nn.Module):
...
@@ -628,19 +604,7 @@ class LocalSphericalNeuralOperator(nn.Module):
return
x
return
x
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
"""
Forward pass through the complete LSNO model.
Parameters
----------
x : torch.Tensor
Input tensor of shape (batch_size, in_chans, height, width)
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, out_chans, height, width)
"""
if
self
.
residual_prediction
:
if
self
.
residual_prediction
:
residual
=
x
residual
=
x
...
...
torch_harmonics/examples/models/s2segformer.py
View file @
901e8635
...
@@ -130,19 +130,7 @@ class OverlapPatchMerging(nn.Module):
...
@@ -130,19 +130,7 @@ class OverlapPatchMerging(nn.Module):
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
"""
Forward pass of the overlap patch merging module.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Merged patches with layer normalization
"""
dtype
=
x
.
dtype
dtype
=
x
.
dtype
with
amp
.
autocast
(
device_type
=
"cuda"
,
enabled
=
False
):
with
amp
.
autocast
(
device_type
=
"cuda"
,
enabled
=
False
):
...
@@ -259,19 +247,6 @@ class MixFFN(nn.Module):
...
@@ -259,19 +247,6 @@ class MixFFN(nn.Module):
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Forward pass of the Mix FFN module.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after Mix FFN processing
"""
residual
=
x
residual
=
x
# norm
# norm
...
@@ -382,19 +357,7 @@ class AttentionWrapper(nn.Module):
...
@@ -382,19 +357,7 @@ class AttentionWrapper(nn.Module):
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Forward pass of the attention wrapper.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after attention processing
"""
residual
=
x
residual
=
x
if
self
.
norm
is
not
None
:
if
self
.
norm
is
not
None
:
x
=
x
.
permute
(
0
,
2
,
3
,
1
)
x
=
x
.
permute
(
0
,
2
,
3
,
1
)
...
@@ -548,19 +511,6 @@ class TransformerBlock(nn.Module):
...
@@ -548,19 +511,6 @@ class TransformerBlock(nn.Module):
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Forward pass of the transformer block.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after transformer block processing
"""
x
=
self
.
fwd
(
x
)
x
=
self
.
fwd
(
x
)
# apply norm
# apply norm
...
@@ -664,19 +614,7 @@ class Upsampling(nn.Module):
...
@@ -664,19 +614,7 @@ class Upsampling(nn.Module):
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Forward pass of the upsampling module.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Upsampled tensor
"""
x
=
self
.
upsample
(
self
.
mlp
(
x
))
x
=
self
.
upsample
(
self
.
mlp
(
x
))
return
x
return
x
...
@@ -871,19 +809,7 @@ class SphericalSegformer(nn.Module):
...
@@ -871,19 +809,7 @@ class SphericalSegformer(nn.Module):
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
"""
Forward pass through the complete spherical segformer model.
Parameters
-----------
x : torch.Tensor
Input tensor of shape (batch_size, in_chans, height, width)
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, out_chans, height, width)
"""
# encoder:
# encoder:
features
=
[]
features
=
[]
feat
=
x
feat
=
x
...
...
torch_harmonics/examples/models/s2transformer.py
View file @
901e8635
...
@@ -113,19 +113,7 @@ class DiscreteContinuousEncoder(nn.Module):
...
@@ -113,19 +113,7 @@ class DiscreteContinuousEncoder(nn.Module):
)
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
"""
Forward pass of the discrete-continuous encoder.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (batch, channels, nlat, nlon)
Returns
-------
torch.Tensor
Encoded tensor with reduced spatial resolution
"""
dtype
=
x
.
dtype
dtype
=
x
.
dtype
with
amp
.
autocast
(
device_type
=
"cuda"
,
enabled
=
False
):
with
amp
.
autocast
(
device_type
=
"cuda"
,
enabled
=
False
):
...
@@ -209,19 +197,7 @@ class DiscreteContinuousDecoder(nn.Module):
...
@@ -209,19 +197,7 @@ class DiscreteContinuousDecoder(nn.Module):
)
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
"""
Forward pass of the discrete-continuous decoder.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (batch, channels, nlat, nlon)
Returns
-------
torch.Tensor
Decoded tensor with restored spatial resolution
"""
dtype
=
x
.
dtype
dtype
=
x
.
dtype
with
amp
.
autocast
(
device_type
=
"cuda"
,
enabled
=
False
):
with
amp
.
autocast
(
device_type
=
"cuda"
,
enabled
=
False
):
...
@@ -593,19 +569,7 @@ class SphericalTransformer(nn.Module):
...
@@ -593,19 +569,7 @@ class SphericalTransformer(nn.Module):
return
x
return
x
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
"""
Forward pass through the complete spherical transformer model.
Parameters
-----------
x : torch.Tensor
Input tensor of shape (batch_size, in_chans, height, width)
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, out_chans, height, width)
"""
if
self
.
residual_prediction
:
if
self
.
residual_prediction
:
residual
=
x
residual
=
x
...
...
torch_harmonics/examples/models/s2unet.py
View file @
901e8635
...
@@ -208,19 +208,7 @@ class DownsamplingBlock(nn.Module):
...
@@ -208,19 +208,7 @@ class DownsamplingBlock(nn.Module):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
bias
,
0
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Forward pass of the downsampling block.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Downsampled tensor
"""
# skip connection
# skip connection
residual
=
x
residual
=
x
if
hasattr
(
self
,
"transform_skip"
):
if
hasattr
(
self
,
"transform_skip"
):
...
@@ -614,19 +602,7 @@ class SphericalUNet(nn.Module):
...
@@ -614,19 +602,7 @@ class SphericalUNet(nn.Module):
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
"""
Forward pass through the complete spherical U-Net model.
Parameters
-----------
x : torch.Tensor
Input tensor of shape (batch_size, in_chans, height, width)
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, out_chans, height, width)
"""
# encoder:
# encoder:
features
=
[]
features
=
[]
feat
=
x
feat
=
x
...
...
torch_harmonics/examples/models/sfno.py
View file @
901e8635
...
@@ -153,19 +153,7 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
...
@@ -153,19 +153,7 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
raise
ValueError
(
f
"Unknown skip connection type
{
outer_skip
}
"
)
raise
ValueError
(
f
"Unknown skip connection type
{
outer_skip
}
"
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
"""
Forward pass through the SFNO block.
Parameters
----------
x : torch.Tensor
Input tensor
Returns
----------
torch.Tensor
Output tensor after processing through the block
"""
x
,
residual
=
self
.
global_conv
(
x
)
x
,
residual
=
self
.
global_conv
(
x
)
x
=
self
.
norm
(
x
)
x
=
self
.
norm
(
x
)
...
@@ -415,19 +403,7 @@ class SphericalFourierNeuralOperator(nn.Module):
...
@@ -415,19 +403,7 @@ class SphericalFourierNeuralOperator(nn.Module):
return
{
"pos_embed.pos_embed"
}
return
{
"pos_embed.pos_embed"
}
def
forward_features
(
self
,
x
):
def
forward_features
(
self
,
x
):
"""
Forward pass through the feature extraction layers.
Parameters
----------
x : torch.Tensor
Input tensor
Returns
----------
torch.Tensor
Features after processing through the network
"""
x
=
self
.
pos_drop
(
x
)
x
=
self
.
pos_drop
(
x
)
for
blk
in
self
.
blocks
:
for
blk
in
self
.
blocks
:
...
@@ -436,19 +412,7 @@ class SphericalFourierNeuralOperator(nn.Module):
...
@@ -436,19 +412,7 @@ class SphericalFourierNeuralOperator(nn.Module):
return
x
return
x
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
"""
Forward pass through the complete SFNO model.
Parameters
----------
x : torch.Tensor
Input tensor of shape (batch_size, in_chans, height, width)
Returns
----------
torch.Tensor
Output tensor of shape (batch_size, out_chans, height, width)
"""
if
self
.
residual_prediction
:
if
self
.
residual_prediction
:
residual
=
x
residual
=
x
...
...
torch_harmonics/random_fields.py
View file @
901e8635
...
@@ -94,24 +94,7 @@ class GaussianRandomFieldS2(torch.nn.Module):
...
@@ -94,24 +94,7 @@ class GaussianRandomFieldS2(torch.nn.Module):
self
.
gaussian_noise
=
torch
.
distributions
.
normal
.
Normal
(
self
.
mean
,
self
.
var
)
self
.
gaussian_noise
=
torch
.
distributions
.
normal
.
Normal
(
self
.
mean
,
self
.
var
)
def
forward
(
self
,
N
,
xi
=
None
):
def
forward
(
self
,
N
,
xi
=
None
):
r
"""
Sample random functions from a spherical GRF.
Parameters
----------
N : int
Number of functions to sample.
xi : torch.Tensor, default is None
Noise is a complex tensor of size (N, nlat, nlat+1).
If None, new Gaussian noise is sampled.
If xi is provided, N is ignored.
Output
-------
u : torch.Tensor
N random samples from the GRF returned as a
tensor of size (N, nlat, 2*nlat) on a equiangular grid.
"""
#Sample Gaussian noise.
#Sample Gaussian noise.
if
xi
is
None
:
if
xi
is
None
:
xi
=
self
.
gaussian_noise
.
sample
(
torch
.
Size
((
N
,
self
.
nlat
,
self
.
nlat
+
1
,
2
))).
squeeze
()
xi
=
self
.
gaussian_noise
.
sample
(
torch
.
Size
((
N
,
self
.
nlat
,
self
.
nlat
+
1
,
2
))).
squeeze
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment