Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
09903774
Unverified
Commit
09903774
authored
Jan 15, 2024
by
Charchit Sharma
Committed by
GitHub
Jan 15, 2024
Browse files
Make T2I Adapter SDXL Training Script torch.compile compatible (#6577)
update for t2i_adapter
parent
d6a70d8b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
4 deletions
+11
-4
examples/t2i_adapter/train_t2i_adapter_sdxl.py
examples/t2i_adapter/train_t2i_adapter_sdxl.py
+11
-4
No files found.
examples/t2i_adapter/train_t2i_adapter_sdxl.py
View file @
09903774
...
@@ -50,6 +50,7 @@ from diffusers import (
...
@@ -50,6 +50,7 @@ from diffusers import (
from
diffusers.optimization
import
get_scheduler
from
diffusers.optimization
import
get_scheduler
from
diffusers.utils
import
check_min_version
,
is_wandb_available
from
diffusers.utils
import
check_min_version
,
is_wandb_available
from
diffusers.utils.import_utils
import
is_xformers_available
from
diffusers.utils.import_utils
import
is_xformers_available
from
diffusers.utils.torch_utils
import
is_compiled_module
MAX_SEQ_LENGTH
=
77
MAX_SEQ_LENGTH
=
77
...
@@ -926,6 +927,11 @@ def main(args):
...
@@ -926,6 +927,11 @@ def main(args):
else
:
else
:
raise
ValueError
(
"xformers is not available. Make sure it is installed correctly"
)
raise
ValueError
(
"xformers is not available. Make sure it is installed correctly"
)
def
unwrap_model
(
model
):
model
=
accelerator
.
unwrap_model
(
model
)
model
=
model
.
_orig_mod
if
is_compiled_module
(
model
)
else
model
return
model
if
args
.
gradient_checkpointing
:
if
args
.
gradient_checkpointing
:
unet
.
enable_gradient_checkpointing
()
unet
.
enable_gradient_checkpointing
()
...
@@ -935,9 +941,9 @@ def main(args):
...
@@ -935,9 +941,9 @@ def main(args):
" doing mixed precision training, copy of the weights should still be float32."
" doing mixed precision training, copy of the weights should still be float32."
)
)
if
accelerator
.
unwrap_model
(
t2iadapter
).
dtype
!=
torch
.
float32
:
if
unwrap_model
(
t2iadapter
).
dtype
!=
torch
.
float32
:
raise
ValueError
(
raise
ValueError
(
f
"Controlnet loaded as datatype
{
accelerator
.
unwrap_model
(
t2iadapter
).
dtype
}
.
{
low_precision_error_string
}
"
f
"Controlnet loaded as datatype
{
unwrap_model
(
t2iadapter
).
dtype
}
.
{
low_precision_error_string
}
"
)
)
# Enable TF32 for faster training on Ampere GPUs,
# Enable TF32 for faster training on Ampere GPUs,
...
@@ -1198,7 +1204,8 @@ def main(args):
...
@@ -1198,7 +1204,8 @@ def main(args):
encoder_hidden_states
=
batch
[
"prompt_ids"
],
encoder_hidden_states
=
batch
[
"prompt_ids"
],
added_cond_kwargs
=
batch
[
"unet_added_conditions"
],
added_cond_kwargs
=
batch
[
"unet_added_conditions"
],
down_block_additional_residuals
=
down_block_additional_residuals
,
down_block_additional_residuals
=
down_block_additional_residuals
,
).
sample
return_dict
=
False
,
)[
0
]
# Denoise the latents
# Denoise the latents
denoised_latents
=
model_pred
*
(
-
sigmas
)
+
noisy_latents
denoised_latents
=
model_pred
*
(
-
sigmas
)
+
noisy_latents
...
@@ -1279,7 +1286,7 @@ def main(args):
...
@@ -1279,7 +1286,7 @@ def main(args):
# Create the pipeline using using the trained modules and save it.
# Create the pipeline using using the trained modules and save it.
accelerator
.
wait_for_everyone
()
accelerator
.
wait_for_everyone
()
if
accelerator
.
is_main_process
:
if
accelerator
.
is_main_process
:
t2iadapter
=
accelerator
.
unwrap_model
(
t2iadapter
)
t2iadapter
=
unwrap_model
(
t2iadapter
)
t2iadapter
.
save_pretrained
(
args
.
output_dir
)
t2iadapter
.
save_pretrained
(
args
.
output_dir
)
if
args
.
push_to_hub
:
if
args
.
push_to_hub
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment