Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
304f1ba1
Unverified
Commit
304f1ba1
authored
Jan 19, 2023
by
Fazzie-Maqianli
Committed by
GitHub
Jan 19, 2023
Browse files
Merge pull request #2499 from feifeibear/dev0116_10
[example] check dreambooth example gradient accmulation must be 1
parents
5db3a5bf
32390cbe
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
8 deletions
+13
-8
examples/images/dreambooth/test_ci.sh
examples/images/dreambooth/test_ci.sh
+0
-0
examples/images/dreambooth/train_dreambooth_colossalai.py
examples/images/dreambooth/train_dreambooth_colossalai.py
+13
-8
No files found.
examples/images/dreambooth/test_ci.sh
0 → 100644
View file @
304f1ba1
examples/images/dreambooth/train_dreambooth_colossalai.py
View file @
304f1ba1
...
@@ -153,7 +153,8 @@ def parse_args(input_args=None):
...
@@ -153,7 +153,8 @@ def parse_args(input_args=None):
"--gradient_accumulation_steps"
,
"--gradient_accumulation_steps"
,
type
=
int
,
type
=
int
,
default
=
1
,
default
=
1
,
help
=
"Number of updates steps to accumulate before performing a backward/update pass."
,
help
=
"Number of updates steps to accumulate before performing a backward/update pass. If using Gemini, it must be 1"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--gradient_checkpointing"
,
"--gradient_checkpointing"
,
...
@@ -361,6 +362,9 @@ def main(args):
...
@@ -361,6 +362,9 @@ def main(args):
else
:
else
:
colossalai
.
launch_from_torch
(
config
=
{},
seed
=
args
.
seed
)
colossalai
.
launch_from_torch
(
config
=
{},
seed
=
args
.
seed
)
local_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
DATA
)
world_size
=
gpc
.
get_world_size
(
ParallelMode
.
DATA
)
if
args
.
with_prior_preservation
:
if
args
.
with_prior_preservation
:
class_images_dir
=
Path
(
args
.
class_data_dir
)
class_images_dir
=
Path
(
args
.
class_data_dir
)
if
not
class_images_dir
.
exists
():
if
not
class_images_dir
.
exists
():
...
@@ -388,7 +392,7 @@ def main(args):
...
@@ -388,7 +392,7 @@ def main(args):
for
example
in
tqdm
(
for
example
in
tqdm
(
sample_dataloader
,
sample_dataloader
,
desc
=
"Generating class images"
,
desc
=
"Generating class images"
,
disable
=
not
gpc
.
get_
local_rank
(
ParallelMode
.
DATA
)
==
0
,
disable
=
not
local_rank
==
0
,
):
):
images
=
pipeline
(
example
[
"prompt"
]).
images
images
=
pipeline
(
example
[
"prompt"
]).
images
...
@@ -400,7 +404,7 @@ def main(args):
...
@@ -400,7 +404,7 @@ def main(args):
del
pipeline
del
pipeline
# Handle the repository creation
# Handle the repository creation
if
gpc
.
get_
local_rank
(
ParallelMode
.
DATA
)
==
0
:
if
local_rank
==
0
:
if
args
.
push_to_hub
:
if
args
.
push_to_hub
:
if
args
.
hub_model_id
is
None
:
if
args
.
hub_model_id
is
None
:
repo_name
=
get_full_repo_name
(
Path
(
args
.
output_dir
).
name
,
token
=
args
.
hub_token
)
repo_name
=
get_full_repo_name
(
Path
(
args
.
output_dir
).
name
,
token
=
args
.
hub_token
)
...
@@ -465,8 +469,9 @@ def main(args):
...
@@ -465,8 +469,9 @@ def main(args):
if
args
.
gradient_checkpointing
:
if
args
.
gradient_checkpointing
:
unet
.
enable_gradient_checkpointing
()
unet
.
enable_gradient_checkpointing
()
assert
args
.
gradient_accumulation_steps
==
1
,
"if using ColossalAI gradient_accumulation_steps must be set to 1."
if
args
.
scale_lr
:
if
args
.
scale_lr
:
args
.
learning_rate
=
args
.
learning_rate
*
args
.
gradient_accumulation_steps
*
args
.
train_batch_size
*
gpc
.
get_
world_size
(
ParallelMode
.
DATA
)
args
.
learning_rate
=
args
.
learning_rate
*
args
.
gradient_accumulation_steps
*
args
.
train_batch_size
*
world_size
unet
=
gemini_zero_dpp
(
unet
,
args
.
placement
)
unet
=
gemini_zero_dpp
(
unet
,
args
.
placement
)
...
@@ -555,7 +560,7 @@ def main(args):
...
@@ -555,7 +560,7 @@ def main(args):
args
.
num_train_epochs
=
math
.
ceil
(
args
.
max_train_steps
/
num_update_steps_per_epoch
)
args
.
num_train_epochs
=
math
.
ceil
(
args
.
max_train_steps
/
num_update_steps_per_epoch
)
# Train!
# Train!
total_batch_size
=
args
.
train_batch_size
*
gpc
.
get_
world_size
(
ParallelMode
.
DATA
)
*
args
.
gradient_accumulation_steps
total_batch_size
=
args
.
train_batch_size
*
world_size
*
args
.
gradient_accumulation_steps
logger
.
info
(
"***** Running training *****"
,
ranks
=
[
0
])
logger
.
info
(
"***** Running training *****"
,
ranks
=
[
0
])
logger
.
info
(
f
" Num examples =
{
len
(
train_dataset
)
}
"
,
ranks
=
[
0
])
logger
.
info
(
f
" Num examples =
{
len
(
train_dataset
)
}
"
,
ranks
=
[
0
])
...
@@ -567,7 +572,7 @@ def main(args):
...
@@ -567,7 +572,7 @@ def main(args):
logger
.
info
(
f
" Total optimization steps =
{
args
.
max_train_steps
}
"
,
ranks
=
[
0
])
logger
.
info
(
f
" Total optimization steps =
{
args
.
max_train_steps
}
"
,
ranks
=
[
0
])
# Only show the progress bar once on each machine.
# Only show the progress bar once on each machine.
progress_bar
=
tqdm
(
range
(
args
.
max_train_steps
),
disable
=
not
gpc
.
get_
local_rank
(
ParallelMode
.
DATA
)
==
0
)
progress_bar
=
tqdm
(
range
(
args
.
max_train_steps
),
disable
=
not
local_rank
==
0
)
progress_bar
.
set_description
(
"Steps"
)
progress_bar
.
set_description
(
"Steps"
)
global_step
=
0
global_step
=
0
...
@@ -644,7 +649,7 @@ def main(args):
...
@@ -644,7 +649,7 @@ def main(args):
if
global_step
%
args
.
save_steps
==
0
:
if
global_step
%
args
.
save_steps
==
0
:
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
torch_unet
=
get_static_torch_model
(
unet
)
torch_unet
=
get_static_torch_model
(
unet
)
if
gpc
.
get_
local_rank
(
ParallelMode
.
DATA
)
==
0
:
if
local_rank
==
0
:
pipeline
=
DiffusionPipeline
.
from_pretrained
(
pipeline
=
DiffusionPipeline
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
args
.
pretrained_model_name_or_path
,
unet
=
torch_unet
,
unet
=
torch_unet
,
...
@@ -659,7 +664,7 @@ def main(args):
...
@@ -659,7 +664,7 @@ def main(args):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
unet
=
get_static_torch_model
(
unet
)
unet
=
get_static_torch_model
(
unet
)
if
gpc
.
get_
local_rank
(
ParallelMode
.
DATA
)
==
0
:
if
local_rank
==
0
:
pipeline
=
DiffusionPipeline
.
from_pretrained
(
pipeline
=
DiffusionPipeline
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
args
.
pretrained_model_name_or_path
,
unet
=
unet
,
unet
=
unet
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment