Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
e3103e17
Unverified
Commit
e3103e17
authored
Jan 15, 2024
by
Charchit Sharma
Committed by
GitHub
Jan 15, 2024
Browse files
Make InstructPix2Pix SDXL Training Script torch.compile compatible (#6576)
* changes for pix2pix_sdxl * style fix
parent
b053053a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
4 deletions
+14
-4
examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
+14
-4
No files found.
examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
View file @
e3103e17
...
@@ -52,6 +52,7 @@ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_instru
...
@@ -52,6 +52,7 @@ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_instru
from
diffusers.training_utils
import
EMAModel
from
diffusers.training_utils
import
EMAModel
from
diffusers.utils
import
check_min_version
,
deprecate
,
is_wandb_available
,
load_image
from
diffusers.utils
import
check_min_version
,
deprecate
,
is_wandb_available
,
load_image
from
diffusers.utils.import_utils
import
is_xformers_available
from
diffusers.utils.import_utils
import
is_xformers_available
from
diffusers.utils.torch_utils
import
is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
...
@@ -531,6 +532,11 @@ def main():
...
@@ -531,6 +532,11 @@ def main():
else
:
else
:
raise
ValueError
(
"xformers is not available. Make sure it is installed correctly"
)
raise
ValueError
(
"xformers is not available. Make sure it is installed correctly"
)
def
unwrap_model
(
model
):
model
=
accelerator
.
unwrap_model
(
model
)
model
=
model
.
_orig_mod
if
is_compiled_module
(
model
)
else
model
return
model
# `accelerate` 0.16.0 will have better support for customized saving
# `accelerate` 0.16.0 will have better support for customized saving
if
version
.
parse
(
accelerate
.
__version__
)
>=
version
.
parse
(
"0.16.0"
):
if
version
.
parse
(
accelerate
.
__version__
)
>=
version
.
parse
(
"0.16.0"
):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
...
@@ -1044,8 +1050,12 @@ def main():
...
@@ -1044,8 +1050,12 @@ def main():
added_cond_kwargs
=
{
"text_embeds"
:
add_text_embeds
,
"time_ids"
:
add_time_ids
}
added_cond_kwargs
=
{
"text_embeds"
:
add_text_embeds
,
"time_ids"
:
add_time_ids
}
model_pred
=
unet
(
model_pred
=
unet
(
concatenated_noisy_latents
,
timesteps
,
encoder_hidden_states
,
added_cond_kwargs
=
added_cond_kwargs
concatenated_noisy_latents
,
).
sample
timesteps
,
encoder_hidden_states
,
added_cond_kwargs
=
added_cond_kwargs
,
return_dict
=
False
,
)[
0
]
loss
=
F
.
mse_loss
(
model_pred
.
float
(),
target
.
float
(),
reduction
=
"mean"
)
loss
=
F
.
mse_loss
(
model_pred
.
float
(),
target
.
float
(),
reduction
=
"mean"
)
# Gather the losses across all processes for logging (if we use distributed training).
# Gather the losses across all processes for logging (if we use distributed training).
...
@@ -1115,7 +1125,7 @@ def main():
...
@@ -1115,7 +1125,7 @@ def main():
# The models need unwrapping because for compatibility in distributed training mode.
# The models need unwrapping because for compatibility in distributed training mode.
pipeline
=
StableDiffusionXLInstructPix2PixPipeline
.
from_pretrained
(
pipeline
=
StableDiffusionXLInstructPix2PixPipeline
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
args
.
pretrained_model_name_or_path
,
unet
=
accelerator
.
unwrap_model
(
unet
),
unet
=
unwrap_model
(
unet
),
text_encoder
=
text_encoder_1
,
text_encoder
=
text_encoder_1
,
text_encoder_2
=
text_encoder_2
,
text_encoder_2
=
text_encoder_2
,
tokenizer
=
tokenizer_1
,
tokenizer
=
tokenizer_1
,
...
@@ -1177,7 +1187,7 @@ def main():
...
@@ -1177,7 +1187,7 @@ def main():
# Create the pipeline using the trained modules and save it.
# Create the pipeline using the trained modules and save it.
accelerator
.
wait_for_everyone
()
accelerator
.
wait_for_everyone
()
if
accelerator
.
is_main_process
:
if
accelerator
.
is_main_process
:
unet
=
accelerator
.
unwrap_model
(
unet
)
unet
=
unwrap_model
(
unet
)
if
args
.
use_ema
:
if
args
.
use_ema
:
ema_unet
.
copy_to
(
unet
.
parameters
())
ema_unet
.
copy_to
(
unet
.
parameters
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment