Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
05faf326
Unverified
Commit
05faf326
authored
Jan 15, 2024
by
gzguevara
Committed by
GitHub
Jan 15, 2024
Browse files
SDXL text-to-image torch compatible (#6550)
* torch compatible * code quality fix * ruff style * ruff format
parent
a080f0d3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
10 deletions
+17
-10
examples/text_to_image/train_text_to_image_sdxl.py
examples/text_to_image/train_text_to_image_sdxl.py
+17
-10
No files found.
examples/text_to_image/train_text_to_image_sdxl.py
View file @
05faf326
...
@@ -44,16 +44,12 @@ from tqdm.auto import tqdm
...
@@ -44,16 +44,12 @@ from tqdm.auto import tqdm
from
transformers
import
AutoTokenizer
,
PretrainedConfig
from
transformers
import
AutoTokenizer
,
PretrainedConfig
import
diffusers
import
diffusers
from
diffusers
import
(
from
diffusers
import
AutoencoderKL
,
DDPMScheduler
,
StableDiffusionXLPipeline
,
UNet2DConditionModel
AutoencoderKL
,
DDPMScheduler
,
StableDiffusionXLPipeline
,
UNet2DConditionModel
,
)
from
diffusers.optimization
import
get_scheduler
from
diffusers.optimization
import
get_scheduler
from
diffusers.training_utils
import
EMAModel
,
compute_snr
from
diffusers.training_utils
import
EMAModel
,
compute_snr
from
diffusers.utils
import
check_min_version
,
is_wandb_available
from
diffusers.utils
import
check_min_version
,
is_wandb_available
from
diffusers.utils.import_utils
import
is_xformers_available
from
diffusers.utils.import_utils
import
is_xformers_available
from
diffusers.utils.torch_utils
import
is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
...
@@ -508,11 +504,12 @@ def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, ca
...
@@ -508,11 +504,12 @@ def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, ca
prompt_embeds
=
text_encoder
(
prompt_embeds
=
text_encoder
(
text_input_ids
.
to
(
text_encoder
.
device
),
text_input_ids
.
to
(
text_encoder
.
device
),
output_hidden_states
=
True
,
output_hidden_states
=
True
,
return_dict
=
False
,
)
)
# We are only ALWAYS interested in the pooled output of the final text encoder
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds
=
prompt_embeds
[
0
]
pooled_prompt_embeds
=
prompt_embeds
[
0
]
prompt_embeds
=
prompt_embeds
.
hidden_states
[
-
2
]
prompt_embeds
=
prompt_embeds
[
-
1
]
[
-
2
]
bs_embed
,
seq_len
,
_
=
prompt_embeds
.
shape
bs_embed
,
seq_len
,
_
=
prompt_embeds
.
shape
prompt_embeds
=
prompt_embeds
.
view
(
bs_embed
,
seq_len
,
-
1
)
prompt_embeds
=
prompt_embeds
.
view
(
bs_embed
,
seq_len
,
-
1
)
prompt_embeds_list
.
append
(
prompt_embeds
)
prompt_embeds_list
.
append
(
prompt_embeds
)
...
@@ -955,6 +952,12 @@ def main(args):
...
@@ -955,6 +952,12 @@ def main(args):
if
accelerator
.
is_main_process
:
if
accelerator
.
is_main_process
:
accelerator
.
init_trackers
(
"text2image-fine-tune-sdxl"
,
config
=
vars
(
args
))
accelerator
.
init_trackers
(
"text2image-fine-tune-sdxl"
,
config
=
vars
(
args
))
# Function for unwraping if torch.compile() was used in accelerate.
def
unwrap_model
(
model
):
model
=
accelerator
.
unwrap_model
(
model
)
model
=
model
.
_orig_mod
if
is_compiled_module
(
model
)
else
model
return
model
# Train!
# Train!
total_batch_size
=
args
.
train_batch_size
*
accelerator
.
num_processes
*
args
.
gradient_accumulation_steps
total_batch_size
=
args
.
train_batch_size
*
accelerator
.
num_processes
*
args
.
gradient_accumulation_steps
...
@@ -1054,8 +1057,12 @@ def main(args):
...
@@ -1054,8 +1057,12 @@ def main(args):
pooled_prompt_embeds
=
batch
[
"pooled_prompt_embeds"
].
to
(
accelerator
.
device
)
pooled_prompt_embeds
=
batch
[
"pooled_prompt_embeds"
].
to
(
accelerator
.
device
)
unet_added_conditions
.
update
({
"text_embeds"
:
pooled_prompt_embeds
})
unet_added_conditions
.
update
({
"text_embeds"
:
pooled_prompt_embeds
})
model_pred
=
unet
(
model_pred
=
unet
(
noisy_model_input
,
timesteps
,
prompt_embeds
,
added_cond_kwargs
=
unet_added_conditions
noisy_model_input
,
).
sample
timesteps
,
prompt_embeds
,
added_cond_kwargs
=
unet_added_conditions
,
return_dict
=
False
,
)[
0
]
# Get the target for loss depending on the prediction type
# Get the target for loss depending on the prediction type
if
args
.
prediction_type
is
not
None
:
if
args
.
prediction_type
is
not
None
:
...
@@ -1206,7 +1213,7 @@ def main(args):
...
@@ -1206,7 +1213,7 @@ def main(args):
accelerator
.
wait_for_everyone
()
accelerator
.
wait_for_everyone
()
if
accelerator
.
is_main_process
:
if
accelerator
.
is_main_process
:
unet
=
accelerator
.
unwrap_model
(
unet
)
unet
=
unwrap_model
(
unet
)
if
args
.
use_ema
:
if
args
.
use_ema
:
ema_unet
.
copy_to
(
unet
.
parameters
())
ema_unet
.
copy_to
(
unet
.
parameters
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment