You need to sign in or sign up before continuing.
Unverified Commit e001fede authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

Fix dreambooth loss type with prior_preservation and fp16 (#826)

Fix dreambooth loss type with prior preservation
parent 0a09af2f
...@@ -544,7 +544,7 @@ def main(): ...@@ -544,7 +544,7 @@ def main():
noise, noise_prior = torch.chunk(noise, 2, dim=0) noise, noise_prior = torch.chunk(noise, 2, dim=0)
# Compute instance loss # Compute instance loss
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean()
# Compute prior loss # Compute prior loss
prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment