Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
e44b205e
Unverified
Commit
e44b205e
authored
Jan 12, 2024
by
Charchit Sharma
Committed by
GitHub
Jan 12, 2024
Browse files
Make ControlNet SDXL Training Script torch.compile compatible (#6526)
* make torch.compile compatible * fix quality
parent
60cb4432
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
4 deletions
+11
-4
examples/controlnet/train_controlnet_sdxl.py
examples/controlnet/train_controlnet_sdxl.py
+11
-4
No files found.
examples/controlnet/train_controlnet_sdxl.py
View file @
e44b205e
...
...
@@ -52,6 +52,7 @@ from diffusers import (
from
diffusers.optimization
import
get_scheduler
from
diffusers.utils
import
check_min_version
,
is_wandb_available
,
make_image_grid
from
diffusers.utils.import_utils
import
is_xformers_available
from
diffusers.utils.torch_utils
import
is_compiled_module
if
is_wandb_available
():
...
...
@@ -847,6 +848,11 @@ def main(args):
logger
.
info
(
"Initializing controlnet weights from unet"
)
controlnet
=
ControlNetModel
.
from_unet
(
unet
)
def
unwrap_model
(
model
):
model
=
accelerator
.
unwrap_model
(
model
)
model
=
model
.
_orig_mod
if
is_compiled_module
(
model
)
else
model
return
model
# `accelerate` 0.16.0 will have better support for customized saving
if
version
.
parse
(
accelerate
.
__version__
)
>=
version
.
parse
(
"0.16.0"
):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
...
...
@@ -908,9 +914,9 @@ def main(args):
" doing mixed precision training, copy of the weights should still be float32."
)
if
accelerator
.
unwrap_model
(
controlnet
).
dtype
!=
torch
.
float32
:
if
unwrap_model
(
controlnet
).
dtype
!=
torch
.
float32
:
raise
ValueError
(
f
"Controlnet loaded as datatype
{
accelerator
.
unwrap_model
(
controlnet
).
dtype
}
.
{
low_precision_error_string
}
"
f
"Controlnet loaded as datatype
{
unwrap_model
(
controlnet
).
dtype
}
.
{
low_precision_error_string
}
"
)
# Enable TF32 for faster training on Ampere GPUs,
...
...
@@ -1158,7 +1164,8 @@ def main(args):
sample
.
to
(
dtype
=
weight_dtype
)
for
sample
in
down_block_res_samples
],
mid_block_additional_residual
=
mid_block_res_sample
.
to
(
dtype
=
weight_dtype
),
).
sample
return_dict
=
False
,
)[
0
]
# Get the target for loss depending on the prediction type
if
noise_scheduler
.
config
.
prediction_type
==
"epsilon"
:
...
...
@@ -1223,7 +1230,7 @@ def main(args):
# Create the pipeline using using the trained modules and save it.
accelerator
.
wait_for_everyone
()
if
accelerator
.
is_main_process
:
controlnet
=
accelerator
.
unwrap_model
(
controlnet
)
controlnet
=
unwrap_model
(
controlnet
)
controlnet
.
save_pretrained
(
args
.
output_dir
)
if
args
.
push_to_hub
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment