Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
48d33b1b
Unverified
Commit
48d33b1b
authored
Jan 06, 2023
by
HELSON
Committed by
GitHub
Jan 06, 2023
Browse files
[gemini] add get static torch model (#2356)
parent
7a332b17
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
164 additions
and
118 deletions
+164
-118
colossalai/nn/parallel/data_parallel.py
colossalai/nn/parallel/data_parallel.py
+1
-13
colossalai/nn/parallel/utils.py
colossalai/nn/parallel/utils.py
+78
-15
examples/images/dreambooth/train_dreambooth_colossalai.py
examples/images/dreambooth/train_dreambooth_colossalai.py
+67
-81
tests/test_gemini/update/test_get_torch_model.py
tests/test_gemini/update/test_get_torch_model.py
+18
-9
No files found.
colossalai/nn/parallel/data_parallel.py
View file @
48d33b1b
...
@@ -389,19 +389,6 @@ class ZeroDDP(ColoDDP):
...
@@ -389,19 +389,6 @@ class ZeroDDP(ColoDDP):
del
temp_chunk
del
temp_chunk
return
param_to_save_data
return
param_to_save_data
def
torch_named_parameters
(
self
):
"""
get named_parameters() of self.module. It is used the same of PyTorch param and returns the real param.data payload.
It works the same as torch.Module named_parameters
"""
params_list
=
[
p
for
p
in
self
.
parameters
(
recurse
=
True
)]
param_to_save_data
=
self
.
_get_param_to_save_data
(
params_list
,
False
)
for
(
name
,
_
),
p
in
zip
(
self
.
named_parameters
(
recurse
=
True
),
params_list
):
if
p
is
not
None
:
assert
p
in
param_to_save_data
,
"Parameter '{}' is neglected in the chunk list"
.
format
(
name
)
record_parameter
=
param_to_save_data
[
p
]
yield
name
,
record_parameter
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
,
only_rank_0
=
True
):
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
,
only_rank_0
=
True
):
r
"""Saves module state to `destination` dictionary, containing a state
r
"""Saves module state to `destination` dictionary, containing a state
of the module, but not its descendants. This is called on every
of the module, but not its descendants. This is called on every
...
@@ -418,6 +405,7 @@ class ZeroDDP(ColoDDP):
...
@@ -418,6 +405,7 @@ class ZeroDDP(ColoDDP):
assert
keep_vars
is
False
,
"`state_dict` with parameter, `keep_vars=True`, is not supported now."
assert
keep_vars
is
False
,
"`state_dict` with parameter, `keep_vars=True`, is not supported now."
param_to_save_data
=
self
.
_get_param_to_save_data
(
self
.
fp32_params
,
only_rank_0
)
param_to_save_data
=
self
.
_get_param_to_save_data
(
self
.
fp32_params
,
only_rank_0
)
# TODO: (HELSON) deal with ddp ignored parameters
for
(
name
,
p
),
fp32_p
in
zip
(
self
.
named_parameters
(),
self
.
fp32_params
):
for
(
name
,
p
),
fp32_p
in
zip
(
self
.
named_parameters
(),
self
.
fp32_params
):
if
p
is
not
None
:
if
p
is
not
None
:
assert
fp32_p
in
param_to_save_data
,
"Parameter '{}' is neglected in the chunk list"
.
format
(
name
)
assert
fp32_p
in
param_to_save_data
,
"Parameter '{}' is neglected in the chunk list"
.
format
(
name
)
...
...
colossalai/nn/parallel/utils.py
View file @
48d33b1b
from
collections
import
OrderedDict
from
copy
import
copy
from
typing
import
Optional
,
Set
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
colossalai.gemini.chunk
import
Chunk
from
colossalai.gemini.chunk
import
Chunk
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
...
@@ -21,30 +26,88 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk):
...
@@ -21,30 +26,88 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk):
return
total_temp
return
total_temp
# TODO() not work for module where two params share the same tensor.
def
_get_dfs_module_list
(
module
:
nn
.
Module
,
memo
:
Optional
[
Set
[
nn
.
Module
]]
=
None
,
prefix
:
str
=
''
):
def
_add_param
(
model
,
name
,
param
):
"""Get a dfs module list of the given module. Its order is same as the order of creations of modules.
name_list
=
name
.
split
(
'.'
)
"""
module
=
model
.
_modules
[
name_list
[
0
]]
if
memo
is
None
:
for
i
in
range
(
1
,
len
(
name_list
)
-
1
):
memo
=
set
()
module
=
module
.
_modules
[
name_list
[
i
]]
if
module
not
in
memo
:
module
.
_parameters
[
name_list
[
-
1
]]
=
param
for
name
,
submodule
in
module
.
_modules
.
items
():
if
submodule
is
None
:
continue
submodule_prefix
=
prefix
+
(
'.'
if
prefix
else
''
)
+
name
for
m
in
_get_dfs_module_list
(
submodule
,
memo
,
submodule_prefix
):
yield
m
memo
.
add
(
module
)
yield
prefix
,
module
def
convert_to_torch_module
(
gemini_ddp_model
:
'GeminiDDP'
)
->
torch
.
nn
.
Module
:
def
_get_shallow_copy_model
(
model
:
nn
.
Module
):
"""convert_to_torch_module
"""Get a shallow copy of the given model. Each submodule is different from the original submodule.
But the new submodule and the old submodule share all attributes.
"""
name_to_module
=
dict
()
for
name
,
module
in
_get_dfs_module_list
(
model
):
new_module
=
copy
(
module
)
new_module
.
_modules
=
OrderedDict
()
for
subname
,
submodule
in
module
.
_modules
.
items
():
if
submodule
is
None
:
continue
full_name
=
name
+
(
'.'
if
name
else
''
)
+
subname
setattr
(
new_module
,
subname
,
name_to_module
[
full_name
])
name_to_module
[
name
]
=
new_module
return
name_to_module
[
''
]
def
get_static_torch_model
(
gemini_ddp_model
,
device
=
torch
.
device
(
"cpu"
),
dtype
=
torch
.
float32
,
only_rank_0
=
True
)
->
torch
.
nn
.
Module
:
"""Get a static torch.nn.Module model from the given GeminiDDP module.
You should notice that the original GeminiDDP model is not modified.
Thus, you can use the original model in further training.
But you should not use the returned torch model to train, this can cause unexpected errors.
Args:
Args:
gemini_ddp_model (GeminiDDP): a gemini ddp model
gemini_ddp_model (GeminiDDP): a gemini ddp model
device (torch.device): the device of the final torch model
dtype (torch.dtype): the dtype of the final torch model
only_rank_0 (bool): if True, only rank0 has the coverted torch model
Returns:
Returns:
torch.nn.Module: a torch model
contains the params of gemini_ddp_model
torch.nn.Module: a
static
torch model
used for saving checkpoints or numeric checks
"""
"""
from
colossalai.nn.parallel
import
GeminiDDP
from
colossalai.nn.parallel
import
GeminiDDP
assert
isinstance
(
gemini_ddp_model
,
GeminiDDP
)
assert
isinstance
(
gemini_ddp_model
,
GeminiDDP
)
module
=
gemini_ddp_model
.
module
# replace ColoTensor to torch.nn.Tensor in module
state_dict
=
gemini_ddp_model
.
state_dict
(
only_rank_0
=
only_rank_0
)
for
n
,
p
in
gemini_ddp_model
.
torch_named_parameters
():
colo_model
=
gemini_ddp_model
.
module
_add_param
(
module
,
n
,
p
)
torch_model
=
_get_shallow_copy_model
(
colo_model
)
if
not
only_rank_0
or
dist
.
get_rank
()
==
0
:
# record the mapping relationship between colo parameters and torch parameters
colo_to_torch
=
dict
()
for
(
name
,
colo_module
),
(
_
,
torch_module
)
in
\
zip
(
_get_dfs_module_list
(
colo_model
),
_get_dfs_module_list
(
torch_model
)):
# clean the parameter list of the new torch module
torch_module
.
_parameters
=
OrderedDict
()
for
sufix_param_name
,
param
in
colo_module
.
named_parameters
(
recurse
=
False
):
# get the full name of the parameter
full_param_name
=
name
+
(
'.'
if
name
else
''
)
+
sufix_param_name
if
full_param_name
not
in
state_dict
:
# this means the parameter is shared by multiple modules
# we should use colo_to_torch to get the torch parameter created before
assert
param
in
colo_to_torch
,
f
"can not find parameter `
{
full_param_name
}
` in the GeminiDDP module"
torch_param
=
colo_to_torch
[
param
]
else
:
# we meet the parameter the first time, just use the state dict to get the data
state_param
=
state_dict
[
full_param_name
]
torch_param
=
torch
.
nn
.
Parameter
(
state_param
.
data
.
to
(
device
=
device
,
dtype
=
dtype
))
colo_to_torch
[
param
]
=
torch_param
setattr
(
torch_module
,
sufix_param_name
,
torch_param
)
dist
.
barrier
()
return
module
return
torch_model
examples/images/dreambooth/train_dreambooth_colossalai.py
View file @
48d33b1b
...
@@ -8,25 +8,23 @@ from typing import Optional
...
@@ -8,25 +8,23 @@ from typing import Optional
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
torch.utils.checkpoint
import
torch.utils.checkpoint
from
diffusers
import
AutoencoderKL
,
DDPMScheduler
,
DiffusionPipeline
,
UNet2DConditionModel
from
diffusers.optimization
import
get_scheduler
from
huggingface_hub
import
HfFolder
,
Repository
,
whoami
from
PIL
import
Image
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
from
torchvision
import
transforms
from
tqdm.auto
import
tqdm
from
transformers
import
AutoTokenizer
,
PretrainedConfig
import
colossalai
import
colossalai
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.logging
import
disable_existing_loggers
,
get_dist_logger
from
colossalai.logging
import
disable_existing_loggers
,
get_dist_logger
from
colossalai.nn.optimizer.gemini_optimizer
import
GeminiAdamOptimizer
from
colossalai.nn.optimizer.gemini_optimizer
import
GeminiAdamOptimizer
from
colossalai.nn.parallel.utils
import
convert_to_torch_module
from
colossalai.nn.parallel.utils
import
get_static_torch_model
from
colossalai.tensor
import
ProcessGroup
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
diffusers
import
AutoencoderKL
,
DDPMScheduler
,
DiffusionPipeline
,
UNet2DConditionModel
from
diffusers.optimization
import
get_scheduler
from
huggingface_hub
import
HfFolder
,
Repository
,
whoami
from
PIL
import
Image
from
torchvision
import
transforms
from
tqdm.auto
import
tqdm
from
transformers
import
AutoTokenizer
,
PretrainedConfig
disable_existing_loggers
()
disable_existing_loggers
()
logger
=
get_dist_logger
()
logger
=
get_dist_logger
()
...
@@ -112,10 +110,8 @@ def parse_args(input_args=None):
...
@@ -112,10 +110,8 @@ def parse_args(input_args=None):
"--num_class_images"
,
"--num_class_images"
,
type
=
int
,
type
=
int
,
default
=
100
,
default
=
100
,
help
=
(
help
=
(
"Minimal class images for prior preservation loss. If there are not enough images already present in"
"Minimal class images for prior preservation loss. If there are not enough images already present in"
" class_data_dir, additional images will be sampled with class_prompt."
),
" class_data_dir, additional images will be sampled with class_prompt."
),
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--output_dir"
,
"--output_dir"
,
...
@@ -128,10 +124,8 @@ def parse_args(input_args=None):
...
@@ -128,10 +124,8 @@ def parse_args(input_args=None):
"--resolution"
,
"--resolution"
,
type
=
int
,
type
=
int
,
default
=
512
,
default
=
512
,
help
=
(
help
=
(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
" resolution"
),
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--placement"
,
"--placement"
,
...
@@ -139,15 +133,14 @@ def parse_args(input_args=None):
...
@@ -139,15 +133,14 @@ def parse_args(input_args=None):
default
=
"cpu"
,
default
=
"cpu"
,
help
=
"Placement Policy for Gemini. Valid when using colossalai as dist plan."
,
help
=
"Placement Policy for Gemini. Valid when using colossalai as dist plan."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--center_crop"
,
"--center_crop"
,
action
=
"store_true"
,
help
=
"Whether to center crop images before resizing to resolution"
action
=
"store_true"
,
)
help
=
"Whether to center crop images before resizing to resolution"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--train_batch_size"
,
"--train_batch_size"
,
type
=
int
,
default
=
4
,
help
=
"Batch size (per device) for the training dataloader."
type
=
int
,
)
default
=
4
,
parser
.
add_argument
(
help
=
"Batch size (per device) for the training dataloader."
)
"--sample_batch_size"
,
type
=
int
,
default
=
4
,
help
=
"Batch size (per device) for sampling images."
parser
.
add_argument
(
"--sample_batch_size"
,
type
=
int
,
default
=
4
,
help
=
"Batch size (per device) for sampling images."
)
)
parser
.
add_argument
(
"--num_train_epochs"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--num_train_epochs"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
parser
.
add_argument
(
"--max_train_steps"
,
"--max_train_steps"
,
...
@@ -183,17 +176,16 @@ def parse_args(input_args=None):
...
@@ -183,17 +176,16 @@ def parse_args(input_args=None):
"--lr_scheduler"
,
"--lr_scheduler"
,
type
=
str
,
type
=
str
,
default
=
"constant"
,
default
=
"constant"
,
help
=
(
help
=
(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
' "constant", "constant_with_warmup"]'
),
)
parser
.
add_argument
(
"--lr_warmup_steps"
,
type
=
int
,
default
=
500
,
help
=
"Number of steps for the warmup in the lr scheduler."
)
parser
.
add_argument
(
"--use_8bit_adam"
,
action
=
"store_true"
,
help
=
"Whether or not to use 8-bit Adam from bitsandbytes."
)
)
parser
.
add_argument
(
"--lr_warmup_steps"
,
type
=
int
,
default
=
500
,
help
=
"Number of steps for the warmup in the lr scheduler."
)
parser
.
add_argument
(
"--use_8bit_adam"
,
action
=
"store_true"
,
help
=
"Whether or not to use 8-bit Adam from bitsandbytes."
)
parser
.
add_argument
(
"--max_grad_norm"
,
default
=
1.0
,
type
=
float
,
help
=
"Max gradient norm."
)
parser
.
add_argument
(
"--max_grad_norm"
,
default
=
1.0
,
type
=
float
,
help
=
"Max gradient norm."
)
parser
.
add_argument
(
"--push_to_hub"
,
action
=
"store_true"
,
help
=
"Whether or not to push the model to the Hub."
)
parser
.
add_argument
(
"--push_to_hub"
,
action
=
"store_true"
,
help
=
"Whether or not to push the model to the Hub."
)
...
@@ -208,10 +200,8 @@ def parse_args(input_args=None):
...
@@ -208,10 +200,8 @@ def parse_args(input_args=None):
"--logging_dir"
,
"--logging_dir"
,
type
=
str
,
type
=
str
,
default
=
"logs"
,
default
=
"logs"
,
help
=
(
help
=
(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--mixed_precision"
,
"--mixed_precision"
,
...
@@ -221,8 +211,7 @@ def parse_args(input_args=None):
...
@@ -221,8 +211,7 @@ def parse_args(input_args=None):
help
=
(
help
=
(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
),
)
)
parser
.
add_argument
(
"--local_rank"
,
type
=
int
,
default
=-
1
,
help
=
"For distributed training: local_rank"
)
parser
.
add_argument
(
"--local_rank"
,
type
=
int
,
default
=-
1
,
help
=
"For distributed training: local_rank"
)
...
@@ -288,14 +277,12 @@ class DreamBoothDataset(Dataset):
...
@@ -288,14 +277,12 @@ class DreamBoothDataset(Dataset):
else
:
else
:
self
.
class_data_root
=
None
self
.
class_data_root
=
None
self
.
image_transforms
=
transforms
.
Compose
(
self
.
image_transforms
=
transforms
.
Compose
([
[
transforms
.
Resize
(
size
,
interpolation
=
transforms
.
InterpolationMode
.
BILINEAR
),
transforms
.
Resize
(
size
,
interpolation
=
transforms
.
InterpolationMode
.
BILINEAR
),
transforms
.
CenterCrop
(
size
)
if
center_crop
else
transforms
.
RandomCrop
(
size
),
transforms
.
CenterCrop
(
size
)
if
center_crop
else
transforms
.
RandomCrop
(
size
),
transforms
.
ToTensor
(),
transforms
.
ToTensor
(),
transforms
.
Normalize
([
0.5
],
[
0.5
]),
transforms
.
Normalize
([
0.5
],
[
0.5
]),
])
]
)
def
__len__
(
self
):
def
__len__
(
self
):
return
self
.
_length
return
self
.
_length
...
@@ -356,26 +343,19 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
...
@@ -356,26 +343,19 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
# Gemini + ZeRO DDP
# Gemini + ZeRO DDP
def
gemini_zero_dpp
(
model
:
torch
.
nn
.
Module
,
pg
:
ProcessGroup
,
placememt_policy
:
str
=
"auto"
):
def
gemini_zero_dpp
(
model
:
torch
.
nn
.
Module
,
placememt_policy
:
str
=
"auto"
):
from
colossalai.nn.parallel
import
GeminiDDP
from
colossalai.nn.parallel
import
GeminiDDP
model
=
GeminiDDP
(
model
=
GeminiDDP
(
model
,
model
,
device
=
get_current_device
(),
placement_policy
=
placememt_policy
,
pin_memory
=
True
,
search_range_mb
=
32
device
=
get_current_device
(),
)
placement_policy
=
placememt_policy
,
pin_memory
=
True
,
search_range_mb
=
64
)
return
model
return
model
def
main
(
args
):
def
main
(
args
):
# config for colossalai
colossalai
.
launch_from_torch
(
config
=
{})
config
=
{
"BATCH"
:
args
.
train_batch_size
,
"gradient_accumulation_steps"
:
args
.
gradient_accumulation_steps
,
"clip_grad_norm"
:
args
.
max_grad_norm
,
}
colossalai
.
launch_from_torch
(
config
=
config
)
pg
=
ProcessGroup
()
if
args
.
seed
is
not
None
:
if
args
.
seed
is
not
None
:
gpc
.
set_seed
(
args
.
seed
)
gpc
.
set_seed
(
args
.
seed
)
...
@@ -405,9 +385,9 @@ def main(args):
...
@@ -405,9 +385,9 @@ def main(args):
pipeline
.
to
(
get_current_device
())
pipeline
.
to
(
get_current_device
())
for
example
in
tqdm
(
for
example
in
tqdm
(
sample_dataloader
,
sample_dataloader
,
desc
=
"Generating class images"
,
desc
=
"Generating class images"
,
disable
=
not
gpc
.
get_local_rank
(
ParallelMode
.
DATA
)
==
0
,
disable
=
not
gpc
.
get_local_rank
(
ParallelMode
.
DATA
)
==
0
,
):
):
images
=
pipeline
(
example
[
"prompt"
]).
images
images
=
pipeline
(
example
[
"prompt"
]).
images
...
@@ -472,10 +452,11 @@ def main(args):
...
@@ -472,10 +452,11 @@ def main(args):
)
)
logger
.
info
(
f
"Loading UNet2DConditionModel from
{
args
.
pretrained_model_name_or_path
}
"
,
ranks
=
[
0
])
logger
.
info
(
f
"Loading UNet2DConditionModel from
{
args
.
pretrained_model_name_or_path
}
"
,
ranks
=
[
0
])
with
ColoInitContext
():
with
ColoInitContext
(
device
=
get_current_device
()):
unet
=
UNet2DConditionModel
.
from_pretrained
(
unet
=
UNet2DConditionModel
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
args
.
pretrained_model_name_or_path
,
subfolder
=
"unet"
,
revision
=
args
.
revision
,
low_cpu_mem_usage
=
False
subfolder
=
"unet"
,
)
revision
=
args
.
revision
,
low_cpu_mem_usage
=
False
)
vae
.
requires_grad_
(
False
)
vae
.
requires_grad_
(
False
)
text_encoder
.
requires_grad_
(
False
)
text_encoder
.
requires_grad_
(
False
)
...
@@ -486,10 +467,10 @@ def main(args):
...
@@ -486,10 +467,10 @@ def main(args):
if
args
.
scale_lr
:
if
args
.
scale_lr
:
args
.
learning_rate
=
args
.
learning_rate
*
args
.
gradient_accumulation_steps
*
args
.
train_batch_size
*
gpc
.
get_world_size
(
ParallelMode
.
DATA
)
args
.
learning_rate
=
args
.
learning_rate
*
args
.
gradient_accumulation_steps
*
args
.
train_batch_size
*
gpc
.
get_world_size
(
ParallelMode
.
DATA
)
unet
=
gemini_zero_dpp
(
unet
,
pg
,
args
.
placement
)
unet
=
gemini_zero_dpp
(
unet
,
args
.
placement
)
# config optimizer for colossalai zero
# config optimizer for colossalai zero
optimizer
=
GeminiAdamOptimizer
(
unet
,
lr
=
args
.
learning_rate
,
initial_scale
=
2
**
5
)
optimizer
=
GeminiAdamOptimizer
(
unet
,
lr
=
args
.
learning_rate
,
initial_scale
=
2
**
5
,
clipping_norm
=
args
.
max_grad_norm
)
# load noise_scheduler
# load noise_scheduler
noise_scheduler
=
DDPMScheduler
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"scheduler"
)
noise_scheduler
=
DDPMScheduler
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"scheduler"
)
...
@@ -520,7 +501,9 @@ def main(args):
...
@@ -520,7 +501,9 @@ def main(args):
pixel_values
=
pixel_values
.
to
(
memory_format
=
torch
.
contiguous_format
).
float
()
pixel_values
=
pixel_values
.
to
(
memory_format
=
torch
.
contiguous_format
).
float
()
input_ids
=
tokenizer
.
pad
(
input_ids
=
tokenizer
.
pad
(
{
"input_ids"
:
input_ids
},
{
"input_ids"
:
input_ids
},
padding
=
"max_length"
,
padding
=
"max_length"
,
max_length
=
tokenizer
.
model_max_length
,
max_length
=
tokenizer
.
model_max_length
,
return_tensors
=
"pt"
,
return_tensors
=
"pt"
,
...
@@ -532,9 +515,11 @@ def main(args):
...
@@ -532,9 +515,11 @@ def main(args):
}
}
return
batch
return
batch
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
train_dataset
,
train_dataset
,
batch_size
=
args
.
train_batch_size
,
shuffle
=
True
,
collate_fn
=
collate_fn
,
num_workers
=
1
batch_size
=
args
.
train_batch_size
,
)
shuffle
=
True
,
collate_fn
=
collate_fn
,
num_workers
=
1
)
# Scheduler and math around the number of training steps.
# Scheduler and math around the number of training steps.
overrode_max_train_steps
=
False
overrode_max_train_steps
=
False
...
@@ -652,15 +637,16 @@ def main(args):
...
@@ -652,15 +637,16 @@ def main(args):
logs
=
{
logs
=
{
"loss"
:
loss
.
detach
().
item
(),
"loss"
:
loss
.
detach
().
item
(),
"lr"
:
optimizer
.
param_groups
[
0
][
"lr"
],
"lr"
:
optimizer
.
param_groups
[
0
][
"lr"
],
}
# lr_scheduler.get_last_lr()[0]}
}
# lr_scheduler.get_last_lr()[0]}
progress_bar
.
set_postfix
(
**
logs
)
progress_bar
.
set_postfix
(
**
logs
)
if
global_step
%
args
.
save_steps
==
0
:
if
global_step
%
args
.
save_steps
==
0
:
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
torch_unet
=
get_static_torch_model
(
unet
)
if
gpc
.
get_local_rank
(
ParallelMode
.
DATA
)
==
0
:
if
gpc
.
get_local_rank
(
ParallelMode
.
DATA
)
==
0
:
pipeline
=
DiffusionPipeline
.
from_pretrained
(
pipeline
=
DiffusionPipeline
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
args
.
pretrained_model_name_or_path
,
unet
=
convert_to_torch_module
(
unet
)
,
unet
=
torch_
unet
,
revision
=
args
.
revision
,
revision
=
args
.
revision
,
)
)
save_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
"checkpoint-
{
global_step
}
"
)
save_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
"checkpoint-
{
global_step
}
"
)
...
@@ -670,7 +656,7 @@ def main(args):
...
@@ -670,7 +656,7 @@ def main(args):
break
break
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
unet
=
convert_to
_torch_mod
ul
e
(
unet
)
unet
=
get_static
_torch_mode
l
(
unet
)
if
gpc
.
get_local_rank
(
ParallelMode
.
DATA
)
==
0
:
if
gpc
.
get_local_rank
(
ParallelMode
.
DATA
)
==
0
:
pipeline
=
DiffusionPipeline
.
from_pretrained
(
pipeline
=
DiffusionPipeline
.
from_pretrained
(
...
...
tests/test_gemini/update/test_
conver
t_torch_mod
ul
e.py
→
tests/test_gemini/update/test_
ge
t_torch_mode
l
.py
View file @
48d33b1b
...
@@ -6,8 +6,9 @@ import torch
...
@@ -6,8 +6,9 @@ import torch
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
import
colossalai
import
colossalai
from
colossalai.nn.parallel.utils
import
convert_to_torch_module
from
colossalai.nn.parallel
import
GeminiDDP
from
colossalai.tensor
import
ColoTensor
from
colossalai.nn.parallel.utils
import
get_static_torch_model
from
colossalai.tensor
import
ColoParameter
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
...
@@ -15,21 +16,29 @@ from colossalai.utils.model.colo_init_context import ColoInitContext
...
@@ -15,21 +16,29 @@ from colossalai.utils.model.colo_init_context import ColoInitContext
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
@
parameterize
(
'model_name'
,
[
'resnet18'
,
'
bert
'
])
@
parameterize
(
'model_name'
,
[
'hanging_param_model'
,
'resnet18'
,
'
gpt2
'
])
def
run_convert_torch_module
(
model_name
:
str
):
def
run_convert_torch_module
(
model_name
:
str
):
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
_
,
_
,
_
,
_
=
get_components_func
()
model_builder
,
_
,
_
,
_
,
_
=
get_components_func
()
with
ColoInitContext
(
device
=
'
cpu
'
):
with
ColoInitContext
(
device
=
torch
.
device
(
"
cpu
"
)
):
model
=
model_builder
(
checkpoint
=
False
)
model
=
model_builder
(
checkpoint
=
False
)
from
colossalai.nn.parallel
import
GeminiDDP
model
=
GeminiDDP
(
model
,
device
=
get_current_device
(),
placement_policy
=
'auto'
,
pin_memory
=
True
)
model
=
GeminiDDP
(
model
,
device
=
get_current_device
(),
placement_policy
=
'auto'
,
pin_memory
=
True
)
pytorch_model
=
get_static_torch_model
(
model
,
only_rank_0
=
False
)
pytorch_model
=
convert_to_torch_module
(
model
)
for
n
,
p
in
pytorch_model
.
named_parameters
():
for
n
,
p
in
pytorch_model
.
named_parameters
():
assert
not
isinstance
(
p
,
ColoTensor
)
assert
type
(
p
)
==
torch
.
nn
.
Parameter
,
f
"type error:
{
n
}
is a
{
type
(
p
)
}
"
# get the static model should not change the original model
for
n
,
p
in
model
.
named_parameters
():
assert
isinstance
(
p
,
ColoParameter
)
for
(
pn
,
pm
),
(
cn
,
cm
)
in
zip
(
pytorch_model
.
named_modules
(),
model
.
named_modules
()):
assert
pn
==
cn
assert
id
(
pm
)
!=
id
(
cm
)
for
pp
,
cp
in
zip
(
pm
.
parameters
(
recurse
=
False
),
cm
.
parameters
(
recurse
=
False
)):
assert
id
(
pp
)
!=
id
(
cp
)
assert
pp
.
shape
==
cp
.
shape
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment