Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
63b769fc
Commit
63b769fc
authored
Jul 21, 2025
by
Boris Bonev
Browse files
fixing losses
parent
f72a48dd
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
48 additions
and
92 deletions
+48
-92
torch_harmonics/examples/losses.py
torch_harmonics/examples/losses.py
+48
-92
No files found.
torch_harmonics/examples/losses.py
View file @
63b769fc
...
...
@@ -40,27 +40,6 @@ from torch_harmonics.quadrature import _precompute_latitudes
def
get_quadrature_weights
(
nlat
:
int
,
nlon
:
int
,
grid
:
str
,
tile
:
bool
=
False
,
normalized
:
bool
=
True
)
->
torch
.
Tensor
:
"""
Get quadrature weights for spherical integration.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
grid : str
Grid type ("equiangular", "legendre-gauss", "lobatto")
tile : bool, optional
Whether to tile weights across longitude dimension, by default False
normalized : bool, optional
Whether to normalize weights to sum to 1, by default True
Returns
-------
torch.Tensor
Quadrature weights tensor
"""
# area weights
_
,
q
=
_precompute_latitudes
(
nlat
=
nlat
,
grid
=
grid
)
q
=
q
.
reshape
(
-
1
,
1
)
*
2
*
torch
.
pi
/
nlon
...
...
@@ -78,7 +57,7 @@ def get_quadrature_weights(nlat: int, nlon: int, grid: str, tile: bool = False,
class
DiceLossS2
(
nn
.
Module
):
"""
Dice loss for spherical segmentation tasks.
Parameters
-----------
nlat : int
...
...
@@ -96,7 +75,7 @@ class DiceLossS2(nn.Module):
mode : str, optional
Aggregation mode ("micro" or "macro"), by default "micro"
"""
def
__init__
(
self
,
nlat
:
int
,
nlon
:
int
,
grid
:
str
=
"equiangular"
,
weight
:
torch
.
Tensor
=
None
,
smooth
:
float
=
0
,
ignore_index
:
int
=
-
100
,
mode
:
str
=
"micro"
):
super
().
__init__
()
...
...
@@ -115,7 +94,6 @@ class DiceLossS2(nn.Module):
self
.
register_buffer
(
"weight"
,
weight
.
unsqueeze
(
0
))
def
forward
(
self
,
prd
:
torch
.
Tensor
,
tar
:
torch
.
Tensor
)
->
torch
.
Tensor
:
prd
=
nn
.
functional
.
softmax
(
prd
,
dim
=
1
)
# mask values
...
...
@@ -158,7 +136,7 @@ class DiceLossS2(nn.Module):
class
CrossEntropyLossS2
(
nn
.
Module
):
"""
Cross-entropy loss for spherical classification tasks.
Parameters
-----------
nlat : int
...
...
@@ -204,7 +182,7 @@ class CrossEntropyLossS2(nn.Module):
class
FocalLossS2
(
nn
.
Module
):
"""
Focal loss for spherical classification tasks.
Parameters
-----------
nlat : int
...
...
@@ -275,14 +253,32 @@ class SphericalLossBase(nn.Module, ABC):
@
abstractmethod
def
_compute_loss_term
(
self
,
prd
:
torch
.
Tensor
,
tar
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Abstract method that must be implemented by child classes to compute loss terms.
Args:
prd (torch.Tensor): Prediction tensor
tar (torch.Tensor): Target tensor
Returns:
torch.Tensor: Computed loss term before integration
"""
pass
def
_post_integration_hook
(
self
,
loss
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Post-integration hook. Commonly used for the roots in Lp norms"""
return
loss
def
forward
(
self
,
prd
:
torch
.
Tensor
,
tar
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""Common forward pass that handles masking and reduction.
Args:
prd (torch.Tensor): Prediction tensor
tar (torch.Tensor): Target tensor
mask (Optional[torch.Tensor], optional): Mask tensor. Defaults to None.
Returns:
torch.Tensor: Final loss value
"""
loss_term
=
self
.
_compute_loss_term
(
prd
,
tar
)
# Integrate over the sphere for each item in the batch
loss
=
self
.
_integrate_sphere
(
loss_term
,
mask
)
...
...
@@ -293,34 +289,22 @@ class SphericalLossBase(nn.Module, ABC):
class
SquaredL2LossS2
(
SphericalLossBase
):
"""Squared L2 loss for spherical regression tasks."""
def
_compute_loss_term
(
self
,
prd
:
torch
.
Tensor
,
tar
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
square
(
prd
-
tar
)
class
L1LossS2
(
SphericalLossBase
):
"""L1 loss for spherical regression tasks."""
def
_compute_loss_term
(
self
,
prd
:
torch
.
Tensor
,
tar
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
abs
(
prd
-
tar
)
class
L2LossS2
(
SquaredL2LossS2
):
"""L2 loss for spherical regression tasks."""
def
_post_integration_hook
(
self
,
loss
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
sqrt
(
loss
)
class
W11LossS2
(
SphericalLossBase
):
"""W11 loss for spherical regression tasks."""
def
__init__
(
self
,
nlat
:
int
,
nlon
:
int
,
grid
:
str
=
"equiangular"
):
super
().
__init__
(
nlat
=
nlat
,
nlon
=
nlon
,
grid
=
grid
)
# Set up grid and domain for FFT
l_phi
=
2
*
torch
.
pi
# domain size
...
...
@@ -387,56 +371,31 @@ class NormalLossS2(SphericalLossBase):
self
.
register_buffer
(
"k_theta_mesh"
,
k_theta_mesh
)
def
compute_gradients
(
self
,
x
):
"""
Compute spatial gradients of the input tensor using FFT.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (batch, nlat, nlon) or (nlat, nlon)
Returns
-------
tuple
Tuple of (grad_phi, grad_theta) gradients
"""
# Make sure x is reshaped to have a batch dimension if it's missing
if
x
.
dim
()
==
2
:
x
=
x
.
unsqueeze
(
0
)
# Add batch dimension
# Compute gradients using FFT
grad_phi
=
torch
.
fft
.
ifft2
(
1j
*
self
.
k_phi_mesh
*
torch
.
fft
.
fft2
(
x
)).
real
grad_theta
=
torch
.
fft
.
ifft2
(
1j
*
self
.
k_theta_mesh
*
torch
.
fft
.
fft2
(
x
)).
real
return
grad_phi
,
grad_theta
x_prime_fft2_phi_h
=
torch
.
fft
.
ifft2
(
1j
*
self
.
k_phi_mesh
*
torch
.
fft
.
fft2
(
x
)).
real
x_prime_fft2_theta_h
=
torch
.
fft
.
ifft2
(
1j
*
self
.
k_theta_mesh
*
torch
.
fft
.
fft2
(
x
)).
real
return
x_prime_fft2_theta_h
,
x_prime_fft2_phi_h
def
compute_normals
(
self
,
x
):
"""
Compute surface normals from the input tensor.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (batch, nlat, nlon) or (nlat, nlon)
Returns
-------
torch.Tensor
Normal vectors with shape (batch, 3, nlat, nlon)
"""
grad_phi
,
grad_theta
=
self
.
compute_gradients
(
x
)
# Construct normal vectors: (-grad_theta, -grad_phi, 1)
normals
=
torch
.
stack
([
-
grad_theta
,
-
grad_phi
,
torch
.
ones_like
(
x
)],
dim
=
1
)
# Normalize
norm
=
torch
.
norm
(
normals
,
dim
=
1
,
keepdim
=
True
)
normals
=
normals
/
(
norm
+
1e-8
)
x
=
x
.
to
(
torch
.
float32
)
# Ensure x has a batch dimension
if
x
.
dim
()
==
2
:
x
=
x
.
unsqueeze
(
0
)
grad_lat
,
grad_lon
=
self
.
compute_gradients
(
x
)
# Create 3D normal vectors
ones
=
torch
.
ones_like
(
x
)
normals
=
torch
.
stack
([
-
grad_lon
,
-
grad_lat
,
ones
],
dim
=
1
)
# Normalize along component dimension
normals
=
F
.
normalize
(
normals
,
p
=
2
,
dim
=
1
)
return
normals
def
_compute_loss_term
(
self
,
prd
:
torch
.
Tensor
,
tar
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# Handle dimensions for both prediction and target
# Ensure we have at least a batch dimension
if
prd
.
dim
()
==
2
:
...
...
@@ -444,18 +403,15 @@ class NormalLossS2(SphericalLossBase):
if
tar
.
dim
()
==
2
:
tar
=
tar
.
unsqueeze
(
0
)
# L1 loss term
l1_loss
=
torch
.
abs
(
prd
-
tar
)
# For 4D tensors (batch, channel, height, width), remove channel if it's 1
if
prd
.
dim
()
==
4
and
prd
.
size
(
1
)
==
1
:
prd
=
prd
.
squeeze
(
1
)
if
tar
.
dim
()
==
4
and
tar
.
size
(
1
)
==
1
:
tar
=
tar
.
squeeze
(
1
)
# Normal consistency loss
prd_normals
=
self
.
compute_normals
(
prd
)
pred_normals
=
self
.
compute_normals
(
prd
)
tar_normals
=
self
.
compute_normals
(
tar
)
# Cosine similarity between normals
cos_sim
=
torch
.
sum
(
prd_normals
*
tar_normals
,
dim
=
1
)
normal_loss
=
1
-
cos_sim
# Combine losses (equal weighting)
combined_loss
=
l1_loss
+
normal_loss
.
unsqueeze
(
1
)
return
combined_loss
# Compute cosine similarity
normal_loss
=
1
-
torch
.
sum
(
pred_normals
*
tar_normals
,
dim
=
1
,
keepdim
=
True
)
return
normal_loss
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment