Unverified Commit e001fede authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

Fix dreambooth loss type with prior_preservation and fp16 (#826)

Fix dreambooth loss type with prior preservation
parent 0a09af2f
...@@ -544,7 +544,7 @@ def main(): ...@@ -544,7 +544,7 @@ def main():
noise, noise_prior = torch.chunk(noise, 2, dim=0) noise, noise_prior = torch.chunk(noise, 2, dim=0)
# Compute instance loss # Compute instance loss
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean()
# Compute prior loss # Compute prior loss
prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment