Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
09903774
Unverified
Commit
09903774
authored
Jan 15, 2024
by
Charchit Sharma
Committed by
GitHub
Jan 15, 2024
Browse files
Make T2I Adapter SDXL Training Script torch.compile compatible (#6577)
update for t2i_adapter
parent
d6a70d8b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
4 deletions
+11
-4
examples/t2i_adapter/train_t2i_adapter_sdxl.py
examples/t2i_adapter/train_t2i_adapter_sdxl.py
+11
-4
No files found.
examples/t2i_adapter/train_t2i_adapter_sdxl.py
View file @
09903774
...
@@ -50,6 +50,7 @@ from diffusers import (
...
@@ -50,6 +50,7 @@ from diffusers import (
from
diffusers.optimization
import
get_scheduler
from
diffusers.optimization
import
get_scheduler
from
diffusers.utils
import
check_min_version
,
is_wandb_available
from
diffusers.utils
import
check_min_version
,
is_wandb_available
from
diffusers.utils.import_utils
import
is_xformers_available
from
diffusers.utils.import_utils
import
is_xformers_available
from
diffusers.utils.torch_utils
import
is_compiled_module
MAX_SEQ_LENGTH
=
77
MAX_SEQ_LENGTH
=
77
...
@@ -926,6 +927,11 @@ def main(args):
...
@@ -926,6 +927,11 @@ def main(args):
else
:
else
:
raise
ValueError
(
"xformers is not available. Make sure it is installed correctly"
)
raise
ValueError
(
"xformers is not available. Make sure it is installed correctly"
)
def
unwrap_model
(
model
):
model
=
accelerator
.
unwrap_model
(
model
)
model
=
model
.
_orig_mod
if
is_compiled_module
(
model
)
else
model
return
model
if
args
.
gradient_checkpointing
:
if
args
.
gradient_checkpointing
:
unet
.
enable_gradient_checkpointing
()
unet
.
enable_gradient_checkpointing
()
...
@@ -935,9 +941,9 @@ def main(args):
...
@@ -935,9 +941,9 @@ def main(args):
" doing mixed precision training, copy of the weights should still be float32."
" doing mixed precision training, copy of the weights should still be float32."
)
)
if
accelerator
.
unwrap_model
(
t2iadapter
).
dtype
!=
torch
.
float32
:
if
unwrap_model
(
t2iadapter
).
dtype
!=
torch
.
float32
:
raise
ValueError
(
raise
ValueError
(
f
"Controlnet loaded as datatype
{
accelerator
.
unwrap_model
(
t2iadapter
).
dtype
}
.
{
low_precision_error_string
}
"
f
"Controlnet loaded as datatype
{
unwrap_model
(
t2iadapter
).
dtype
}
.
{
low_precision_error_string
}
"
)
)
# Enable TF32 for faster training on Ampere GPUs,
# Enable TF32 for faster training on Ampere GPUs,
...
@@ -1198,7 +1204,8 @@ def main(args):
...
@@ -1198,7 +1204,8 @@ def main(args):
encoder_hidden_states
=
batch
[
"prompt_ids"
],
encoder_hidden_states
=
batch
[
"prompt_ids"
],
added_cond_kwargs
=
batch
[
"unet_added_conditions"
],
added_cond_kwargs
=
batch
[
"unet_added_conditions"
],
down_block_additional_residuals
=
down_block_additional_residuals
,
down_block_additional_residuals
=
down_block_additional_residuals
,
).
sample
return_dict
=
False
,
)[
0
]
# Denoise the latents
# Denoise the latents
denoised_latents
=
model_pred
*
(
-
sigmas
)
+
noisy_latents
denoised_latents
=
model_pred
*
(
-
sigmas
)
+
noisy_latents
...
@@ -1279,7 +1286,7 @@ def main(args):
...
@@ -1279,7 +1286,7 @@ def main(args):
# Create the pipeline using using the trained modules and save it.
# Create the pipeline using using the trained modules and save it.
accelerator
.
wait_for_everyone
()
accelerator
.
wait_for_everyone
()
if
accelerator
.
is_main_process
:
if
accelerator
.
is_main_process
:
t2iadapter
=
accelerator
.
unwrap_model
(
t2iadapter
)
t2iadapter
=
unwrap_model
(
t2iadapter
)
t2iadapter
.
save_pretrained
(
args
.
output_dir
)
t2iadapter
.
save_pretrained
(
args
.
output_dir
)
if
args
.
push_to_hub
:
if
args
.
push_to_hub
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment