Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
3b500984
Unverified
Commit
3b500984
authored
Jul 08, 2022
by
Jiarui Fang
Committed by
GitHub
Jul 08, 2022
Browse files
[tensor] fix some unittests (#1234)
parent
a45ddf2d
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
27 additions
and
11 deletions
+27
-11
colossalai/nn/_ops/linear.py
colossalai/nn/_ops/linear.py
+3
-2
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+6
-3
colossalai/utils/model/colo_init_context.py
colossalai/utils/model/colo_init_context.py
+5
-2
tests/test_ddp/test_ddp_state_dict.py
tests/test_ddp/test_ddp_state_dict.py
+9
-1
tests/test_tensor/test_model.py
tests/test_tensor/test_model.py
+2
-3
tests/test_utils/test_activation_checkpointing.py
tests/test_utils/test_activation_checkpointing.py
+1
-0
tests/test_utils/test_colo_checkpoint.py
tests/test_utils/test_colo_checkpoint.py
+1
-0
No files found.
colossalai/nn/_ops/linear.py
View file @
3b500984
...
...
@@ -11,18 +11,19 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
# Input:S[1] x Weight:S[0] = Output:P
# All-Reduce(Output) + bias = res
# Input:S[1]
pg
=
weight
.
get_process_group
()
input_tensor
=
input_tensor
.
convert_to_dist_spec
(
distspec
.
shard
([
-
1
],
[
weight
.
get_tp_world_size
()]))
# Output:P
partial_output
=
F
.
linear
(
input_tensor
,
weight
)
# Reduce(Output)
output
=
reduce_input
(
partial_output
,
weight
.
get_process_group
())
output
=
reduce_input
(
partial_output
,
pg
)
# Bias
if
bias
is
not
None
:
assert
not
bias
.
has_compute_spec
(),
'Invalid bias spec for 1Drow Linear op'
output
=
output
+
bias
pg
=
weight
.
get_process_group
()
output
=
ColoTensor
.
from_torch_tensor
(
output
,
spec
=
ColoTensorSpec
(
pg
,
distspec
.
replicate
()))
return
output
...
...
colossalai/tensor/colo_tensor.py
View file @
3b500984
...
...
@@ -72,7 +72,7 @@ class ColoTensor(torch.Tensor):
def
__init__
(
self
,
data
:
torch
.
Tensor
,
spec
:
Optional
[
ColoTensorSpec
]
=
None
)
->
None
:
# If not set spec, use a DP process group and replicate dist spec
if
not
spec
:
if
spec
is
None
:
self
.
has_initialized
=
False
self
.
dist_spec
=
distspec
.
replicate
()
self
.
compute_spec
=
None
...
...
@@ -81,7 +81,10 @@ class ColoTensor(torch.Tensor):
self
.
has_initialized
=
True
self
.
dist_spec
=
spec
.
dist_attr
self
.
compute_spec
=
spec
.
compute_attr
self
.
process_group
=
spec
.
pg
if
spec
.
pg
is
None
:
self
.
process_group
=
ProcessGroup
()
else
:
self
.
process_group
=
spec
.
pg
self
.
_type
=
TensorType
.
NONMODEL
self
.
_graph_node
=
None
...
...
@@ -125,7 +128,7 @@ class ColoTensor(torch.Tensor):
dist_spec (_DistSpec): target dist spec.
"""
assert
isinstance
(
dist_spec
,
_DistSpec
)
assert
self
.
process_group
assert
self
.
process_group
is
not
None
self
.
_convert_to_dist_spec
(
dist_spec
)
def
set_tensor_spec
(
self
,
dist_spec
,
compute_spec
):
...
...
colossalai/utils/model/colo_init_context.py
View file @
3b500984
from
.utils
import
InsertPostInitMethodToModuleSubClasses
import
torch
from
colossalai.tensor
import
ColoTensor
,
ColoParameter
,
distspec
from
colossalai.tensor
import
ColoTensor
,
ColoParameter
,
distspec
,
ProcessGroup
from
colossalai.nn.parallel.layers
import
register_colo_module
,
\
ColoLinear
,
ColoEmbedding
...
...
@@ -47,8 +47,11 @@ def colo_state_dict(self, destination=None, prefix='', keep_vars=False, state_di
has_dist_parameter
=
True
mapping1
[
id
(
param
)]
=
copy
(
param
.
dist_spec
)
mapping2
[
id
(
param
)]
=
copy
(
param
.
compute_spec
)
mapping3
[
id
(
param
)]
=
param
.
get_process_group
()
# TODO(jiaruifang) fixme, we should elegently handle the default PG in init context
if
param
.
get_process_group
()
is
None
:
param
.
process_group
=
ProcessGroup
()
param
.
set_dist_spec
(
distspec
.
replicate
())
mapping3
[
id
(
param
)]
=
param
.
get_process_group
()
param
.
process_group
=
None
# TODO: fix when keep_vars = True
...
...
tests/test_ddp/test_ddp_state_dict.py
View file @
3b500984
...
...
@@ -13,7 +13,7 @@ from colossalai.nn.parallel import ZeroDDP, ColoDDP
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
typing
import
Callable
from
collections
import
OrderedDict
from
colossalai.tensor
import
ProcessGroup
from
colossalai.tensor
import
ProcessGroup
,
ColoParameter
def
check_state_dict_equal
(
state_dict
:
OrderedDict
,
other_state_dict
:
OrderedDict
):
...
...
@@ -43,7 +43,15 @@ def run_state_dict(ddp_init_func: Callable[[torch.nn.Module], ColoDDP]):
model
=
model_builder
()
model
=
ddp_init_func
(
model
)
torch_state_dict
=
torch_model
.
state_dict
()
for
param
in
model
.
parameters
():
if
isinstance
(
param
,
ColoParameter
):
assert
param
.
get_process_group
()
is
not
None
model
.
load_state_dict
(
torch_state_dict
)
for
param
in
model
.
parameters
():
if
isinstance
(
param
,
ColoParameter
):
assert
param
.
get_process_group
()
is
not
None
state_dict
=
model
.
state_dict
()
check_state_dict_equal
(
torch_state_dict
,
state_dict
)
...
...
tests/test_tensor/test_model.py
View file @
3b500984
...
...
@@ -186,7 +186,6 @@ def test_model_parameters():
assert
param_cnt
==
2
# @pytest.mark.skip
def
test_colo_optimizer
():
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'simple_net'
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
...
...
@@ -316,7 +315,7 @@ def run_model_dist(rank, world_size, port):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
for
name
in
[
'simple_net'
]:
run_1d_row_tp
(
name
)
for
name
in
[
'bert'
,
'simple_net'
]:
for
name
in
[
'simple_net'
]:
run_1d_hybrid_tp
(
name
)
...
...
@@ -346,6 +345,6 @@ def test_pretrain_load(world_size):
if
__name__
==
'__main__'
:
# test_model_parameters()
# test_colo_optimizer()
# test_colo_opt
g
imizer()
test_model
(
4
)
# test_pretrain_load(4)
tests/test_utils/test_activation_checkpointing.py
View file @
3b500984
...
...
@@ -17,6 +17,7 @@ def forward(x, weight):
@
pytest
.
mark
.
gpu
@
pytest
.
mark
.
skip
(
"set seed error"
)
@
pytest
.
mark
.
parametrize
(
"cpu_offload"
,
[
True
,
False
])
def
test_activation_checkpointing
(
cpu_offload
):
...
...
tests/test_utils/test_colo_checkpoint.py
View file @
3b500984
...
...
@@ -215,6 +215,7 @@ def run_dist(rank, world_size, port, use_ddp, test_epoch, test_scheduler):
run_checkpoint
(
init_1d_row_for_linear_weight_spec
,
use_ddp
,
test_epoch
,
test_scheduler
,
pg
)
@
pytest
.
mark
.
skip
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
4
])
@
pytest
.
mark
.
parametrize
(
'use_ddp'
,
[
True
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment