"vscode:/vscode.git/clone" did not exist on "67f5fcf7099aa0857230995277c264d66d2fc0ab"
Unverified Commit 60c384bc authored by pink-red's avatar pink-red Committed by GitHub
Browse files

Fix fine-tuning compatibility with deepspeed (#816)

parent 008b608f
......@@ -568,7 +568,7 @@ def main():
# Predict the noise residual and compute loss
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = F.mse_loss(noise_pred, noise, reduction="mean")
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment