Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
47ecb223
Unverified
Commit
47ecb223
authored
Feb 20, 2023
by
Haofan Wang
Committed by
GitHub
Feb 20, 2023
Browse files
[example] add LoRA support (#2821)
* add lora * format
parent
b6a108cb
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
691 additions
and
0 deletions
+691
-0
examples/images/dreambooth/train_dreambooth_colossalai_lora.py
...les/images/dreambooth/train_dreambooth_colossalai_lora.py
+691
-0
No files found.
examples/images/dreambooth/train_dreambooth_colossalai_lora.py
0 → 100644
View file @
47ecb223
import
argparse
import
hashlib
import
math
import
os
from
pathlib
import
Path
from
typing
import
Optional
import
torch
import
torch.nn.functional
as
F
import
torch.utils.checkpoint
from
diffusers
import
AutoencoderKL
,
DDPMScheduler
,
DiffusionPipeline
,
UNet2DConditionModel
from
diffusers.loaders
import
AttnProcsLayers
from
diffusers.models.cross_attention
import
LoRACrossAttnProcessor
from
diffusers.optimization
import
get_scheduler
from
huggingface_hub
import
HfFolder
,
Repository
,
create_repo
,
whoami
from
PIL
import
Image
from
torch.utils.data
import
Dataset
from
torchvision
import
transforms
from
tqdm.auto
import
tqdm
from
transformers
import
AutoTokenizer
,
PretrainedConfig
import
colossalai
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.logging
import
disable_existing_loggers
,
get_dist_logger
from
colossalai.nn.optimizer.gemini_optimizer
import
GeminiAdamOptimizer
from
colossalai.nn.parallel.utils
import
get_static_torch_model
from
colossalai.utils
import
get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
disable_existing_loggers
()
logger
=
get_dist_logger
()
def
import_model_class_from_model_name_or_path
(
pretrained_model_name_or_path
:
str
):
text_encoder_config
=
PretrainedConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
subfolder
=
"text_encoder"
,
revision
=
args
.
revision
,
)
model_class
=
text_encoder_config
.
architectures
[
0
]
if
model_class
==
"CLIPTextModel"
:
from
transformers
import
CLIPTextModel
return
CLIPTextModel
elif
model_class
==
"RobertaSeriesModelWithTransformation"
:
from
diffusers.pipelines.alt_diffusion.modeling_roberta_series
import
RobertaSeriesModelWithTransformation
return
RobertaSeriesModelWithTransformation
else
:
raise
ValueError
(
f
"
{
model_class
}
is not supported."
)
def
parse_args
(
input_args
=
None
):
parser
=
argparse
.
ArgumentParser
(
description
=
"Simple example of a training script."
)
parser
.
add_argument
(
"--pretrained_model_name_or_path"
,
type
=
str
,
default
=
None
,
required
=
True
,
help
=
"Path to pretrained model or model identifier from huggingface.co/models."
,
)
parser
.
add_argument
(
"--revision"
,
type
=
str
,
default
=
None
,
required
=
False
,
help
=
"Revision of pretrained model identifier from huggingface.co/models."
,
)
parser
.
add_argument
(
"--tokenizer_name"
,
type
=
str
,
default
=
None
,
help
=
"Pretrained tokenizer name or path if not the same as model_name"
,
)
parser
.
add_argument
(
"--instance_data_dir"
,
type
=
str
,
default
=
None
,
required
=
True
,
help
=
"A folder containing the training data of instance images."
,
)
parser
.
add_argument
(
"--class_data_dir"
,
type
=
str
,
default
=
None
,
required
=
False
,
help
=
"A folder containing the training data of class images."
,
)
parser
.
add_argument
(
"--instance_prompt"
,
type
=
str
,
default
=
"a photo of sks dog"
,
required
=
False
,
help
=
"The prompt with identifier specifying the instance"
,
)
parser
.
add_argument
(
"--class_prompt"
,
type
=
str
,
default
=
None
,
help
=
"The prompt to specify images in the same class as provided instance images."
,
)
parser
.
add_argument
(
"--with_prior_preservation"
,
default
=
False
,
action
=
"store_true"
,
help
=
"Flag to add prior preservation loss."
,
)
parser
.
add_argument
(
"--prior_loss_weight"
,
type
=
float
,
default
=
1.0
,
help
=
"The weight of prior preservation loss."
)
parser
.
add_argument
(
"--num_class_images"
,
type
=
int
,
default
=
100
,
help
=
(
"Minimal class images for prior preservation loss. If there are not enough images already present in"
" class_data_dir, additional images will be sampled with class_prompt."
),
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
"text-inversion-model"
,
help
=
"The output directory where the model predictions and checkpoints will be written."
,
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
None
,
help
=
"A seed for reproducible training."
)
parser
.
add_argument
(
"--resolution"
,
type
=
int
,
default
=
512
,
help
=
(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser
.
add_argument
(
"--placement"
,
type
=
str
,
default
=
"cpu"
,
help
=
"Placement Policy for Gemini. Valid when using colossalai as dist plan."
,
)
parser
.
add_argument
(
"--center_crop"
,
default
=
False
,
action
=
"store_true"
,
help
=
(
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
" cropped. The images will be resized to the resolution first before cropping."
),
)
parser
.
add_argument
(
"--train_batch_size"
,
type
=
int
,
default
=
4
,
help
=
"Batch size (per device) for the training dataloader."
)
parser
.
add_argument
(
"--sample_batch_size"
,
type
=
int
,
default
=
4
,
help
=
"Batch size (per device) for sampling images."
)
parser
.
add_argument
(
"--num_train_epochs"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--max_train_steps"
,
type
=
int
,
default
=
None
,
help
=
"Total number of training steps to perform. If provided, overrides num_train_epochs."
,
)
parser
.
add_argument
(
"--save_steps"
,
type
=
int
,
default
=
500
,
help
=
"Save checkpoint every X updates steps."
)
parser
.
add_argument
(
"--gradient_checkpointing"
,
action
=
"store_true"
,
help
=
"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass."
,
)
parser
.
add_argument
(
"--learning_rate"
,
type
=
float
,
default
=
5e-6
,
help
=
"Initial learning rate (after the potential warmup period) to use."
,
)
parser
.
add_argument
(
"--scale_lr"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size."
,
)
parser
.
add_argument
(
"--lr_scheduler"
,
type
=
str
,
default
=
"constant"
,
help
=
(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser
.
add_argument
(
"--lr_warmup_steps"
,
type
=
int
,
default
=
500
,
help
=
"Number of steps for the warmup in the lr scheduler."
)
parser
.
add_argument
(
"--use_8bit_adam"
,
action
=
"store_true"
,
help
=
"Whether or not to use 8-bit Adam from bitsandbytes."
)
parser
.
add_argument
(
"--max_grad_norm"
,
default
=
1.0
,
type
=
float
,
help
=
"Max gradient norm."
)
parser
.
add_argument
(
"--push_to_hub"
,
action
=
"store_true"
,
help
=
"Whether or not to push the model to the Hub."
)
parser
.
add_argument
(
"--hub_token"
,
type
=
str
,
default
=
None
,
help
=
"The token to use to push to the Model Hub."
)
parser
.
add_argument
(
"--hub_model_id"
,
type
=
str
,
default
=
None
,
help
=
"The name of the repository to keep in sync with the local `output_dir`."
,
)
parser
.
add_argument
(
"--logging_dir"
,
type
=
str
,
default
=
"logs"
,
help
=
(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser
.
add_argument
(
"--mixed_precision"
,
type
=
str
,
default
=
None
,
choices
=
[
"no"
,
"fp16"
,
"bf16"
],
help
=
(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser
.
add_argument
(
"--local_rank"
,
type
=
int
,
default
=-
1
,
help
=
"For distributed training: local_rank"
)
if
input_args
is
not
None
:
args
=
parser
.
parse_args
(
input_args
)
else
:
args
=
parser
.
parse_args
()
env_local_rank
=
int
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
-
1
))
if
env_local_rank
!=
-
1
and
env_local_rank
!=
args
.
local_rank
:
args
.
local_rank
=
env_local_rank
if
args
.
with_prior_preservation
:
if
args
.
class_data_dir
is
None
:
raise
ValueError
(
"You must specify a data directory for class images."
)
if
args
.
class_prompt
is
None
:
raise
ValueError
(
"You must specify prompt for class images."
)
else
:
if
args
.
class_data_dir
is
not
None
:
logger
.
warning
(
"You need not use --class_data_dir without --with_prior_preservation."
)
if
args
.
class_prompt
is
not
None
:
logger
.
warning
(
"You need not use --class_prompt without --with_prior_preservation."
)
return
args
class
DreamBoothDataset
(
Dataset
):
"""
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
It pre-processes the images and the tokenizes prompts.
"""
def
__init__
(
self
,
instance_data_root
,
instance_prompt
,
tokenizer
,
class_data_root
=
None
,
class_prompt
=
None
,
size
=
512
,
center_crop
=
False
,
):
self
.
size
=
size
self
.
center_crop
=
center_crop
self
.
tokenizer
=
tokenizer
self
.
instance_data_root
=
Path
(
instance_data_root
)
if
not
self
.
instance_data_root
.
exists
():
raise
ValueError
(
"Instance images root doesn't exists."
)
self
.
instance_images_path
=
list
(
Path
(
instance_data_root
).
iterdir
())
self
.
num_instance_images
=
len
(
self
.
instance_images_path
)
self
.
instance_prompt
=
instance_prompt
self
.
_length
=
self
.
num_instance_images
if
class_data_root
is
not
None
:
self
.
class_data_root
=
Path
(
class_data_root
)
self
.
class_data_root
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
self
.
class_images_path
=
list
(
self
.
class_data_root
.
iterdir
())
self
.
num_class_images
=
len
(
self
.
class_images_path
)
self
.
_length
=
max
(
self
.
num_class_images
,
self
.
num_instance_images
)
self
.
class_prompt
=
class_prompt
else
:
self
.
class_data_root
=
None
self
.
image_transforms
=
transforms
.
Compose
([
transforms
.
Resize
(
size
,
interpolation
=
transforms
.
InterpolationMode
.
BILINEAR
),
transforms
.
CenterCrop
(
size
)
if
center_crop
else
transforms
.
RandomCrop
(
size
),
transforms
.
ToTensor
(),
transforms
.
Normalize
([
0.5
],
[
0.5
]),
])
def
__len__
(
self
):
return
self
.
_length
def
__getitem__
(
self
,
index
):
example
=
{}
instance_image
=
Image
.
open
(
self
.
instance_images_path
[
index
%
self
.
num_instance_images
])
if
not
instance_image
.
mode
==
"RGB"
:
instance_image
=
instance_image
.
convert
(
"RGB"
)
example
[
"instance_images"
]
=
self
.
image_transforms
(
instance_image
)
example
[
"instance_prompt_ids"
]
=
self
.
tokenizer
(
self
.
instance_prompt
,
padding
=
"do_not_pad"
,
truncation
=
True
,
max_length
=
self
.
tokenizer
.
model_max_length
,
).
input_ids
if
self
.
class_data_root
:
class_image
=
Image
.
open
(
self
.
class_images_path
[
index
%
self
.
num_class_images
])
if
not
class_image
.
mode
==
"RGB"
:
class_image
=
class_image
.
convert
(
"RGB"
)
example
[
"class_images"
]
=
self
.
image_transforms
(
class_image
)
example
[
"class_prompt_ids"
]
=
self
.
tokenizer
(
self
.
class_prompt
,
padding
=
"do_not_pad"
,
truncation
=
True
,
max_length
=
self
.
tokenizer
.
model_max_length
,
).
input_ids
return
example
class
PromptDataset
(
Dataset
):
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
def
__init__
(
self
,
prompt
,
num_samples
):
self
.
prompt
=
prompt
self
.
num_samples
=
num_samples
def
__len__
(
self
):
return
self
.
num_samples
def
__getitem__
(
self
,
index
):
example
=
{}
example
[
"prompt"
]
=
self
.
prompt
example
[
"index"
]
=
index
return
example
def
get_full_repo_name
(
model_id
:
str
,
organization
:
Optional
[
str
]
=
None
,
token
:
Optional
[
str
]
=
None
):
if
token
is
None
:
token
=
HfFolder
.
get_token
()
if
organization
is
None
:
username
=
whoami
(
token
)[
"name"
]
return
f
"
{
username
}
/
{
model_id
}
"
else
:
return
f
"
{
organization
}
/
{
model_id
}
"
# Gemini + ZeRO DDP
def
gemini_zero_dpp
(
model
:
torch
.
nn
.
Module
,
placememt_policy
:
str
=
"auto"
):
from
colossalai.nn.parallel
import
GeminiDDP
model
=
GeminiDDP
(
model
,
device
=
get_current_device
(),
placement_policy
=
placememt_policy
,
pin_memory
=
True
,
search_range_mb
=
64
)
return
model
def
main
(
args
):
if
args
.
seed
is
None
:
colossalai
.
launch_from_torch
(
config
=
{})
else
:
colossalai
.
launch_from_torch
(
config
=
{},
seed
=
args
.
seed
)
local_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
DATA
)
world_size
=
gpc
.
get_world_size
(
ParallelMode
.
DATA
)
if
args
.
with_prior_preservation
:
class_images_dir
=
Path
(
args
.
class_data_dir
)
if
not
class_images_dir
.
exists
():
class_images_dir
.
mkdir
(
parents
=
True
)
cur_class_images
=
len
(
list
(
class_images_dir
.
iterdir
()))
if
cur_class_images
<
args
.
num_class_images
:
torch_dtype
=
torch
.
float16
if
get_current_device
()
==
"cuda"
else
torch
.
float32
pipeline
=
DiffusionPipeline
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
torch_dtype
=
torch_dtype
,
safety_checker
=
None
,
revision
=
args
.
revision
,
)
pipeline
.
set_progress_bar_config
(
disable
=
True
)
num_new_images
=
args
.
num_class_images
-
cur_class_images
logger
.
info
(
f
"Number of class images to sample:
{
num_new_images
}
."
)
sample_dataset
=
PromptDataset
(
args
.
class_prompt
,
num_new_images
)
sample_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
sample_dataset
,
batch_size
=
args
.
sample_batch_size
)
pipeline
.
to
(
get_current_device
())
for
example
in
tqdm
(
sample_dataloader
,
desc
=
"Generating class images"
,
disable
=
not
local_rank
==
0
,
):
images
=
pipeline
(
example
[
"prompt"
]).
images
for
i
,
image
in
enumerate
(
images
):
hash_image
=
hashlib
.
sha1
(
image
.
tobytes
()).
hexdigest
()
image_filename
=
class_images_dir
/
f
"
{
example
[
'index'
][
i
]
+
cur_class_images
}
-
{
hash_image
}
.jpg"
image
.
save
(
image_filename
)
del
pipeline
# Handle the repository creation
if
local_rank
==
0
:
if
args
.
push_to_hub
:
if
args
.
hub_model_id
is
None
:
repo_name
=
get_full_repo_name
(
Path
(
args
.
output_dir
).
name
,
token
=
args
.
hub_token
)
else
:
repo_name
=
args
.
hub_model_id
create_repo
(
repo_name
,
exist_ok
=
True
,
token
=
args
.
hub_token
)
repo
=
Repository
(
args
.
output_dir
,
clone_from
=
repo_name
,
token
=
args
.
hub_token
)
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
),
"w+"
)
as
gitignore
:
if
"step_*"
not
in
gitignore
:
gitignore
.
write
(
"step_*
\n
"
)
if
"epoch_*"
not
in
gitignore
:
gitignore
.
write
(
"epoch_*
\n
"
)
elif
args
.
output_dir
is
not
None
:
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
# Load the tokenizer
if
args
.
tokenizer_name
:
logger
.
info
(
f
"Loading tokenizer from
{
args
.
tokenizer_name
}
"
,
ranks
=
[
0
])
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
tokenizer_name
,
revision
=
args
.
revision
,
use_fast
=
False
,
)
elif
args
.
pretrained_model_name_or_path
:
logger
.
info
(
"Loading tokenizer from pretrained model"
,
ranks
=
[
0
])
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"tokenizer"
,
revision
=
args
.
revision
,
use_fast
=
False
,
)
# import correct text encoder class
text_encoder_cls
=
import_model_class_from_model_name_or_path
(
args
.
pretrained_model_name_or_path
)
# Load models and create wrapper for stable diffusion
logger
.
info
(
f
"Loading text_encoder from
{
args
.
pretrained_model_name_or_path
}
"
,
ranks
=
[
0
])
text_encoder
=
text_encoder_cls
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"text_encoder"
,
revision
=
args
.
revision
,
)
logger
.
info
(
f
"Loading AutoencoderKL from
{
args
.
pretrained_model_name_or_path
}
"
,
ranks
=
[
0
])
vae
=
AutoencoderKL
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"vae"
,
revision
=
args
.
revision
,
)
logger
.
info
(
f
"Loading UNet2DConditionModel from
{
args
.
pretrained_model_name_or_path
}
"
,
ranks
=
[
0
])
with
ColoInitContext
(
device
=
get_current_device
()):
unet
=
UNet2DConditionModel
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"unet"
,
revision
=
args
.
revision
,
low_cpu_mem_usage
=
False
)
unet
.
requires_grad_
(
False
)
# Set correct lora layers
lora_attn_procs
=
{}
for
name
in
unet
.
attn_processors
.
keys
():
cross_attention_dim
=
None
if
name
.
endswith
(
"attn1.processor"
)
else
unet
.
config
.
cross_attention_dim
if
name
.
startswith
(
"mid_block"
):
hidden_size
=
unet
.
config
.
block_out_channels
[
-
1
]
elif
name
.
startswith
(
"up_blocks"
):
block_id
=
int
(
name
[
len
(
"up_blocks."
)])
hidden_size
=
list
(
reversed
(
unet
.
config
.
block_out_channels
))[
block_id
]
elif
name
.
startswith
(
"down_blocks"
):
block_id
=
int
(
name
[
len
(
"down_blocks."
)])
hidden_size
=
unet
.
config
.
block_out_channels
[
block_id
]
lora_attn_procs
[
name
]
=
LoRACrossAttnProcessor
(
hidden_size
=
hidden_size
,
cross_attention_dim
=
cross_attention_dim
)
unet
.
set_attn_processor
(
lora_attn_procs
)
lora_layers
=
AttnProcsLayers
(
unet
.
attn_processors
)
vae
.
requires_grad_
(
False
)
text_encoder
.
requires_grad_
(
False
)
if
args
.
gradient_checkpointing
:
unet
.
enable_gradient_checkpointing
()
if
args
.
scale_lr
:
args
.
learning_rate
=
args
.
learning_rate
*
args
.
train_batch_size
*
world_size
unet
=
gemini_zero_dpp
(
unet
,
args
.
placement
)
# config optimizer for colossalai zero
optimizer
=
GeminiAdamOptimizer
(
unet
,
lr
=
args
.
learning_rate
,
initial_scale
=
2
**
5
,
clipping_norm
=
args
.
max_grad_norm
)
# load noise_scheduler
noise_scheduler
=
DDPMScheduler
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"scheduler"
)
# prepare dataset
logger
.
info
(
f
"Prepare dataset from
{
args
.
instance_data_dir
}
"
,
ranks
=
[
0
])
train_dataset
=
DreamBoothDataset
(
instance_data_root
=
args
.
instance_data_dir
,
instance_prompt
=
args
.
instance_prompt
,
class_data_root
=
args
.
class_data_dir
if
args
.
with_prior_preservation
else
None
,
class_prompt
=
args
.
class_prompt
,
tokenizer
=
tokenizer
,
size
=
args
.
resolution
,
center_crop
=
args
.
center_crop
,
)
def
collate_fn
(
examples
):
input_ids
=
[
example
[
"instance_prompt_ids"
]
for
example
in
examples
]
pixel_values
=
[
example
[
"instance_images"
]
for
example
in
examples
]
# Concat class and instance examples for prior preservation.
# We do this to avoid doing two forward passes.
if
args
.
with_prior_preservation
:
input_ids
+=
[
example
[
"class_prompt_ids"
]
for
example
in
examples
]
pixel_values
+=
[
example
[
"class_images"
]
for
example
in
examples
]
pixel_values
=
torch
.
stack
(
pixel_values
)
pixel_values
=
pixel_values
.
to
(
memory_format
=
torch
.
contiguous_format
).
float
()
input_ids
=
tokenizer
.
pad
(
{
"input_ids"
:
input_ids
},
padding
=
"max_length"
,
max_length
=
tokenizer
.
model_max_length
,
return_tensors
=
"pt"
,
).
input_ids
batch
=
{
"input_ids"
:
input_ids
,
"pixel_values"
:
pixel_values
,
}
return
batch
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
train_dataset
,
batch_size
=
args
.
train_batch_size
,
shuffle
=
True
,
collate_fn
=
collate_fn
,
num_workers
=
1
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps
=
False
num_update_steps_per_epoch
=
math
.
ceil
(
len
(
train_dataloader
))
if
args
.
max_train_steps
is
None
:
args
.
max_train_steps
=
args
.
num_train_epochs
*
num_update_steps_per_epoch
overrode_max_train_steps
=
True
lr_scheduler
=
get_scheduler
(
args
.
lr_scheduler
,
optimizer
=
optimizer
,
num_warmup_steps
=
args
.
lr_warmup_steps
,
num_training_steps
=
args
.
max_train_steps
,
)
weight_dtype
=
torch
.
float32
if
args
.
mixed_precision
==
"fp16"
:
weight_dtype
=
torch
.
float16
elif
args
.
mixed_precision
==
"bf16"
:
weight_dtype
=
torch
.
bfloat16
# Move text_encode and vae to gpu.
# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
vae
.
to
(
get_current_device
(),
dtype
=
weight_dtype
)
text_encoder
.
to
(
get_current_device
(),
dtype
=
weight_dtype
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch
=
math
.
ceil
(
len
(
train_dataloader
))
if
overrode_max_train_steps
:
args
.
max_train_steps
=
args
.
num_train_epochs
*
num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args
.
num_train_epochs
=
math
.
ceil
(
args
.
max_train_steps
/
num_update_steps_per_epoch
)
# Train!
total_batch_size
=
args
.
train_batch_size
*
world_size
logger
.
info
(
"***** Running training *****"
,
ranks
=
[
0
])
logger
.
info
(
f
" Num examples =
{
len
(
train_dataset
)
}
"
,
ranks
=
[
0
])
logger
.
info
(
f
" Num batches each epoch =
{
len
(
train_dataloader
)
}
"
,
ranks
=
[
0
])
logger
.
info
(
f
" Num Epochs =
{
args
.
num_train_epochs
}
"
,
ranks
=
[
0
])
logger
.
info
(
f
" Instantaneous batch size per device =
{
args
.
train_batch_size
}
"
,
ranks
=
[
0
])
logger
.
info
(
f
" Total train batch size (w. parallel, distributed & accumulation) =
{
total_batch_size
}
"
,
ranks
=
[
0
])
logger
.
info
(
f
" Total optimization steps =
{
args
.
max_train_steps
}
"
,
ranks
=
[
0
])
# Only show the progress bar once on each machine.
progress_bar
=
tqdm
(
range
(
args
.
max_train_steps
),
disable
=
not
local_rank
==
0
)
progress_bar
.
set_description
(
"Steps"
)
global_step
=
0
torch
.
cuda
.
synchronize
()
for
epoch
in
range
(
args
.
num_train_epochs
):
unet
.
train
()
for
step
,
batch
in
enumerate
(
train_dataloader
):
torch
.
cuda
.
reset_peak_memory_stats
()
# Move batch to gpu
for
key
,
value
in
batch
.
items
():
batch
[
key
]
=
value
.
to
(
get_current_device
(),
non_blocking
=
True
)
# Convert images to latent space
optimizer
.
zero_grad
()
latents
=
vae
.
encode
(
batch
[
"pixel_values"
].
to
(
dtype
=
weight_dtype
)).
latent_dist
.
sample
()
latents
=
latents
*
0.18215
# Sample noise that we'll add to the latents
noise
=
torch
.
randn_like
(
latents
)
bsz
=
latents
.
shape
[
0
]
# Sample a random timestep for each image
timesteps
=
torch
.
randint
(
0
,
noise_scheduler
.
config
.
num_train_timesteps
,
(
bsz
,),
device
=
latents
.
device
)
timesteps
=
timesteps
.
long
()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents
=
noise_scheduler
.
add_noise
(
latents
,
noise
,
timesteps
)
# Get the text embedding for conditioning
encoder_hidden_states
=
text_encoder
(
batch
[
"input_ids"
])[
0
]
# Predict the noise residual
model_pred
=
unet
(
noisy_latents
,
timesteps
,
encoder_hidden_states
).
sample
# Get the target for loss depending on the prediction type
if
noise_scheduler
.
config
.
prediction_type
==
"epsilon"
:
target
=
noise
elif
noise_scheduler
.
config
.
prediction_type
==
"v_prediction"
:
target
=
noise_scheduler
.
get_velocity
(
latents
,
noise
,
timesteps
)
else
:
raise
ValueError
(
f
"Unknown prediction type
{
noise_scheduler
.
config
.
prediction_type
}
"
)
if
args
.
with_prior_preservation
:
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
model_pred
,
model_pred_prior
=
torch
.
chunk
(
model_pred
,
2
,
dim
=
0
)
target
,
target_prior
=
torch
.
chunk
(
target
,
2
,
dim
=
0
)
# Compute instance loss
loss
=
F
.
mse_loss
(
model_pred
.
float
(),
target
.
float
(),
reduction
=
"none"
).
mean
([
1
,
2
,
3
]).
mean
()
# Compute prior loss
prior_loss
=
F
.
mse_loss
(
model_pred_prior
.
float
(),
target_prior
.
float
(),
reduction
=
"mean"
)
# Add the prior loss to the instance loss.
loss
=
loss
+
args
.
prior_loss_weight
*
prior_loss
else
:
loss
=
F
.
mse_loss
(
model_pred
.
float
(),
target
.
float
(),
reduction
=
"mean"
)
optimizer
.
backward
(
loss
)
optimizer
.
step
()
lr_scheduler
.
step
()
logger
.
info
(
f
"max GPU_mem cost is
{
torch
.
cuda
.
max_memory_allocated
()
/
2
**
20
}
MB"
,
ranks
=
[
0
])
# Checks if the accelerator has performed an optimization step behind the scenes
progress_bar
.
update
(
1
)
global_step
+=
1
logs
=
{
"loss"
:
loss
.
detach
().
item
(),
"lr"
:
optimizer
.
param_groups
[
0
][
"lr"
],
}
# lr_scheduler.get_last_lr()[0]}
progress_bar
.
set_postfix
(
**
logs
)
if
global_step
%
args
.
save_steps
==
0
:
torch
.
cuda
.
synchronize
()
torch_unet
=
get_static_torch_model
(
unet
)
if
local_rank
==
0
:
save_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
"checkpoint-
{
global_step
}
"
)
torch_unet
=
torch_unet
.
to
(
torch
.
float32
)
torch_unet
.
save_attn_procs
(
save_path
)
logger
.
info
(
f
"Saving model checkpoint to
{
save_path
}
"
,
ranks
=
[
0
])
if
global_step
>=
args
.
max_train_steps
:
break
torch
.
cuda
.
synchronize
()
torch_unet
=
get_static_torch_model
(
unet
)
if
local_rank
==
0
:
torch_unet
=
torch_unet
.
to
(
torch
.
float32
)
torch_unet
.
save_attn_procs
(
save_path
)
logger
.
info
(
f
"Saving model checkpoint to
{
args
.
output_dir
}
"
,
ranks
=
[
0
])
if
args
.
push_to_hub
:
repo
.
push_to_hub
(
commit_message
=
"End of training"
,
blocking
=
False
,
auto_lfs_prune
=
True
)
if
__name__
==
"__main__"
:
args
=
parse_args
()
main
(
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment