Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
33d2b5b0
Unverified
Commit
33d2b5b0
authored
Jan 12, 2024
by
gzguevara
Committed by
GitHub
Jan 12, 2024
Browse files
SD text-to-image torch compile compatible (#6519)
* added unwrapper * fiz typo
parent
f486d34b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
3 deletions
+10
-3
examples/text_to_image/train_text_to_image.py
examples/text_to_image/train_text_to_image.py
+10
-3
No files found.
examples/text_to_image/train_text_to_image.py
View file @
33d2b5b0
...
...
@@ -46,6 +46,7 @@ from diffusers.optimization import get_scheduler
from
diffusers.training_utils
import
EMAModel
,
compute_snr
from
diffusers.utils
import
check_min_version
,
deprecate
,
is_wandb_available
,
make_image_grid
from
diffusers.utils.import_utils
import
is_xformers_available
from
diffusers.utils.torch_utils
import
is_compiled_module
if
is_wandb_available
():
...
...
@@ -833,6 +834,12 @@ def main():
tracker_config
.
pop
(
"validation_prompts"
)
accelerator
.
init_trackers
(
args
.
tracker_project_name
,
tracker_config
)
# Function for unwrapping if model was compiled with `torch.compile`.
def
unwrap_model
(
model
):
model
=
accelerator
.
unwrap_model
(
model
)
model
=
model
.
_orig_mod
if
is_compiled_module
(
model
)
else
model
return
model
# Train!
total_batch_size
=
args
.
train_batch_size
*
accelerator
.
num_processes
*
args
.
gradient_accumulation_steps
...
...
@@ -912,7 +919,7 @@ def main():
noisy_latents
=
noise_scheduler
.
add_noise
(
latents
,
noise
,
timesteps
)
# Get the text embedding for conditioning
encoder_hidden_states
=
text_encoder
(
batch
[
"input_ids"
])[
0
]
encoder_hidden_states
=
text_encoder
(
batch
[
"input_ids"
]
,
return_dict
=
False
)[
0
]
# Get the target for loss depending on the prediction type
if
args
.
prediction_type
is
not
None
:
...
...
@@ -927,7 +934,7 @@ def main():
raise
ValueError
(
f
"Unknown prediction type
{
noise_scheduler
.
config
.
prediction_type
}
"
)
# Predict the noise residual and compute loss
model_pred
=
unet
(
noisy_latents
,
timesteps
,
encoder_hidden_states
).
sample
model_pred
=
unet
(
noisy_latents
,
timesteps
,
encoder_hidden_states
,
return_dict
=
False
)[
0
]
if
args
.
snr_gamma
is
None
:
loss
=
F
.
mse_loss
(
model_pred
.
float
(),
target
.
float
(),
reduction
=
"mean"
)
...
...
@@ -1023,7 +1030,7 @@ def main():
# Create the pipeline using the trained modules and save it.
accelerator
.
wait_for_everyone
()
if
accelerator
.
is_main_process
:
unet
=
accelerator
.
unwrap_model
(
unet
)
unet
=
unwrap_model
(
unet
)
if
args
.
use_ema
:
ema_unet
.
copy_to
(
unet
.
parameters
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment