Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
a9b27b92
Unverified
Commit
a9b27b92
authored
Jan 04, 2023
by
Fazzie-Maqianli
Committed by
GitHub
Jan 04, 2023
Browse files
[exmaple] fix dreamblooth format (#2315)
parent
da1c47f0
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
90 additions
and
101 deletions
+90
-101
examples/images/diffusion/configs/Teyvat/train_colossalai_teyvat.yaml
...ges/diffusion/configs/Teyvat/train_colossalai_teyvat.yaml
+1
-1
examples/images/diffusion/configs/train_colossalai.yaml
examples/images/diffusion/configs/train_colossalai.yaml
+1
-1
examples/images/diffusion/configs/train_colossalai_cifar10.yaml
...es/images/diffusion/configs/train_colossalai_cifar10.yaml
+1
-1
examples/images/diffusion/configs/train_pokemon.yaml
examples/images/diffusion/configs/train_pokemon.yaml
+1
-1
examples/images/dreambooth/train_dreambooth_colossalai.py
examples/images/dreambooth/train_dreambooth_colossalai.py
+86
-97
No files found.
examples/images/diffusion/configs/Teyvat/train_colossalai_teyvat.yaml
View file @
a9b27b92
...
...
@@ -108,7 +108,7 @@ lightning:
params
:
use_chunk
:
True
enable_distributed_storage
:
True
placement_policy
:
auto
placement_policy
:
cuda
force_outputs_fp32
:
true
log_every_n_steps
:
2
...
...
examples/images/diffusion/configs/train_colossalai.yaml
View file @
a9b27b92
...
...
@@ -105,7 +105,7 @@ lightning:
params
:
use_chunk
:
True
enable_distributed_storage
:
True
placement_policy
:
auto
placement_policy
:
cuda
force_outputs_fp32
:
true
log_every_n_steps
:
2
...
...
examples/images/diffusion/configs/train_colossalai_cifar10.yaml
View file @
a9b27b92
...
...
@@ -109,7 +109,7 @@ lightning:
params
:
use_chunk
:
True
enable_distributed_storage
:
True
placement_policy
:
auto
placement_policy
:
cuda
force_outputs_fp32
:
true
log_every_n_steps
:
2
...
...
examples/images/diffusion/configs/train_pokemon.yaml
View file @
a9b27b92
...
...
@@ -102,7 +102,7 @@ lightning:
params
:
use_chunk
:
True
enable_distributed_storage
:
True
placement_policy
:
auto
placement_policy
:
cuda
force_outputs_fp32
:
true
log_every_n_steps
:
2
...
...
examples/images/dreambooth/train_dreambooth_colossalai.py
View file @
a9b27b92
import
argparse
import
hashlib
import
itertools
import
math
import
os
from
pathlib
import
Path
from
typing
import
Optional
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
import
torch.utils.checkpoint
from
copy
import
deepcopy
from
diffusers
import
AutoencoderKL
,
DDPMScheduler
,
DiffusionPipeline
,
UNet2DConditionModel
from
diffusers.optimization
import
get_scheduler
from
huggingface_hub
import
HfFolder
,
Repository
,
whoami
from
packaging
import
version
from
PIL
import
Image
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
torch.utils.data
import
Dataset
from
torchvision
import
transforms
from
tqdm.auto
import
tqdm
from
transformers
import
AutoTokenizer
,
PretrainedConfig
import
colossalai
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.logging
import
disable_existing_loggers
,
get_dist_logger
from
colossalai.nn.optimizer.gemini_optimizer
import
GeminiAdamOptimizer
from
colossalai.nn.parallel
import
ZeroDDP
from
colossalai.nn.parallel.utils
import
convert_to_torch_module
from
colossalai.tensor
import
ColoTensor
,
ProcessGroup
from
colossalai.tensor
import
ProcessGroup
from
colossalai.utils
import
get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
diffusers
import
AutoencoderKL
,
DDPMScheduler
,
DiffusionPipeline
,
UNet2DConditionModel
from
diffusers.optimization
import
get_scheduler
from
huggingface_hub
import
HfFolder
,
Repository
,
whoami
from
PIL
import
Image
from
torchvision
import
transforms
from
tqdm.auto
import
tqdm
from
transformers
import
AutoTokenizer
,
PretrainedConfig
disable_existing_loggers
()
logger
=
get_dist_logger
()
...
...
@@ -118,8 +112,10 @@ def parse_args(input_args=None):
"--num_class_images"
,
type
=
int
,
default
=
100
,
help
=
(
"Minimal class images for prior preservation loss. If there are not enough images already present in"
" class_data_dir, additional images will be sampled with class_prompt."
),
help
=
(
"Minimal class images for prior preservation loss. If there are not enough images already present in"
" class_data_dir, additional images will be sampled with class_prompt."
),
)
parser
.
add_argument
(
"--output_dir"
,
...
...
@@ -132,23 +128,26 @@ def parse_args(input_args=None):
"--resolution"
,
type
=
int
,
default
=
512
,
help
=
(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
help
=
(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser
.
add_argument
(
"--placement"
,
type
=
str
,
default
=
'
cpu
'
,
default
=
"
cpu
"
,
help
=
"Placement Policy for Gemini. Valid when using colossalai as dist plan."
,
)
parser
.
add_argument
(
"--center_crop"
,
action
=
"store_true"
,
help
=
"Whether to center crop images before resizing to resolution"
)
parser
.
add_argument
(
"--train_batch_size"
,
type
=
int
,
default
=
4
,
help
=
"Batch size (per device) for the training dataloader."
)
parser
.
add_argument
(
"--sample_batch_size"
,
type
=
int
,
default
=
4
,
help
=
"Batch size (per device) for sampling images."
)
parser
.
add_argument
(
"--center_crop"
,
action
=
"store_true"
,
help
=
"Whether to center crop images before resizing to resolution"
)
parser
.
add_argument
(
"--train_batch_size"
,
type
=
int
,
default
=
4
,
help
=
"Batch size (per device) for the training dataloader."
)
parser
.
add_argument
(
"--sample_batch_size"
,
type
=
int
,
default
=
4
,
help
=
"Batch size (per device) for sampling images."
)
parser
.
add_argument
(
"--num_train_epochs"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--max_train_steps"
,
...
...
@@ -184,16 +183,17 @@ def parse_args(input_args=None):
"--lr_scheduler"
,
type
=
str
,
default
=
"constant"
,
help
=
(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
help
=
(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser
.
add_argument
(
"--lr_warmup_steps"
,
type
=
int
,
default
=
500
,
help
=
"Number of steps for the warmup in the lr scheduler."
)
parser
.
add_argument
(
"--use_8bit_adam"
,
action
=
"store_true"
,
help
=
"Whether or not to use 8-bit Adam from bitsandbytes."
)
parser
.
add_argument
(
"--lr_warmup_steps"
,
type
=
int
,
default
=
500
,
help
=
"Number of steps for the warmup in the lr scheduler."
)
parser
.
add_argument
(
"--use_8bit_adam"
,
action
=
"store_true"
,
help
=
"Whether or not to use 8-bit Adam from bitsandbytes."
)
parser
.
add_argument
(
"--max_grad_norm"
,
default
=
1.0
,
type
=
float
,
help
=
"Max gradient norm."
)
parser
.
add_argument
(
"--push_to_hub"
,
action
=
"store_true"
,
help
=
"Whether or not to push the model to the Hub."
)
...
...
@@ -208,8 +208,10 @@ def parse_args(input_args=None):
"--logging_dir"
,
type
=
str
,
default
=
"logs"
,
help
=
(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
help
=
(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser
.
add_argument
(
"--mixed_precision"
,
...
...
@@ -219,7 +221,8 @@ def parse_args(input_args=None):
help
=
(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser
.
add_argument
(
"--local_rank"
,
type
=
int
,
default
=-
1
,
help
=
"For distributed training: local_rank"
)
...
...
@@ -285,12 +288,14 @@ class DreamBoothDataset(Dataset):
else
:
self
.
class_data_root
=
None
self
.
image_transforms
=
transforms
.
Compose
([
transforms
.
Resize
(
size
,
interpolation
=
transforms
.
InterpolationMode
.
BILINEAR
),
transforms
.
CenterCrop
(
size
)
if
center_crop
else
transforms
.
RandomCrop
(
size
),
transforms
.
ToTensor
(),
transforms
.
Normalize
([
0.5
],
[
0.5
]),
])
self
.
image_transforms
=
transforms
.
Compose
(
[
transforms
.
Resize
(
size
,
interpolation
=
transforms
.
InterpolationMode
.
BILINEAR
),
transforms
.
CenterCrop
(
size
)
if
center_crop
else
transforms
.
RandomCrop
(
size
),
transforms
.
ToTensor
(),
transforms
.
Normalize
([
0.5
],
[
0.5
]),
]
)
def
__len__
(
self
):
return
self
.
_length
...
...
@@ -352,26 +357,11 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
# Gemini + ZeRO DDP
def
gemini_zero_dpp
(
model
:
torch
.
nn
.
Module
,
pg
:
ProcessGroup
,
placememt_policy
:
str
=
"auto"
):
cai_version
=
colossalai
.
__version__
if
version
.
parse
(
cai_version
)
>
version
.
parse
(
"0.1.10"
):
from
colossalai.nn.parallel
import
GeminiDDP
model
=
GeminiDDP
(
model
,
device
=
get_current_device
(),
placement_policy
=
placememt_policy
,
pin_memory
=
True
,
search_range_mb
=
32
)
elif
version
.
parse
(
cai_version
)
<=
version
.
parse
(
"0.1.10"
)
and
version
.
parse
(
cai_version
)
>=
version
.
parse
(
"0.1.9"
):
from
colossalai.gemini
import
ChunkManager
,
GeminiManager
chunk_size
=
ChunkManager
.
search_chunk_size
(
model
,
64
*
1024
**
2
,
32
)
gemini_manager
=
GeminiManager
(
placememt_policy
,
chunk_manager
)
chunk_manager
=
ChunkManager
(
chunk_size
,
pg
,
enable_distributed_storage
=
True
,
init_device
=
GeminiManager
.
get_default_device
(
placememt_policy
))
model
=
ZeroDDP
(
model
,
gemini_manager
)
else
:
raise
NotImplemented
(
f
"CAI version
{
cai_version
}
is not supported"
)
from
colossalai.nn.parallel
import
GeminiDDP
model
=
GeminiDDP
(
model
,
device
=
get_current_device
(),
placement_policy
=
placememt_policy
,
pin_memory
=
True
,
search_range_mb
=
32
)
return
model
...
...
@@ -383,7 +373,7 @@ def main(args):
"gradient_accumulation_steps"
:
args
.
gradient_accumulation_steps
,
"clip_grad_norm"
:
args
.
max_grad_norm
,
}
colossalai
.
launch_from_torch
(
config
=
config
)
pg
=
ProcessGroup
()
...
...
@@ -414,9 +404,11 @@ def main(args):
pipeline
.
to
(
get_current_device
())
for
example
in
tqdm
(
sample_dataloader
,
desc
=
"Generating class images"
,
disable
=
not
gpc
.
get_local_rank
(
ParallelMode
.
DATA
)
==
0
):
for
example
in
tqdm
(
sample_dataloader
,
desc
=
"Generating class images"
,
disable
=
not
gpc
.
get_local_rank
(
ParallelMode
.
DATA
)
==
0
,
):
images
=
pipeline
(
example
[
"prompt"
]).
images
for
i
,
image
in
enumerate
(
images
):
...
...
@@ -466,23 +458,24 @@ def main(args):
logger
.
info
(
f
"Loading text_encoder from
{
args
.
pretrained_model_name_or_path
}
"
,
ranks
=
[
0
])
text_encoder
=
text_encoder_cls
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"text_encoder"
,
revision
=
args
.
revision
,)
text_encoder
=
text_encoder_cls
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"text_encoder"
,
revision
=
args
.
revision
,
)
logger
.
info
(
f
"Loading AutoencoderKL from
{
args
.
pretrained_model_name_or_path
}
"
,
ranks
=
[
0
])
vae
=
AutoencoderKL
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"vae"
,
revision
=
args
.
revision
,)
vae
=
AutoencoderKL
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"vae"
,
revision
=
args
.
revision
,
)
logger
.
info
(
f
"Loading UNet2DConditionModel from
{
args
.
pretrained_model_name_or_path
}
"
,
ranks
=
[
0
])
with
ColoInitContext
():
unet
=
UNet2DConditionModel
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"unet"
,
revision
=
args
.
revision
,
low_cpu_mem_usage
=
False
)
unet
=
UNet2DConditionModel
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"unet"
,
revision
=
args
.
revision
,
low_cpu_mem_usage
=
False
)
vae
.
requires_grad_
(
False
)
text_encoder
.
requires_grad_
(
False
)
...
...
@@ -491,7 +484,7 @@ def main(args):
unet
.
enable_gradient_checkpointing
()
if
args
.
scale_lr
:
args
.
learning_rate
=
(
args
.
learning_rate
*
args
.
gradient_accumulation_steps
*
args
.
train_batch_size
*
2
)
args
.
learning_rate
=
args
.
learning_rate
*
args
.
gradient_accumulation_steps
*
args
.
train_batch_size
*
2
unet
=
gemini_zero_dpp
(
unet
,
pg
,
args
.
placement
)
...
...
@@ -502,7 +495,7 @@ def main(args):
noise_scheduler
=
DDPMScheduler
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"scheduler"
)
# prepare dataset
logger
.
info
(
f
"Prepare dataset"
,
ranks
=
[
0
])
logger
.
info
(
f
"Prepare dataset
from
{
args
.
instance_data_dir
}
"
,
ranks
=
[
0
])
train_dataset
=
DreamBoothDataset
(
instance_data_root
=
args
.
instance_data_dir
,
instance_prompt
=
args
.
instance_prompt
,
...
...
@@ -527,9 +520,7 @@ def main(args):
pixel_values
=
pixel_values
.
to
(
memory_format
=
torch
.
contiguous_format
).
float
()
input_ids
=
tokenizer
.
pad
(
{
"input_ids"
:
input_ids
},
{
"input_ids"
:
input_ids
},
padding
=
"max_length"
,
max_length
=
tokenizer
.
model_max_length
,
return_tensors
=
"pt"
,
...
...
@@ -541,11 +532,9 @@ def main(args):
}
return
batch
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
train_dataset
,
batch_size
=
args
.
train_batch_size
,
shuffle
=
True
,
collate_fn
=
collate_fn
,
num_workers
=
1
)
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
train_dataset
,
batch_size
=
args
.
train_batch_size
,
shuffle
=
True
,
collate_fn
=
collate_fn
,
num_workers
=
1
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps
=
False
...
...
@@ -662,8 +651,8 @@ def main(args):
global_step
+=
1
logs
=
{
"loss"
:
loss
.
detach
().
item
(),
"lr"
:
optimizer
.
param_groups
[
0
][
'
lr
'
]
}
#lr_scheduler.get_last_lr()[0]}
"lr"
:
optimizer
.
param_groups
[
0
][
"
lr
"
],
}
#
lr_scheduler.get_last_lr()[0]}
progress_bar
.
set_postfix
(
**
logs
)
if
global_step
%
args
.
save_steps
==
0
:
...
...
@@ -681,15 +670,15 @@ def main(args):
break
torch
.
cuda
.
synchronize
()
unet
=
convert_to_torch_module
(
unet
)
unet
=
convert_to_torch_module
(
unet
)
if
gpc
.
get_local_rank
(
ParallelMode
.
DATA
)
==
0
:
pipeline
=
DiffusionPipeline
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
unet
=
unet
,
revision
=
args
.
revision
,
)
pipeline
.
save_pretrained
(
args
.
output_dir
)
logger
.
info
(
f
"Saving model checkpoint to
{
args
.
output_dir
}
"
,
ranks
=
[
0
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment