Unverified Commit e4fe9413 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[examples] update loss computation (#1861)

update loss computation
parent ac373846
...@@ -732,7 +732,7 @@ def main(args): ...@@ -732,7 +732,7 @@ def main(args):
target, target_prior = torch.chunk(target, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0)
# Compute instance loss # Compute instance loss
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
# Compute prior loss # Compute prior loss
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
......
...@@ -634,7 +634,8 @@ def main(): ...@@ -634,7 +634,8 @@ def main():
else: else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
accelerator.backward(loss) accelerator.backward(loss)
optimizer.step() optimizer.step()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment