Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
60cb4432
Unverified
Commit
60cb4432
authored
Jan 12, 2024
by
Vinh H. Pham
Committed by
GitHub
Jan 12, 2024
Browse files
Make Dreambooth SD LoRA Training Script torch.compile compatible (#6534)
support compile
parent
1dd0ac94
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
11 deletions
+22
-11
examples/dreambooth/train_dreambooth_lora.py
examples/dreambooth/train_dreambooth_lora.py
+22
-11
No files found.
examples/dreambooth/train_dreambooth_lora.py
View file @
60cb4432
...
@@ -56,6 +56,7 @@ from diffusers.loaders import LoraLoaderMixin
...
@@ -56,6 +56,7 @@ from diffusers.loaders import LoraLoaderMixin
from
diffusers.optimization
import
get_scheduler
from
diffusers.optimization
import
get_scheduler
from
diffusers.utils
import
check_min_version
,
convert_state_dict_to_diffusers
,
is_wandb_available
from
diffusers.utils
import
check_min_version
,
convert_state_dict_to_diffusers
,
is_wandb_available
from
diffusers.utils.import_utils
import
is_xformers_available
from
diffusers.utils.import_utils
import
is_xformers_available
from
diffusers.utils.torch_utils
import
is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
...
@@ -647,6 +648,7 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte
...
@@ -647,6 +648,7 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte
prompt_embeds
=
text_encoder
(
prompt_embeds
=
text_encoder
(
text_input_ids
,
text_input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
return_dict
=
False
,
)
)
prompt_embeds
=
prompt_embeds
[
0
]
prompt_embeds
=
prompt_embeds
[
0
]
...
@@ -843,6 +845,11 @@ def main(args):
...
@@ -843,6 +845,11 @@ def main(args):
)
)
text_encoder
.
add_adapter
(
text_lora_config
)
text_encoder
.
add_adapter
(
text_lora_config
)
def
unwrap_model
(
model
):
model
=
accelerator
.
unwrap_model
(
model
)
model
=
model
.
_orig_mod
if
is_compiled_module
(
model
)
else
model
return
model
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def
save_model_hook
(
models
,
weights
,
output_dir
):
def
save_model_hook
(
models
,
weights
,
output_dir
):
if
accelerator
.
is_main_process
:
if
accelerator
.
is_main_process
:
...
@@ -852,9 +859,9 @@ def main(args):
...
@@ -852,9 +859,9 @@ def main(args):
text_encoder_lora_layers_to_save
=
None
text_encoder_lora_layers_to_save
=
None
for
model
in
models
:
for
model
in
models
:
if
isinstance
(
model
,
type
(
accelerator
.
unwrap_model
(
unet
))):
if
isinstance
(
model
,
type
(
unwrap_model
(
unet
))):
unet_lora_layers_to_save
=
convert_state_dict_to_diffusers
(
get_peft_model_state_dict
(
model
))
unet_lora_layers_to_save
=
convert_state_dict_to_diffusers
(
get_peft_model_state_dict
(
model
))
elif
isinstance
(
model
,
type
(
accelerator
.
unwrap_model
(
text_encoder
))):
elif
isinstance
(
model
,
type
(
unwrap_model
(
text_encoder
))):
text_encoder_lora_layers_to_save
=
convert_state_dict_to_diffusers
(
text_encoder_lora_layers_to_save
=
convert_state_dict_to_diffusers
(
get_peft_model_state_dict
(
model
)
get_peft_model_state_dict
(
model
)
)
)
...
@@ -877,9 +884,9 @@ def main(args):
...
@@ -877,9 +884,9 @@ def main(args):
while
len
(
models
)
>
0
:
while
len
(
models
)
>
0
:
model
=
models
.
pop
()
model
=
models
.
pop
()
if
isinstance
(
model
,
type
(
accelerator
.
unwrap_model
(
unet
))):
if
isinstance
(
model
,
type
(
unwrap_model
(
unet
))):
unet_
=
model
unet_
=
model
elif
isinstance
(
model
,
type
(
accelerator
.
unwrap_model
(
text_encoder
))):
elif
isinstance
(
model
,
type
(
unwrap_model
(
text_encoder
))):
text_encoder_
=
model
text_encoder_
=
model
else
:
else
:
raise
ValueError
(
f
"unexpected save model:
{
model
.
__class__
}
"
)
raise
ValueError
(
f
"unexpected save model:
{
model
.
__class__
}
"
)
...
@@ -1118,7 +1125,7 @@ def main(args):
...
@@ -1118,7 +1125,7 @@ def main(args):
text_encoder_use_attention_mask
=
args
.
text_encoder_use_attention_mask
,
text_encoder_use_attention_mask
=
args
.
text_encoder_use_attention_mask
,
)
)
if
accelerator
.
unwrap_model
(
unet
).
config
.
in_channels
==
channels
*
2
:
if
unwrap_model
(
unet
).
config
.
in_channels
==
channels
*
2
:
noisy_model_input
=
torch
.
cat
([
noisy_model_input
,
noisy_model_input
],
dim
=
1
)
noisy_model_input
=
torch
.
cat
([
noisy_model_input
,
noisy_model_input
],
dim
=
1
)
if
args
.
class_labels_conditioning
==
"timesteps"
:
if
args
.
class_labels_conditioning
==
"timesteps"
:
...
@@ -1128,8 +1135,12 @@ def main(args):
...
@@ -1128,8 +1135,12 @@ def main(args):
# Predict the noise residual
# Predict the noise residual
model_pred
=
unet
(
model_pred
=
unet
(
noisy_model_input
,
timesteps
,
encoder_hidden_states
,
class_labels
=
class_labels
noisy_model_input
,
).
sample
timesteps
,
encoder_hidden_states
,
class_labels
=
class_labels
,
return_dict
=
False
,
)[
0
]
# if model predicts variance, throw away the prediction. we will only train on the
# if model predicts variance, throw away the prediction. we will only train on the
# simplified training objective. This means that all schedulers using the fine tuned
# simplified training objective. This means that all schedulers using the fine tuned
...
@@ -1215,8 +1226,8 @@ def main(args):
...
@@ -1215,8 +1226,8 @@ def main(args):
# create pipeline
# create pipeline
pipeline
=
DiffusionPipeline
.
from_pretrained
(
pipeline
=
DiffusionPipeline
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
args
.
pretrained_model_name_or_path
,
unet
=
accelerator
.
unwrap_model
(
unet
),
unet
=
unwrap_model
(
unet
),
text_encoder
=
None
if
args
.
pre_compute_text_embeddings
else
accelerator
.
unwrap_model
(
text_encoder
),
text_encoder
=
None
if
args
.
pre_compute_text_embeddings
else
unwrap_model
(
text_encoder
),
revision
=
args
.
revision
,
revision
=
args
.
revision
,
variant
=
args
.
variant
,
variant
=
args
.
variant
,
torch_dtype
=
weight_dtype
,
torch_dtype
=
weight_dtype
,
...
@@ -1284,13 +1295,13 @@ def main(args):
...
@@ -1284,13 +1295,13 @@ def main(args):
# Save the lora layers
# Save the lora layers
accelerator
.
wait_for_everyone
()
accelerator
.
wait_for_everyone
()
if
accelerator
.
is_main_process
:
if
accelerator
.
is_main_process
:
unet
=
accelerator
.
unwrap_model
(
unet
)
unet
=
unwrap_model
(
unet
)
unet
=
unet
.
to
(
torch
.
float32
)
unet
=
unet
.
to
(
torch
.
float32
)
unet_lora_state_dict
=
convert_state_dict_to_diffusers
(
get_peft_model_state_dict
(
unet
))
unet_lora_state_dict
=
convert_state_dict_to_diffusers
(
get_peft_model_state_dict
(
unet
))
if
args
.
train_text_encoder
:
if
args
.
train_text_encoder
:
text_encoder
=
accelerator
.
unwrap_model
(
text_encoder
)
text_encoder
=
unwrap_model
(
text_encoder
)
text_encoder_state_dict
=
convert_state_dict_to_diffusers
(
get_peft_model_state_dict
(
text_encoder
))
text_encoder_state_dict
=
convert_state_dict_to_diffusers
(
get_peft_model_state_dict
(
text_encoder
))
else
:
else
:
text_encoder_state_dict
=
None
text_encoder_state_dict
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment