Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
a2e61d61
Unverified
Commit
a2e61d61
authored
Mar 24, 2022
by
ver217
Committed by
GitHub
Mar 24, 2022
Browse files
[zero] zero init ctx enable rm_torch_payload_on_the_fly (#512)
* enable rm_torch_payload_on_the_fly * polish docstr
parent
81145208
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
13 deletions
+17
-13
colossalai/zero/init_ctx/init_context.py
colossalai/zero/init_ctx/init_context.py
+17
-13
No files found.
colossalai/zero/init_ctx/init_context.py
View file @
a2e61d61
...
@@ -83,22 +83,26 @@ class InsertPostInitMethodToModuleSubClasses(object):
...
@@ -83,22 +83,26 @@ class InsertPostInitMethodToModuleSubClasses(object):
class
ZeroInitContext
(
InsertPostInitMethodToModuleSubClasses
):
class
ZeroInitContext
(
InsertPostInitMethodToModuleSubClasses
):
r
"""
"""
A context to initialize model.
A context to initialize model.
1. Convert the model to fp16.
1. Convert the model to fp16.
2. The paramaters of the module are adapted to type ShardedParameter.
2. The paramaters of the module are adapted to type ShardedParameter.
3. Shard the param and grad according to flags.
3. Shard the param and grad according to flags.
target_device: the device where param data after exiting the context
Args:
shard_strategy: shard strategy instance
convert_fp16 (bool): Whether to convert params to fp16.
shard_param: is param sharded after exiting the context
target_device (torch.device): The device where param data after exiting the context.
shard_grad: is param sharded after exiting the context
shard_strategy (BaseShardStrategy): Shard strategy instance.
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
rm_torch_payload_on_the_fly:
shard_grad (bool, optional): Is param sharded after exiting the context. Defaults to False.
True: remove tensor payload on param.data after module init finished.
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
False: remove tensor payload on param.data afther the context exist.
This will reduce memory usage when initializing model.
But it's not suitable for all models, especially when there are `weight init` operations in `__init__`.
If set to `False`, remove tensor payload on param.data afther the context exist.
This is used when you add some logic to operate tensors in __init__ of module.
This is used when you add some logic to operate tensors in __init__ of module.
See torchvision resnet18.
See torchvision resnet18. Defaults to False.
model_numel_tensor (torch.Tensor, optional): A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int).
dp_process_group (Optional[ProcessGroup], optional): Data parallel process group. Defaults to None.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -110,14 +114,14 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -110,14 +114,14 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
rm_torch_payload_on_the_fly
:
bool
=
False
,
rm_torch_payload_on_the_fly
:
bool
=
False
,
model_numel_tensor
:
torch
.
Tensor
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int
),
model_numel_tensor
:
torch
.
Tensor
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int
),
dp_process_group
:
Optional
[
ProcessGroup
]
=
None
):
dp_process_group
:
Optional
[
ProcessGroup
]
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
convert_fp16
=
convert_fp16
self
.
convert_fp16
=
convert_fp16
self
.
target_device
=
target_device
self
.
target_device
=
target_device
self
.
shard_param
=
shard_param
self
.
shard_param
=
shard_param
self
.
shard_grad
=
shard_grad
self
.
shard_grad
=
shard_grad
self
.
shard_strategy
=
shard_strategy
self
.
shard_strategy
=
shard_strategy
# FIXME(jiaruifang) now setting it to True is invalid.
self
.
rm_torch_payload_on_the_fly
=
rm_torch_payload_on_the_fly
self
.
rm_torch_payload_on_the_fly
=
False
self
.
initialized_param_list
=
[]
self
.
initialized_param_list
=
[]
self
.
model_numel_tensor
=
model_numel_tensor
self
.
model_numel_tensor
=
model_numel_tensor
self
.
dp_process_group
=
dp_process_group
or
gpc
.
get_group
(
ParallelMode
.
DATA
)
self
.
dp_process_group
=
dp_process_group
or
gpc
.
get_group
(
ParallelMode
.
DATA
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment