Unverified Commit 0ddc5bf7 authored by Lucca Zenóbio's avatar Lucca Zenóbio Committed by GitHub
Browse files

fix mixed precision training on train_dreambooth_inpaint_lora (#3138)

cast to weight dtype
parent c5933c9c
...@@ -735,7 +735,7 @@ def main(): ...@@ -735,7 +735,7 @@ def main():
torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8)) torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8))
for mask in masks for mask in masks
] ]
) ).to(dtype=weight_dtype)
mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8) mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8)
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment