Commit 63b769fc authored by Boris Bonev's avatar Boris Bonev
Browse files

fixing losses

parent f72a48dd
......@@ -40,27 +40,6 @@ from torch_harmonics.quadrature import _precompute_latitudes
def get_quadrature_weights(nlat: int, nlon: int, grid: str, tile: bool = False, normalized: bool = True) -> torch.Tensor:
"""
Get quadrature weights for spherical integration.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
grid : str
Grid type ("equiangular", "legendre-gauss", "lobatto")
tile : bool, optional
Whether to tile weights across longitude dimension, by default False
normalized : bool, optional
Whether to normalize weights to sum to 1, by default True
Returns
-------
torch.Tensor
Quadrature weights tensor
"""
# area weights
_, q = _precompute_latitudes(nlat=nlat, grid=grid)
q = q.reshape(-1, 1) * 2 * torch.pi / nlon
......@@ -78,7 +57,7 @@ def get_quadrature_weights(nlat: int, nlon: int, grid: str, tile: bool = False,
class DiceLossS2(nn.Module):
"""
Dice loss for spherical segmentation tasks.
Parameters
-----------
nlat : int
......@@ -96,7 +75,7 @@ class DiceLossS2(nn.Module):
mode : str, optional
Aggregation mode ("micro" or "macro"), by default "micro"
"""
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular", weight: torch.Tensor = None, smooth: float = 0, ignore_index: int = -100, mode: str = "micro"):
super().__init__()
......@@ -115,7 +94,6 @@ class DiceLossS2(nn.Module):
self.register_buffer("weight", weight.unsqueeze(0))
def forward(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor:
prd = nn.functional.softmax(prd, dim=1)
# mask values
......@@ -158,7 +136,7 @@ class DiceLossS2(nn.Module):
class CrossEntropyLossS2(nn.Module):
"""
Cross-entropy loss for spherical classification tasks.
Parameters
-----------
nlat : int
......@@ -204,7 +182,7 @@ class CrossEntropyLossS2(nn.Module):
class FocalLossS2(nn.Module):
"""
Focal loss for spherical classification tasks.
Parameters
-----------
nlat : int
......@@ -275,14 +253,32 @@ class SphericalLossBase(nn.Module, ABC):
@abstractmethod
def _compute_loss_term(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor:
"""Abstract method that must be implemented by child classes to compute loss terms.
Args:
prd (torch.Tensor): Prediction tensor
tar (torch.Tensor): Target tensor
Returns:
torch.Tensor: Computed loss term before integration
"""
pass
def _post_integration_hook(self, loss: torch.Tensor) -> torch.Tensor:
"""Post-integration hook. Commonly used for the roots in Lp norms"""
return loss
def forward(self, prd: torch.Tensor, tar: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Common forward pass that handles masking and reduction.
Args:
prd (torch.Tensor): Prediction tensor
tar (torch.Tensor): Target tensor
mask (Optional[torch.Tensor], optional): Mask tensor. Defaults to None.
Returns:
torch.Tensor: Final loss value
"""
loss_term = self._compute_loss_term(prd, tar)
# Integrate over the sphere for each item in the batch
loss = self._integrate_sphere(loss_term, mask)
......@@ -293,34 +289,22 @@ class SphericalLossBase(nn.Module, ABC):
class SquaredL2LossS2(SphericalLossBase):
"""Squared L2 loss for spherical regression tasks."""
def _compute_loss_term(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor:
return torch.square(prd - tar)
class L1LossS2(SphericalLossBase):
"""L1 loss for spherical regression tasks."""
def _compute_loss_term(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor:
return torch.abs(prd - tar)
class L2LossS2(SquaredL2LossS2):
"""L2 loss for spherical regression tasks."""
def _post_integration_hook(self, loss: torch.Tensor) -> torch.Tensor:
return torch.sqrt(loss)
class W11LossS2(SphericalLossBase):
"""W11 loss for spherical regression tasks."""
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular"):
super().__init__(nlat=nlat, nlon=nlon, grid=grid)
# Set up grid and domain for FFT
l_phi = 2 * torch.pi # domain size
......@@ -387,56 +371,31 @@ class NormalLossS2(SphericalLossBase):
self.register_buffer("k_theta_mesh", k_theta_mesh)
def compute_gradients(self, x):
"""
Compute spatial gradients of the input tensor using FFT.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (batch, nlat, nlon) or (nlat, nlon)
Returns
-------
tuple
Tuple of (grad_phi, grad_theta) gradients
"""
# Make sure x is reshaped to have a batch dimension if it's missing
if x.dim() == 2:
x = x.unsqueeze(0) # Add batch dimension
# Compute gradients using FFT
grad_phi = torch.fft.ifft2(1j * self.k_phi_mesh * torch.fft.fft2(x)).real
grad_theta = torch.fft.ifft2(1j * self.k_theta_mesh * torch.fft.fft2(x)).real
return grad_phi, grad_theta
x_prime_fft2_phi_h = torch.fft.ifft2(1j * self.k_phi_mesh * torch.fft.fft2(x)).real
x_prime_fft2_theta_h = torch.fft.ifft2(1j * self.k_theta_mesh * torch.fft.fft2(x)).real
return x_prime_fft2_theta_h, x_prime_fft2_phi_h
def compute_normals(self, x):
"""
Compute surface normals from the input tensor.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (batch, nlat, nlon) or (nlat, nlon)
Returns
-------
torch.Tensor
Normal vectors with shape (batch, 3, nlat, nlon)
"""
grad_phi, grad_theta = self.compute_gradients(x)
# Construct normal vectors: (-grad_theta, -grad_phi, 1)
normals = torch.stack([-grad_theta, -grad_phi, torch.ones_like(x)], dim=1)
# Normalize
norm = torch.norm(normals, dim=1, keepdim=True)
normals = normals / (norm + 1e-8)
x = x.to(torch.float32)
# Ensure x has a batch dimension
if x.dim() == 2:
x = x.unsqueeze(0)
grad_lat, grad_lon = self.compute_gradients(x)
# Create 3D normal vectors
ones = torch.ones_like(x)
normals = torch.stack([-grad_lon, -grad_lat, ones], dim=1)
# Normalize along component dimension
normals = F.normalize(normals, p=2, dim=1)
return normals
def _compute_loss_term(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor:
# Handle dimensions for both prediction and target
# Ensure we have at least a batch dimension
if prd.dim() == 2:
......@@ -444,18 +403,15 @@ class NormalLossS2(SphericalLossBase):
if tar.dim() == 2:
tar = tar.unsqueeze(0)
# L1 loss term
l1_loss = torch.abs(prd - tar)
# For 4D tensors (batch, channel, height, width), remove channel if it's 1
if prd.dim() == 4 and prd.size(1) == 1:
prd = prd.squeeze(1)
if tar.dim() == 4 and tar.size(1) == 1:
tar = tar.squeeze(1)
# Normal consistency loss
prd_normals = self.compute_normals(prd)
pred_normals = self.compute_normals(prd)
tar_normals = self.compute_normals(tar)
# Cosine similarity between normals
cos_sim = torch.sum(prd_normals * tar_normals, dim=1)
normal_loss = 1 - cos_sim
# Combine losses (equal weighting)
combined_loss = l1_loss + normal_loss.unsqueeze(1)
return combined_loss
# Compute cosine similarity
normal_loss = 1 - torch.sum(pred_normals * tar_normals, dim=1, keepdim=True)
return normal_loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment