Commit 63b769fc authored by Boris Bonev's avatar Boris Bonev
Browse files

fixing losses

parent f72a48dd
...@@ -40,27 +40,6 @@ from torch_harmonics.quadrature import _precompute_latitudes ...@@ -40,27 +40,6 @@ from torch_harmonics.quadrature import _precompute_latitudes
def get_quadrature_weights(nlat: int, nlon: int, grid: str, tile: bool = False, normalized: bool = True) -> torch.Tensor: def get_quadrature_weights(nlat: int, nlon: int, grid: str, tile: bool = False, normalized: bool = True) -> torch.Tensor:
"""
Get quadrature weights for spherical integration.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
grid : str
Grid type ("equiangular", "legendre-gauss", "lobatto")
tile : bool, optional
Whether to tile weights across longitude dimension, by default False
normalized : bool, optional
Whether to normalize weights to sum to 1, by default True
Returns
-------
torch.Tensor
Quadrature weights tensor
"""
# area weights # area weights
_, q = _precompute_latitudes(nlat=nlat, grid=grid) _, q = _precompute_latitudes(nlat=nlat, grid=grid)
q = q.reshape(-1, 1) * 2 * torch.pi / nlon q = q.reshape(-1, 1) * 2 * torch.pi / nlon
...@@ -78,7 +57,7 @@ def get_quadrature_weights(nlat: int, nlon: int, grid: str, tile: bool = False, ...@@ -78,7 +57,7 @@ def get_quadrature_weights(nlat: int, nlon: int, grid: str, tile: bool = False,
class DiceLossS2(nn.Module): class DiceLossS2(nn.Module):
""" """
Dice loss for spherical segmentation tasks. Dice loss for spherical segmentation tasks.
Parameters Parameters
----------- -----------
nlat : int nlat : int
...@@ -96,7 +75,7 @@ class DiceLossS2(nn.Module): ...@@ -96,7 +75,7 @@ class DiceLossS2(nn.Module):
mode : str, optional mode : str, optional
Aggregation mode ("micro" or "macro"), by default "micro" Aggregation mode ("micro" or "macro"), by default "micro"
""" """
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular", weight: torch.Tensor = None, smooth: float = 0, ignore_index: int = -100, mode: str = "micro"): def __init__(self, nlat: int, nlon: int, grid: str = "equiangular", weight: torch.Tensor = None, smooth: float = 0, ignore_index: int = -100, mode: str = "micro"):
super().__init__() super().__init__()
...@@ -115,7 +94,6 @@ class DiceLossS2(nn.Module): ...@@ -115,7 +94,6 @@ class DiceLossS2(nn.Module):
self.register_buffer("weight", weight.unsqueeze(0)) self.register_buffer("weight", weight.unsqueeze(0))
def forward(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor: def forward(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor:
prd = nn.functional.softmax(prd, dim=1) prd = nn.functional.softmax(prd, dim=1)
# mask values # mask values
...@@ -158,7 +136,7 @@ class DiceLossS2(nn.Module): ...@@ -158,7 +136,7 @@ class DiceLossS2(nn.Module):
class CrossEntropyLossS2(nn.Module): class CrossEntropyLossS2(nn.Module):
""" """
Cross-entropy loss for spherical classification tasks. Cross-entropy loss for spherical classification tasks.
Parameters Parameters
----------- -----------
nlat : int nlat : int
...@@ -204,7 +182,7 @@ class CrossEntropyLossS2(nn.Module): ...@@ -204,7 +182,7 @@ class CrossEntropyLossS2(nn.Module):
class FocalLossS2(nn.Module): class FocalLossS2(nn.Module):
""" """
Focal loss for spherical classification tasks. Focal loss for spherical classification tasks.
Parameters Parameters
----------- -----------
nlat : int nlat : int
...@@ -275,14 +253,32 @@ class SphericalLossBase(nn.Module, ABC): ...@@ -275,14 +253,32 @@ class SphericalLossBase(nn.Module, ABC):
@abstractmethod @abstractmethod
def _compute_loss_term(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor: def _compute_loss_term(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor:
"""Abstract method that must be implemented by child classes to compute loss terms.
Args:
prd (torch.Tensor): Prediction tensor
tar (torch.Tensor): Target tensor
Returns:
torch.Tensor: Computed loss term before integration
"""
pass pass
def _post_integration_hook(self, loss: torch.Tensor) -> torch.Tensor: def _post_integration_hook(self, loss: torch.Tensor) -> torch.Tensor:
"""Post-integration hook. Commonly used for the roots in Lp norms"""
return loss return loss
def forward(self, prd: torch.Tensor, tar: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: def forward(self, prd: torch.Tensor, tar: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Common forward pass that handles masking and reduction.
Args:
prd (torch.Tensor): Prediction tensor
tar (torch.Tensor): Target tensor
mask (Optional[torch.Tensor], optional): Mask tensor. Defaults to None.
Returns:
torch.Tensor: Final loss value
"""
loss_term = self._compute_loss_term(prd, tar) loss_term = self._compute_loss_term(prd, tar)
# Integrate over the sphere for each item in the batch # Integrate over the sphere for each item in the batch
loss = self._integrate_sphere(loss_term, mask) loss = self._integrate_sphere(loss_term, mask)
...@@ -293,34 +289,22 @@ class SphericalLossBase(nn.Module, ABC): ...@@ -293,34 +289,22 @@ class SphericalLossBase(nn.Module, ABC):
class SquaredL2LossS2(SphericalLossBase): class SquaredL2LossS2(SphericalLossBase):
"""Squared L2 loss for spherical regression tasks."""
def _compute_loss_term(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor: def _compute_loss_term(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor:
return torch.square(prd - tar) return torch.square(prd - tar)
class L1LossS2(SphericalLossBase): class L1LossS2(SphericalLossBase):
"""L1 loss for spherical regression tasks."""
def _compute_loss_term(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor: def _compute_loss_term(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor:
return torch.abs(prd - tar) return torch.abs(prd - tar)
class L2LossS2(SquaredL2LossS2): class L2LossS2(SquaredL2LossS2):
"""L2 loss for spherical regression tasks."""
def _post_integration_hook(self, loss: torch.Tensor) -> torch.Tensor: def _post_integration_hook(self, loss: torch.Tensor) -> torch.Tensor:
return torch.sqrt(loss) return torch.sqrt(loss)
class W11LossS2(SphericalLossBase): class W11LossS2(SphericalLossBase):
"""W11 loss for spherical regression tasks."""
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular"): def __init__(self, nlat: int, nlon: int, grid: str = "equiangular"):
super().__init__(nlat=nlat, nlon=nlon, grid=grid) super().__init__(nlat=nlat, nlon=nlon, grid=grid)
# Set up grid and domain for FFT # Set up grid and domain for FFT
l_phi = 2 * torch.pi # domain size l_phi = 2 * torch.pi # domain size
...@@ -387,56 +371,31 @@ class NormalLossS2(SphericalLossBase): ...@@ -387,56 +371,31 @@ class NormalLossS2(SphericalLossBase):
self.register_buffer("k_theta_mesh", k_theta_mesh) self.register_buffer("k_theta_mesh", k_theta_mesh)
def compute_gradients(self, x): def compute_gradients(self, x):
"""
Compute spatial gradients of the input tensor using FFT.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (batch, nlat, nlon) or (nlat, nlon)
Returns
-------
tuple
Tuple of (grad_phi, grad_theta) gradients
"""
# Make sure x is reshaped to have a batch dimension if it's missing # Make sure x is reshaped to have a batch dimension if it's missing
if x.dim() == 2: if x.dim() == 2:
x = x.unsqueeze(0) # Add batch dimension x = x.unsqueeze(0) # Add batch dimension
# Compute gradients using FFT x_prime_fft2_phi_h = torch.fft.ifft2(1j * self.k_phi_mesh * torch.fft.fft2(x)).real
grad_phi = torch.fft.ifft2(1j * self.k_phi_mesh * torch.fft.fft2(x)).real x_prime_fft2_theta_h = torch.fft.ifft2(1j * self.k_theta_mesh * torch.fft.fft2(x)).real
grad_theta = torch.fft.ifft2(1j * self.k_theta_mesh * torch.fft.fft2(x)).real return x_prime_fft2_theta_h, x_prime_fft2_phi_h
return grad_phi, grad_theta
def compute_normals(self, x): def compute_normals(self, x):
""" x = x.to(torch.float32)
Compute surface normals from the input tensor. # Ensure x has a batch dimension
if x.dim() == 2:
Parameters x = x.unsqueeze(0)
-----------
x : torch.Tensor grad_lat, grad_lon = self.compute_gradients(x)
Input tensor with shape (batch, nlat, nlon) or (nlat, nlon)
# Create 3D normal vectors
Returns ones = torch.ones_like(x)
------- normals = torch.stack([-grad_lon, -grad_lat, ones], dim=1)
torch.Tensor
Normal vectors with shape (batch, 3, nlat, nlon) # Normalize along component dimension
""" normals = F.normalize(normals, p=2, dim=1)
grad_phi, grad_theta = self.compute_gradients(x)
# Construct normal vectors: (-grad_theta, -grad_phi, 1)
normals = torch.stack([-grad_theta, -grad_phi, torch.ones_like(x)], dim=1)
# Normalize
norm = torch.norm(normals, dim=1, keepdim=True)
normals = normals / (norm + 1e-8)
return normals return normals
def _compute_loss_term(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor: def _compute_loss_term(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor:
# Handle dimensions for both prediction and target # Handle dimensions for both prediction and target
# Ensure we have at least a batch dimension # Ensure we have at least a batch dimension
if prd.dim() == 2: if prd.dim() == 2:
...@@ -444,18 +403,15 @@ class NormalLossS2(SphericalLossBase): ...@@ -444,18 +403,15 @@ class NormalLossS2(SphericalLossBase):
if tar.dim() == 2: if tar.dim() == 2:
tar = tar.unsqueeze(0) tar = tar.unsqueeze(0)
# L1 loss term # For 4D tensors (batch, channel, height, width), remove channel if it's 1
l1_loss = torch.abs(prd - tar) if prd.dim() == 4 and prd.size(1) == 1:
prd = prd.squeeze(1)
if tar.dim() == 4 and tar.size(1) == 1:
tar = tar.squeeze(1)
# Normal consistency loss pred_normals = self.compute_normals(prd)
prd_normals = self.compute_normals(prd)
tar_normals = self.compute_normals(tar) tar_normals = self.compute_normals(tar)
# Cosine similarity between normals
cos_sim = torch.sum(prd_normals * tar_normals, dim=1)
normal_loss = 1 - cos_sim
# Combine losses (equal weighting)
combined_loss = l1_loss + normal_loss.unsqueeze(1)
return combined_loss # Compute cosine similarity
normal_loss = 1 - torch.sum(pred_normals * tar_normals, dim=1, keepdim=True)
return normal_loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment