Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
a2e61d61
Unverified
Commit
a2e61d61
authored
Mar 24, 2022
by
ver217
Committed by
GitHub
Mar 24, 2022
Browse files
[zero] zero init ctx enable rm_torch_payload_on_the_fly (#512)
* enable rm_torch_payload_on_the_fly * polish docstr
parent
81145208
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
13 deletions
+17
-13
colossalai/zero/init_ctx/init_context.py
colossalai/zero/init_ctx/init_context.py
+17
-13
No files found.
colossalai/zero/init_ctx/init_context.py
View file @
a2e61d61
...
...
@@ -83,22 +83,26 @@ class InsertPostInitMethodToModuleSubClasses(object):
class
ZeroInitContext
(
InsertPostInitMethodToModuleSubClasses
):
r
"""
A context to initialize model.
"""
A context to initialize model.
1. Convert the model to fp16.
2. The paramaters of the module are adapted to type ShardedParameter.
3. Shard the param and grad according to flags.
target_device: the device where param data after exiting the context
shard_strategy: shard strategy instance
shard_param: is param sharded after exiting the context
shard_grad: is param sharded after exiting the context
rm_torch_payload_on_the_fly:
True: remove tensor payload on param.data after module init finished.
False: remove tensor payload on param.data afther the context exist.
Args:
convert_fp16 (bool): Whether to convert params to fp16.
target_device (torch.device): The device where param data after exiting the context.
shard_strategy (BaseShardStrategy): Shard strategy instance.
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
shard_grad (bool, optional): Is param sharded after exiting the context. Defaults to False.
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
This will reduce memory usage when initializing model.
But it's not suitable for all models, especially when there are `weight init` operations in `__init__`.
If set to `False`, remove tensor payload on param.data afther the context exist.
This is used when you add some logic to operate tensors in __init__ of module.
See torchvision resnet18.
See torchvision resnet18. Defaults to False.
model_numel_tensor (torch.Tensor, optional): A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int).
dp_process_group (Optional[ProcessGroup], optional): Data parallel process group. Defaults to None.
"""
def
__init__
(
self
,
...
...
@@ -110,14 +114,14 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
rm_torch_payload_on_the_fly
:
bool
=
False
,
model_numel_tensor
:
torch
.
Tensor
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int
),
dp_process_group
:
Optional
[
ProcessGroup
]
=
None
):
super
().
__init__
()
self
.
convert_fp16
=
convert_fp16
self
.
target_device
=
target_device
self
.
shard_param
=
shard_param
self
.
shard_grad
=
shard_grad
self
.
shard_strategy
=
shard_strategy
# FIXME(jiaruifang) now setting it to True is invalid.
self
.
rm_torch_payload_on_the_fly
=
False
self
.
rm_torch_payload_on_the_fly
=
rm_torch_payload_on_the_fly
self
.
initialized_param_list
=
[]
self
.
model_numel_tensor
=
model_numel_tensor
self
.
dp_process_group
=
dp_process_group
or
gpc
.
get_group
(
ParallelMode
.
DATA
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment