Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
08702fc1
Unverified
Commit
08702fc1
authored
Jan 15, 2024
by
Vinh H. Pham
Committed by
GitHub
Jan 15, 2024
Browse files
Make text-to-image SDXL LoRA Training Script torch.compile compatible (#6556)
make compile compatible
parent
7ce89e97
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
17 deletions
+26
-17
examples/text_to_image/train_text_to_image_lora_sdxl.py
examples/text_to_image/train_text_to_image_lora_sdxl.py
+26
-17
No files found.
examples/text_to_image/train_text_to_image_lora_sdxl.py
View file @
08702fc1
...
@@ -54,6 +54,7 @@ from diffusers.optimization import get_scheduler
...
@@ -54,6 +54,7 @@ from diffusers.optimization import get_scheduler
from
diffusers.training_utils
import
cast_training_params
,
compute_snr
from
diffusers.training_utils
import
cast_training_params
,
compute_snr
from
diffusers.utils
import
check_min_version
,
convert_state_dict_to_diffusers
,
is_wandb_available
from
diffusers.utils
import
check_min_version
,
convert_state_dict_to_diffusers
,
is_wandb_available
from
diffusers.utils.import_utils
import
is_xformers_available
from
diffusers.utils.import_utils
import
is_xformers_available
from
diffusers.utils.torch_utils
import
is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
...
@@ -460,13 +461,12 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
...
@@ -460,13 +461,12 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
text_input_ids
=
text_input_ids_list
[
i
]
text_input_ids
=
text_input_ids_list
[
i
]
prompt_embeds
=
text_encoder
(
prompt_embeds
=
text_encoder
(
text_input_ids
.
to
(
text_encoder
.
device
),
text_input_ids
.
to
(
text_encoder
.
device
),
output_hidden_states
=
True
,
return_dict
=
False
output_hidden_states
=
True
,
)
)
# We are only ALWAYS interested in the pooled output of the final text encoder
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds
=
prompt_embeds
[
0
]
pooled_prompt_embeds
=
prompt_embeds
[
0
]
prompt_embeds
=
prompt_embeds
.
hidden_states
[
-
2
]
prompt_embeds
=
prompt_embeds
[
-
1
]
[
-
2
]
bs_embed
,
seq_len
,
_
=
prompt_embeds
.
shape
bs_embed
,
seq_len
,
_
=
prompt_embeds
.
shape
prompt_embeds
=
prompt_embeds
.
view
(
bs_embed
,
seq_len
,
-
1
)
prompt_embeds
=
prompt_embeds
.
view
(
bs_embed
,
seq_len
,
-
1
)
prompt_embeds_list
.
append
(
prompt_embeds
)
prompt_embeds_list
.
append
(
prompt_embeds
)
...
@@ -637,6 +637,11 @@ def main(args):
...
@@ -637,6 +637,11 @@ def main(args):
# only upcast trainable parameters (LoRA) into fp32
# only upcast trainable parameters (LoRA) into fp32
cast_training_params
(
models
,
dtype
=
torch
.
float32
)
cast_training_params
(
models
,
dtype
=
torch
.
float32
)
def
unwrap_model
(
model
):
model
=
accelerator
.
unwrap_model
(
model
)
model
=
model
.
_orig_mod
if
is_compiled_module
(
model
)
else
model
return
model
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def
save_model_hook
(
models
,
weights
,
output_dir
):
def
save_model_hook
(
models
,
weights
,
output_dir
):
if
accelerator
.
is_main_process
:
if
accelerator
.
is_main_process
:
...
@@ -647,13 +652,13 @@ def main(args):
...
@@ -647,13 +652,13 @@ def main(args):
text_encoder_two_lora_layers_to_save
=
None
text_encoder_two_lora_layers_to_save
=
None
for
model
in
models
:
for
model
in
models
:
if
isinstance
(
model
,
type
(
accelerator
.
unwrap_model
(
unet
))):
if
isinstance
(
model
,
type
(
unwrap_model
(
unet
))):
unet_lora_layers_to_save
=
convert_state_dict_to_diffusers
(
get_peft_model_state_dict
(
model
))
unet_lora_layers_to_save
=
convert_state_dict_to_diffusers
(
get_peft_model_state_dict
(
model
))
elif
isinstance
(
model
,
type
(
accelerator
.
unwrap_model
(
text_encoder_one
))):
elif
isinstance
(
model
,
type
(
unwrap_model
(
text_encoder_one
))):
text_encoder_one_lora_layers_to_save
=
convert_state_dict_to_diffusers
(
text_encoder_one_lora_layers_to_save
=
convert_state_dict_to_diffusers
(
get_peft_model_state_dict
(
model
)
get_peft_model_state_dict
(
model
)
)
)
elif
isinstance
(
model
,
type
(
accelerator
.
unwrap_model
(
text_encoder_two
))):
elif
isinstance
(
model
,
type
(
unwrap_model
(
text_encoder_two
))):
text_encoder_two_lora_layers_to_save
=
convert_state_dict_to_diffusers
(
text_encoder_two_lora_layers_to_save
=
convert_state_dict_to_diffusers
(
get_peft_model_state_dict
(
model
)
get_peft_model_state_dict
(
model
)
)
)
...
@@ -678,11 +683,11 @@ def main(args):
...
@@ -678,11 +683,11 @@ def main(args):
while
len
(
models
)
>
0
:
while
len
(
models
)
>
0
:
model
=
models
.
pop
()
model
=
models
.
pop
()
if
isinstance
(
model
,
type
(
accelerator
.
unwrap_model
(
unet
))):
if
isinstance
(
model
,
type
(
unwrap_model
(
unet
))):
unet_
=
model
unet_
=
model
elif
isinstance
(
model
,
type
(
accelerator
.
unwrap_model
(
text_encoder_one
))):
elif
isinstance
(
model
,
type
(
unwrap_model
(
text_encoder_one
))):
text_encoder_one_
=
model
text_encoder_one_
=
model
elif
isinstance
(
model
,
type
(
accelerator
.
unwrap_model
(
text_encoder_two
))):
elif
isinstance
(
model
,
type
(
unwrap_model
(
text_encoder_two
))):
text_encoder_two_
=
model
text_encoder_two_
=
model
else
:
else
:
raise
ValueError
(
f
"unexpected save model:
{
model
.
__class__
}
"
)
raise
ValueError
(
f
"unexpected save model:
{
model
.
__class__
}
"
)
...
@@ -1031,8 +1036,12 @@ def main(args):
...
@@ -1031,8 +1036,12 @@ def main(args):
)
)
unet_added_conditions
.
update
({
"text_embeds"
:
pooled_prompt_embeds
})
unet_added_conditions
.
update
({
"text_embeds"
:
pooled_prompt_embeds
})
model_pred
=
unet
(
model_pred
=
unet
(
noisy_model_input
,
timesteps
,
prompt_embeds
,
added_cond_kwargs
=
unet_added_conditions
noisy_model_input
,
).
sample
timesteps
,
prompt_embeds
,
added_cond_kwargs
=
unet_added_conditions
,
return_dict
=
False
,
)[
0
]
# Get the target for loss depending on the prediction type
# Get the target for loss depending on the prediction type
if
args
.
prediction_type
is
not
None
:
if
args
.
prediction_type
is
not
None
:
...
@@ -1125,9 +1134,9 @@ def main(args):
...
@@ -1125,9 +1134,9 @@ def main(args):
pipeline
=
StableDiffusionXLPipeline
.
from_pretrained
(
pipeline
=
StableDiffusionXLPipeline
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
args
.
pretrained_model_name_or_path
,
vae
=
vae
,
vae
=
vae
,
text_encoder
=
accelerator
.
unwrap_model
(
text_encoder_one
),
text_encoder
=
unwrap_model
(
text_encoder_one
),
text_encoder_2
=
accelerator
.
unwrap_model
(
text_encoder_two
),
text_encoder_2
=
unwrap_model
(
text_encoder_two
),
unet
=
accelerator
.
unwrap_model
(
unet
),
unet
=
unwrap_model
(
unet
),
revision
=
args
.
revision
,
revision
=
args
.
revision
,
variant
=
args
.
variant
,
variant
=
args
.
variant
,
torch_dtype
=
weight_dtype
,
torch_dtype
=
weight_dtype
,
...
@@ -1166,12 +1175,12 @@ def main(args):
...
@@ -1166,12 +1175,12 @@ def main(args):
# Save the lora layers
# Save the lora layers
accelerator
.
wait_for_everyone
()
accelerator
.
wait_for_everyone
()
if
accelerator
.
is_main_process
:
if
accelerator
.
is_main_process
:
unet
=
accelerator
.
unwrap_model
(
unet
)
unet
=
unwrap_model
(
unet
)
unet_lora_state_dict
=
convert_state_dict_to_diffusers
(
get_peft_model_state_dict
(
unet
))
unet_lora_state_dict
=
convert_state_dict_to_diffusers
(
get_peft_model_state_dict
(
unet
))
if
args
.
train_text_encoder
:
if
args
.
train_text_encoder
:
text_encoder_one
=
accelerator
.
unwrap_model
(
text_encoder_one
)
text_encoder_one
=
unwrap_model
(
text_encoder_one
)
text_encoder_two
=
accelerator
.
unwrap_model
(
text_encoder_two
)
text_encoder_two
=
unwrap_model
(
text_encoder_two
)
text_encoder_lora_layers
=
convert_state_dict_to_diffusers
(
get_peft_model_state_dict
(
text_encoder_one
))
text_encoder_lora_layers
=
convert_state_dict_to_diffusers
(
get_peft_model_state_dict
(
text_encoder_one
))
text_encoder_2_lora_layers
=
convert_state_dict_to_diffusers
(
get_peft_model_state_dict
(
text_encoder_two
))
text_encoder_2_lora_layers
=
convert_state_dict_to_diffusers
(
get_peft_model_state_dict
(
text_encoder_two
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment