Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
dd92b90a
Unverified
Commit
dd92b90a
authored
Apr 19, 2022
by
ver217
Committed by
GitHub
Apr 19, 2022
Browse files
[DO NOT MERGE] [zero] init fp16 params directly in ZeroInitContext (#808)
* init fp16 param directly * polish code
parent
227d1cd4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
4 deletions
+11
-4
colossalai/zero/init_ctx/init_context.py
colossalai/zero/init_ctx/init_context.py
+11
-4
No files found.
colossalai/zero/init_ctx/init_context.py
View file @
dd92b90a
...
@@ -23,14 +23,17 @@ def _substitute_init_recursively(cls, func):
...
@@ -23,14 +23,17 @@ def _substitute_init_recursively(cls, func):
class
InsertPostInitMethodToModuleSubClasses
(
object
):
class
InsertPostInitMethodToModuleSubClasses
(
object
):
def
__init__
(
self
):
def
__init__
(
self
,
default_dtype
:
Optional
[
torch
.
dtype
]
=
None
):
pass
self
.
_old_default_dtype
=
None
self
.
_default_dtype
=
default_dtype
def
__enter__
(
self
):
def
__enter__
(
self
):
r
"""
r
"""
Enter the context scope.
Enter the context scope.
"""
"""
if
self
.
_default_dtype
is
not
None
:
self
.
_old_default_dtype
=
torch
.
get_default_dtype
()
torch
.
set_default_dtype
(
self
.
_default_dtype
)
def
preprocess_after
(
f
):
def
preprocess_after
(
f
):
@
functools
.
wraps
(
f
)
@
functools
.
wraps
(
f
)
...
@@ -61,6 +64,8 @@ class InsertPostInitMethodToModuleSubClasses(object):
...
@@ -61,6 +64,8 @@ class InsertPostInitMethodToModuleSubClasses(object):
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
if
self
.
_default_dtype
is
not
None
:
torch
.
set_default_dtype
(
self
.
_old_default_dtype
)
def
_disable_class
(
cls
):
def
_disable_class
(
cls
):
cls
.
__init__
=
cls
.
_old_init
cls
.
__init__
=
cls
.
_old_init
...
@@ -123,6 +128,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -123,6 +128,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
shard_strategy (BaseShardStrategy): Shard strategy instance.
shard_strategy (BaseShardStrategy): Shard strategy instance.
seed (int, optional): Random seed for weight initialization
seed (int, optional): Random seed for weight initialization
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
default_dtype (torch.dtype, optional): If it's not None, parameters will be initialized as ``default_dtype`` then converted to fp16.
model_numel_tensor (torch.Tensor, optional): A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int).
model_numel_tensor (torch.Tensor, optional): A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int).
"""
"""
...
@@ -131,9 +137,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -131,9 +137,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
shard_strategy
:
BaseShardStrategy
,
shard_strategy
:
BaseShardStrategy
,
seed
:
int
=
2
**
10
-
1
,
seed
:
int
=
2
**
10
-
1
,
shard_param
:
bool
=
False
,
shard_param
:
bool
=
False
,
default_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
model_numel_tensor
:
torch
.
Tensor
=
torch
.
zeros
(
1
,
dtype
=
torch
.
long
)):
model_numel_tensor
:
torch
.
Tensor
=
torch
.
zeros
(
1
,
dtype
=
torch
.
long
)):
super
().
__init__
()
super
().
__init__
(
default_dtype
=
default_dtype
)
self
.
shard_strategy
=
shard_strategy
self
.
shard_strategy
=
shard_strategy
self
.
param_list
=
[]
self
.
param_list
=
[]
self
.
model_numel_tensor
=
model_numel_tensor
self
.
model_numel_tensor
=
model_numel_tensor
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment