Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
67b49282
Unverified
Commit
67b49282
authored
Apr 02, 2022
by
Jiarui Fang
Committed by
GitHub
Apr 02, 2022
Browse files
[zero] polish init context (#645)
parent
f5d3a9c2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
13 deletions
+25
-13
colossalai/zero/init_ctx/init_context.py
colossalai/zero/init_ctx/init_context.py
+25
-13
No files found.
colossalai/zero/init_ctx/init_context.py
View file @
67b49282
...
@@ -90,9 +90,11 @@ class ZeroContextConfig(object):
...
@@ -90,9 +90,11 @@ class ZeroContextConfig(object):
Args:
Args:
target_device (torch.device): The device where param data are after exiting the context.
target_device (torch.device): The device where param data are after exiting the context.
replicated (bool, optional): Whether the param is replicated across data parallel group.
replicated (bool, optional): Whether the param is replicated across data parallel (DP) group.
We do not need to synchronize (reduce) the grads of the replicated params among DP group.
Some parameters are not replicated, e.g. parameters in MOE experts.
Some parameters are not replicated, e.g. parameters in MOE experts.
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
The process group among which tensors are sharded is assigned as an runtime arg.
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
This will reduce memory usage when initializing model.
This will reduce memory usage when initializing model.
But it's not suitable for all models, especially when there are `weight init` operations in `__init__`.
But it's not suitable for all models, especially when there are `weight init` operations in `__init__`.
...
@@ -110,6 +112,9 @@ class ZeroContextConfig(object):
...
@@ -110,6 +112,9 @@ class ZeroContextConfig(object):
self
.
target_device
=
target_device
self
.
target_device
=
target_device
self
.
is_replicated
:
bool
=
replicated
self
.
is_replicated
:
bool
=
replicated
self
.
shard_param
:
bool
=
shard_param
self
.
shard_param
:
bool
=
shard_param
if
self
.
is_replicated
is
False
:
assert
self
.
shard_param
is
True
,
f
"ZeroContextConfig shard_param must be False when is_replicated is False"
self
.
rm_torch_payload_on_the_fly
:
bool
=
rm_torch_payload_on_the_fly
self
.
rm_torch_payload_on_the_fly
:
bool
=
rm_torch_payload_on_the_fly
...
@@ -117,8 +122,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -117,8 +122,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
"""A context to initialize model.
"""A context to initialize model.
1. Convert the model to fp16.
1. Convert the model to fp16.
2. The paramaters of the module are adapted to type ShardedParameter.
2. The paramaters of the module are adapted to type
`
ShardedParameter
`
.
3. Shard the param and grad according to flag
s
.
3. Shard the param and grad according to flag
`shard_param`
.
Args:
Args:
target_device (torch.device): The device where param data are after exiting the context.
target_device (torch.device): The device where param data are after exiting the context.
...
@@ -144,7 +149,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -144,7 +149,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
super
().
__init__
()
super
().
__init__
()
self
.
shard_strategy
=
shard_strategy
self
.
shard_strategy
=
shard_strategy
self
.
initialized_param_list
=
[]
# a list contains params that could be sharded.
self
.
shardable_param_list
=
[]
self
.
model_numel_tensor
=
model_numel_tensor
self
.
model_numel_tensor
=
model_numel_tensor
self
.
dp_process_group
=
dp_process_group
or
gpc
.
get_group
(
ParallelMode
.
DATA
)
self
.
dp_process_group
=
dp_process_group
or
gpc
.
get_group
(
ParallelMode
.
DATA
)
...
@@ -181,21 +187,17 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -181,21 +187,17 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
"""The callback function when exiting context.
"""The callback function when exiting context.
"""
"""
if
not
self
.
rm_torch_payload_on_the_fly
:
if
not
self
.
rm_torch_payload_on_the_fly
:
for
param
in
self
.
initialized
_param_list
:
for
param
in
self
.
shardable
_param_list
:
assert
hasattr
(
param
,
'colo_attr'
)
assert
hasattr
(
param
,
'colo_attr'
)
param
.
colo_attr
.
remove_torch_payload
()
param
.
colo_attr
.
remove_torch_payload
()
del
self
.
initialized
_param_list
del
self
.
shardable
_param_list
def
_post_init_method
(
self
,
module
:
torch
.
nn
.
Module
):
def
_post_init_method
(
self
,
module
:
torch
.
nn
.
Module
):
"""
"""
The function to call at the end of the constructor of each module.
The function to call at the end of the constructor of each module.
NOTE() The module may be passed to this function multiple times.
NOTE() The module may be passed to this function multiple times.
"""
"""
def
half_fn
(
t
:
torch
.
Tensor
):
return
t
.
half
()
if
t
.
is_floating_point
()
else
t
for
param
in
module
.
parameters
(
recurse
=
False
):
for
param
in
module
.
parameters
(
recurse
=
False
):
# avoid adapting a param to ShardedParam twice
# avoid adapting a param to ShardedParam twice
if
hasattr
(
param
,
'colo_attr'
):
if
hasattr
(
param
,
'colo_attr'
):
...
@@ -207,10 +209,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -207,10 +209,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
param
.
is_replicated
=
self
.
is_replicated
param
.
is_replicated
=
self
.
is_replicated
# convert parameters to half
# convert parameters to half
param_half
=
half_fn
(
param
)
param_half
=
cast_tensor_to_fp16
(
param
.
data
)
param
.
data
=
param_half
param
.
data
=
param_half
if
param
.
grad
is
not
None
:
if
param
.
grad
is
not
None
:
grad_half
=
half_fn
(
param
.
grad
)
grad_half
=
cast_tensor_to_fp16
(
param
.
grad
)
param
.
grad
.
data
=
grad_half
param
.
grad
.
data
=
grad_half
# move torch parameters to the target device
# move torch parameters to the target device
...
@@ -223,7 +225,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -223,7 +225,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if
self
.
shard_param
:
if
self
.
shard_param
:
self
.
shard_strategy
.
shard
([
param
.
colo_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
self
.
shard_strategy
.
shard
([
param
.
colo_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
self
.
initialized
_param_list
.
append
(
param
)
self
.
shardable
_param_list
.
append
(
param
)
# We must cast buffers
# We must cast buffers
# If we use BN, buffers may be on CPU and Float
# If we use BN, buffers may be on CPU and Float
...
@@ -255,6 +257,16 @@ def no_shard_zero_context(is_replicated: bool = True) -> AbstractContextManager:
...
@@ -255,6 +257,16 @@ def no_shard_zero_context(is_replicated: bool = True) -> AbstractContextManager:
def
no_shard_zero_decrator
(
is_replicated
:
bool
=
True
):
def
no_shard_zero_decrator
(
is_replicated
:
bool
=
True
):
"""
A decorator used to wrap an __init__ function of Module.
The parameters initialized by the model will not sharded.
is_replicated indicates the grad of the param won't be reduced among the data parallel process group.
>>> def MyModule(torch.nn.Module):
>>> @no_shard_zero_decrator(is_replicated = False)
>>> def __init__(self, ...)
>>> ....
"""
def
_wrapper
(
init_func
):
def
_wrapper
(
init_func
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment