Commit 763d4371 authored by Andrea Paris's avatar Andrea Paris
Browse files

small bf16 fix for w11 loss

parent 6ac50e26
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.amp as amp
import torch.nn.functional as F import torch.nn.functional as F
from typing import Optional from typing import Optional
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
...@@ -259,11 +260,14 @@ class W11LossS2(SphericalLossBase): ...@@ -259,11 +260,14 @@ class W11LossS2(SphericalLossBase):
self.register_buffer("k_theta_mesh", k_theta_mesh) self.register_buffer("k_theta_mesh", k_theta_mesh)
def _compute_loss_term(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor: def _compute_loss_term(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor:
prd_prime_fft2_phi_h = torch.fft.ifft2(1j * self.k_phi_mesh * torch.fft.fft2(prd)).real prdtype = prd.dtype
prd_prime_fft2_theta_h = torch.fft.ifft2(1j * self.k_theta_mesh * torch.fft.fft2(prd)).real with amp.autocast(device_type="cuda", enabled=False):
prd = prd.to(torch.float32)
tar_prime_fft2_phi_h = torch.fft.ifft2(1j * self.k_phi_mesh * torch.fft.fft2(tar)).real prd_prime_fft2_phi_h = torch.fft.ifft2(1j * self.k_phi_mesh * torch.fft.fft2(prd)).real
tar_prime_fft2_theta_h = torch.fft.ifft2(1j * self.k_theta_mesh * torch.fft.fft2(tar)).real prd_prime_fft2_theta_h = torch.fft.ifft2(1j * self.k_theta_mesh * torch.fft.fft2(prd)).real
tar_prime_fft2_phi_h = torch.fft.ifft2(1j * self.k_phi_mesh * torch.fft.fft2(tar)).real
tar_prime_fft2_theta_h = torch.fft.ifft2(1j * self.k_theta_mesh * torch.fft.fft2(tar)).real
# Return the element-wise loss term # Return the element-wise loss term
return torch.abs(prd_prime_fft2_phi_h - tar_prime_fft2_phi_h) + torch.abs(prd_prime_fft2_theta_h - tar_prime_fft2_theta_h) return torch.abs(prd_prime_fft2_phi_h - tar_prime_fft2_phi_h) + torch.abs(prd_prime_fft2_theta_h - tar_prime_fft2_theta_h)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment