Unverified Commit c96bfa5c authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Mochi-1] ensuring to compute the fourier features in FP32 in Mochi encoder (#10031)

compute fourier features in FP32.
parent 6b288ec4
......@@ -437,7 +437,8 @@ class FourierFeatures(nn.Module):
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
r"""Forward method of the `FourierFeatures` class."""
original_dtype = inputs.dtype
inputs = inputs.to(torch.float32)
num_channels = inputs.shape[1]
num_freqs = (self.stop - self.start) // self.step
......@@ -450,7 +451,7 @@ class FourierFeatures(nn.Module):
# Scale channels by frequency.
h = w * h
return torch.cat([inputs, torch.sin(h), torch.cos(h)], dim=1)
return torch.cat([inputs, torch.sin(h), torch.cos(h)], dim=1).to(original_dtype)
class MochiEncoder3D(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment