Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
08702fc1
Unverified
Commit
08702fc1
authored
Jan 15, 2024
by
Vinh H. Pham
Committed by
GitHub
Jan 15, 2024
Browse files
Make text-to-image SDXL LoRA Training Script torch.compile compatible (#6556)
make compile compatible
parent
7ce89e97
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
17 deletions
+26
-17
examples/text_to_image/train_text_to_image_lora_sdxl.py
examples/text_to_image/train_text_to_image_lora_sdxl.py
+26
-17
No files found.
examples/text_to_image/train_text_to_image_lora_sdxl.py
View file @
08702fc1
...
@@ -54,6 +54,7 @@ from diffusers.optimization import get_scheduler
...
@@ -54,6 +54,7 @@ from diffusers.optimization import get_scheduler
from
diffusers.training_utils
import
cast_training_params
,
compute_snr
from
diffusers.training_utils
import
cast_training_params
,
compute_snr
from
diffusers.utils
import
check_min_version
,
convert_state_dict_to_diffusers
,
is_wandb_available
from
diffusers.utils
import
check_min_version
,
convert_state_dict_to_diffusers
,
is_wandb_available
from
diffusers.utils.import_utils
import
is_xformers_available
from
diffusers.utils.import_utils
import
is_xformers_available
from
diffusers.utils.torch_utils
import
is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
...
@@ -460,13 +461,12 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
...
@@ -460,13 +461,12 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
text_input_ids
=
text_input_ids_list
[
i
]
text_input_ids
=
text_input_ids_list
[
i
]
prompt_embeds
=
text_encoder
(
prompt_embeds
=
text_encoder
(
text_input_ids
.
to
(
text_encoder
.
device
),
text_input_ids
.
to
(
text_encoder
.
device
),
output_hidden_states
=
True
,
return_dict
=
False
output_hidden_states
=
True
,
)
)
# We are only ALWAYS interested in the pooled output of the final text encoder
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds
=
prompt_embeds
[
0
]
pooled_prompt_embeds
=
prompt_embeds
[
0
]
prompt_embeds
=
prompt_embeds
.
hidden_states
[
-
2
]
prompt_embeds
=
prompt_embeds
[
-
1
]
[
-
2
]
bs_embed
,
seq_len
,
_
=
prompt_embeds
.
shape
bs_embed
,
seq_len
,
_
=
prompt_embeds
.
shape
prompt_embeds
=
prompt_embeds
.
view
(
bs_embed
,
seq_len
,
-
1
)
prompt_embeds
=
prompt_embeds
.
view
(
bs_embed
,
seq_len
,
-
1
)
prompt_embeds_list
.
append
(
prompt_embeds
)
prompt_embeds_list
.
append
(
prompt_embeds
)
...
@@ -637,6 +637,11 @@ def main(args):
...
@@ -637,6 +637,11 @@ def main(args):
# only upcast trainable parameters (LoRA) into fp32
# only upcast trainable parameters (LoRA) into fp32
cast_training_params
(
models
,
dtype
=
torch
.
float32
)
cast_training_params
(
models
,
dtype
=
torch
.
float32
)
def
unwrap_model
(
model
):
model
=
accelerator
.
unwrap_model
(
model
)
model
=
model
.
_orig_mod
if
is_compiled_module
(
model
)
else
model
return
model
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def
save_model_hook
(
models
,
weights
,
output_dir
):
def
save_model_hook
(
models
,
weights
,
output_dir
):
if
accelerator
.
is_main_process
:
if
accelerator
.
is_main_process
:
...
@@ -647,13 +652,13 @@ def main(args):
...
@@ -647,13 +652,13 @@ def main(args):
text_encoder_two_lora_layers_to_save
=
None
text_encoder_two_lora_layers_to_save
=
None
for
model
in
models
:
for
model
in
models
:
if
isinstance
(
model
,
type
(
accelerator
.
unwrap_model
(
unet
))):
if
isinstance
(
model
,
type
(
unwrap_model
(
unet
))):
unet_lora_layers_to_save
=
convert_state_dict_to_diffusers
(
get_peft_model_state_dict
(
model
))
unet_lora_layers_to_save
=
convert_state_dict_to_diffusers
(
get_peft_model_state_dict
(
model
))
elif
isinstance
(
model
,
type
(
accelerator
.
unwrap_model
(
text_encoder_one
))):
elif
isinstance
(
model
,
type
(
unwrap_model
(
text_encoder_one
))):
text_encoder_one_lora_layers_to_save
=
convert_state_dict_to_diffusers
(
text_encoder_one_lora_layers_to_save
=
convert_state_dict_to_diffusers
(
get_peft_model_state_dict
(
model
)
get_peft_model_state_dict
(
model
)
)
)
elif
isinstance
(
model
,
type
(
accelerator
.
unwrap_model
(
text_encoder_two
))):
elif
isinstance
(
model
,
type
(
unwrap_model
(
text_encoder_two
))):
text_encoder_two_lora_layers_to_save
=
convert_state_dict_to_diffusers
(
text_encoder_two_lora_layers_to_save
=
convert_state_dict_to_diffusers
(
get_peft_model_state_dict
(
model
)
get_peft_model_state_dict
(
model
)
)
)
...
@@ -678,11 +683,11 @@ def main(args):
...
@@ -678,11 +683,11 @@ def main(args):
while
len
(
models
)
>
0
:
while
len
(
models
)
>
0
:
model
=
models
.
pop
()
model
=
models
.
pop
()
if
isinstance
(
model
,
type
(
accelerator
.
unwrap_model
(
unet
))):
if
isinstance
(
model
,
type
(
unwrap_model
(
unet
))):
unet_
=
model
unet_
=
model
elif
isinstance
(
model
,
type
(
accelerator
.
unwrap_model
(
text_encoder_one
))):
elif
isinstance
(
model
,
type
(
unwrap_model
(
text_encoder_one
))):
text_encoder_one_
=
model
text_encoder_one_
=
model
elif
isinstance
(
model
,
type
(
accelerator
.
unwrap_model
(
text_encoder_two
))):
elif
isinstance
(
model
,
type
(
unwrap_model
(
text_encoder_two
))):
text_encoder_two_
=
model
text_encoder_two_
=
model
else
:
else
:
raise
ValueError
(
f
"unexpected save model:
{
model
.
__class__
}
"
)
raise
ValueError
(
f
"unexpected save model:
{
model
.
__class__
}
"
)
...
@@ -1031,8 +1036,12 @@ def main(args):
...
@@ -1031,8 +1036,12 @@ def main(args):
)
)
unet_added_conditions
.
update
({
"text_embeds"
:
pooled_prompt_embeds
})
unet_added_conditions
.
update
({
"text_embeds"
:
pooled_prompt_embeds
})
model_pred
=
unet
(
model_pred
=
unet
(
noisy_model_input
,
timesteps
,
prompt_embeds
,
added_cond_kwargs
=
unet_added_conditions
noisy_model_input
,
).
sample
timesteps
,
prompt_embeds
,
added_cond_kwargs
=
unet_added_conditions
,
return_dict
=
False
,
)[
0
]
# Get the target for loss depending on the prediction type
# Get the target for loss depending on the prediction type
if
args
.
prediction_type
is
not
None
:
if
args
.
prediction_type
is
not
None
:
...
@@ -1125,9 +1134,9 @@ def main(args):
...
@@ -1125,9 +1134,9 @@ def main(args):
pipeline
=
StableDiffusionXLPipeline
.
from_pretrained
(
pipeline
=
StableDiffusionXLPipeline
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
args
.
pretrained_model_name_or_path
,
vae
=
vae
,
vae
=
vae
,
text_encoder
=
accelerator
.
unwrap_model
(
text_encoder_one
),
text_encoder
=
unwrap_model
(
text_encoder_one
),
text_encoder_2
=
accelerator
.
unwrap_model
(
text_encoder_two
),
text_encoder_2
=
unwrap_model
(
text_encoder_two
),
unet
=
accelerator
.
unwrap_model
(
unet
),
unet
=
unwrap_model
(
unet
),
revision
=
args
.
revision
,
revision
=
args
.
revision
,
variant
=
args
.
variant
,
variant
=
args
.
variant
,
torch_dtype
=
weight_dtype
,
torch_dtype
=
weight_dtype
,
...
@@ -1166,12 +1175,12 @@ def main(args):
...
@@ -1166,12 +1175,12 @@ def main(args):
# Save the lora layers
# Save the lora layers
accelerator
.
wait_for_everyone
()
accelerator
.
wait_for_everyone
()
if
accelerator
.
is_main_process
:
if
accelerator
.
is_main_process
:
unet
=
accelerator
.
unwrap_model
(
unet
)
unet
=
unwrap_model
(
unet
)
unet_lora_state_dict
=
convert_state_dict_to_diffusers
(
get_peft_model_state_dict
(
unet
))
unet_lora_state_dict
=
convert_state_dict_to_diffusers
(
get_peft_model_state_dict
(
unet
))
if
args
.
train_text_encoder
:
if
args
.
train_text_encoder
:
text_encoder_one
=
accelerator
.
unwrap_model
(
text_encoder_one
)
text_encoder_one
=
unwrap_model
(
text_encoder_one
)
text_encoder_two
=
accelerator
.
unwrap_model
(
text_encoder_two
)
text_encoder_two
=
unwrap_model
(
text_encoder_two
)
text_encoder_lora_layers
=
convert_state_dict_to_diffusers
(
get_peft_model_state_dict
(
text_encoder_one
))
text_encoder_lora_layers
=
convert_state_dict_to_diffusers
(
get_peft_model_state_dict
(
text_encoder_one
))
text_encoder_2_lora_layers
=
convert_state_dict_to_diffusers
(
get_peft_model_state_dict
(
text_encoder_two
))
text_encoder_2_lora_layers
=
convert_state_dict_to_diffusers
(
get_peft_model_state_dict
(
text_encoder_two
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment