Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
05faf326
Unverified
Commit
05faf326
authored
Jan 15, 2024
by
gzguevara
Committed by
GitHub
Jan 15, 2024
Browse files
SDXL text-to-image torch compatible (#6550)
* torch compatible * code quality fix * ruff style * ruff format
parent
a080f0d3
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
10 deletions
+17
-10
examples/text_to_image/train_text_to_image_sdxl.py
examples/text_to_image/train_text_to_image_sdxl.py
+17
-10
No files found.
examples/text_to_image/train_text_to_image_sdxl.py
View file @
05faf326
...
@@ -44,16 +44,12 @@ from tqdm.auto import tqdm
...
@@ -44,16 +44,12 @@ from tqdm.auto import tqdm
from
transformers
import
AutoTokenizer
,
PretrainedConfig
from
transformers
import
AutoTokenizer
,
PretrainedConfig
import
diffusers
import
diffusers
from
diffusers
import
(
from
diffusers
import
AutoencoderKL
,
DDPMScheduler
,
StableDiffusionXLPipeline
,
UNet2DConditionModel
AutoencoderKL
,
DDPMScheduler
,
StableDiffusionXLPipeline
,
UNet2DConditionModel
,
)
from
diffusers.optimization
import
get_scheduler
from
diffusers.optimization
import
get_scheduler
from
diffusers.training_utils
import
EMAModel
,
compute_snr
from
diffusers.training_utils
import
EMAModel
,
compute_snr
from
diffusers.utils
import
check_min_version
,
is_wandb_available
from
diffusers.utils
import
check_min_version
,
is_wandb_available
from
diffusers.utils.import_utils
import
is_xformers_available
from
diffusers.utils.import_utils
import
is_xformers_available
from
diffusers.utils.torch_utils
import
is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
...
@@ -508,11 +504,12 @@ def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, ca
...
@@ -508,11 +504,12 @@ def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, ca
prompt_embeds
=
text_encoder
(
prompt_embeds
=
text_encoder
(
text_input_ids
.
to
(
text_encoder
.
device
),
text_input_ids
.
to
(
text_encoder
.
device
),
output_hidden_states
=
True
,
output_hidden_states
=
True
,
return_dict
=
False
,
)
)
# We are only ALWAYS interested in the pooled output of the final text encoder
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds
=
prompt_embeds
[
0
]
pooled_prompt_embeds
=
prompt_embeds
[
0
]
prompt_embeds
=
prompt_embeds
.
hidden_states
[
-
2
]
prompt_embeds
=
prompt_embeds
[
-
1
]
[
-
2
]
bs_embed
,
seq_len
,
_
=
prompt_embeds
.
shape
bs_embed
,
seq_len
,
_
=
prompt_embeds
.
shape
prompt_embeds
=
prompt_embeds
.
view
(
bs_embed
,
seq_len
,
-
1
)
prompt_embeds
=
prompt_embeds
.
view
(
bs_embed
,
seq_len
,
-
1
)
prompt_embeds_list
.
append
(
prompt_embeds
)
prompt_embeds_list
.
append
(
prompt_embeds
)
...
@@ -955,6 +952,12 @@ def main(args):
...
@@ -955,6 +952,12 @@ def main(args):
if
accelerator
.
is_main_process
:
if
accelerator
.
is_main_process
:
accelerator
.
init_trackers
(
"text2image-fine-tune-sdxl"
,
config
=
vars
(
args
))
accelerator
.
init_trackers
(
"text2image-fine-tune-sdxl"
,
config
=
vars
(
args
))
# Function for unwraping if torch.compile() was used in accelerate.
def
unwrap_model
(
model
):
model
=
accelerator
.
unwrap_model
(
model
)
model
=
model
.
_orig_mod
if
is_compiled_module
(
model
)
else
model
return
model
# Train!
# Train!
total_batch_size
=
args
.
train_batch_size
*
accelerator
.
num_processes
*
args
.
gradient_accumulation_steps
total_batch_size
=
args
.
train_batch_size
*
accelerator
.
num_processes
*
args
.
gradient_accumulation_steps
...
@@ -1054,8 +1057,12 @@ def main(args):
...
@@ -1054,8 +1057,12 @@ def main(args):
pooled_prompt_embeds
=
batch
[
"pooled_prompt_embeds"
].
to
(
accelerator
.
device
)
pooled_prompt_embeds
=
batch
[
"pooled_prompt_embeds"
].
to
(
accelerator
.
device
)
unet_added_conditions
.
update
({
"text_embeds"
:
pooled_prompt_embeds
})
unet_added_conditions
.
update
({
"text_embeds"
:
pooled_prompt_embeds
})
model_pred
=
unet
(
model_pred
=
unet
(
noisy_model_input
,
timesteps
,
prompt_embeds
,
added_cond_kwargs
=
unet_added_conditions
noisy_model_input
,
).
sample
timesteps
,
prompt_embeds
,
added_cond_kwargs
=
unet_added_conditions
,
return_dict
=
False
,
)[
0
]
# Get the target for loss depending on the prediction type
# Get the target for loss depending on the prediction type
if
args
.
prediction_type
is
not
None
:
if
args
.
prediction_type
is
not
None
:
...
@@ -1206,7 +1213,7 @@ def main(args):
...
@@ -1206,7 +1213,7 @@ def main(args):
accelerator
.
wait_for_everyone
()
accelerator
.
wait_for_everyone
()
if
accelerator
.
is_main_process
:
if
accelerator
.
is_main_process
:
unet
=
accelerator
.
unwrap_model
(
unet
)
unet
=
unwrap_model
(
unet
)
if
args
.
use_ema
:
if
args
.
use_ema
:
ema_unet
.
copy_to
(
unet
.
parameters
())
ema_unet
.
copy_to
(
unet
.
parameters
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment