Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
260a5580
Unverified
Commit
260a5580
authored
Jul 13, 2022
by
HELSON
Committed by
GitHub
Jul 13, 2022
Browse files
[hotfix] fix shape error in backward when using ColoTensor (#1298)
parent
f83c4d65
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
26 additions
and
56 deletions
+26
-56
colossalai/nn/_ops/addmm.py
colossalai/nn/_ops/addmm.py
+3
-3
colossalai/nn/_ops/linear.py
colossalai/nn/_ops/linear.py
+5
-5
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+7
-5
tests/test_tensor/test_model.py
tests/test_tensor/test_model.py
+11
-43
No files found.
colossalai/nn/_ops/addmm.py
View file @
260a5580
...
...
@@ -11,16 +11,16 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
# mat1:S[1] x mat2:S[0] = Output:P
# beta * input + alpha * All-Reduce(Output) = res
mat1
=
mat1
.
redistribute
(
ShardSpec
([
-
1
],
[
mat2
.
get_tp_world_size
()]))
mat1
=
mat1
.
redistribute
(
ShardSpec
([
-
1
],
[
mat2
.
get_tp_world_size
()])
,
mat2
.
get_process_group
()
)
# Output:P
partial_output
=
torch
.
mm
(
mat1
,
mat2
)
# Reduce(Output)
output
=
reduce_input
(
partial_output
,
mat
1
.
get_process_group
())
output
=
reduce_input
(
partial_output
,
mat
2
.
get_process_group
())
# input
assert
not
input_tensor
.
has_compute_spec
(),
'Invalid input spec for 1Drow addmm op'
output
=
beta
*
input_tensor
+
alpha
*
output
output
=
ColoTensor
.
from_torch_tensor
(
output
,
spec
=
ColoTensorSpec
(
ReplicaSpec
()))
output
=
ColoTensor
.
from_torch_tensor
(
output
,
spec
=
ColoTensorSpec
(
input_tensor
.
get_process_group
()))
return
output
...
...
colossalai/nn/_ops/linear.py
View file @
260a5580
...
...
@@ -3,15 +3,15 @@ from typing import Optional
from
._utils
import
GeneralTensor
,
convert_to_colo_tensor
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
._utils
import
reduce_input
,
reduce_grad
from
colossalai.tensor
import
ComputePattern
,
ComputePattern
,
ComputeSpec
,
ColoTensor
,
ShardSpec
,
ReplicaSpec
,
ColoTensorSpec
from
colossalai.tensor
import
ComputePattern
,
ComputeSpec
,
ColoTensor
,
ShardSpec
,
ReplicaSpec
,
ColoTensorSpec
def
colo_linear_1
D
row
(
input_tensor
:
ColoTensor
,
weight
:
ColoTensor
,
bias
:
Optional
[
ColoTensor
])
->
'ColoTensor'
:
def
colo_linear_1
d
row
(
input_tensor
:
ColoTensor
,
weight
:
ColoTensor
,
bias
:
Optional
[
ColoTensor
])
->
'ColoTensor'
:
# Input:S[1] x Weight:S[0] = Output:P
# All-Reduce(Output) + bias = res
# Input:S[1]
pg
=
weight
.
get_process_group
()
input_tensor
=
input_tensor
.
redistribute
(
ShardSpec
([
-
1
],
[
weight
.
get_tp_world_size
()]))
input_tensor
=
input_tensor
.
redistribute
(
ShardSpec
([
-
1
],
[
weight
.
get_tp_world_size
()])
,
pg
)
# Output:P
partial_output
=
F
.
linear
(
input_tensor
,
weight
)
...
...
@@ -27,7 +27,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
return
output
def
colo_linear_1
D
col
(
input_tensor
:
ColoTensor
,
weight
:
ColoTensor
,
bias
:
Optional
[
ColoTensor
])
->
'ColoTensor'
:
def
colo_linear_1
d
col
(
input_tensor
:
ColoTensor
,
weight
:
ColoTensor
,
bias
:
Optional
[
ColoTensor
])
->
'ColoTensor'
:
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
# All-Gather(Output)
# Input:B
...
...
@@ -48,7 +48,7 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
def
colo_linear_1d
(
mode
:
str
,
input_tensor
:
ColoTensor
,
weight
:
ColoTensor
,
bias
:
Optional
[
ColoTensor
])
->
'ColoTensor'
:
assert
mode
in
(
'row'
,
'col'
)
funcs
=
{
'row'
:
colo_linear_1
D
row
,
'col'
:
colo_linear_1
D
col
}
funcs
=
{
'row'
:
colo_linear_1
d
row
,
'col'
:
colo_linear_1
d
col
}
return
funcs
[
mode
](
input_tensor
,
weight
,
bias
)
...
...
colossalai/tensor/colo_tensor.py
View file @
260a5580
...
...
@@ -204,12 +204,14 @@ class ColoTensor(torch.Tensor):
ColoTensor: a redistributed colotensor
"""
if
pg
is
not
None
and
pg
!=
self
.
get_process_group
():
print
(
'here _redistribute'
)
# if the pg is not equal, convert the current tensor to replicated
self
.
_redistribute
(
ReplicaSpec
())
self
.
process_group
=
pg
ret
=
DistSpecManager
.
handle_trans_spec
(
self
,
self
.
dist_spec
,
dist_spec
,
self
.
process_group
)
return
ColoTensor
.
from_torch_tensor
(
ret
,
ColoTensorSpec
(
self
.
process_group
,
dist_attr
=
dist_spec
))
handled
=
self
.
redistribute
(
ReplicaSpec
())
else
:
handled
=
self
pg
=
self
.
process_group
ret
=
DistSpecManager
.
handle_trans_spec
(
handled
,
handled
.
dist_spec
,
dist_spec
,
pg
)
return
ColoTensor
.
from_torch_tensor
(
ret
,
ColoTensorSpec
(
pg
=
pg
,
dist_attr
=
dist_spec
))
def
to_replicate_
(
self
):
"""to_replicate_
...
...
tests/test_tensor/test_model.py
View file @
260a5580
...
...
@@ -11,42 +11,13 @@ from colossalai.testing import rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.tensor
import
ShardSpec
,
ColoTensorSpec
,
ComputePattern
,
\
ComputeSpec
,
ColoTensor
,
DistSpecManager
,
ProcessGroup
,
ReplicaSpec
from
colossalai.tensor
import
ColoTensor
,
ProcessGroup
from
colossalai.nn.optimizer
import
ColoOptimizer
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
_utils
import
split_param_row_tp1d
,
split_param_col_tp1d
def
init_1d_row_linear
(
weight
:
ColoTensor
,
pg
:
ProcessGroup
):
spec
=
(
ShardSpec
([
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
weight
.
set_process_group
(
pg
)
weight
.
set_tensor_spec
(
*
spec
)
def
init_1d_col_linear
(
weight
,
pg
):
spec
=
(
ShardSpec
([
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
weight
.
set_process_group
(
pg
)
weight
.
set_tensor_spec
(
*
spec
)
def
init_1d_row_embedding
(
weight
,
pg
):
spec
=
(
ShardSpec
([
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
weight
.
set_process_group
(
pg
)
weight
.
set_tensor_spec
(
*
spec
)
def
init_1d_col_embedding
(
weight
,
pg
):
spec
=
(
ShardSpec
([
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
weight
.
set_process_group
(
pg
)
weight
.
set_tensor_spec
(
*
spec
)
def
run_1d_hybrid_tp
(
model_name
):
# A simple net with two stacked nn.Linear
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
...
...
@@ -79,19 +50,16 @@ def run_1d_hybrid_tp(model_name):
# num_class = type_vocab_size = 2 | (8, 2)
if
'classifier'
in
name
and
'weight'
in
name
:
init_1d_row_linear
(
p
,
pg
)
split_param_col_tp1d
(
p
,
pg
)
# num_class = vocab_size = 30524 | (30524, 8)
elif
'word_embeddings'
in
name
and
'weight'
in
name
:
init_1d_row_embedding
(
p
,
pg
)
split_param_row_tp1d
(
p
,
pg
)
# num_class = seq_len = 512 | (512, 8)
elif
'position_embeddings'
in
name
and
'weight'
in
name
:
init_1d_row_embedding
(
p
,
pg
)
split_param_row_tp1d
(
p
,
pg
)
# num_class = type_vocab_size = 2 | (2, 8)
elif
'token_type_embeddings'
in
name
and
'weight'
in
name
:
init_1d_col_embedding
(
p
,
pg
)
elif
p
.
process_group
.
tp_world_size
()
==
1
:
with
DistSpecManager
.
no_grad
():
p
.
redistribute
(
ReplicaSpec
(),
pg
)
split_param_col_tp1d
(
p
,
pg
)
elif
"simple_net"
==
model_name
:
# A naive way to set spec for all weights in Linear
...
...
@@ -99,13 +67,13 @@ def run_1d_hybrid_tp(model_name):
if
not
isinstance
(
p
,
ColoTensor
):
continue
if
'embed'
in
name
and
'weight'
in
name
:
init_1d_col_embedding
(
p
,
pg
)
split_param_col_tp1d
(
p
,
pg
)
if
'proj1'
in
name
and
(
'weight'
in
name
or
'bias'
in
name
):
init_1d_col_linear
(
p
,
pg
)
split_param_row_tp1d
(
p
,
pg
)
if
'proj2'
in
name
and
'weight'
in
name
:
init_1d_row_linear
(
p
,
pg
)
split_param_col_tp1d
(
p
,
pg
)
if
'classifier'
in
name
and
(
'weight'
in
name
or
'bias'
in
name
):
init_1d_col_linear
(
p
,
pg
)
split_param_row_tp1d
(
p
,
pg
)
model
=
model
.
cuda
()
model
.
train
()
...
...
@@ -327,9 +295,9 @@ def _run_pretrain_load():
def
run_model_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
for
name
in
[
'bert'
]:
for
name
in
[
'bert'
,
'simple_net'
]:
run_1d_row_tp
(
name
)
for
name
in
[
'bert'
]:
for
name
in
[
'bert'
,
'simple_net'
]:
run_1d_hybrid_tp
(
name
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment