Commit 763d4371 authored by Andrea Paris's avatar Andrea Paris
Browse files

small bf16 fix for w11 loss

parent 6ac50e26
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.amp as amp
import torch.nn.functional as F import torch.nn.functional as F
from typing import Optional from typing import Optional
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
...@@ -259,6 +260,9 @@ class W11LossS2(SphericalLossBase): ...@@ -259,6 +260,9 @@ class W11LossS2(SphericalLossBase):
self.register_buffer("k_theta_mesh", k_theta_mesh) self.register_buffer("k_theta_mesh", k_theta_mesh)
def _compute_loss_term(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor: def _compute_loss_term(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor:
prdtype = prd.dtype
with amp.autocast(device_type="cuda", enabled=False):
prd = prd.to(torch.float32)
prd_prime_fft2_phi_h = torch.fft.ifft2(1j * self.k_phi_mesh * torch.fft.fft2(prd)).real prd_prime_fft2_phi_h = torch.fft.ifft2(1j * self.k_phi_mesh * torch.fft.fft2(prd)).real
prd_prime_fft2_theta_h = torch.fft.ifft2(1j * self.k_theta_mesh * torch.fft.fft2(prd)).real prd_prime_fft2_theta_h = torch.fft.ifft2(1j * self.k_theta_mesh * torch.fft.fft2(prd)).real
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment