Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
9f4fb3f2
Unverified
Commit
9f4fb3f2
authored
Nov 14, 2022
by
Jiarui Fang
Committed by
GitHub
Nov 14, 2022
Browse files
[ColoTensor] ColoInitContext initialize parameters in shard mode. (#1937)
parent
b42b6728
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
84 additions
and
5 deletions
+84
-5
colossalai/utils/model/colo_init_context.py
colossalai/utils/model/colo_init_context.py
+22
-3
tests/test_tensor/test_context.py
tests/test_tensor/test_context.py
+61
-0
tests/test_tensor/test_sharded_linear.py
tests/test_tensor/test_sharded_linear.py
+0
-1
tests/test_tensor/test_tp_with_zero.py
tests/test_tensor/test_tp_with_zero.py
+1
-1
No files found.
colossalai/utils/model/colo_init_context.py
View file @
9f4fb3f2
from
typing
import
Iterator
,
Tuple
,
Union
from
typing
import
Dict
,
Iterator
,
Optional
,
Tuple
,
Union
import
torch
from
torch
import
nn
...
...
@@ -36,7 +36,10 @@ def ColoModulize(module):
class
ColoInitContext
(
InsertPostInitMethodToModuleSubClasses
):
def
__init__
(
self
,
device
:
torch
.
device
=
torch
.
device
(
'cpu'
),
dtype
:
torch
.
dtype
=
torch
.
float
):
def
__init__
(
self
,
device
:
torch
.
device
=
torch
.
device
(
'cpu'
),
dtype
:
torch
.
dtype
=
torch
.
float
,
default_shard_plan
:
Optional
[
Dict
]
=
None
):
"""
Args:
device (torch.device): the device where parameters initialized are resident. Defaults to torch.device('cpu').
...
...
@@ -47,6 +50,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
self
.
_dtype
=
dtype
self
.
_register_colo_modules
()
self
.
_default_shard_plan
=
default_shard_plan
def
_register_colo_modules
(
self
):
register_colo_module
(
torch
.
nn
.
Linear
,
ColoLinear
())
...
...
@@ -64,6 +68,10 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
if
hasattr
(
module
,
'_colo_visited'
):
return
if
self
.
_default_shard_plan
is
not
None
:
default_pg
=
self
.
_default_shard_plan
.
get
(
'pg'
,
None
)
default_shard_spec
=
self
.
_default_shard_plan
.
get
(
'shard_spec'
,
None
)
name_list
=
[]
for
name
,
param
in
_named_params_with_replica
(
module
):
if
isinstance
(
param
,
ColoTensor
):
...
...
@@ -91,7 +99,18 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
# TODO(jiaruifang) we initialize a Default PG memory
colo_param
=
ColoParameter
(
param
.
to
(
device
=
self
.
_device
,
dtype
=
self
.
_dtype
),
requires_grad
=
requires_grad
)
# add mapping record
# if default_shard_plan exists, shard the param during initialization.
# This can reduce the model size after initialization.
# NOTE() embedding usually can not be correctly sharded. So I use except to handle
# the param that can not be sharded by the default plan
if
self
.
_default_shard_plan
is
not
None
:
colo_param
.
set_process_group
(
default_pg
)
try
:
colo_param
.
set_dist_spec
(
default_shard_spec
)
except
:
pass
replaced_tensors
[
param
]
=
colo_param
delattr
(
submodule
,
param_name
)
setattr
(
submodule
,
param_name
,
colo_param
)
...
...
tests/test_tensor/test_context.py
View file @
9f4fb3f2
from
functools
import
partial
import
pytest
import
torch
import
torch.multiprocessing
as
mp
import
colossalai
from
colossalai.tensor
import
(
ColoParameter
,
ColoTensorSpec
,
ComputePattern
,
ComputeSpec
,
ProcessGroup
,
ReplicaSpec
,
ShardSpec
,
)
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.test_tensor.common_utils
import
set_seed
def
run_colo_init_context
(
rank
:
int
,
world_size
:
int
,
port
:
int
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
# make sure seed of each process is the same, so the params are consistent among processes and the params are exactly replicated.
set_seed
(
42
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
# keep parameters replicated during init
with
ColoInitContext
(
device
=
get_current_device
()):
model1
=
model_builder
()
# shard the parameters during init
set_seed
(
42
)
shard_spec
=
ReplicaSpec
()
# ShardSpec(dims=[0], num_partitions=[world_size])
default_shard_plan
=
{
'pg'
:
ProcessGroup
(
tp_degree
=
world_size
),
'shard_spec'
:
shard_spec
}
with
ColoInitContext
(
device
=
get_current_device
(),
default_shard_plan
=
default_shard_plan
):
model2
=
model_builder
()
# reshard both models
new_shard
=
ShardSpec
(
dims
=
[
-
1
],
num_partitions
=
[
world_size
])
for
p1
,
p2
in
zip
(
model1
.
parameters
(),
model2
.
parameters
()):
p1
:
ColoParameter
=
p1
p1
.
set_process_group
(
ProcessGroup
(
tp_degree
=
world_size
))
p1
.
set_dist_spec
(
new_shard
)
p2
.
set_dist_spec
(
new_shard
)
for
p1
,
p2
in
zip
(
model1
.
parameters
(),
model2
.
parameters
()):
assert
(
torch
.
allclose
(
p1
,
p2
))
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
def
test_colo_init_context
(
world_size
):
run_func
=
partial
(
run_colo_init_context
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_colo_init_context
(
2
)
tests/test_tensor/test_sharded_linear.py
View file @
9f4fb3f2
from
functools
import
partial
from
lib2to3
import
pgen2
import
pytest
import
torch
...
...
tests/test_tensor/test_tp_with_zero.py
View file @
9f4fb3f2
...
...
@@ -18,7 +18,7 @@ from colossalai.utils.cuda import get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.zero
import
ZeroOptimizer
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.test_tensor.common_utils
import
set_seed
,
tensor_equal
,
tensor_shard_equal
from
tests.test_tensor.common_utils
import
set_seed
,
tensor_shard_equal
from
tests.test_tensor.model.test_gpt2
import
init_megatron_spec
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment