Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
52c6ad26
Unverified
Commit
52c6ad26
authored
Nov 15, 2022
by
Jiarui Fang
Committed by
GitHub
Nov 15, 2022
Browse files
[ColoTensor] reconfig ColoInitContext, decouple default_pg and default_dist_spec. (#1953)
parent
598d456d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
14 deletions
+20
-14
colossalai/utils/model/colo_init_context.py
colossalai/utils/model/colo_init_context.py
+14
-11
tests/test_tensor/test_context.py
tests/test_tensor/test_context.py
+6
-3
No files found.
colossalai/utils/model/colo_init_context.py
View file @
52c6ad26
...
...
@@ -4,7 +4,7 @@ import torch
from
torch
import
nn
from
colossalai.nn.parallel.layers
import
ColoEmbedding
,
ColoLinear
,
register_colo_module
from
colossalai.tensor
import
ColoParameter
,
ColoTensor
from
colossalai.tensor
import
ColoParameter
,
ColoTensor
,
ProcessGroup
,
ShardSpec
from
.utils
import
InsertPostInitMethodToModuleSubClasses
...
...
@@ -39,18 +39,22 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
def
__init__
(
self
,
device
:
torch
.
device
=
torch
.
device
(
'cpu'
),
dtype
:
torch
.
dtype
=
torch
.
float
,
default_shard_plan
:
Optional
[
Dict
]
=
None
):
default_pg
:
Optional
[
ProcessGroup
]
=
None
,
default_dist_spec
=
None
):
"""
Args:
device (torch.device): the device where parameters initialized are resident. Defaults to torch.device('cpu').
dtype (torch.dtype): the dtype of parameters initialized. Defults to torch.float.
default_pg (ProcessGroup): the default process group for all initialized parameters.
default_dist_spec: the default distributed specifications.
"""
super
().
__init__
()
self
.
_device
=
device
self
.
_dtype
=
dtype
self
.
_register_colo_modules
()
self
.
_default_shard_plan
=
default_shard_plan
self
.
_default_pg
=
default_pg
self
.
_default_dist_spec
=
default_dist_spec
def
_register_colo_modules
(
self
):
register_colo_module
(
torch
.
nn
.
Linear
,
ColoLinear
())
...
...
@@ -68,10 +72,6 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
if
hasattr
(
module
,
'_colo_visited'
):
return
if
self
.
_default_shard_plan
is
not
None
:
default_pg
=
self
.
_default_shard_plan
.
get
(
'pg'
,
None
)
default_shard_spec
=
self
.
_default_shard_plan
.
get
(
'shard_spec'
,
None
)
name_list
=
[]
for
name
,
param
in
_named_params_with_replica
(
module
):
if
isinstance
(
param
,
ColoTensor
):
...
...
@@ -96,7 +96,8 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
else
:
# detaching tensor is necessary for optimizers.
requires_grad
=
param
.
requires_grad
# TODO(jiaruifang) we initialize a Default PG memory
# param is the global tensor.
colo_param
=
ColoParameter
(
param
.
to
(
device
=
self
.
_device
,
dtype
=
self
.
_dtype
),
requires_grad
=
requires_grad
)
...
...
@@ -104,10 +105,12 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
# This can reduce the model size after initialization.
# NOTE() embedding usually can not be correctly sharded. So I use except to handle
# the param that can not be sharded by the default plan
if
self
.
_default_shard_plan
is
not
None
:
colo_param
.
set_process_group
(
default_pg
)
if
self
.
_default_pg
is
not
None
:
colo_param
.
set_process_group
(
self
.
_default_pg
)
if
self
.
_default_dist_spec
is
not
None
:
try
:
colo_param
.
set_dist_spec
(
default_
shard
_spec
)
colo_param
.
set_dist_spec
(
self
.
_
default_
dist
_spec
)
except
:
pass
...
...
tests/test_tensor/test_context.py
View file @
52c6ad26
...
...
@@ -37,9 +37,12 @@ def run_colo_init_context(rank: int, world_size: int, port: int):
# shard the parameters during init
set_seed
(
42
)
shard_spec
=
ReplicaSpec
()
# ShardSpec(dims=[0], num_partitions=[world_size])
default_shard_plan
=
{
'pg'
:
ProcessGroup
(
tp_degree
=
world_size
),
'shard_spec'
:
shard_spec
}
with
ColoInitContext
(
device
=
get_current_device
(),
default_shard_plan
=
default_shard_plan
):
# If using ShardSpec, the assertations will failed.
# But it is not a bug, the initialized values are not consist with the original one.
# shard_spec = ShardSpec(dims=[0], num_partitions=[world_size])
default_pg
=
ProcessGroup
(
tp_degree
=
world_size
)
with
ColoInitContext
(
device
=
get_current_device
(),
default_pg
=
default_pg
,
default_dist_spec
=
shard_spec
):
model2
=
model_builder
()
# reshard both models
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment