"vscode:/vscode.git/clone" did not exist on "10a711167aa0d40697d6ff2ab9e4e4ec9e9b3372"
Unverified Commit 8aa4372a authored by Prathik Rao's avatar Prathik Rao Committed by GitHub
Browse files

reorder model wrap + bug fix (#1799)



* reorder model wrap

* bug fix
Co-authored-by: default avatarPrathik Rao <prathikrao@microsoft.com>
parent 60438389
......@@ -287,7 +287,6 @@ def main(args):
"UpBlock2D",
),
)
model = ORTModule(model)
accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
if accepts_prediction_type:
......@@ -359,6 +358,8 @@ def main(args):
max_value=args.ema_max_decay,
)
model = ORTModule(model)
# Handle the repository creation
if accelerator.is_main_process:
if args.push_to_hub:
......@@ -424,7 +425,7 @@ def main(args):
with accelerator.accumulate(model):
# Predict the noise residual
model_output = model(noisy_images, timesteps, return_dict=True)[0]
model_output = model(noisy_images, timesteps, return_dict=False)[0]
if args.prediction_type == "epsilon":
loss = F.mse_loss(model_output, noise) # this could have different weights!
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment