Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
089f0f4c
"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "32798bf242a6b15e91a6fadc444f8806b4e8bb46"
Unverified
Commit
089f0f4c
authored
Jan 10, 2023
by
Fazzie-Maqianli
Committed by
GitHub
Jan 09, 2023
Browse files
update to latest colossalai (#1951)
parent
aba2a65d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
20 deletions
+18
-20
examples/research_projects/colossalai/train_dreambooth_colossalai.py
...search_projects/colossalai/train_dreambooth_colossalai.py
+18
-20
No files found.
examples/research_projects/colossalai/train_dreambooth_colossalai.py
View file @
089f0f4c
...
@@ -15,8 +15,7 @@ from colossalai.context.parallel_mode import ParallelMode
...
@@ -15,8 +15,7 @@ from colossalai.context.parallel_mode import ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.logging
import
disable_existing_loggers
,
get_dist_logger
from
colossalai.logging
import
disable_existing_loggers
,
get_dist_logger
from
colossalai.nn.optimizer.gemini_optimizer
import
GeminiAdamOptimizer
from
colossalai.nn.optimizer.gemini_optimizer
import
GeminiAdamOptimizer
from
colossalai.nn.parallel.utils
import
convert_to_torch_module
from
colossalai.nn.parallel.utils
import
get_static_torch_model
from
colossalai.tensor
import
ProcessGroup
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
diffusers
import
AutoencoderKL
,
DDPMScheduler
,
DiffusionPipeline
,
UNet2DConditionModel
from
diffusers
import
AutoencoderKL
,
DDPMScheduler
,
DiffusionPipeline
,
UNet2DConditionModel
...
@@ -356,26 +355,17 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
...
@@ -356,26 +355,17 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
# Gemini + ZeRO DDP
# Gemini + ZeRO DDP
def
gemini_zero_dpp
(
model
:
torch
.
nn
.
Module
,
pg
:
ProcessGroup
,
placememt_policy
:
str
=
"auto"
):
def
gemini_zero_dpp
(
model
:
torch
.
nn
.
Module
,
placememt_policy
:
str
=
"auto"
):
from
colossalai.nn.parallel
import
GeminiDDP
from
colossalai.nn.parallel
import
GeminiDDP
model
=
GeminiDDP
(
model
=
GeminiDDP
(
model
,
device
=
get_current_device
(),
placement_policy
=
placememt_policy
,
pin_memory
=
True
,
search_range_mb
=
32
model
,
device
=
get_current_device
(),
placement_policy
=
placememt_policy
,
pin_memory
=
True
,
search_range_mb
=
64
)
)
return
model
return
model
def
main
(
args
):
def
main
(
args
):
# config for colossalai
colossalai
.
launch_from_torch
(
config
=
{})
config
=
{
"BATCH"
:
args
.
train_batch_size
,
"gradient_accumulation_steps"
:
args
.
gradient_accumulation_steps
,
"clip_grad_norm"
:
args
.
max_grad_norm
,
}
colossalai
.
launch_from_torch
(
config
=
config
)
pg
=
ProcessGroup
()
if
args
.
seed
is
not
None
:
if
args
.
seed
is
not
None
:
gpc
.
set_seed
(
args
.
seed
)
gpc
.
set_seed
(
args
.
seed
)
...
@@ -472,7 +462,7 @@ def main(args):
...
@@ -472,7 +462,7 @@ def main(args):
)
)
logger
.
info
(
f
"Loading UNet2DConditionModel from
{
args
.
pretrained_model_name_or_path
}
"
,
ranks
=
[
0
])
logger
.
info
(
f
"Loading UNet2DConditionModel from
{
args
.
pretrained_model_name_or_path
}
"
,
ranks
=
[
0
])
with
ColoInitContext
():
with
ColoInitContext
(
device
=
get_current_device
()
):
unet
=
UNet2DConditionModel
.
from_pretrained
(
unet
=
UNet2DConditionModel
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"unet"
,
revision
=
args
.
revision
,
low_cpu_mem_usage
=
False
args
.
pretrained_model_name_or_path
,
subfolder
=
"unet"
,
revision
=
args
.
revision
,
low_cpu_mem_usage
=
False
)
)
...
@@ -484,12 +474,19 @@ def main(args):
...
@@ -484,12 +474,19 @@ def main(args):
unet
.
enable_gradient_checkpointing
()
unet
.
enable_gradient_checkpointing
()
if
args
.
scale_lr
:
if
args
.
scale_lr
:
args
.
learning_rate
=
args
.
learning_rate
*
args
.
gradient_accumulation_steps
*
args
.
train_batch_size
*
2
args
.
learning_rate
=
(
args
.
learning_rate
*
args
.
gradient_accumulation_steps
*
args
.
train_batch_size
*
gpc
.
get_world_size
(
ParallelMode
.
DATA
)
)
unet
=
gemini_zero_dpp
(
unet
,
pg
,
args
.
placement
)
unet
=
gemini_zero_dpp
(
unet
,
args
.
placement
)
# config optimizer for colossalai zero
# config optimizer for colossalai zero
optimizer
=
GeminiAdamOptimizer
(
unet
,
lr
=
args
.
learning_rate
,
initial_scale
=
2
**
5
)
optimizer
=
GeminiAdamOptimizer
(
unet
,
lr
=
args
.
learning_rate
,
initial_scale
=
2
**
5
,
clipping_norm
=
args
.
max_grad_norm
)
# load noise_scheduler
# load noise_scheduler
noise_scheduler
=
DDPMScheduler
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"scheduler"
)
noise_scheduler
=
DDPMScheduler
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"scheduler"
)
...
@@ -657,10 +654,11 @@ def main(args):
...
@@ -657,10 +654,11 @@ def main(args):
if
global_step
%
args
.
save_steps
==
0
:
if
global_step
%
args
.
save_steps
==
0
:
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
torch_unet
=
get_static_torch_model
(
unet
)
if
gpc
.
get_local_rank
(
ParallelMode
.
DATA
)
==
0
:
if
gpc
.
get_local_rank
(
ParallelMode
.
DATA
)
==
0
:
pipeline
=
DiffusionPipeline
.
from_pretrained
(
pipeline
=
DiffusionPipeline
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
args
.
pretrained_model_name_or_path
,
unet
=
convert_to_torch_module
(
unet
)
,
unet
=
torch_
unet
,
revision
=
args
.
revision
,
revision
=
args
.
revision
,
)
)
save_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
"checkpoint-
{
global_step
}
"
)
save_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
"checkpoint-
{
global_step
}
"
)
...
@@ -670,7 +668,7 @@ def main(args):
...
@@ -670,7 +668,7 @@ def main(args):
break
break
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
unet
=
convert_to
_torch_mod
ul
e
(
unet
)
unet
=
get_static
_torch_mode
l
(
unet
)
if
gpc
.
get_local_rank
(
ParallelMode
.
DATA
)
==
0
:
if
gpc
.
get_local_rank
(
ParallelMode
.
DATA
)
==
0
:
pipeline
=
DiffusionPipeline
.
from_pretrained
(
pipeline
=
DiffusionPipeline
.
from_pretrained
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment