Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
036404ca
Unverified
Commit
036404ca
authored
Apr 02, 2022
by
Jiarui Fang
Committed by
GitHub
Apr 02, 2022
Browse files
Revert "[zero] polish init context (#645)" (#657)
parent
b31daed4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
25 deletions
+13
-25
colossalai/zero/init_ctx/init_context.py
colossalai/zero/init_ctx/init_context.py
+13
-25
No files found.
colossalai/zero/init_ctx/init_context.py
View file @
036404ca
...
...
@@ -90,11 +90,9 @@ class ZeroContextConfig(object):
Args:
target_device (torch.device): The device where param data are after exiting the context.
replicated (bool, optional): Whether the param is replicated across data parallel (DP) group.
We do not need to synchronize (reduce) the grads of the replicated params among DP group.
replicated (bool, optional): Whether the param is replicated across data parallel group.
Some parameters are not replicated, e.g. parameters in MOE experts.
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
The process group among which tensors are sharded is assigned as an runtime arg.
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
This will reduce memory usage when initializing model.
But it's not suitable for all models, especially when there are `weight init` operations in `__init__`.
...
...
@@ -112,9 +110,6 @@ class ZeroContextConfig(object):
self
.
target_device
=
target_device
self
.
is_replicated
:
bool
=
replicated
self
.
shard_param
:
bool
=
shard_param
if
self
.
is_replicated
is
False
:
assert
self
.
shard_param
is
True
,
f
"ZeroContextConfig shard_param must be False when is_replicated is False"
self
.
rm_torch_payload_on_the_fly
:
bool
=
rm_torch_payload_on_the_fly
...
...
@@ -122,8 +117,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
"""A context to initialize model.
1. Convert the model to fp16.
2. The paramaters of the module are adapted to type
`
ShardedParameter
`
.
3. Shard the param and grad according to flag
`shard_param`
.
2. The paramaters of the module are adapted to type ShardedParameter.
3. Shard the param and grad according to flag
s
.
Args:
target_device (torch.device): The device where param data are after exiting the context.
...
...
@@ -149,8 +144,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
super
().
__init__
()
self
.
shard_strategy
=
shard_strategy
# a list contains params that could be sharded.
self
.
shardable_param_list
=
[]
self
.
initialized_param_list
=
[]
self
.
model_numel_tensor
=
model_numel_tensor
self
.
dp_process_group
=
dp_process_group
or
gpc
.
get_group
(
ParallelMode
.
DATA
)
...
...
@@ -187,17 +181,21 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
"""The callback function when exiting context.
"""
if
not
self
.
rm_torch_payload_on_the_fly
:
for
param
in
self
.
shardable
_param_list
:
for
param
in
self
.
initialized
_param_list
:
assert
hasattr
(
param
,
'colo_attr'
)
param
.
colo_attr
.
remove_torch_payload
()
del
self
.
shardable
_param_list
del
self
.
initialized
_param_list
def
_post_init_method
(
self
,
module
:
torch
.
nn
.
Module
):
"""
The function to call at the end of the constructor of each module.
NOTE() The module may be passed to this function multiple times.
"""
def
half_fn
(
t
:
torch
.
Tensor
):
return
t
.
half
()
if
t
.
is_floating_point
()
else
t
for
param
in
module
.
parameters
(
recurse
=
False
):
# avoid adapting a param to ShardedParam twice
if
hasattr
(
param
,
'colo_attr'
):
...
...
@@ -209,10 +207,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
param
.
is_replicated
=
self
.
is_replicated
# convert parameters to half
param_half
=
cast_tensor_to_fp16
(
param
.
data
)
param_half
=
half_fn
(
param
)
param
.
data
=
param_half
if
param
.
grad
is
not
None
:
grad_half
=
cast_tensor_to_fp16
(
param
.
grad
)
grad_half
=
half_fn
(
param
.
grad
)
param
.
grad
.
data
=
grad_half
# move torch parameters to the target device
...
...
@@ -225,7 +223,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if
self
.
shard_param
:
self
.
shard_strategy
.
shard
([
param
.
colo_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
self
.
shardable
_param_list
.
append
(
param
)
self
.
initialized
_param_list
.
append
(
param
)
# We must cast buffers
# If we use BN, buffers may be on CPU and Float
...
...
@@ -257,16 +255,6 @@ def no_shard_zero_context(is_replicated: bool = True) -> AbstractContextManager:
def
no_shard_zero_decrator
(
is_replicated
:
bool
=
True
):
"""
A decorator used to wrap an __init__ function of Module.
The parameters initialized by the model will not sharded.
is_replicated indicates the grad of the param won't be reduced among the data parallel process group.
>>> def MyModule(torch.nn.Module):
>>> @no_shard_zero_decrator(is_replicated = False)
>>> def __init__(self, ...)
>>> ....
"""
def
_wrapper
(
init_func
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment