Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
67b49282
Unverified
Commit
67b49282
authored
Apr 02, 2022
by
Jiarui Fang
Committed by
GitHub
Apr 02, 2022
Browse files
[zero] polish init context (#645)
parent
f5d3a9c2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
13 deletions
+25
-13
colossalai/zero/init_ctx/init_context.py
colossalai/zero/init_ctx/init_context.py
+25
-13
No files found.
colossalai/zero/init_ctx/init_context.py
View file @
67b49282
...
...
@@ -90,9 +90,11 @@ class ZeroContextConfig(object):
Args:
target_device (torch.device): The device where param data are after exiting the context.
replicated (bool, optional): Whether the param is replicated across data parallel group.
replicated (bool, optional): Whether the param is replicated across data parallel (DP) group.
We do not need to synchronize (reduce) the grads of the replicated params among DP group.
Some parameters are not replicated, e.g. parameters in MOE experts.
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
The process group among which tensors are sharded is assigned as an runtime arg.
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
This will reduce memory usage when initializing model.
But it's not suitable for all models, especially when there are `weight init` operations in `__init__`.
...
...
@@ -110,6 +112,9 @@ class ZeroContextConfig(object):
self
.
target_device
=
target_device
self
.
is_replicated
:
bool
=
replicated
self
.
shard_param
:
bool
=
shard_param
if
self
.
is_replicated
is
False
:
assert
self
.
shard_param
is
True
,
f
"ZeroContextConfig shard_param must be False when is_replicated is False"
self
.
rm_torch_payload_on_the_fly
:
bool
=
rm_torch_payload_on_the_fly
...
...
@@ -117,8 +122,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
"""A context to initialize model.
1. Convert the model to fp16.
2. The paramaters of the module are adapted to type ShardedParameter.
3. Shard the param and grad according to flag
s
.
2. The paramaters of the module are adapted to type
`
ShardedParameter
`
.
3. Shard the param and grad according to flag
`shard_param`
.
Args:
target_device (torch.device): The device where param data are after exiting the context.
...
...
@@ -144,7 +149,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
super
().
__init__
()
self
.
shard_strategy
=
shard_strategy
self
.
initialized_param_list
=
[]
# a list contains params that could be sharded.
self
.
shardable_param_list
=
[]
self
.
model_numel_tensor
=
model_numel_tensor
self
.
dp_process_group
=
dp_process_group
or
gpc
.
get_group
(
ParallelMode
.
DATA
)
...
...
@@ -181,21 +187,17 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
"""The callback function when exiting context.
"""
if
not
self
.
rm_torch_payload_on_the_fly
:
for
param
in
self
.
initialized
_param_list
:
for
param
in
self
.
shardable
_param_list
:
assert
hasattr
(
param
,
'colo_attr'
)
param
.
colo_attr
.
remove_torch_payload
()
del
self
.
initialized
_param_list
del
self
.
shardable
_param_list
def
_post_init_method
(
self
,
module
:
torch
.
nn
.
Module
):
"""
The function to call at the end of the constructor of each module.
NOTE() The module may be passed to this function multiple times.
"""
def
half_fn
(
t
:
torch
.
Tensor
):
return
t
.
half
()
if
t
.
is_floating_point
()
else
t
for
param
in
module
.
parameters
(
recurse
=
False
):
# avoid adapting a param to ShardedParam twice
if
hasattr
(
param
,
'colo_attr'
):
...
...
@@ -207,10 +209,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
param
.
is_replicated
=
self
.
is_replicated
# convert parameters to half
param_half
=
half_fn
(
param
)
param_half
=
cast_tensor_to_fp16
(
param
.
data
)
param
.
data
=
param_half
if
param
.
grad
is
not
None
:
grad_half
=
half_fn
(
param
.
grad
)
grad_half
=
cast_tensor_to_fp16
(
param
.
grad
)
param
.
grad
.
data
=
grad_half
# move torch parameters to the target device
...
...
@@ -223,7 +225,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if
self
.
shard_param
:
self
.
shard_strategy
.
shard
([
param
.
colo_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
self
.
initialized
_param_list
.
append
(
param
)
self
.
shardable
_param_list
.
append
(
param
)
# We must cast buffers
# If we use BN, buffers may be on CPU and Float
...
...
@@ -255,6 +257,16 @@ def no_shard_zero_context(is_replicated: bool = True) -> AbstractContextManager:
def
no_shard_zero_decrator
(
is_replicated
:
bool
=
True
):
"""
A decorator used to wrap an __init__ function of Module.
The parameters initialized by the model will not sharded.
is_replicated indicates the grad of the param won't be reduced among the data parallel process group.
>>> def MyModule(torch.nn.Module):
>>> @no_shard_zero_decrator(is_replicated = False)
>>> def __init__(self, ...)
>>> ....
"""
def
_wrapper
(
init_func
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment