Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
7ce89e97
Unverified
Commit
7ce89e97
authored
Jan 15, 2024
by
Vinh H. Pham
Committed by
GitHub
Jan 15, 2024
Browse files
Make text-to-image SD LoRA Training Script torch.compile compatible (#6555)
make compile compatible
parent
05faf326
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
5 deletions
+11
-5
examples/text_to_image/train_text_to_image_lora.py
examples/text_to_image/train_text_to_image_lora.py
+11
-5
No files found.
examples/text_to_image/train_text_to_image_lora.py
View file @
7ce89e97
...
@@ -46,6 +46,7 @@ from diffusers.optimization import get_scheduler
...
@@ -46,6 +46,7 @@ from diffusers.optimization import get_scheduler
from
diffusers.training_utils
import
cast_training_params
,
compute_snr
from
diffusers.training_utils
import
cast_training_params
,
compute_snr
from
diffusers.utils
import
check_min_version
,
convert_state_dict_to_diffusers
,
is_wandb_available
from
diffusers.utils
import
check_min_version
,
convert_state_dict_to_diffusers
,
is_wandb_available
from
diffusers.utils.import_utils
import
is_xformers_available
from
diffusers.utils.import_utils
import
is_xformers_available
from
diffusers.utils.torch_utils
import
is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
...
@@ -596,6 +597,11 @@ def main():
...
@@ -596,6 +597,11 @@ def main():
]
]
)
)
def
unwrap_model
(
model
):
model
=
accelerator
.
unwrap_model
(
model
)
model
=
model
.
_orig_mod
if
is_compiled_module
(
model
)
else
model
return
model
def
preprocess_train
(
examples
):
def
preprocess_train
(
examples
):
images
=
[
image
.
convert
(
"RGB"
)
for
image
in
examples
[
image_column
]]
images
=
[
image
.
convert
(
"RGB"
)
for
image
in
examples
[
image_column
]]
examples
[
"pixel_values"
]
=
[
train_transforms
(
image
)
for
image
in
images
]
examples
[
"pixel_values"
]
=
[
train_transforms
(
image
)
for
image
in
images
]
...
@@ -729,7 +735,7 @@ def main():
...
@@ -729,7 +735,7 @@ def main():
noisy_latents
=
noise_scheduler
.
add_noise
(
latents
,
noise
,
timesteps
)
noisy_latents
=
noise_scheduler
.
add_noise
(
latents
,
noise
,
timesteps
)
# Get the text embedding for conditioning
# Get the text embedding for conditioning
encoder_hidden_states
=
text_encoder
(
batch
[
"input_ids"
])[
0
]
encoder_hidden_states
=
text_encoder
(
batch
[
"input_ids"
]
,
return_dict
=
False
)[
0
]
# Get the target for loss depending on the prediction type
# Get the target for loss depending on the prediction type
if
args
.
prediction_type
is
not
None
:
if
args
.
prediction_type
is
not
None
:
...
@@ -744,7 +750,7 @@ def main():
...
@@ -744,7 +750,7 @@ def main():
raise
ValueError
(
f
"Unknown prediction type
{
noise_scheduler
.
config
.
prediction_type
}
"
)
raise
ValueError
(
f
"Unknown prediction type
{
noise_scheduler
.
config
.
prediction_type
}
"
)
# Predict the noise residual and compute loss
# Predict the noise residual and compute loss
model_pred
=
unet
(
noisy_latents
,
timesteps
,
encoder_hidden_states
).
sample
model_pred
=
unet
(
noisy_latents
,
timesteps
,
encoder_hidden_states
,
return_dict
=
False
)[
0
]
if
args
.
snr_gamma
is
None
:
if
args
.
snr_gamma
is
None
:
loss
=
F
.
mse_loss
(
model_pred
.
float
(),
target
.
float
(),
reduction
=
"mean"
)
loss
=
F
.
mse_loss
(
model_pred
.
float
(),
target
.
float
(),
reduction
=
"mean"
)
...
@@ -809,7 +815,7 @@ def main():
...
@@ -809,7 +815,7 @@ def main():
save_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
"checkpoint-
{
global_step
}
"
)
save_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
"checkpoint-
{
global_step
}
"
)
accelerator
.
save_state
(
save_path
)
accelerator
.
save_state
(
save_path
)
unwrapped_unet
=
accelerator
.
unwrap_model
(
unet
)
unwrapped_unet
=
unwrap_model
(
unet
)
unet_lora_state_dict
=
convert_state_dict_to_diffusers
(
unet_lora_state_dict
=
convert_state_dict_to_diffusers
(
get_peft_model_state_dict
(
unwrapped_unet
)
get_peft_model_state_dict
(
unwrapped_unet
)
)
)
...
@@ -837,7 +843,7 @@ def main():
...
@@ -837,7 +843,7 @@ def main():
# create pipeline
# create pipeline
pipeline
=
DiffusionPipeline
.
from_pretrained
(
pipeline
=
DiffusionPipeline
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
args
.
pretrained_model_name_or_path
,
unet
=
accelerator
.
unwrap_model
(
unet
),
unet
=
unwrap_model
(
unet
),
revision
=
args
.
revision
,
revision
=
args
.
revision
,
variant
=
args
.
variant
,
variant
=
args
.
variant
,
torch_dtype
=
weight_dtype
,
torch_dtype
=
weight_dtype
,
...
@@ -878,7 +884,7 @@ def main():
...
@@ -878,7 +884,7 @@ def main():
if
accelerator
.
is_main_process
:
if
accelerator
.
is_main_process
:
unet
=
unet
.
to
(
torch
.
float32
)
unet
=
unet
.
to
(
torch
.
float32
)
unwrapped_unet
=
accelerator
.
unwrap_model
(
unet
)
unwrapped_unet
=
unwrap_model
(
unet
)
unet_lora_state_dict
=
convert_state_dict_to_diffusers
(
get_peft_model_state_dict
(
unwrapped_unet
))
unet_lora_state_dict
=
convert_state_dict_to_diffusers
(
get_peft_model_state_dict
(
unwrapped_unet
))
StableDiffusionPipeline
.
save_lora_weights
(
StableDiffusionPipeline
.
save_lora_weights
(
save_directory
=
args
.
output_dir
,
save_directory
=
args
.
output_dir
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment