Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
dce06680
Unverified
Commit
dce06680
authored
Jan 16, 2024
by
Steve Rhoades
Committed by
GitHub
Jan 17, 2024
Browse files
Fixes torch.compile() compatible training (#6589)
resolve conflicts
parent
dd631683
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
6 deletions
+12
-6
examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
...diffusion_training/train_dreambooth_lora_sdxl_advanced.py
+12
-6
No files found.
examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
View file @
dce06680
...
...
@@ -68,6 +68,7 @@ from diffusers.utils import (
is_wandb_available
,
)
from
diffusers.utils.import_utils
import
is_xformers_available
from
diffusers.utils.torch_utils
import
is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
...
...
@@ -1293,6 +1294,11 @@ def main(args):
else
:
param
.
requires_grad
=
False
def
unwrap_model
(
model
):
model
=
accelerator
.
unwrap_model
(
model
)
model
=
model
.
_orig_mod
if
is_compiled_module
(
model
)
else
model
return
model
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def
save_model_hook
(
models
,
weights
,
output_dir
):
if
accelerator
.
is_main_process
:
...
...
@@ -1303,14 +1309,14 @@ def main(args):
text_encoder_two_lora_layers_to_save
=
None
for
model
in
models
:
if
isinstance
(
model
,
type
(
accelerator
.
unwrap_model
(
unet
))):
if
isinstance
(
model
,
type
(
unwrap_model
(
unet
))):
unet_lora_layers_to_save
=
convert_state_dict_to_diffusers
(
get_peft_model_state_dict
(
model
))
elif
isinstance
(
model
,
type
(
accelerator
.
unwrap_model
(
text_encoder_one
))):
elif
isinstance
(
model
,
type
(
unwrap_model
(
text_encoder_one
))):
if
args
.
train_text_encoder
:
text_encoder_one_lora_layers_to_save
=
convert_state_dict_to_diffusers
(
get_peft_model_state_dict
(
model
)
)
elif
isinstance
(
model
,
type
(
accelerator
.
unwrap_model
(
text_encoder_two
))):
elif
isinstance
(
model
,
type
(
unwrap_model
(
text_encoder_two
))):
if
args
.
train_text_encoder
:
text_encoder_two_lora_layers_to_save
=
convert_state_dict_to_diffusers
(
get_peft_model_state_dict
(
model
)
...
...
@@ -1338,11 +1344,11 @@ def main(args):
while
len
(
models
)
>
0
:
model
=
models
.
pop
()
if
isinstance
(
model
,
type
(
accelerator
.
unwrap_model
(
unet
))):
if
isinstance
(
model
,
type
(
unwrap_model
(
unet
))):
unet_
=
model
elif
isinstance
(
model
,
type
(
accelerator
.
unwrap_model
(
text_encoder_one
))):
elif
isinstance
(
model
,
type
(
unwrap_model
(
text_encoder_one
))):
text_encoder_one_
=
model
elif
isinstance
(
model
,
type
(
accelerator
.
unwrap_model
(
text_encoder_two
))):
elif
isinstance
(
model
,
type
(
unwrap_model
(
text_encoder_two
))):
text_encoder_two_
=
model
else
:
raise
ValueError
(
f
"unexpected save model:
{
model
.
__class__
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment