Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
dce06680
Unverified
Commit
dce06680
authored
Jan 16, 2024
by
Steve Rhoades
Committed by
GitHub
Jan 17, 2024
Browse files
Fixes torch.compile() compatible training (#6589)
resolve conflicts
parent
dd631683
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
6 deletions
+12
-6
examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
...diffusion_training/train_dreambooth_lora_sdxl_advanced.py
+12
-6
No files found.
examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
View file @
dce06680
...
@@ -68,6 +68,7 @@ from diffusers.utils import (
...
@@ -68,6 +68,7 @@ from diffusers.utils import (
is_wandb_available
,
is_wandb_available
,
)
)
from
diffusers.utils.import_utils
import
is_xformers_available
from
diffusers.utils.import_utils
import
is_xformers_available
from
diffusers.utils.torch_utils
import
is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
...
@@ -1293,6 +1294,11 @@ def main(args):
...
@@ -1293,6 +1294,11 @@ def main(args):
else
:
else
:
param
.
requires_grad
=
False
param
.
requires_grad
=
False
def
unwrap_model
(
model
):
model
=
accelerator
.
unwrap_model
(
model
)
model
=
model
.
_orig_mod
if
is_compiled_module
(
model
)
else
model
return
model
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def
save_model_hook
(
models
,
weights
,
output_dir
):
def
save_model_hook
(
models
,
weights
,
output_dir
):
if
accelerator
.
is_main_process
:
if
accelerator
.
is_main_process
:
...
@@ -1303,14 +1309,14 @@ def main(args):
...
@@ -1303,14 +1309,14 @@ def main(args):
text_encoder_two_lora_layers_to_save
=
None
text_encoder_two_lora_layers_to_save
=
None
for
model
in
models
:
for
model
in
models
:
if
isinstance
(
model
,
type
(
accelerator
.
unwrap_model
(
unet
))):
if
isinstance
(
model
,
type
(
unwrap_model
(
unet
))):
unet_lora_layers_to_save
=
convert_state_dict_to_diffusers
(
get_peft_model_state_dict
(
model
))
unet_lora_layers_to_save
=
convert_state_dict_to_diffusers
(
get_peft_model_state_dict
(
model
))
elif
isinstance
(
model
,
type
(
accelerator
.
unwrap_model
(
text_encoder_one
))):
elif
isinstance
(
model
,
type
(
unwrap_model
(
text_encoder_one
))):
if
args
.
train_text_encoder
:
if
args
.
train_text_encoder
:
text_encoder_one_lora_layers_to_save
=
convert_state_dict_to_diffusers
(
text_encoder_one_lora_layers_to_save
=
convert_state_dict_to_diffusers
(
get_peft_model_state_dict
(
model
)
get_peft_model_state_dict
(
model
)
)
)
elif
isinstance
(
model
,
type
(
accelerator
.
unwrap_model
(
text_encoder_two
))):
elif
isinstance
(
model
,
type
(
unwrap_model
(
text_encoder_two
))):
if
args
.
train_text_encoder
:
if
args
.
train_text_encoder
:
text_encoder_two_lora_layers_to_save
=
convert_state_dict_to_diffusers
(
text_encoder_two_lora_layers_to_save
=
convert_state_dict_to_diffusers
(
get_peft_model_state_dict
(
model
)
get_peft_model_state_dict
(
model
)
...
@@ -1338,11 +1344,11 @@ def main(args):
...
@@ -1338,11 +1344,11 @@ def main(args):
while
len
(
models
)
>
0
:
while
len
(
models
)
>
0
:
model
=
models
.
pop
()
model
=
models
.
pop
()
if
isinstance
(
model
,
type
(
accelerator
.
unwrap_model
(
unet
))):
if
isinstance
(
model
,
type
(
unwrap_model
(
unet
))):
unet_
=
model
unet_
=
model
elif
isinstance
(
model
,
type
(
accelerator
.
unwrap_model
(
text_encoder_one
))):
elif
isinstance
(
model
,
type
(
unwrap_model
(
text_encoder_one
))):
text_encoder_one_
=
model
text_encoder_one_
=
model
elif
isinstance
(
model
,
type
(
accelerator
.
unwrap_model
(
text_encoder_two
))):
elif
isinstance
(
model
,
type
(
unwrap_model
(
text_encoder_two
))):
text_encoder_two_
=
model
text_encoder_two_
=
model
else
:
else
:
raise
ValueError
(
f
"unexpected save model:
{
model
.
__class__
}
"
)
raise
ValueError
(
f
"unexpected save model:
{
model
.
__class__
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment