Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
797a9dc5
Unverified
Commit
797a9dc5
authored
May 13, 2022
by
Ziyue Jiang
Committed by
GitHub
May 13, 2022
Browse files
add DistSpec for loss and test_model (#947)
parent
67c33f57
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
64 additions
and
94 deletions
+64
-94
colossalai/tensor/_ops/__init__.py
colossalai/tensor/_ops/__init__.py
+1
-1
colossalai/tensor/_ops/layernorm.py
colossalai/tensor/_ops/layernorm.py
+1
-1
colossalai/tensor/_ops/loss.py
colossalai/tensor/_ops/loss.py
+4
-5
colossalai/tensor/dist_spec_mgr.py
colossalai/tensor/dist_spec_mgr.py
+4
-2
colossalai/tensor/spec.py
colossalai/tensor/spec.py
+14
-1
tests/test_tensor/test_model.py
tests/test_tensor/test_model.py
+40
-84
No files found.
colossalai/tensor/_ops/__init__.py
View file @
797a9dc5
from
.linear
import
colo_linear
from
.linear
import
colo_linear
from
.element_wise
import
*
from
.element_wise
import
*
from
.layernorm
import
colo_layernorm
from
.layernorm
import
colo_layernorm
#
from .loss import colo_cross_entropy
from
.loss
import
colo_cross_entropy
from
.embedding
import
colo_embedding
from
.embedding
import
colo_embedding
from
.addmm
import
colo_addmm
from
.addmm
import
colo_addmm
colossalai/tensor/_ops/layernorm.py
View file @
797a9dc5
...
@@ -28,7 +28,7 @@ def colo_layernorm(types, args=(), kwargs=None, pg=None):
...
@@ -28,7 +28,7 @@ def colo_layernorm(types, args=(), kwargs=None, pg=None):
if
isinstance
(
input_tensor
,
ColoTensor
):
if
isinstance
(
input_tensor
,
ColoTensor
):
# TODO (ver217): check input dist spec
# TODO (ver217): check input dist spec
input_tensor
.
to_dist_spec
(
dist_spec
.
replicate
())
input_tensor
.
to_dist_spec
(
dist_spec
.
replicate
(
input_tensor
.
spec
.
get_process_group
()
))
input_tensor
=
input_tensor
.
torch_tensor
()
input_tensor
=
input_tensor
.
torch_tensor
()
if
isinstance
(
weight
,
ColoTensor
):
if
isinstance
(
weight
,
ColoTensor
):
weight
=
weight
.
torch_tensor
()
weight
=
weight
.
torch_tensor
()
...
...
colossalai/tensor/_ops/loss.py
View file @
797a9dc5
from
colossalai.tensor.spec
import
Shard
Pattern
from
colossalai.tensor.
dist_
spec
import
DistPlacement
Pattern
import
torch
import
torch
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.tensor
import
ColoTensor
from
colossalai.tensor
import
ColoTensor
...
@@ -27,12 +27,11 @@ def colo_cross_entropy(types, args=(), kwargs=None, pg=None):
...
@@ -27,12 +27,11 @@ def colo_cross_entropy(types, args=(), kwargs=None, pg=None):
if
isinstance
(
target
,
ColoTensor
):
if
isinstance
(
target
,
ColoTensor
):
target
=
target
.
torch_tensor
()
target
=
target
.
torch_tensor
()
if
input_tensor
.
is_gathered
():
# Input is gathered
if
input_tensor
.
spec
.
is_gathered
():
# Input is gathered
# TODO(jzy) Shall we make the result of loss function a ColoTensor?
return
ColoTensor
.
init_from_torch_tensor
(
torch
.
nn
.
functional
.
cross_entropy
(
return
ColoTensor
.
init_from_torch_tensor
(
torch
.
nn
.
functional
.
cross_entropy
(
input_tensor
.
torch_tensor
(),
target
,
weight
))
input_tensor
.
torch_tensor
(),
target
,
weight
))
elif
input_tensor
.
has_spec
()
and
input_tensor
.
shard_
spec
.
num_action
==
1
:
# Single Model Parallel Applied
elif
input_tensor
.
has_spec
()
and
input_tensor
.
spec
.
num_action
==
1
:
# Single Model Parallel Applied
if
input_tensor
.
s
hard_pattern
==
ShardPattern
.
Col
:
if
input_tensor
.
s
pec
.
is_1Dcol
()
:
return
ColoTensor
.
init_from_torch_tensor
(
return
ColoTensor
.
init_from_torch_tensor
(
VocabParallelCrossEntropyLoss1D
()(
input_tensor
.
torch_tensor
(),
target
))
VocabParallelCrossEntropyLoss1D
()(
input_tensor
.
torch_tensor
(),
target
))
else
:
else
:
...
...
colossalai/tensor/dist_spec_mgr.py
View file @
797a9dc5
...
@@ -53,7 +53,8 @@ class DistSpecManager:
...
@@ -53,7 +53,8 @@ class DistSpecManager:
@
staticmethod
@
staticmethod
def
_r2r
(
tensor
:
torch
.
Tensor
,
old_dist_spec
:
_DistSpec
,
dist_spec
:
_DistSpec
)
->
torch
.
Tensor
:
def
_r2r
(
tensor
:
torch
.
Tensor
,
old_dist_spec
:
_DistSpec
,
dist_spec
:
_DistSpec
)
->
torch
.
Tensor
:
if
old_dist_spec
.
process_group
is
not
None
and
old_dist_spec
.
process_group
!=
dist_spec
.
process_group
:
if
old_dist_spec
.
process_group
is
not
None
and
old_dist_spec
.
process_group
!=
dist_spec
.
process_group
\
and
dist_spec
.
process_group
is
not
None
:
raise
NotImplementedError
raise
NotImplementedError
return
tensor
return
tensor
...
@@ -65,7 +66,8 @@ class DistSpecManager:
...
@@ -65,7 +66,8 @@ class DistSpecManager:
@
staticmethod
@
staticmethod
def
_s2r
(
tensor
:
torch
.
Tensor
,
old_dist_spec
:
_DistSpec
,
dist_spec
:
_DistSpec
)
->
torch
.
Tensor
:
def
_s2r
(
tensor
:
torch
.
Tensor
,
old_dist_spec
:
_DistSpec
,
dist_spec
:
_DistSpec
)
->
torch
.
Tensor
:
if
old_dist_spec
.
process_group
!=
dist_spec
.
process_group
:
if
old_dist_spec
.
process_group
!=
dist_spec
.
process_group
\
and
dist_spec
.
process_group
is
not
None
:
raise
NotImplementedError
raise
NotImplementedError
return
DistSpecManager
.
_gather
(
tensor
,
old_dist_spec
)
return
DistSpecManager
.
_gather
(
tensor
,
old_dist_spec
)
...
...
colossalai/tensor/spec.py
View file @
797a9dc5
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
List
from
typing
import
List
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.tensor.dist_spec
import
_DistSpec
from
colossalai.tensor.dist_spec
import
_DistSpec
,
DistPlacementPattern
class
ComputePattern
(
Enum
):
class
ComputePattern
(
Enum
):
...
@@ -84,3 +84,16 @@ class TensorSpec(object):
...
@@ -84,3 +84,16 @@ class TensorSpec(object):
def
get_process_group
(
self
):
def
get_process_group
(
self
):
return
self
.
dist_spec
.
process_group
return
self
.
dist_spec
.
process_group
def
get_placement
(
self
):
return
self
.
dist_spec
.
placement
def
is_gathered
(
self
):
return
self
.
dist_spec
.
placement
==
DistPlacementPattern
.
REPLICATE
\
or
(
len
(
self
.
dist_spec
.
num_partitions
)
==
1
and
self
.
dist_spec
.
num_partitions
[
0
]
==
1
)
\
or
(
self
.
dist_spec
.
process_group
.
size
()
==
1
)
def
is_1Dcol
(
self
):
return
self
.
dist_spec
.
placement
==
DistPlacementPattern
.
SHARD
\
and
len
(
self
.
dist_spec
.
dims
)
==
1
and
self
.
dist_spec
.
dims
[
0
]
==
-
1
\ No newline at end of file
tests/test_tensor/test_model.py
View file @
797a9dc5
...
@@ -9,7 +9,8 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use
...
@@ -9,7 +9,8 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.utils
import
ColoInitContext
from
colossalai.utils
import
ColoInitContext
from
colossalai.tensor
import
named_params_with_colotensor
,
TensorSpec
,
ComputePattern
,
ParallelAction
,
ColoTensor
,
ColoOptimizer
from
colossalai.tensor
import
named_params_with_colotensor
,
TensorSpec
,
ComputePattern
,
\
ParallelAction
,
ColoTensor
,
ColoOptimizer
,
dist_spec
,
DistSpecManager
from
colossalai.context
import
ParallelMode
from
colossalai.context
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
...
@@ -85,6 +86,34 @@ def set_seed(seed):
...
@@ -85,6 +86,34 @@ def set_seed(seed):
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
deterministic
=
True
def
init_1d_row_linear
(
weight
):
spec
=
TensorSpec
(
dist_spec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
-
1
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DRow
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)])
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
def
init_1d_col_linear
(
weight
,
gather_out
=
True
):
spec
=
TensorSpec
(
dist_spec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DCol
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
,
\
gather_out
=
gather_out
)])
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
def
init_1d_row_embedding
(
weight
):
spec
=
TensorSpec
(
dist_spec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DRow
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)])
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
def
init_1d_col_embedding
(
weight
):
spec
=
TensorSpec
(
dist_spec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
-
1
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DCol
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)])
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
def
run_1d_hybrid_tp
(
model_name
):
def
run_1d_hybrid_tp
(
model_name
):
# A simple net with two stacked nn.Linear
# A simple net with two stacked nn.Linear
...
@@ -106,84 +135,35 @@ def run_1d_hybrid_tp(model_name):
...
@@ -106,84 +135,35 @@ def run_1d_hybrid_tp(model_name):
p2
.
data
.
copy_
(
p1
.
data
)
p2
.
data
.
copy_
(
p1
.
data
)
if
'bert'
==
model_name
:
if
'bert'
==
model_name
:
parallel_action_list_row
=
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DRow_Linear
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)
]
spec_linear_row
=
TensorSpec
(
parallel_action_list_row
)
parallel_action_list_embedding_col
=
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DCol_Embedding
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)
]
spec_embedding_col
=
TensorSpec
(
parallel_action_list_embedding_col
)
parallel_action_list_embedding_row
=
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DRow_Embedding
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)
]
spec_embedding_row
=
TensorSpec
(
parallel_action_list_embedding_row
)
for
name
,
p
in
model
.
colo_named_parameters
():
for
name
,
p
in
model
.
colo_named_parameters
():
if
not
isinstance
(
p
,
ColoTensor
):
if
not
isinstance
(
p
,
ColoTensor
):
continue
continue
# print(name)
# print(name)
# num_class = type_vocab_size = 2 | (8, 2)
# num_class = type_vocab_size = 2 | (8, 2)
if
'classifier'
in
name
and
'weight'
in
name
:
if
'classifier'
in
name
and
'weight'
in
name
:
p
.
set_spec
(
spec
_linear
_row
)
init_1d_row
_linear
(
p
)
# num_class = vocab_size = 30524 | (30524, 8)
# num_class = vocab_size = 30524 | (30524, 8)
if
'word_embeddings'
in
name
and
'weight'
in
name
:
if
'word_embeddings'
in
name
and
'weight'
in
name
:
p
.
set_spec
(
spec
_embedding
_row
)
init_1d_row
_embedding
(
p
)
# num_class = seq_len = 512 | (512, 8)
# num_class = seq_len = 512 | (512, 8)
if
'position_embeddings'
in
name
and
'weight'
in
name
:
if
'position_embeddings'
in
name
and
'weight'
in
name
:
p
.
set_spec
(
spec
_embedding
_row
)
init_1d_row
_embedding
(
p
)
# num_class = type_vocab_size = 2 | (2, 8)
# num_class = type_vocab_size = 2 | (2, 8)
if
'token_type_embeddings'
in
name
and
'weight'
in
name
:
if
'token_type_embeddings'
in
name
and
'weight'
in
name
:
p
.
set_spec
(
spec
_embedding
_col
)
init_1d_col
_embedding
(
p
)
elif
"simple_net"
==
model_name
:
elif
"simple_net"
==
model_name
:
parallel_action_list_row
=
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DRow_Linear
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)
]
spec_row
=
TensorSpec
(
parallel_action_list_row
)
parallel_action_list_col
=
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DCol_Linear
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
),
]
spec_col
=
TensorSpec
(
parallel_action_list_col
)
parallel_action_list_classifier_col
=
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DCol_Linear
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
,
gather_out
=
False
),
]
spec_classifier_col
=
TensorSpec
(
parallel_action_list_classifier_col
)
parallel_action_list_embedding_col
=
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DCol_Embedding
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)
]
spec_embedding_col
=
TensorSpec
(
parallel_action_list_embedding_col
)
# A naive way to set spec for all weights in Linear
# A naive way to set spec for all weights in Linear
for
name
,
p
in
model
.
colo_named_parameters
():
for
name
,
p
in
model
.
colo_named_parameters
():
if
not
isinstance
(
p
,
ColoTensor
):
if
not
isinstance
(
p
,
ColoTensor
):
continue
continue
if
'embed'
in
name
and
'weight'
in
name
:
if
'embed'
in
name
and
'weight'
in
name
:
p
.
set_spec
(
spec
_embedding
_col
)
init_1d_col
_embedding
(
p
)
if
'proj1'
in
name
and
(
'weight'
in
name
or
'bias'
in
name
):
if
'proj1'
in
name
and
(
'weight'
in
name
or
'bias'
in
name
):
p
.
set_spec
(
spec_col
)
init_1d_col_linear
(
p
)
if
'proj2'
in
name
and
'weight'
in
name
:
if
'proj2'
in
name
and
'weight'
in
name
:
p
.
set_spec
(
spec_row
)
init_1d_row_linear
(
p
)
if
'classifier'
in
name
and
(
'weight'
in
name
or
'bias'
in
name
):
if
'classifier'
in
name
and
(
'weight'
in
name
or
'bias'
in
name
):
p
.
set_spec
(
spec_classifier_col
)
init_1d_col_linear
(
p
,
gather_out
=
False
)
model
=
model
.
cuda
()
model
=
model
.
cuda
()
colo_optimizer
=
ColoOptimizer
(
dict
(
model
.
named_parameters
()),
torch
.
optim
.
SGD
,
lr
=
0.1
)
colo_optimizer
=
ColoOptimizer
(
dict
(
model
.
named_parameters
()),
torch
.
optim
.
SGD
,
lr
=
0.1
)
...
@@ -251,8 +231,6 @@ def run_1d_hybrid_tp(model_name):
...
@@ -251,8 +231,6 @@ def run_1d_hybrid_tp(model_name):
break
break
# FIXME (ver217): enable this test
@
pytest
.
mark
.
skip
# Test the overrided parameters() and named_parameters() member functions
# Test the overrided parameters() and named_parameters() member functions
def
test_model_parameters
():
def
test_model_parameters
():
# build a module with 2 Linear, 4 parameters in total.
# build a module with 2 Linear, 4 parameters in total.
...
@@ -285,8 +263,6 @@ def test_model_parameters():
...
@@ -285,8 +263,6 @@ def test_model_parameters():
assert
param_cnt
==
2
assert
param_cnt
==
2
# FIXME (ver217): enable this test
@
pytest
.
mark
.
skip
def
test_colo_optimizer
():
def
test_colo_optimizer
():
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'simple_net'
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'simple_net'
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
...
@@ -329,29 +305,14 @@ def run_1d_row_tp(model_name: str):
...
@@ -329,29 +305,14 @@ def run_1d_row_tp(model_name: str):
if
rank
==
0
:
if
rank
==
0
:
model_torch
=
model_builder
(
checkpoint
=
True
)
model_torch
=
model_builder
(
checkpoint
=
True
)
model_torch
=
model_torch
.
cuda
()
model_torch
=
model_torch
.
cuda
()
parallel_action_list
=
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DRow_Linear
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)
]
spec
=
TensorSpec
(
parallel_action_list
)
parallel_action_list_embedding_row
=
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DRow_Embedding
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)
]
spec_embedding_row
=
TensorSpec
(
parallel_action_list_embedding_row
)
# A naive way to set spec for all weights in Linear
# A naive way to set spec for all weights in Linear
for
name
,
p
in
model
.
colo_named_parameters
():
for
name
,
p
in
model
.
colo_named_parameters
():
if
not
isinstance
(
p
,
ColoTensor
):
if
not
isinstance
(
p
,
ColoTensor
):
continue
continue
if
'weight'
in
name
and
'LayerNorm'
not
in
name
and
'ln'
not
in
name
and
'embed'
not
in
name
:
if
'weight'
in
name
and
'LayerNorm'
not
in
name
and
'ln'
not
in
name
and
'embed'
not
in
name
:
p
.
set_spec
(
spec
)
init_1d_row_linear
(
p
)
if
'embed'
in
name
and
'weight'
in
name
:
if
'embed'
in
name
and
'weight'
in
name
:
p
.
set_spec
(
spec
_embedding
_row
)
init_1d_row
_embedding
(
p
)
model
=
model
.
cuda
()
model
=
model
.
cuda
()
...
@@ -434,9 +395,6 @@ def run_model_dist(rank, world_size, port):
...
@@ -434,9 +395,6 @@ def run_model_dist(rank, world_size, port):
for
name
in
[
'bert'
,
'simple_net'
]:
for
name
in
[
'bert'
,
'simple_net'
]:
run_1d_hybrid_tp
(
name
)
run_1d_hybrid_tp
(
name
)
# FIXME (ver217): enable this test
@
pytest
.
mark
.
skip
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
# @parameterize('world_size', [1, 4])
# @parameterize('world_size', [1, 4])
...
@@ -454,8 +412,6 @@ def run_pretrain_load_dist(rank, world_size, port):
...
@@ -454,8 +412,6 @@ def run_pretrain_load_dist(rank, world_size, port):
# The test case has to download huggingface pretrained models from the internet
# The test case has to download huggingface pretrained models from the internet
# So we manually trigger the test.
# So we manually trigger the test.
# FIXME (ver217): enable this test
@
pytest
.
mark
.
skip
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment