"tests/git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "e1f363463773c4fba379c11b1a00a5a20a53debd"
Unverified Commit e607a582 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Examples] Fix type-casting issue in the ControlNet training script (#2994)

* fix: norm group test for UNet3D.

* fix: type-casting issue in controlnet training.
parent ea39cd7e
...@@ -972,8 +972,10 @@ def main(args): ...@@ -972,8 +972,10 @@ def main(args):
noisy_latents, noisy_latents,
timesteps, timesteps,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
down_block_additional_residuals=down_block_res_samples, down_block_additional_residuals=[
mid_block_additional_residual=mid_block_res_sample, sample.to(dtype=weight_dtype) for sample in down_block_res_samples
],
mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
).sample ).sample
# Get the target for loss depending on the prediction type # Get the target for loss depending on the prediction type
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment