Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
7b93c2a8
Unverified
Commit
7b93c2a8
authored
Aug 16, 2023
by
Suraj Patil
Committed by
GitHub
Aug 16, 2023
Browse files
[research_projects] SDXL controlnet script (#4633)
add controlent script,
parent
a7de9650
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
1459 additions
and
0 deletions
+1459
-0
examples/research_projects/controlnet/train_controlnet_webdataset.py
...search_projects/controlnet/train_controlnet_webdataset.py
+1459
-0
No files found.
examples/research_projects/controlnet/train_controlnet_webdataset.py
0 → 100644
View file @
7b93c2a8
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import
argparse
import
functools
import
gc
import
itertools
import
json
import
logging
import
math
import
os
import
random
import
shutil
from
pathlib
import
Path
from
typing
import
List
,
Optional
,
Union
import
accelerate
import
cv2
import
numpy
as
np
import
torch
import
torch.utils.checkpoint
import
transformers
import
webdataset
as
wds
from
accelerate
import
Accelerator
from
accelerate.logging
import
get_logger
from
accelerate.utils
import
ProjectConfiguration
,
set_seed
from
braceexpand
import
braceexpand
from
huggingface_hub
import
create_repo
,
upload_folder
from
packaging
import
version
from
PIL
import
Image
from
torch.utils.data
import
default_collate
from
torchvision
import
transforms
from
tqdm.auto
import
tqdm
from
transformers
import
AutoTokenizer
,
DPTFeatureExtractor
,
DPTForDepthEstimation
,
PretrainedConfig
from
webdataset.tariterators
import
(
base_plus_ext
,
tar_file_expander
,
url_opener
,
valid_sample
,
)
import
diffusers
from
diffusers
import
(
AutoencoderKL
,
ControlNetModel
,
EulerDiscreteScheduler
,
StableDiffusionXLControlNetPipeline
,
UNet2DConditionModel
,
)
from
diffusers.optimization
import
get_scheduler
from
diffusers.utils
import
check_min_version
,
is_wandb_available
from
diffusers.utils.import_utils
import
is_xformers_available
MAX_SEQ_LENGTH
=
77
if
is_wandb_available
():
import
wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version
(
"0.18.0.dev0"
)
logger
=
get_logger
(
__name__
)
def
filter_keys
(
key_set
):
def
_f
(
dictionary
):
return
{
k
:
v
for
k
,
v
in
dictionary
.
items
()
if
k
in
key_set
}
return
_f
def
group_by_keys_nothrow
(
data
,
keys
=
base_plus_ext
,
lcase
=
True
,
suffixes
=
None
,
handler
=
None
):
"""Return function over iterator that groups key, value pairs into samples.
:param keys: function that splits the key into key and extension (base_plus_ext)
:param lcase: convert suffixes to lower case (Default value = True)
"""
current_sample
=
None
for
filesample
in
data
:
assert
isinstance
(
filesample
,
dict
)
fname
,
value
=
filesample
[
"fname"
],
filesample
[
"data"
]
prefix
,
suffix
=
keys
(
fname
)
if
prefix
is
None
:
continue
if
lcase
:
suffix
=
suffix
.
lower
()
# FIXME webdataset version throws if suffix in current_sample, but we have a potential for
# this happening in the current LAION400m dataset if a tar ends with same prefix as the next
# begins, rare, but can happen since prefix aren't unique across tar files in that dataset
if
current_sample
is
None
or
prefix
!=
current_sample
[
"__key__"
]
or
suffix
in
current_sample
:
if
valid_sample
(
current_sample
):
yield
current_sample
current_sample
=
{
"__key__"
:
prefix
,
"__url__"
:
filesample
[
"__url__"
]}
if
suffixes
is
None
or
suffix
in
suffixes
:
current_sample
[
suffix
]
=
value
if
valid_sample
(
current_sample
):
yield
current_sample
def
tarfile_to_samples_nothrow
(
src
,
handler
=
wds
.
warn_and_continue
):
# NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
streams
=
url_opener
(
src
,
handler
=
handler
)
files
=
tar_file_expander
(
streams
,
handler
=
handler
)
samples
=
group_by_keys_nothrow
(
files
,
handler
=
handler
)
return
samples
def
control_transform
(
image
):
image
=
np
.
array
(
image
)
low_threshold
=
100
high_threshold
=
200
image
=
cv2
.
Canny
(
image
,
low_threshold
,
high_threshold
)
image
=
image
[:,
:,
None
]
image
=
np
.
concatenate
([
image
,
image
,
image
],
axis
=
2
)
control_image
=
Image
.
fromarray
(
image
)
return
control_image
def
canny_image_transform
(
example
,
resolution
=
1024
):
image
=
example
[
"image"
]
image
=
transforms
.
Resize
(
resolution
,
interpolation
=
transforms
.
InterpolationMode
.
BILINEAR
)(
image
)
# get crop coordinates
c_top
,
c_left
,
_
,
_
=
transforms
.
RandomCrop
.
get_params
(
image
,
output_size
=
(
resolution
,
resolution
))
image
=
transforms
.
functional
.
crop
(
image
,
c_top
,
c_left
,
resolution
,
resolution
)
control_image
=
control_transform
(
image
)
image
=
transforms
.
ToTensor
()(
image
)
image
=
transforms
.
Normalize
([
0.5
],
[
0.5
])(
image
)
control_image
=
transforms
.
ToTensor
()(
control_image
)
example
[
"image"
]
=
image
example
[
"control_image"
]
=
control_image
example
[
"crop_coords"
]
=
(
c_top
,
c_left
)
return
example
def
depth_image_transform
(
example
,
feature_extractor
,
resolution
=
1024
):
image
=
example
[
"image"
]
image
=
transforms
.
Resize
(
resolution
,
interpolation
=
transforms
.
InterpolationMode
.
BILINEAR
)(
image
)
# get crop coordinates
c_top
,
c_left
,
_
,
_
=
transforms
.
RandomCrop
.
get_params
(
image
,
output_size
=
(
resolution
,
resolution
))
image
=
transforms
.
functional
.
crop
(
image
,
c_top
,
c_left
,
resolution
,
resolution
)
control_image
=
feature_extractor
(
images
=
image
,
return_tensors
=
"pt"
).
pixel_values
.
squeeze
(
0
)
image
=
transforms
.
ToTensor
()(
image
)
image
=
transforms
.
Normalize
([
0.5
],
[
0.5
])(
image
)
example
[
"image"
]
=
image
example
[
"control_image"
]
=
control_image
example
[
"crop_coords"
]
=
(
c_top
,
c_left
)
return
example
class
WebdatasetFilter
:
def
__init__
(
self
,
min_size
=
1024
,
max_pwatermark
=
0.5
):
self
.
min_size
=
min_size
self
.
max_pwatermark
=
max_pwatermark
def
__call__
(
self
,
x
):
try
:
if
"json"
in
x
:
x_json
=
json
.
loads
(
x
[
"json"
])
filter_size
=
(
x_json
.
get
(
"original_width"
,
0.0
)
or
0.0
)
>=
self
.
min_size
and
x_json
.
get
(
"original_height"
,
0
)
>=
self
.
min_size
filter_watermark
=
(
x_json
.
get
(
"pwatermark"
,
1.0
)
or
1.0
)
<=
self
.
max_pwatermark
return
filter_size
and
filter_watermark
else
:
return
False
except
Exception
:
return
False
class
Text2ImageDataset
:
def
__init__
(
self
,
train_shards_path_or_url
:
Union
[
str
,
List
[
str
]],
eval_shards_path_or_url
:
Union
[
str
,
List
[
str
]],
num_train_examples
:
int
,
per_gpu_batch_size
:
int
,
global_batch_size
:
int
,
num_workers
:
int
,
resolution
:
int
=
256
,
center_crop
:
bool
=
True
,
random_flip
:
bool
=
False
,
shuffle_buffer_size
:
int
=
1000
,
pin_memory
:
bool
=
False
,
persistent_workers
:
bool
=
False
,
control_type
:
str
=
"canny"
,
feature_extractor
:
Optional
[
DPTFeatureExtractor
]
=
None
,
):
if
not
isinstance
(
train_shards_path_or_url
,
str
):
train_shards_path_or_url
=
[
list
(
braceexpand
(
urls
))
for
urls
in
train_shards_path_or_url
]
# flatten list using itertools
train_shards_path_or_url
=
list
(
itertools
.
chain
.
from_iterable
(
train_shards_path_or_url
))
if
not
isinstance
(
eval_shards_path_or_url
,
str
):
eval_shards_path_or_url
=
[
list
(
braceexpand
(
urls
))
for
urls
in
eval_shards_path_or_url
]
# flatten list using itertools
eval_shards_path_or_url
=
list
(
itertools
.
chain
.
from_iterable
(
eval_shards_path_or_url
))
def
get_orig_size
(
json
):
return
(
int
(
json
.
get
(
"original_width"
,
0.0
)),
int
(
json
.
get
(
"original_height"
,
0.0
)))
if
control_type
==
"canny"
:
image_transform
=
functools
.
partial
(
canny_image_transform
,
resolution
=
resolution
)
elif
control_type
==
"depth"
:
image_transform
=
functools
.
partial
(
depth_image_transform
,
feature_extractor
=
feature_extractor
,
resolution
=
resolution
)
processing_pipeline
=
[
wds
.
decode
(
"pil"
,
handler
=
wds
.
ignore_and_continue
),
wds
.
rename
(
image
=
"jpg;png;jpeg;webp"
,
control_image
=
"jpg;png;jpeg;webp"
,
text
=
"text;txt;caption"
,
orig_size
=
"json"
,
handler
=
wds
.
warn_and_continue
,
),
wds
.
map
(
filter_keys
({
"image"
,
"control_image"
,
"text"
,
"orig_size"
})),
wds
.
map_dict
(
orig_size
=
get_orig_size
),
wds
.
map
(
image_transform
),
wds
.
to_tuple
(
"image"
,
"control_image"
,
"text"
,
"orig_size"
,
"crop_coords"
),
]
# Create train dataset and loader
pipeline
=
[
wds
.
ResampledShards
(
train_shards_path_or_url
),
tarfile_to_samples_nothrow
,
wds
.
select
(
WebdatasetFilter
(
min_size
=
512
)),
wds
.
shuffle
(
shuffle_buffer_size
),
*
processing_pipeline
,
wds
.
batched
(
per_gpu_batch_size
,
partial
=
False
,
collation_fn
=
default_collate
),
]
num_worker_batches
=
math
.
ceil
(
num_train_examples
/
(
global_batch_size
*
num_workers
))
# per dataloader worker
num_batches
=
num_worker_batches
*
num_workers
num_samples
=
num_batches
*
global_batch_size
# each worker is iterating over this
self
.
_train_dataset
=
wds
.
DataPipeline
(
*
pipeline
).
with_epoch
(
num_worker_batches
)
self
.
_train_dataloader
=
wds
.
WebLoader
(
self
.
_train_dataset
,
batch_size
=
None
,
shuffle
=
False
,
num_workers
=
num_workers
,
pin_memory
=
pin_memory
,
persistent_workers
=
persistent_workers
,
)
# add meta-data to dataloader instance for convenience
self
.
_train_dataloader
.
num_batches
=
num_batches
self
.
_train_dataloader
.
num_samples
=
num_samples
# Create eval dataset and loader
pipeline
=
[
wds
.
SimpleShardList
(
eval_shards_path_or_url
),
wds
.
split_by_worker
,
wds
.
tarfile_to_samples
(
handler
=
wds
.
ignore_and_continue
),
*
processing_pipeline
,
wds
.
batched
(
per_gpu_batch_size
,
partial
=
False
,
collation_fn
=
default_collate
),
]
self
.
_eval_dataset
=
wds
.
DataPipeline
(
*
pipeline
)
self
.
_eval_dataloader
=
wds
.
WebLoader
(
self
.
_eval_dataset
,
batch_size
=
None
,
shuffle
=
False
,
num_workers
=
num_workers
,
pin_memory
=
pin_memory
,
persistent_workers
=
persistent_workers
,
)
@
property
def
train_dataset
(
self
):
return
self
.
_train_dataset
@
property
def
train_dataloader
(
self
):
return
self
.
_train_dataloader
@
property
def
eval_dataset
(
self
):
return
self
.
_eval_dataset
@
property
def
eval_dataloader
(
self
):
return
self
.
_eval_dataloader
def
image_grid
(
imgs
,
rows
,
cols
):
assert
len
(
imgs
)
==
rows
*
cols
w
,
h
=
imgs
[
0
].
size
grid
=
Image
.
new
(
"RGB"
,
size
=
(
cols
*
w
,
rows
*
h
))
for
i
,
img
in
enumerate
(
imgs
):
grid
.
paste
(
img
,
box
=
(
i
%
cols
*
w
,
i
//
cols
*
h
))
return
grid
def
log_validation
(
vae
,
unet
,
controlnet
,
args
,
accelerator
,
weight_dtype
,
step
):
logger
.
info
(
"Running validation... "
)
controlnet
=
accelerator
.
unwrap_model
(
controlnet
)
pipeline
=
StableDiffusionXLControlNetPipeline
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
vae
=
vae
,
unet
=
unet
,
controlnet
=
controlnet
,
revision
=
args
.
revision
,
torch_dtype
=
weight_dtype
,
)
# pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline
=
pipeline
.
to
(
accelerator
.
device
)
pipeline
.
set_progress_bar_config
(
disable
=
True
)
if
args
.
enable_xformers_memory_efficient_attention
:
pipeline
.
enable_xformers_memory_efficient_attention
()
if
args
.
seed
is
None
:
generator
=
None
else
:
generator
=
torch
.
Generator
(
device
=
accelerator
.
device
).
manual_seed
(
args
.
seed
)
if
len
(
args
.
validation_image
)
==
len
(
args
.
validation_prompt
):
validation_images
=
args
.
validation_image
validation_prompts
=
args
.
validation_prompt
elif
len
(
args
.
validation_image
)
==
1
:
validation_images
=
args
.
validation_image
*
len
(
args
.
validation_prompt
)
validation_prompts
=
args
.
validation_prompt
elif
len
(
args
.
validation_prompt
)
==
1
:
validation_images
=
args
.
validation_image
validation_prompts
=
args
.
validation_prompt
*
len
(
args
.
validation_image
)
else
:
raise
ValueError
(
"number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
)
image_logs
=
[]
for
validation_prompt
,
validation_image
in
zip
(
validation_prompts
,
validation_images
):
validation_image
=
Image
.
open
(
validation_image
).
convert
(
"RGB"
)
validation_image
=
validation_image
.
resize
((
args
.
resolution
,
args
.
resolution
))
images
=
[]
for
_
in
range
(
args
.
num_validation_images
):
with
torch
.
autocast
(
"cuda"
):
image
=
pipeline
(
validation_prompt
,
image
=
validation_image
,
num_inference_steps
=
20
,
generator
=
generator
).
images
[
0
]
images
.
append
(
image
)
image_logs
.
append
(
{
"validation_image"
:
validation_image
,
"images"
:
images
,
"validation_prompt"
:
validation_prompt
}
)
for
tracker
in
accelerator
.
trackers
:
if
tracker
.
name
==
"tensorboard"
:
for
log
in
image_logs
:
images
=
log
[
"images"
]
validation_prompt
=
log
[
"validation_prompt"
]
validation_image
=
log
[
"validation_image"
]
formatted_images
=
[]
formatted_images
.
append
(
np
.
asarray
(
validation_image
))
for
image
in
images
:
formatted_images
.
append
(
np
.
asarray
(
image
))
formatted_images
=
np
.
stack
(
formatted_images
)
tracker
.
writer
.
add_images
(
validation_prompt
,
formatted_images
,
step
,
dataformats
=
"NHWC"
)
elif
tracker
.
name
==
"wandb"
:
formatted_images
=
[]
for
log
in
image_logs
:
images
=
log
[
"images"
]
validation_prompt
=
log
[
"validation_prompt"
]
validation_image
=
log
[
"validation_image"
]
formatted_images
.
append
(
wandb
.
Image
(
validation_image
,
caption
=
"Controlnet conditioning"
))
for
image
in
images
:
image
=
wandb
.
Image
(
image
,
caption
=
validation_prompt
)
formatted_images
.
append
(
image
)
tracker
.
log
({
"validation"
:
formatted_images
})
else
:
logger
.
warn
(
f
"image logging not implemented for
{
tracker
.
name
}
"
)
del
pipeline
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
return
image_logs
def
import_model_class_from_model_name_or_path
(
pretrained_model_name_or_path
:
str
,
revision
:
str
,
subfolder
:
str
=
"text_encoder"
):
text_encoder_config
=
PretrainedConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
subfolder
=
subfolder
,
revision
=
revision
,
use_auth_token
=
True
)
model_class
=
text_encoder_config
.
architectures
[
0
]
if
model_class
==
"CLIPTextModel"
:
from
transformers
import
CLIPTextModel
return
CLIPTextModel
elif
model_class
==
"CLIPTextModelWithProjection"
:
from
transformers
import
CLIPTextModelWithProjection
return
CLIPTextModelWithProjection
else
:
raise
ValueError
(
f
"
{
model_class
}
is not supported."
)
def
save_model_card
(
repo_id
:
str
,
image_logs
=
None
,
base_model
=
str
,
repo_folder
=
None
):
img_str
=
""
if
image_logs
is
not
None
:
img_str
=
"You can find some example images below.
\n
"
for
i
,
log
in
enumerate
(
image_logs
):
images
=
log
[
"images"
]
validation_prompt
=
log
[
"validation_prompt"
]
validation_image
=
log
[
"validation_image"
]
validation_image
.
save
(
os
.
path
.
join
(
repo_folder
,
"image_control.png"
))
img_str
+=
f
"prompt:
{
validation_prompt
}
\n
"
images
=
[
validation_image
]
+
images
image_grid
(
images
,
1
,
len
(
images
)).
save
(
os
.
path
.
join
(
repo_folder
,
f
"images_
{
i
}
.png"
))
img_str
+=
f
"
\n
"
yaml
=
f
"""
---
license: creativeml-openrail-m
base_model:
{
base_model
}
tags:
- stable-diffusion-xl
- stable-diffusion-xl-diffusers
- text-to-image
- diffusers
- controlnet
inference: true
---
"""
model_card
=
f
"""
# controlnet-
{
repo_id
}
These are controlnet weights trained on
{
base_model
}
with new type of conditioning.
{
img_str
}
"""
with
open
(
os
.
path
.
join
(
repo_folder
,
"README.md"
),
"w"
)
as
f
:
f
.
write
(
yaml
+
model_card
)
def
parse_args
(
input_args
=
None
):
parser
=
argparse
.
ArgumentParser
(
description
=
"Simple example of a ControlNet training script."
)
parser
.
add_argument
(
"--pretrained_model_name_or_path"
,
type
=
str
,
default
=
None
,
required
=
True
,
help
=
"Path to pretrained model or model identifier from huggingface.co/models."
,
)
parser
.
add_argument
(
"--pretrained_vae_model_name_or_path"
,
type
=
str
,
default
=
None
,
help
=
"Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038."
,
)
parser
.
add_argument
(
"--controlnet_model_name_or_path"
,
type
=
str
,
default
=
None
,
help
=
"Path to pretrained controlnet model or model identifier from huggingface.co/models."
" If not specified controlnet weights are initialized from unet."
,
)
parser
.
add_argument
(
"--revision"
,
type
=
str
,
default
=
None
,
required
=
False
,
help
=
(
"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
" float32 precision."
),
)
parser
.
add_argument
(
"--tokenizer_name"
,
type
=
str
,
default
=
None
,
help
=
"Pretrained tokenizer name or path if not the same as model_name"
,
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
"controlnet-model"
,
help
=
"The output directory where the model predictions and checkpoints will be written."
,
)
parser
.
add_argument
(
"--cache_dir"
,
type
=
str
,
default
=
None
,
help
=
"The directory where the downloaded models and datasets will be stored."
,
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
None
,
help
=
"A seed for reproducible training."
)
parser
.
add_argument
(
"--resolution"
,
type
=
int
,
default
=
512
,
help
=
(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser
.
add_argument
(
"--crops_coords_top_left_h"
,
type
=
int
,
default
=
0
,
help
=
(
"Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."
),
)
parser
.
add_argument
(
"--crops_coords_top_left_w"
,
type
=
int
,
default
=
0
,
help
=
(
"Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."
),
)
parser
.
add_argument
(
"--train_batch_size"
,
type
=
int
,
default
=
4
,
help
=
"Batch size (per device) for the training dataloader."
)
parser
.
add_argument
(
"--num_train_epochs"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--max_train_steps"
,
type
=
int
,
default
=
None
,
help
=
"Total number of training steps to perform. If provided, overrides num_train_epochs."
,
)
parser
.
add_argument
(
"--checkpointing_steps"
,
type
=
int
,
default
=
500
,
help
=
(
"Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
"See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
"instructions."
),
)
parser
.
add_argument
(
"--checkpoints_total_limit"
,
type
=
int
,
default
=
3
,
help
=
(
"Max number of checkpoints to store."
),
)
parser
.
add_argument
(
"--resume_from_checkpoint"
,
type
=
str
,
default
=
None
,
help
=
(
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
),
)
parser
.
add_argument
(
"--gradient_accumulation_steps"
,
type
=
int
,
default
=
1
,
help
=
"Number of updates steps to accumulate before performing a backward/update pass."
,
)
parser
.
add_argument
(
"--gradient_checkpointing"
,
action
=
"store_true"
,
help
=
"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass."
,
)
parser
.
add_argument
(
"--learning_rate"
,
type
=
float
,
default
=
5e-6
,
help
=
"Initial learning rate (after the potential warmup period) to use."
,
)
parser
.
add_argument
(
"--scale_lr"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size."
,
)
parser
.
add_argument
(
"--lr_scheduler"
,
type
=
str
,
default
=
"constant"
,
help
=
(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser
.
add_argument
(
"--lr_warmup_steps"
,
type
=
int
,
default
=
500
,
help
=
"Number of steps for the warmup in the lr scheduler."
)
parser
.
add_argument
(
"--lr_num_cycles"
,
type
=
int
,
default
=
1
,
help
=
"Number of hard resets of the lr in cosine_with_restarts scheduler."
,
)
parser
.
add_argument
(
"--lr_power"
,
type
=
float
,
default
=
1.0
,
help
=
"Power factor of the polynomial scheduler."
)
parser
.
add_argument
(
"--use_8bit_adam"
,
action
=
"store_true"
,
help
=
"Whether or not to use 8-bit Adam from bitsandbytes."
)
parser
.
add_argument
(
"--dataloader_num_workers"
,
type
=
int
,
default
=
1
,
help
=
(
"Number of subprocesses to use for data loading."
),
)
parser
.
add_argument
(
"--adam_beta1"
,
type
=
float
,
default
=
0.9
,
help
=
"The beta1 parameter for the Adam optimizer."
)
parser
.
add_argument
(
"--adam_beta2"
,
type
=
float
,
default
=
0.999
,
help
=
"The beta2 parameter for the Adam optimizer."
)
parser
.
add_argument
(
"--adam_weight_decay"
,
type
=
float
,
default
=
1e-2
,
help
=
"Weight decay to use."
)
parser
.
add_argument
(
"--adam_epsilon"
,
type
=
float
,
default
=
1e-08
,
help
=
"Epsilon value for the Adam optimizer"
)
parser
.
add_argument
(
"--max_grad_norm"
,
default
=
1.0
,
type
=
float
,
help
=
"Max gradient norm."
)
parser
.
add_argument
(
"--push_to_hub"
,
action
=
"store_true"
,
help
=
"Whether or not to push the model to the Hub."
)
parser
.
add_argument
(
"--hub_token"
,
type
=
str
,
default
=
None
,
help
=
"The token to use to push to the Model Hub."
)
parser
.
add_argument
(
"--hub_model_id"
,
type
=
str
,
default
=
None
,
help
=
"The name of the repository to keep in sync with the local `output_dir`."
,
)
parser
.
add_argument
(
"--logging_dir"
,
type
=
str
,
default
=
"logs"
,
help
=
(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser
.
add_argument
(
"--allow_tf32"
,
action
=
"store_true"
,
help
=
(
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser
.
add_argument
(
"--report_to"
,
type
=
str
,
default
=
"tensorboard"
,
help
=
(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser
.
add_argument
(
"--mixed_precision"
,
type
=
str
,
default
=
None
,
choices
=
[
"no"
,
"fp16"
,
"bf16"
],
help
=
(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser
.
add_argument
(
"--enable_xformers_memory_efficient_attention"
,
action
=
"store_true"
,
help
=
"Whether or not to use xformers."
)
parser
.
add_argument
(
"--set_grads_to_none"
,
action
=
"store_true"
,
help
=
(
"Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
" behaviors, so disable this argument if it causes any problems. More info:"
" https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
),
)
parser
.
add_argument
(
"--train_shards_path_or_url"
,
type
=
str
,
default
=
None
,
help
=
(
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
" or to a folder containing files that 🤗 Datasets can understand."
),
)
parser
.
add_argument
(
"--eval_shards_path_or_url"
,
type
=
str
,
default
=
None
,
help
=
"The config of the Dataset, leave as None if there's only one config."
,
)
parser
.
add_argument
(
"--train_data_dir"
,
type
=
str
,
default
=
None
,
help
=
(
"A folder containing the training data. Folder contents must follow the structure described in"
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
),
)
parser
.
add_argument
(
"--image_column"
,
type
=
str
,
default
=
"image"
,
help
=
"The column of the dataset containing the target image."
)
parser
.
add_argument
(
"--conditioning_image_column"
,
type
=
str
,
default
=
"conditioning_image"
,
help
=
"The column of the dataset containing the controlnet conditioning image."
,
)
parser
.
add_argument
(
"--caption_column"
,
type
=
str
,
default
=
"text"
,
help
=
"The column of the dataset containing a caption or a list of captions."
,
)
parser
.
add_argument
(
"--max_train_samples"
,
type
=
int
,
default
=
None
,
help
=
(
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
),
)
parser
.
add_argument
(
"--proportion_empty_prompts"
,
type
=
float
,
default
=
0
,
help
=
"Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement)."
,
)
parser
.
add_argument
(
"--validation_prompt"
,
type
=
str
,
default
=
None
,
nargs
=
"+"
,
help
=
(
"A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
" Provide either a matching number of `--validation_image`s, a single `--validation_image`"
" to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
),
)
parser
.
add_argument
(
"--validation_image"
,
type
=
str
,
default
=
None
,
nargs
=
"+"
,
help
=
(
"A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
" a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
" `--validation_image` that will be used with all `--validation_prompt`s."
),
)
parser
.
add_argument
(
"--num_validation_images"
,
type
=
int
,
default
=
4
,
help
=
"Number of images to be generated for each `--validation_image`, `--validation_prompt` pair"
,
)
parser
.
add_argument
(
"--validation_steps"
,
type
=
int
,
default
=
100
,
help
=
(
"Run validation every X steps. Validation consists of running the prompt"
" `args.validation_prompt` multiple times: `args.num_validation_images`"
" and logging the images."
),
)
parser
.
add_argument
(
"--tracker_project_name"
,
type
=
str
,
default
=
"sd_xl_train_controlnet"
,
help
=
(
"The `project_name` argument passed to Accelerator.init_trackers for"
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
),
)
parser
.
add_argument
(
"--control_type"
,
type
=
str
,
default
=
"canny"
,
help
=
(
"The type of controlnet conditioning image to use. One of `canny`, `depth`"
" Defaults to `canny`."
),
)
parser
.
add_argument
(
"--transformer_layers_per_block"
,
type
=
str
,
default
=
None
,
help
=
(
"The number of layers per block in the transformer. If None, defaults to"
" `args.transformer_layers`."
),
)
parser
.
add_argument
(
"--old_style_controlnet"
,
action
=
"store_true"
,
default
=
False
,
help
=
(
"Use the old style controlnet, which is a single transformer layer with"
" a single head. Defaults to False."
),
)
if
input_args
is
not
None
:
args
=
parser
.
parse_args
(
input_args
)
else
:
args
=
parser
.
parse_args
()
if
args
.
proportion_empty_prompts
<
0
or
args
.
proportion_empty_prompts
>
1
:
raise
ValueError
(
"`--proportion_empty_prompts` must be in the range [0, 1]."
)
if
args
.
validation_prompt
is
not
None
and
args
.
validation_image
is
None
:
raise
ValueError
(
"`--validation_image` must be set if `--validation_prompt` is set"
)
if
args
.
validation_prompt
is
None
and
args
.
validation_image
is
not
None
:
raise
ValueError
(
"`--validation_prompt` must be set if `--validation_image` is set"
)
if
(
args
.
validation_image
is
not
None
and
args
.
validation_prompt
is
not
None
and
len
(
args
.
validation_image
)
!=
1
and
len
(
args
.
validation_prompt
)
!=
1
and
len
(
args
.
validation_image
)
!=
len
(
args
.
validation_prompt
)
):
raise
ValueError
(
"Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
" or the same number of `--validation_prompt`s and `--validation_image`s"
)
if
args
.
resolution
%
8
!=
0
:
raise
ValueError
(
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
)
return
args
# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
def
encode_prompt
(
prompt_batch
,
text_encoders
,
tokenizers
,
proportion_empty_prompts
,
is_train
=
True
):
prompt_embeds_list
=
[]
captions
=
[]
for
caption
in
prompt_batch
:
if
random
.
random
()
<
proportion_empty_prompts
:
captions
.
append
(
""
)
elif
isinstance
(
caption
,
str
):
captions
.
append
(
caption
)
elif
isinstance
(
caption
,
(
list
,
np
.
ndarray
)):
# take a random caption if there are multiple
captions
.
append
(
random
.
choice
(
caption
)
if
is_train
else
caption
[
0
])
with
torch
.
no_grad
():
for
tokenizer
,
text_encoder
in
zip
(
tokenizers
,
text_encoders
):
text_inputs
=
tokenizer
(
captions
,
padding
=
"max_length"
,
max_length
=
tokenizer
.
model_max_length
,
truncation
=
True
,
return_tensors
=
"pt"
,
)
text_input_ids
=
text_inputs
.
input_ids
prompt_embeds
=
text_encoder
(
text_input_ids
.
to
(
text_encoder
.
device
),
output_hidden_states
=
True
,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds
=
prompt_embeds
[
0
]
prompt_embeds
=
prompt_embeds
.
hidden_states
[
-
2
]
bs_embed
,
seq_len
,
_
=
prompt_embeds
.
shape
prompt_embeds
=
prompt_embeds
.
view
(
bs_embed
,
seq_len
,
-
1
)
prompt_embeds_list
.
append
(
prompt_embeds
)
prompt_embeds
=
torch
.
concat
(
prompt_embeds_list
,
dim
=-
1
)
pooled_prompt_embeds
=
pooled_prompt_embeds
.
view
(
bs_embed
,
-
1
)
return
prompt_embeds
,
pooled_prompt_embeds
def
main
(
args
):
logging_dir
=
Path
(
args
.
output_dir
,
args
.
logging_dir
)
accelerator_project_config
=
ProjectConfiguration
(
project_dir
=
args
.
output_dir
,
logging_dir
=
logging_dir
)
accelerator
=
Accelerator
(
gradient_accumulation_steps
=
args
.
gradient_accumulation_steps
,
mixed_precision
=
args
.
mixed_precision
,
log_with
=
args
.
report_to
,
project_config
=
accelerator_project_config
,
)
# Make one log on every process with the configuration for debugging.
logging
.
basicConfig
(
format
=
"%(asctime)s - %(levelname)s - %(name)s - %(message)s"
,
datefmt
=
"%m/%d/%Y %H:%M:%S"
,
level
=
logging
.
INFO
,
)
logger
.
info
(
accelerator
.
state
,
main_process_only
=
False
)
if
accelerator
.
is_local_main_process
:
transformers
.
utils
.
logging
.
set_verbosity_warning
()
diffusers
.
utils
.
logging
.
set_verbosity_info
()
else
:
transformers
.
utils
.
logging
.
set_verbosity_error
()
diffusers
.
utils
.
logging
.
set_verbosity_error
()
# If passed along, set the training seed now.
if
args
.
seed
is
not
None
:
set_seed
(
args
.
seed
)
# Handle the repository creation
if
accelerator
.
is_main_process
:
if
args
.
output_dir
is
not
None
:
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
if
args
.
push_to_hub
:
repo_id
=
create_repo
(
repo_id
=
args
.
hub_model_id
or
Path
(
args
.
output_dir
).
name
,
exist_ok
=
True
,
token
=
args
.
hub_token
,
private
=
True
,
).
repo_id
# Load the tokenizers
tokenizer_one
=
AutoTokenizer
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"tokenizer"
,
revision
=
args
.
revision
,
use_fast
=
False
)
tokenizer_two
=
AutoTokenizer
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"tokenizer_2"
,
revision
=
args
.
revision
,
use_fast
=
False
)
# import correct text encoder classes
text_encoder_cls_one
=
import_model_class_from_model_name_or_path
(
args
.
pretrained_model_name_or_path
,
args
.
revision
)
text_encoder_cls_two
=
import_model_class_from_model_name_or_path
(
args
.
pretrained_model_name_or_path
,
args
.
revision
,
subfolder
=
"text_encoder_2"
)
# Load scheduler and models
# noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
noise_scheduler
=
EulerDiscreteScheduler
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"scheduler"
)
text_encoder_one
=
text_encoder_cls_one
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"text_encoder"
,
revision
=
args
.
revision
)
text_encoder_two
=
text_encoder_cls_two
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"text_encoder_2"
,
revision
=
args
.
revision
)
vae_path
=
(
args
.
pretrained_model_name_or_path
if
args
.
pretrained_vae_model_name_or_path
is
None
else
args
.
pretrained_vae_model_name_or_path
)
vae
=
AutoencoderKL
.
from_pretrained
(
vae_path
,
subfolder
=
"vae"
if
args
.
pretrained_vae_model_name_or_path
is
None
else
None
,
revision
=
args
.
revision
,
)
unet
=
UNet2DConditionModel
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"unet"
,
revision
=
args
.
revision
,
use_auth_token
=
True
)
if
args
.
controlnet_model_name_or_path
:
logger
.
info
(
"Loading existing controlnet weights"
)
pre_controlnet
=
ControlNetModel
.
from_pretrained
(
args
.
controlnet_model_name_or_path
)
else
:
logger
.
info
(
"Initializing controlnet weights from unet"
)
pre_controlnet
=
ControlNetModel
.
from_unet
(
unet
)
if
args
.
transformer_layers_per_block
is
not
None
:
transformer_layers_per_block
=
[
int
(
x
)
for
x
in
args
.
transformer_layers_per_block
.
split
(
","
)]
down_block_types
=
[
"DownBlock2D"
if
l
==
0
else
"CrossAttnDownBlock2D"
for
l
in
transformer_layers_per_block
]
controlnet
=
ControlNetModel
.
from_config
(
pre_controlnet
.
config
,
down_block_types
=
down_block_types
,
transformer_layers_per_block
=
transformer_layers_per_block
,
)
controlnet
.
load_state_dict
(
pre_controlnet
.
state_dict
(),
strict
=
False
)
del
pre_controlnet
else
:
controlnet
=
pre_controlnet
if
args
.
control_type
==
"depth"
:
feature_extractor
=
DPTFeatureExtractor
.
from_pretrained
(
"Intel/dpt-hybrid-midas"
)
depth_model
=
DPTForDepthEstimation
.
from_pretrained
(
"Intel/dpt-hybrid-midas"
)
depth_model
.
requires_grad_
(
False
)
else
:
feature_extractor
=
None
depth_model
=
None
# `accelerate` 0.16.0 will have better support for customized saving
if
version
.
parse
(
accelerate
.
__version__
)
>=
version
.
parse
(
"0.16.0"
):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def
save_model_hook
(
models
,
weights
,
output_dir
):
i
=
len
(
weights
)
-
1
while
len
(
weights
)
>
0
:
weights
.
pop
()
model
=
models
[
i
]
sub_dir
=
"controlnet"
model
.
save_pretrained
(
os
.
path
.
join
(
output_dir
,
sub_dir
))
i
-=
1
def
load_model_hook
(
models
,
input_dir
):
while
len
(
models
)
>
0
:
# pop models so that they are not loaded again
model
=
models
.
pop
()
# load diffusers style into model
load_model
=
ControlNetModel
.
from_pretrained
(
input_dir
,
subfolder
=
"controlnet"
)
model
.
register_to_config
(
**
load_model
.
config
)
model
.
load_state_dict
(
load_model
.
state_dict
())
del
load_model
accelerator
.
register_save_state_pre_hook
(
save_model_hook
)
accelerator
.
register_load_state_pre_hook
(
load_model_hook
)
vae
.
requires_grad_
(
False
)
unet
.
requires_grad_
(
False
)
text_encoder_one
.
requires_grad_
(
False
)
text_encoder_two
.
requires_grad_
(
False
)
controlnet
.
train
()
if
args
.
enable_xformers_memory_efficient_attention
:
if
is_xformers_available
():
import
xformers
xformers_version
=
version
.
parse
(
xformers
.
__version__
)
if
xformers_version
==
version
.
parse
(
"0.0.16"
):
logger
.
warn
(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet
.
enable_xformers_memory_efficient_attention
()
controlnet
.
enable_xformers_memory_efficient_attention
()
else
:
raise
ValueError
(
"xformers is not available. Make sure it is installed correctly"
)
if
args
.
gradient_checkpointing
:
controlnet
.
enable_gradient_checkpointing
()
# Check that all trainable models are in full precision
low_precision_error_string
=
(
" Please make sure to always have all model weights in full float32 precision when starting training - even if"
" doing mixed precision training, copy of the weights should still be float32."
)
if
accelerator
.
unwrap_model
(
controlnet
).
dtype
!=
torch
.
float32
:
raise
ValueError
(
f
"Controlnet loaded as datatype
{
accelerator
.
unwrap_model
(
controlnet
).
dtype
}
.
{
low_precision_error_string
}
"
)
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if
args
.
allow_tf32
:
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
True
if
args
.
scale_lr
:
args
.
learning_rate
=
(
args
.
learning_rate
*
args
.
gradient_accumulation_steps
*
args
.
train_batch_size
*
accelerator
.
num_processes
)
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
if
args
.
use_8bit_adam
:
try
:
import
bitsandbytes
as
bnb
except
ImportError
:
raise
ImportError
(
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
optimizer_class
=
bnb
.
optim
.
AdamW8bit
else
:
optimizer_class
=
torch
.
optim
.
AdamW
# Optimizer creation
params_to_optimize
=
controlnet
.
parameters
()
optimizer
=
optimizer_class
(
params_to_optimize
,
lr
=
args
.
learning_rate
,
betas
=
(
args
.
adam_beta1
,
args
.
adam_beta2
),
weight_decay
=
args
.
adam_weight_decay
,
eps
=
args
.
adam_epsilon
,
)
# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
weight_dtype
=
torch
.
float32
if
accelerator
.
mixed_precision
==
"fp16"
:
weight_dtype
=
torch
.
float16
elif
accelerator
.
mixed_precision
==
"bf16"
:
weight_dtype
=
torch
.
bfloat16
# Move vae, unet and text_encoder to device and cast to weight_dtype
# The VAE is in float32 to avoid NaN losses.
if
args
.
pretrained_vae_model_name_or_path
is
not
None
:
vae
.
to
(
accelerator
.
device
,
dtype
=
weight_dtype
)
else
:
vae
.
to
(
accelerator
.
device
,
dtype
=
torch
.
float32
)
unet
.
to
(
accelerator
.
device
,
dtype
=
weight_dtype
)
text_encoder_one
.
to
(
accelerator
.
device
,
dtype
=
weight_dtype
)
text_encoder_two
.
to
(
accelerator
.
device
,
dtype
=
weight_dtype
)
if
args
.
control_type
==
"depth"
:
depth_model
.
to
(
accelerator
.
device
,
dtype
=
weight_dtype
)
# Here, we compute not just the text embeddings but also the additional embeddings
# needed for the SD XL UNet to operate.
def
compute_embeddings
(
prompt_batch
,
original_sizes
,
crop_coords
,
proportion_empty_prompts
,
text_encoders
,
tokenizers
,
is_train
=
True
):
target_size
=
(
args
.
resolution
,
args
.
resolution
)
original_sizes
=
list
(
map
(
list
,
zip
(
*
original_sizes
)))
crops_coords_top_left
=
list
(
map
(
list
,
zip
(
*
crop_coords
)))
original_sizes
=
torch
.
tensor
(
original_sizes
,
dtype
=
torch
.
long
)
crops_coords_top_left
=
torch
.
tensor
(
crops_coords_top_left
,
dtype
=
torch
.
long
)
# crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w)
prompt_embeds
,
pooled_prompt_embeds
=
encode_prompt
(
prompt_batch
,
text_encoders
,
tokenizers
,
proportion_empty_prompts
,
is_train
)
add_text_embeds
=
pooled_prompt_embeds
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
# add_time_ids = list(crops_coords_top_left + target_size)
add_time_ids
=
list
(
target_size
)
add_time_ids
=
torch
.
tensor
([
add_time_ids
])
add_time_ids
=
add_time_ids
.
repeat
(
len
(
prompt_batch
),
1
)
# add_time_ids = torch.cat([torch.tensor(original_sizes, dtype=torch.long), add_time_ids], dim=-1)
add_time_ids
=
torch
.
cat
([
original_sizes
,
crops_coords_top_left
,
add_time_ids
],
dim
=-
1
)
add_time_ids
=
add_time_ids
.
to
(
accelerator
.
device
,
dtype
=
prompt_embeds
.
dtype
)
prompt_embeds
=
prompt_embeds
.
to
(
accelerator
.
device
)
add_text_embeds
=
add_text_embeds
.
to
(
accelerator
.
device
)
unet_added_cond_kwargs
=
{
"text_embeds"
:
add_text_embeds
,
"time_ids"
:
add_time_ids
}
return
{
"prompt_embeds"
:
prompt_embeds
,
**
unet_added_cond_kwargs
}
def
get_sigmas
(
timesteps
,
n_dim
=
4
,
dtype
=
torch
.
float32
):
sigmas
=
noise_scheduler
.
sigmas
.
to
(
device
=
accelerator
.
device
,
dtype
=
dtype
)
schedule_timesteps
=
noise_scheduler
.
timesteps
.
to
(
accelerator
.
device
)
timesteps
=
timesteps
.
to
(
accelerator
.
device
)
step_indices
=
[(
schedule_timesteps
==
t
).
nonzero
().
item
()
for
t
in
timesteps
]
sigma
=
sigmas
[
step_indices
].
flatten
()
while
len
(
sigma
.
shape
)
<
n_dim
:
sigma
=
sigma
.
unsqueeze
(
-
1
)
return
sigma
dataset
=
Text2ImageDataset
(
train_shards_path_or_url
=
args
.
train_shards_path_or_url
,
eval_shards_path_or_url
=
args
.
eval_shards_path_or_url
,
num_train_examples
=
args
.
max_train_samples
,
per_gpu_batch_size
=
args
.
train_batch_size
,
global_batch_size
=
args
.
train_batch_size
*
accelerator
.
num_processes
,
num_workers
=
args
.
dataloader_num_workers
,
resolution
=
args
.
resolution
,
center_crop
=
False
,
random_flip
=
False
,
shuffle_buffer_size
=
1000
,
pin_memory
=
True
,
persistent_workers
=
True
,
control_type
=
args
.
control_type
,
feature_extractor
=
feature_extractor
,
)
train_dataloader
=
dataset
.
train_dataloader
# Let's first compute all the embeddings so that we can free up the text encoders
# from memory.
text_encoders
=
[
text_encoder_one
,
text_encoder_two
]
tokenizers
=
[
tokenizer_one
,
tokenizer_two
]
compute_embeddings_fn
=
functools
.
partial
(
compute_embeddings
,
proportion_empty_prompts
=
args
.
proportion_empty_prompts
,
text_encoders
=
text_encoders
,
tokenizers
=
tokenizers
,
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps
=
False
num_update_steps_per_epoch
=
math
.
ceil
(
train_dataloader
.
num_batches
/
args
.
gradient_accumulation_steps
)
if
args
.
max_train_steps
is
None
:
args
.
max_train_steps
=
args
.
num_train_epochs
*
num_update_steps_per_epoch
overrode_max_train_steps
=
True
lr_scheduler
=
get_scheduler
(
args
.
lr_scheduler
,
optimizer
=
optimizer
,
num_warmup_steps
=
args
.
lr_warmup_steps
*
accelerator
.
num_processes
,
num_training_steps
=
args
.
max_train_steps
*
accelerator
.
num_processes
,
num_cycles
=
args
.
lr_num_cycles
,
power
=
args
.
lr_power
,
)
# Prepare everything with our `accelerator`.
controlnet
,
optimizer
,
lr_scheduler
=
accelerator
.
prepare
(
controlnet
,
optimizer
,
lr_scheduler
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch
=
math
.
ceil
(
train_dataloader
.
num_batches
/
args
.
gradient_accumulation_steps
)
if
overrode_max_train_steps
:
args
.
max_train_steps
=
args
.
num_train_epochs
*
num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args
.
num_train_epochs
=
math
.
ceil
(
args
.
max_train_steps
/
num_update_steps_per_epoch
)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if
accelerator
.
is_main_process
:
tracker_config
=
dict
(
vars
(
args
))
# tensorboard cannot handle list types for config
tracker_config
.
pop
(
"validation_prompt"
)
tracker_config
.
pop
(
"validation_image"
)
accelerator
.
init_trackers
(
args
.
tracker_project_name
,
config
=
tracker_config
)
# Train!
total_batch_size
=
args
.
train_batch_size
*
accelerator
.
num_processes
*
args
.
gradient_accumulation_steps
logger
.
info
(
"***** Running training *****"
)
logger
.
info
(
f
" Num batches each epoch =
{
train_dataloader
.
num_batches
}
"
)
logger
.
info
(
f
" Num Epochs =
{
args
.
num_train_epochs
}
"
)
logger
.
info
(
f
" Instantaneous batch size per device =
{
args
.
train_batch_size
}
"
)
logger
.
info
(
f
" Total train batch size (w. parallel, distributed & accumulation) =
{
total_batch_size
}
"
)
logger
.
info
(
f
" Gradient Accumulation steps =
{
args
.
gradient_accumulation_steps
}
"
)
logger
.
info
(
f
" Total optimization steps =
{
args
.
max_train_steps
}
"
)
global_step
=
0
first_epoch
=
0
# Potentially load in the weights and states from a previous save
if
args
.
resume_from_checkpoint
:
if
args
.
resume_from_checkpoint
!=
"latest"
:
path
=
os
.
path
.
basename
(
args
.
resume_from_checkpoint
)
else
:
# Get the most recent checkpoint
dirs
=
os
.
listdir
(
args
.
output_dir
)
dirs
=
[
d
for
d
in
dirs
if
d
.
startswith
(
"checkpoint"
)]
dirs
=
sorted
(
dirs
,
key
=
lambda
x
:
int
(
x
.
split
(
"-"
)[
1
]))
path
=
dirs
[
-
1
]
if
len
(
dirs
)
>
0
else
None
if
path
is
None
:
accelerator
.
print
(
f
"Checkpoint '
{
args
.
resume_from_checkpoint
}
' does not exist. Starting a new training run."
)
args
.
resume_from_checkpoint
=
None
initial_global_step
=
0
else
:
accelerator
.
print
(
f
"Resuming from checkpoint
{
path
}
"
)
accelerator
.
load_state
(
os
.
path
.
join
(
args
.
output_dir
,
path
))
global_step
=
int
(
path
.
split
(
"-"
)[
1
])
initial_global_step
=
global_step
first_epoch
=
global_step
//
num_update_steps_per_epoch
else
:
initial_global_step
=
0
progress_bar
=
tqdm
(
range
(
0
,
args
.
max_train_steps
),
initial
=
initial_global_step
,
desc
=
"Steps"
,
# Only show the progress bar once on each machine.
disable
=
not
accelerator
.
is_local_main_process
,
)
image_logs
=
None
for
epoch
in
range
(
first_epoch
,
args
.
num_train_epochs
):
for
step
,
batch
in
enumerate
(
train_dataloader
):
with
accelerator
.
accumulate
(
controlnet
):
image
,
control_image
,
text
,
orig_size
,
crop_coords
=
batch
encoded_text
=
compute_embeddings_fn
(
text
,
orig_size
,
crop_coords
)
image
=
image
.
to
(
accelerator
.
device
,
non_blocking
=
True
)
control_image
=
control_image
.
to
(
accelerator
.
device
,
non_blocking
=
True
)
if
args
.
pretrained_vae_model_name_or_path
is
not
None
:
pixel_values
=
image
.
to
(
dtype
=
weight_dtype
)
if
vae
.
dtype
!=
weight_dtype
:
vae
.
to
(
dtype
=
weight_dtype
)
else
:
pixel_values
=
image
# latents = vae.encode(pixel_values).latent_dist.sample()
# encode pixel values with batch size of at most 8
latents
=
[]
for
i
in
range
(
0
,
pixel_values
.
shape
[
0
],
8
):
latents
.
append
(
vae
.
encode
(
pixel_values
[
i
:
i
+
8
]).
latent_dist
.
sample
())
latents
=
torch
.
cat
(
latents
,
dim
=
0
)
latents
=
latents
*
vae
.
config
.
scaling_factor
if
args
.
pretrained_vae_model_name_or_path
is
None
:
latents
=
latents
.
to
(
weight_dtype
)
if
args
.
control_type
==
"depth"
:
control_image
=
control_image
.
to
(
weight_dtype
)
with
torch
.
autocast
(
"cuda"
):
depth_map
=
depth_model
(
control_image
).
predicted_depth
depth_map
=
torch
.
nn
.
functional
.
interpolate
(
depth_map
.
unsqueeze
(
1
),
size
=
image
.
shape
[
2
:],
mode
=
"bicubic"
,
align_corners
=
False
,
)
depth_min
=
torch
.
amin
(
depth_map
,
dim
=
[
1
,
2
,
3
],
keepdim
=
True
)
depth_max
=
torch
.
amax
(
depth_map
,
dim
=
[
1
,
2
,
3
],
keepdim
=
True
)
depth_map
=
(
depth_map
-
depth_min
)
/
(
depth_max
-
depth_min
)
control_image
=
(
depth_map
*
255.0
).
to
(
torch
.
uint8
).
float
()
/
255.0
# hack to match inference
control_image
=
torch
.
cat
([
control_image
]
*
3
,
dim
=
1
)
# Sample noise that we'll add to the latents
noise
=
torch
.
randn_like
(
latents
)
bsz
=
latents
.
shape
[
0
]
# Sample a random timestep for each image
timesteps
=
torch
.
randint
(
0
,
noise_scheduler
.
config
.
num_train_timesteps
,
(
bsz
,),
device
=
latents
.
device
)
timesteps
=
timesteps
.
long
()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents
=
noise_scheduler
.
add_noise
(
latents
,
noise
,
timesteps
)
sigmas
=
get_sigmas
(
timesteps
,
len
(
noisy_latents
.
shape
),
noisy_latents
.
dtype
)
inp_noisy_latents
=
noisy_latents
/
((
sigmas
**
2
+
1
)
**
0.5
)
# ControlNet conditioning.
controlnet_image
=
control_image
.
to
(
dtype
=
weight_dtype
)
prompt_embeds
=
encoded_text
.
pop
(
"prompt_embeds"
)
down_block_res_samples
,
mid_block_res_sample
=
controlnet
(
inp_noisy_latents
,
timesteps
,
encoder_hidden_states
=
prompt_embeds
,
added_cond_kwargs
=
encoded_text
,
controlnet_cond
=
controlnet_image
,
return_dict
=
False
,
)
# Predict the noise residual
model_pred
=
unet
(
inp_noisy_latents
,
timesteps
,
encoder_hidden_states
=
prompt_embeds
,
added_cond_kwargs
=
encoded_text
,
down_block_additional_residuals
=
[
sample
.
to
(
dtype
=
weight_dtype
)
for
sample
in
down_block_res_samples
],
mid_block_additional_residual
=
mid_block_res_sample
.
to
(
dtype
=
weight_dtype
),
).
sample
model_pred
=
model_pred
*
(
-
sigmas
)
+
noisy_latents
weighing
=
sigmas
**-
2.0
# Get the target for loss depending on the prediction type
if
noise_scheduler
.
config
.
prediction_type
==
"epsilon"
:
target
=
latents
# compute loss against the denoised latents
elif
noise_scheduler
.
config
.
prediction_type
==
"v_prediction"
:
target
=
noise_scheduler
.
get_velocity
(
latents
,
noise
,
timesteps
)
else
:
raise
ValueError
(
f
"Unknown prediction type
{
noise_scheduler
.
config
.
prediction_type
}
"
)
loss
=
torch
.
mean
(
(
weighing
.
float
()
*
(
model_pred
.
float
()
-
target
.
float
())
**
2
).
reshape
(
target
.
shape
[
0
],
-
1
),
1
)
loss
=
loss
.
mean
()
accelerator
.
backward
(
loss
)
if
accelerator
.
sync_gradients
:
params_to_clip
=
controlnet
.
parameters
()
accelerator
.
clip_grad_norm_
(
params_to_clip
,
args
.
max_grad_norm
)
optimizer
.
step
()
lr_scheduler
.
step
()
optimizer
.
zero_grad
(
set_to_none
=
args
.
set_grads_to_none
)
# Checks if the accelerator has performed an optimization step behind the scenes
if
accelerator
.
sync_gradients
:
progress_bar
.
update
(
1
)
global_step
+=
1
if
accelerator
.
is_main_process
:
if
global_step
%
args
.
checkpointing_steps
==
0
:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if
args
.
checkpoints_total_limit
is
not
None
:
checkpoints
=
os
.
listdir
(
args
.
output_dir
)
checkpoints
=
[
d
for
d
in
checkpoints
if
d
.
startswith
(
"checkpoint"
)]
checkpoints
=
sorted
(
checkpoints
,
key
=
lambda
x
:
int
(
x
.
split
(
"-"
)[
1
]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if
len
(
checkpoints
)
>=
args
.
checkpoints_total_limit
:
num_to_remove
=
len
(
checkpoints
)
-
args
.
checkpoints_total_limit
+
1
removing_checkpoints
=
checkpoints
[
0
:
num_to_remove
]
logger
.
info
(
f
"
{
len
(
checkpoints
)
}
checkpoints already exist, removing
{
len
(
removing_checkpoints
)
}
checkpoints"
)
logger
.
info
(
f
"removing checkpoints:
{
', '
.
join
(
removing_checkpoints
)
}
"
)
for
removing_checkpoint
in
removing_checkpoints
:
removing_checkpoint
=
os
.
path
.
join
(
args
.
output_dir
,
removing_checkpoint
)
shutil
.
rmtree
(
removing_checkpoint
)
save_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
"checkpoint-
{
global_step
}
"
)
accelerator
.
save_state
(
save_path
)
logger
.
info
(
f
"Saved state to
{
save_path
}
"
)
if
args
.
validation_prompt
is
not
None
and
global_step
%
args
.
validation_steps
==
0
:
image_logs
=
log_validation
(
vae
,
unet
,
controlnet
,
args
,
accelerator
,
weight_dtype
,
global_step
)
logs
=
{
"loss"
:
loss
.
detach
().
item
(),
"lr"
:
lr_scheduler
.
get_last_lr
()[
0
]}
progress_bar
.
set_postfix
(
**
logs
)
accelerator
.
log
(
logs
,
step
=
global_step
)
if
global_step
>=
args
.
max_train_steps
:
break
# Create the pipeline using using the trained modules and save it.
accelerator
.
wait_for_everyone
()
if
accelerator
.
is_main_process
:
controlnet
=
accelerator
.
unwrap_model
(
controlnet
)
controlnet
.
save_pretrained
(
args
.
output_dir
)
if
args
.
push_to_hub
:
save_model_card
(
repo_id
,
image_logs
=
image_logs
,
base_model
=
args
.
pretrained_model_name_or_path
,
repo_folder
=
args
.
output_dir
,
)
upload_folder
(
repo_id
=
repo_id
,
folder_path
=
args
.
output_dir
,
commit_message
=
"End of training"
,
ignore_patterns
=
[
"step_*"
,
"epoch_*"
],
)
accelerator
.
end_training
()
if
__name__
==
"__main__"
:
args
=
parse_args
()
main
(
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment