Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
dff5ff35
Unverified
Commit
dff5ff35
authored
Aug 07, 2023
by
Patrick von Platen
Committed by
GitHub
Aug 07, 2023
Browse files
[SDXL LoRA] fix batch size lora (#4509)
fix batch size lora
parent
b2456717
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
4 deletions
+4
-4
examples/dreambooth/train_dreambooth_lora_sdxl.py
examples/dreambooth/train_dreambooth_lora_sdxl.py
+4
-4
No files found.
examples/dreambooth/train_dreambooth_lora_sdxl.py
View file @
dff5ff35
...
@@ -1103,11 +1103,11 @@ def main(args):
...
@@ -1103,11 +1103,11 @@ def main(args):
"time_ids"
:
add_time_ids
.
repeat
(
elems_to_repeat
,
1
),
"time_ids"
:
add_time_ids
.
repeat
(
elems_to_repeat
,
1
),
"text_embeds"
:
unet_add_text_embeds
.
repeat
(
elems_to_repeat
,
1
),
"text_embeds"
:
unet_add_text_embeds
.
repeat
(
elems_to_repeat
,
1
),
}
}
prompt_embeds
=
prompt_embeds
.
repeat
(
elems_to_repeat
,
1
,
1
)
prompt_embeds
_input
=
prompt_embeds
.
repeat
(
elems_to_repeat
,
1
,
1
)
model_pred
=
unet
(
model_pred
=
unet
(
noisy_model_input
,
noisy_model_input
,
timesteps
,
timesteps
,
prompt_embeds
,
prompt_embeds
_input
,
added_cond_kwargs
=
unet_added_conditions
,
added_cond_kwargs
=
unet_added_conditions
,
).
sample
).
sample
else
:
else
:
...
@@ -1119,9 +1119,9 @@ def main(args):
...
@@ -1119,9 +1119,9 @@ def main(args):
text_input_ids_list
=
[
tokens_one
,
tokens_two
],
text_input_ids_list
=
[
tokens_one
,
tokens_two
],
)
)
unet_added_conditions
.
update
({
"text_embeds"
:
pooled_prompt_embeds
.
repeat
(
elems_to_repeat
,
1
)})
unet_added_conditions
.
update
({
"text_embeds"
:
pooled_prompt_embeds
.
repeat
(
elems_to_repeat
,
1
)})
prompt_embeds
=
prompt_embeds
.
repeat
(
elems_to_repeat
,
1
,
1
)
prompt_embeds
_input
=
prompt_embeds
.
repeat
(
elems_to_repeat
,
1
,
1
)
model_pred
=
unet
(
model_pred
=
unet
(
noisy_model_input
,
timesteps
,
prompt_embeds
,
added_cond_kwargs
=
unet_added_conditions
noisy_model_input
,
timesteps
,
prompt_embeds
_input
,
added_cond_kwargs
=
unet_added_conditions
).
sample
).
sample
# Get the target for loss depending on the prediction type
# Get the target for loss depending on the prediction type
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment