Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
262d539a
"examples/vscode:/vscode.git/clone" did not exist on "b6f5ba9a809fcd2e5b2c440f538c1ccc965a9e59"
Unverified
Commit
262d539a
authored
Jun 05, 2023
by
Patrick von Platen
Committed by
GitHub
Jun 05, 2023
Browse files
Correct multi gpu dreambooth (#3673)
Correct multi gpu
parent
0fc2fb71
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
2 additions
and
2 deletions
+2
-2
examples/dreambooth/train_dreambooth.py
examples/dreambooth/train_dreambooth.py
+1
-1
examples/dreambooth/train_dreambooth_lora.py
examples/dreambooth/train_dreambooth_lora.py
+1
-1
No files found.
examples/dreambooth/train_dreambooth.py
View file @
262d539a
...
@@ -1211,7 +1211,7 @@ def main(args):
...
@@ -1211,7 +1211,7 @@ def main(args):
text_encoder_use_attention_mask
=
args
.
text_encoder_use_attention_mask
,
text_encoder_use_attention_mask
=
args
.
text_encoder_use_attention_mask
,
)
)
if
unet
.
config
.
in_channels
==
channels
*
2
:
if
accelerator
.
unwrap_model
(
unet
)
.
config
.
in_channels
==
channels
*
2
:
noisy_model_input
=
torch
.
cat
([
noisy_model_input
,
noisy_model_input
],
dim
=
1
)
noisy_model_input
=
torch
.
cat
([
noisy_model_input
,
noisy_model_input
],
dim
=
1
)
if
args
.
class_labels_conditioning
==
"timesteps"
:
if
args
.
class_labels_conditioning
==
"timesteps"
:
...
...
examples/dreambooth/train_dreambooth_lora.py
View file @
262d539a
...
@@ -1156,7 +1156,7 @@ def main(args):
...
@@ -1156,7 +1156,7 @@ def main(args):
text_encoder_use_attention_mask
=
args
.
text_encoder_use_attention_mask
,
text_encoder_use_attention_mask
=
args
.
text_encoder_use_attention_mask
,
)
)
if
unet
.
config
.
in_channels
==
channels
*
2
:
if
accelerator
.
unwrap_model
(
unet
)
.
config
.
in_channels
==
channels
*
2
:
noisy_model_input
=
torch
.
cat
([
noisy_model_input
,
noisy_model_input
],
dim
=
1
)
noisy_model_input
=
torch
.
cat
([
noisy_model_input
,
noisy_model_input
],
dim
=
1
)
if
args
.
class_labels_conditioning
==
"timesteps"
:
if
args
.
class_labels_conditioning
==
"timesteps"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment