Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
d7ecaf36
Unverified
Commit
d7ecaf36
authored
Apr 07, 2022
by
HELSON
Committed by
GitHub
Apr 07, 2022
Browse files
[zero] fix init bugs in zero context (#686)
* adapt model weight initialization for methods in Pytorch nn.init
parent
0ed7042f
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
117 additions
and
86 deletions
+117
-86
colossalai/zero/init_ctx/init_context.py
colossalai/zero/init_ctx/init_context.py
+84
-41
tests/test_moe/test_moe_zero_init.py
tests/test_moe/test_moe_zero_init.py
+24
-25
tests/test_moe/test_moe_zero_model.py
tests/test_moe/test_moe_zero_model.py
+2
-5
tests/test_moe/test_moe_zero_optim.py
tests/test_moe/test_moe_zero_optim.py
+1
-2
tests/test_zero_data_parallel/test_init_context.py
tests/test_zero_data_parallel/test_init_context.py
+1
-3
tests/test_zero_data_parallel/test_shard_model_v2.py
tests/test_zero_data_parallel/test_shard_model_v2.py
+2
-5
tests/test_zero_data_parallel/test_sharded_optim_v2.py
tests/test_zero_data_parallel/test_sharded_optim_v2.py
+1
-2
tests/test_zero_data_parallel/test_state_dict.py
tests/test_zero_data_parallel/test_state_dict.py
+2
-3
No files found.
colossalai/zero/init_ctx/init_context.py
View file @
d7ecaf36
...
...
@@ -3,6 +3,8 @@ import functools
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
import
torch.distributed
as
dist
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context.singleton_meta
import
SingletonMeta
...
...
@@ -10,7 +12,6 @@ from colossalai.logging import get_dist_logger
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_model._utils
import
cast_tensor_to_fp16
from
colossalai.zero.sharded_param
import
ShardedParamV2
from
torch.distributed
import
ProcessGroup
from
contextlib
import
AbstractContextManager
...
...
@@ -93,24 +94,21 @@ class ZeroContextConfig(object):
replicated (bool, optional): Whether the param is replicated across data parallel group.
Some parameters are not replicated, e.g. parameters in MOE experts.
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
This will reduce memory usage when initializing model.
But it's not suitable for all models, especially when there are `weight init` operations in `__init__`.
If set to `False`, remove tensor payload on param.data afther the context exist.
This is used when you add some logic to operate tensors in __init__ of module.
See torchvision resnet18. Defaults to False.
"""
def
__init__
(
self
,
target_device
:
torch
.
device
,
replicated
:
bool
=
True
,
shard_param
:
bool
=
False
,
rm_torch_payload_on_the_fly
:
bool
=
False
):
def
__init__
(
self
,
target_device
:
torch
.
device
,
replicated
:
bool
=
True
,
shard_param
:
bool
=
False
):
super
().
__init__
()
if
shard_param
:
assert
replicated
,
"Non-replicated parameters can't be sharded."
# replicated no-shard parameters should locate in cuda, since we will broadcast them soon
if
replicated
and
not
shard_param
:
assert
target_device
.
type
==
'cuda'
,
"Replicated no-shard paramters should locate in cuda."
self
.
target_device
=
target_device
self
.
is_replicated
:
bool
=
replicated
self
.
shard_param
:
bool
=
shard_param
self
.
rm_torch_payload_on_the_fly
:
bool
=
rm_torch_payload_on_the_fly
class
ZeroInitContext
(
InsertPostInitMethodToModuleSubClasses
):
...
...
@@ -123,35 +121,27 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
Args:
target_device (torch.device): The device where param data are after exiting the context.
shard_strategy (BaseShardStrategy): Shard strategy instance.
seed (int, optional): Random seed for weight initialization
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
This will reduce memory usage when initializing model.
But it's not suitable for all models, especially when there are `weight init` operations in `__init__`.
If set to `False`, remove tensor payload on param.data afther the context exist.
This is used when you add some logic to operate tensors in __init__ of module.
See torchvision resnet18. Defaults to False.
model_numel_tensor (torch.Tensor, optional): A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int).
dp_process_group (Optional[ProcessGroup], optional): Data parallel process group. Defaults to None.
"""
def
__init__
(
self
,
target_device
:
torch
.
device
,
shard_strategy
:
BaseShardStrategy
,
seed
:
int
=
2
**
10
-
1
,
shard_param
:
bool
=
False
,
rm_torch_payload_on_the_fly
:
bool
=
False
,
model_numel_tensor
:
torch
.
Tensor
=
torch
.
zeros
(
1
,
dtype
=
torch
.
long
),
dp_process_group
:
Optional
[
ProcessGroup
]
=
None
):
model_numel_tensor
:
torch
.
Tensor
=
torch
.
zeros
(
1
,
dtype
=
torch
.
long
)):
super
().
__init__
()
self
.
shard_strategy
=
shard_strategy
self
.
initialized_param_list
=
[]
self
.
sharded_param_list
=
[]
self
.
unshard_param_list
=
[]
self
.
model_numel_tensor
=
model_numel_tensor
self
.
dp_process_group
=
dp_process_group
or
gpc
.
get_group
(
ParallelMode
.
DATA
)
self
.
seed
=
seed
self
.
dp_process_group
=
gpc
.
get_group
(
ParallelMode
.
DATA
)
self
.
config
=
ZeroContextConfig
(
target_device
=
target_device
,
replicated
=
True
,
shard_param
=
shard_param
,
rm_torch_payload_on_the_fly
=
rm_torch_payload_on_the_fly
)
self
.
config
=
ZeroContextConfig
(
target_device
=
target_device
,
replicated
=
True
,
shard_param
=
shard_param
)
ZeroContextMgr
().
current_context
=
self
...
...
@@ -167,9 +157,35 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
def
shard_param
(
self
):
return
self
.
config
.
shard_param
@
property
def
rm_torch_payload_on_the_fly
(
self
):
return
self
.
config
.
rm_torch_payload_on_the_fly
@
staticmethod
def
calc_fanin_fanout
(
tensor
:
torch
.
Tensor
):
"""We use this function to substitute fan-in and fan-out calculation in torch.nn.init.
This can help us get correct fan-in and fan-out for sharded tensor.
"""
assert
isinstance
(
tensor
,
nn
.
Parameter
),
"Sharded tensor initilization is only allowed for paramters"
# get correct shape of input tensor
if
not
hasattr
(
tensor
,
'colo_attr'
)
or
not
tensor
.
colo_attr
.
param_is_sharded
:
tensor_shape
=
tensor
.
shape
else
:
tensor_shape
=
tensor
.
colo_attr
.
sharded_data_tensor
.
origin_shape
dimensions
=
len
(
tensor_shape
)
if
dimensions
<
2
:
raise
ValueError
(
"Fan in and fan out can not be computed for tensor with fewer than 2 dimensions"
)
num_input_fmaps
=
tensor_shape
[
1
]
num_output_fmaps
=
tensor_shape
[
0
]
receptive_field_size
=
1
if
dimensions
>
2
:
# math.prod is not always available, accumulate the product manually
# we could use functools.reduce but that is not supported by TorchScript
for
s
in
tensor_shape
[
2
:]:
receptive_field_size
*=
s
fan_in
=
num_input_fmaps
*
receptive_field_size
fan_out
=
num_output_fmaps
*
receptive_field_size
return
fan_in
,
fan_out
def
_pre_context_exec
(
self
):
"""
...
...
@@ -177,15 +193,40 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
"""
self
.
logger
=
get_dist_logger
(
"ZeroInitContext"
)
# substitute fan-in and fan-out calculation
self
.
nn_fanin_fanout
=
nn
.
init
.
_calculate_fan_in_and_fan_out
nn
.
init
.
_calculate_fan_in_and_fan_out
=
self
.
calc_fanin_fanout
# reserve rng states
self
.
cpu_rng_state
=
torch
.
get_rng_state
()
self
.
cuda_rng_state
=
torch
.
cuda
.
get_rng_state
()
# set new seed for initialization, since we initialize sharded tensor separately
# we don't want all processes have the same seed
# otherwise all sharded tensors are same after init
offset
=
self
.
seed
+
1
# we want to have more 1 in binary format seed
torch
.
manual_seed
(
self
.
seed
+
offset
*
dist
.
get_rank
())
def
_post_context_exec
(
self
):
"""The callback function when exiting context.
"""
if
not
self
.
rm_torch_payload_on_the_fly
:
for
param
in
self
.
initialized_param_list
:
assert
hasattr
(
param
,
'colo_attr'
)
param
.
colo_attr
.
remove_torch_payload
()
for
param
in
self
.
sharded_param_list
:
assert
hasattr
(
param
,
'colo_attr'
)
param
.
colo_attr
.
remove_torch_payload
()
del
self
.
sharded_param_list
src_rank
=
gpc
.
get_ranks_in_group
(
ParallelMode
.
DATA
)[
0
]
for
param
in
self
.
unshard_param_list
:
assert
hasattr
(
param
,
'colo_attr'
)
if
param
.
is_replicated
:
dist
.
broadcast
(
tensor
=
param
.
data
,
src
=
src_rank
,
group
=
self
.
dp_process_group
)
del
self
.
unshard_param_list
del
self
.
initialized_param_list
nn
.
init
.
_calculate_fan_in_and_fan_out
=
self
.
nn_fanin_fanout
torch
.
set_rng_state
(
self
.
cpu_rng_state
)
torch
.
cuda
.
set_rng_state
(
self
.
cuda_rng_state
)
def
_post_init_method
(
self
,
module
:
torch
.
nn
.
Module
):
"""
...
...
@@ -219,11 +260,14 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if
param
.
grad
is
not
None
:
param
.
grad
=
param
.
grad
.
to
(
target_device
)
param
.
colo_attr
=
ShardedParamV2
(
param
,
rm_torch_payload
=
self
.
rm_torch_payload_on_the_fly
)
param
.
colo_attr
=
ShardedParamV2
(
param
,
rm_torch_payload
=
False
)
if
self
.
shard_param
:
self
.
shard_strategy
.
shard
([
param
.
colo_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
self
.
initialized_param_list
.
append
(
param
)
param
.
data
=
param
.
colo_attr
.
sharded_data_tensor
.
payload
self
.
sharded_param_list
.
append
(
param
)
else
:
self
.
unshard_param_list
.
append
(
param
)
# We must cast buffers
# If we use BN, buffers may be on CPU and Float
...
...
@@ -250,8 +294,7 @@ class ZeroContextMgr(metaclass=SingletonMeta):
def
no_shard_zero_context
(
is_replicated
:
bool
=
True
)
->
AbstractContextManager
:
return
ZeroContextMgr
().
hijack_context_config
(
target_device
=
torch
.
device
(
'cuda'
,
torch
.
cuda
.
current_device
()),
replicated
=
is_replicated
,
shard_param
=
False
,
rm_torch_payload_on_the_fly
=
False
)
shard_param
=
False
)
def
no_shard_zero_decrator
(
is_replicated
:
bool
=
True
):
...
...
tests/test_moe/test_moe_zero_init.py
View file @
d7ecaf36
...
...
@@ -51,36 +51,36 @@ def run_moe_zero_init(init_device_type, shard_strategy_class):
with
ZeroInitContext
(
target_device
=
init_device
,
shard_strategy
=
shard_strategy_class
(),
shard_param
=
True
,
model_numel_tensor
=
model_numel_tensor
,
rm_torch_payload_on_the_fly
=
False
):
model_numel_tensor
=
model_numel_tensor
):
model
=
MoeModel
()
for
name
,
param
in
model
.
named_parameters
():
assert
hasattr
(
param
,
'colo_attr'
)
for
name
,
param
in
model
.
named_parameters
():
assert
hasattr
(
param
,
'colo_attr'
)
# the weights in the gate should be fp32
if
'gate'
in
name
:
assert
param
.
colo_attr
.
sharded_data_tensor
.
dtype
==
torch
.
float32
else
:
assert
param
.
colo_attr
.
sharded_data_tensor
.
dtype
==
torch
.
half
# the weights in the gate should be fp32
if
'gate'
in
name
:
assert
param
.
colo_attr
.
sharded_data_tensor
.
dtype
==
torch
.
float32
else
:
assert
param
.
colo_attr
.
sharded_data_tensor
.
dtype
==
torch
.
half
# the parameters in moe experts and its gate should not be sharded
if
(
'experts'
in
name
)
or
(
'gate'
in
name
)
or
(
'residual_combine'
in
name
):
assert
not
param
.
colo_attr
.
sharded_data_tensor
.
is_sharded
else
:
assert
param
.
colo_attr
.
sharded_data_tensor
.
is_sharded
# the parameters in moe experts and its gate should not be sharded
if
(
'experts'
in
name
)
or
(
'gate'
in
name
)
or
(
'residual_combine'
in
name
):
assert
not
param
.
colo_attr
.
sharded_data_tensor
.
is_sharded
assert
param
.
colo_attr
.
sharded_data_tensor
.
data_ptr
()
==
param
.
data
.
data_ptr
()
else
:
assert
param
.
colo_attr
.
sharded_data_tensor
.
is_sharded
# the parameters in moe experts is not replicated
if
'experts'
in
name
:
assert
not
param
.
is_replicated
else
:
assert
param
.
is_replicated
# the parameters in moe experts is not replicated
if
'experts'
in
name
:
assert
not
param
.
is_replicated
else
:
assert
param
.
is_replicated
if
param
.
colo_attr
.
param_is_sharded
:
assert
param
.
colo_attr
.
sharded_data_tensor
.
payload
.
device
.
type
==
init_device
.
type
,
\
f
'
{
param
.
colo_attr
.
sharded_data_tensor
.
payload
.
device
.
type
}
vs.
{
init_device
.
type
}
'
else
:
assert
param
.
colo_attr
.
sharded_data_tensor
.
payload
.
device
.
type
==
'cuda'
if
param
.
colo_attr
.
param_is_sharded
:
assert
param
.
colo_attr
.
sharded_data_tensor
.
payload
.
device
.
type
==
init_device
.
type
,
\
f
'
{
param
.
colo_attr
.
sharded_data_tensor
.
payload
.
device
.
type
}
vs.
{
init_device
.
type
}
'
else
:
assert
param
.
colo_attr
.
sharded_data_tensor
.
payload
.
device
.
type
==
'cuda'
def
_run_dist
(
rank
,
world_size
,
port
):
...
...
@@ -91,7 +91,6 @@ def _run_dist(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
,
4
])
@
pytest
.
mark
.
skip
(
"Under development"
)
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_moe_zero_init
(
world_size
):
run_func
=
partial
(
_run_dist
,
world_size
=
world_size
,
port
=
free_port
())
...
...
tests/test_moe/test_moe_zero_model.py
View file @
d7ecaf36
...
...
@@ -28,12 +28,9 @@ def run_model_test(enable_autocast, shard_strategy_class):
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'no_leaf_module'
)
_
,
train_dataloader
,
_
,
_
,
criterion
=
get_components_func
()
rm_torch_payload_on_the_fly
=
False
with
ZeroInitContext
(
target_device
=
torch
.
cuda
.
current_device
(),
with
ZeroInitContext
(
target_device
=
torch
.
device
(
'cuda'
,
torch
.
cuda
.
current_device
()),
shard_strategy
=
shard_strategy
,
shard_param
=
True
,
rm_torch_payload_on_the_fly
=
rm_torch_payload_on_the_fly
):
shard_param
=
True
):
zero_model
=
MoeModel
()
zero_model
=
ShardedModelV2
(
zero_model
,
shard_strategy
,
use_memory_tracer
=
True
)
...
...
tests/test_moe/test_moe_zero_optim.py
View file @
d7ecaf36
...
...
@@ -60,8 +60,7 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
with
ZeroInitContext
(
target_device
=
torch
.
device
(
'cpu'
)
if
cpu_offload
else
torch
.
device
(
f
'cuda:
{
get_current_device
()
}
'
),
shard_strategy
=
shard_strategy
,
shard_param
=
True
,
rm_torch_payload_on_the_fly
=
False
):
shard_param
=
True
):
zero_model
=
MoeModel
()
zero_model
=
ShardedModelV2
(
...
...
tests/test_zero_data_parallel/test_init_context.py
View file @
d7ecaf36
...
...
@@ -28,7 +28,6 @@ def run_model_test(init_device_type, shard_strategy_class):
for
get_components_func
in
non_distributed_component_funcs
:
model_builder
,
_
,
_
,
_
,
_
=
get_components_func
()
model_numel_tensor
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int
)
if
init_device_type
==
'cuda'
:
init_device
=
torch
.
device
(
f
"cuda:
{
get_current_device
()
}
"
)
elif
init_device_type
==
'cpu'
:
...
...
@@ -40,8 +39,7 @@ def run_model_test(init_device_type, shard_strategy_class):
with
ZeroInitContext
(
target_device
=
init_device
,
shard_strategy
=
shard_strategy_class
(),
shard_param
=
True
,
model_numel_tensor
=
model_numel_tensor
,
rm_torch_payload_on_the_fly
=
False
):
model_numel_tensor
=
model_numel_tensor
):
model
=
model_builder
(
checkpoint
=
True
)
for
param
in
model
.
parameters
():
...
...
tests/test_zero_data_parallel/test_shard_model_v2.py
View file @
d7ecaf36
...
...
@@ -29,12 +29,9 @@ def run_model_test(enable_autocast, shard_strategy_class):
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
_
,
_
,
criterion
=
get_components_func
()
rm_torch_payload_on_the_fly
=
False
with
ZeroInitContext
(
target_device
=
torch
.
cuda
.
current_device
(),
with
ZeroInitContext
(
target_device
=
torch
.
device
(
'cuda'
,
torch
.
cuda
.
current_device
()),
shard_strategy
=
shard_strategy
,
shard_param
=
True
,
rm_torch_payload_on_the_fly
=
rm_torch_payload_on_the_fly
):
shard_param
=
True
):
zero_model
=
model_builder
(
checkpoint
=
True
)
zero_model
=
ShardedModelV2
(
zero_model
,
shard_strategy
,
use_memory_tracer
=
True
)
...
...
tests/test_zero_data_parallel/test_sharded_optim_v2.py
View file @
d7ecaf36
...
...
@@ -60,8 +60,7 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
with
ZeroInitContext
(
target_device
=
torch
.
device
(
f
'cpu:0'
)
if
cpu_offload
else
torch
.
device
(
f
'cuda:
{
get_current_device
()
}
'
),
shard_strategy
=
shard_strategy
,
shard_param
=
True
,
rm_torch_payload_on_the_fly
=
False
):
shard_param
=
True
):
zero_model
=
model_builder
(
checkpoint
=
True
)
zero_model
=
ShardedModelV2
(
zero_model
,
...
...
tests/test_zero_data_parallel/test_state_dict.py
View file @
d7ecaf36
...
...
@@ -27,10 +27,9 @@ def run_zero_state_dict(shard_strategy_class):
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer
,
criterion
=
get_components_func
()
with
ZeroInitContext
(
target_device
=
torch
.
cuda
.
current_device
(),
with
ZeroInitContext
(
target_device
=
torch
.
device
(
'cuda'
,
torch
.
cuda
.
current_device
()
)
,
shard_strategy
=
shard_strategy
,
shard_param
=
True
,
rm_torch_payload_on_the_fly
=
False
):
shard_param
=
True
):
zero_model
=
model_builder
(
checkpoint
=
True
)
zero_model
=
ShardedModelV2
(
zero_model
,
shard_strategy
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment