"examples/vscode:/vscode.git/clone" did not exist on "b6f5ba9a809fcd2e5b2c440f538c1ccc965a9e59"
Unverified Commit 262d539a authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Correct multi gpu dreambooth (#3673)

Correct multi gpu
parent 0fc2fb71
...@@ -1211,7 +1211,7 @@ def main(args): ...@@ -1211,7 +1211,7 @@ def main(args):
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
) )
if unet.config.in_channels == channels * 2: if accelerator.unwrap_model(unet).config.in_channels == channels * 2:
noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1) noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
if args.class_labels_conditioning == "timesteps": if args.class_labels_conditioning == "timesteps":
......
...@@ -1156,7 +1156,7 @@ def main(args): ...@@ -1156,7 +1156,7 @@ def main(args):
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
) )
if unet.config.in_channels == channels * 2: if accelerator.unwrap_model(unet).config.in_channels == channels * 2:
noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1) noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
if args.class_labels_conditioning == "timesteps": if args.class_labels_conditioning == "timesteps":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment