Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
80aed29c
Unverified
Commit
80aed29c
authored
Mar 21, 2023
by
YH
Committed by
GitHub
Mar 21, 2023
Browse files
[zero] Refactor ZeroContextConfig class using dataclass (#3186)
parent
9d644ff0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
20 deletions
+19
-20
colossalai/zero/init_ctx/init_context.py
colossalai/zero/init_ctx/init_context.py
+19
-20
No files found.
colossalai/zero/init_ctx/init_context.py
View file @
80aed29c
import
contextlib
import
contextlib
import
functools
import
functools
from
typing
import
Optional
from
contextlib
import
AbstractContextManager
from
contextlib
import
AbstractContextManager
from
dataclasses
import
dataclass
from
typing
import
Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context.singleton_meta
import
SingletonMeta
from
colossalai.context.singleton_meta
import
SingletonMeta
from
colossalai.core
import
global_context
as
gpc
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils.model.utils
import
InsertPostInitMethodToModuleSubClasses
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_model._utils
import
cast_tensor_to_fp16
from
colossalai.zero.sharded_model._utils
import
cast_tensor_to_fp16
from
colossalai.zero.sharded_model.sharded_model_v2
import
ShardedModelV2
from
colossalai.zero.sharded_model.sharded_model_v2
import
ShardedModelV2
from
colossalai.zero.sharded_param
import
ShardedParamV2
from
colossalai.zero.sharded_param
import
ShardedParamV2
from
colossalai.utils.model.utils
import
InsertPostInitMethodToModuleSubClasses
class
ZeroContextConfig
(
object
):
@
dataclass
class
ZeroContextConfig
:
"""The configuration used to control zero context initialization.
"""The configuration used to control zero context initialization.
Args:
Args:
target_device (torch.device): The device where param data are after exiting the context.
target_device (torch.device): The device where param data are after exiting the context.
replicated (bool, optional): Whether the param is replicated across data parallel group.
is_
replicated (bool, optional): Whether the param is replicated across data parallel group.
Some parameters are not replicated, e.g. parameters in MOE experts.
Some parameters are not replicated, e.g. parameters in MOE experts.
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
"""
"""
def
__init__
(
self
,
target_device
:
torch
.
device
,
replicated
:
bool
=
True
,
shard_param
:
bool
=
False
):
target_device
:
torch
.
device
super
().
__init__
()
is_replicated
:
bool
=
True
shard_param
:
bool
=
False
if
shard_param
:
def
__post_init__
(
self
):
assert
replicated
,
"Non-replicated parameters can't be sharded."
if
self
.
shard_param
:
assert
self
.
is_replicated
,
"Non-replicated parameters can't be sharded."
# replicated no-shard parameters should locate in cuda, since we will broadcast them soon
if
self
.
is_replicated
and
not
self
.
shard_param
:
if
replicated
and
not
shard_param
:
assert
self
.
target_device
.
type
==
'cuda'
,
"Replicated no-shard parameters should be located in cuda."
assert
target_device
.
type
==
'cuda'
,
"Replicated no-shard paramters should locate in cuda."
self
.
target_device
=
target_device
self
.
is_replicated
:
bool
=
replicated
self
.
shard_param
:
bool
=
shard_param
class
ZeroInitContext
(
InsertPostInitMethodToModuleSubClasses
):
class
ZeroInitContext
(
InsertPostInitMethodToModuleSubClasses
):
...
@@ -74,7 +73,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -74,7 +73,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
self
.
seed
=
seed
self
.
seed
=
seed
self
.
dp_process_group
=
gpc
.
get_group
(
ParallelMode
.
DATA
)
self
.
dp_process_group
=
gpc
.
get_group
(
ParallelMode
.
DATA
)
self
.
config
=
ZeroContextConfig
(
target_device
=
target_device
,
replicated
=
True
,
shard_param
=
shard_param
)
self
.
config
=
ZeroContextConfig
(
target_device
=
target_device
,
is_
replicated
=
True
,
shard_param
=
shard_param
)
ZeroContextMgr
().
current_context
=
self
ZeroContextMgr
().
current_context
=
self
...
@@ -124,7 +123,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -124,7 +123,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
return
fan_in
,
fan_out
return
fan_in
,
fan_out
def
_pre_context_exec
(
self
):
def
_pre_context_exec
(
self
):
"""
"""
The Callback function when entering the context
The Callback function when entering the context
"""
"""
self
.
logger
=
get_dist_logger
(
"ZeroInitContext"
)
self
.
logger
=
get_dist_logger
(
"ZeroInitContext"
)
...
@@ -248,7 +247,7 @@ class ZeroContextMgr(metaclass=SingletonMeta):
...
@@ -248,7 +247,7 @@ class ZeroContextMgr(metaclass=SingletonMeta):
def
no_shard_zero_context
(
is_replicated
:
bool
=
True
)
->
AbstractContextManager
:
def
no_shard_zero_context
(
is_replicated
:
bool
=
True
)
->
AbstractContextManager
:
return
ZeroContextMgr
().
hijack_context_config
(
target_device
=
torch
.
device
(
'cuda'
,
torch
.
cuda
.
current_device
()),
return
ZeroContextMgr
().
hijack_context_config
(
target_device
=
torch
.
device
(
'cuda'
,
torch
.
cuda
.
current_device
()),
replicated
=
is_replicated
,
is_
replicated
=
is_replicated
,
shard_param
=
False
)
shard_param
=
False
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment