Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
177c3744
Unverified
Commit
177c3744
authored
Jun 23, 2022
by
Jiarui Fang
Committed by
GitHub
Jun 23, 2022
Browse files
remove gather out in parallel action (#1163)
parent
51f1ec96
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
43 additions
and
32 deletions
+43
-32
colossalai/nn/_ops/addmm.py
colossalai/nn/_ops/addmm.py
+4
-13
colossalai/nn/_ops/embedding.py
colossalai/nn/_ops/embedding.py
+1
-3
colossalai/nn/_ops/embedding_bag.py
colossalai/nn/_ops/embedding_bag.py
+2
-3
colossalai/nn/_ops/linear.py
colossalai/nn/_ops/linear.py
+1
-4
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+29
-3
colossalai/tensor/spec.py
colossalai/tensor/spec.py
+2
-3
tests/test_tensor/test_linear_tp.py
tests/test_tensor/test_linear_tp.py
+1
-0
tests/test_tensor/test_model.py
tests/test_tensor/test_model.py
+3
-3
No files found.
colossalai/nn/_ops/addmm.py
View file @
177c3744
...
@@ -37,10 +37,10 @@ def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
...
@@ -37,10 +37,10 @@ def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
output_spec
=
TensorSpec
(
distspec
.
shard
(
mat2
.
spec
.
get_process_group
(),
[
-
1
],
[
mat2
.
spec
.
get_process_group_size
()]),
output_spec
=
TensorSpec
(
distspec
.
shard
(
mat2
.
spec
.
get_process_group
(),
[
-
1
],
[
mat2
.
spec
.
get_process_group_size
()]),
ParallelAction
(
ComputePattern
.
TP1D
))
ParallelAction
(
ComputePattern
.
TP1D
))
output
=
ColoTensor
.
from_torch_tensor
(
output_parallel
,
spec
=
output_spec
)
output
=
ColoTensor
.
from_torch_tensor
(
output_parallel
,
spec
=
output_spec
)
if
parallel_action
.
gather_out
:
# All-Gather(Output)
# TODO(jiaruifang) addam is special case
output
=
output
.
convert_to_dist_spec
(
distspec
.
replicate
(
mat2
.
spec
.
get_process_group
()))
# since gpt call view after the Op.
return
output
return
output
.
to_replicate
()
def
colo_addmm_1d
(
mode
:
str
,
input_tensor
:
ColoTensor
,
mat1
:
ColoTensor
,
mat2
:
ColoTensor
,
beta
:
Number
,
def
colo_addmm_1d
(
mode
:
str
,
input_tensor
:
ColoTensor
,
mat1
:
ColoTensor
,
mat2
:
ColoTensor
,
beta
:
Number
,
...
@@ -62,11 +62,6 @@ def colo_addmm(input_tensor: GeneralTensor,
...
@@ -62,11 +62,6 @@ def colo_addmm(input_tensor: GeneralTensor,
"""
"""
input_tensor
,
mat1
,
mat2
=
tuple
(
map
(
convert_to_colo_tensor
,
(
input_tensor
,
mat1
,
mat2
)))
input_tensor
,
mat1
,
mat2
=
tuple
(
map
(
convert_to_colo_tensor
,
(
input_tensor
,
mat1
,
mat2
)))
# building the computing graph, inputs -> op
# if GraphGlobalEnv().graph_building:
# cur_op_node = GraphOpNode('linear', [weight, bias])
# cur_op_node.add_prev_tensor(input_tensor)
# Add communication logic before and after linear call.
# Add communication logic before and after linear call.
ret_tensor
=
None
ret_tensor
=
None
if
not
mat2
.
has_spec
():
# No Model Parallel Applied
if
not
mat2
.
has_spec
():
# No Model Parallel Applied
...
@@ -84,8 +79,4 @@ def colo_addmm(input_tensor: GeneralTensor,
...
@@ -84,8 +79,4 @@ def colo_addmm(input_tensor: GeneralTensor,
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
# building the computing graph, op -> output
# if GraphGlobalEnv().graph_building:
# cur_op_node.add_post_tensor(ret_tensor)
return
ret_tensor
return
ret_tensor
colossalai/nn/_ops/embedding.py
View file @
177c3744
...
@@ -30,9 +30,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
...
@@ -30,9 +30,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
distspec
.
shard
(
weight
.
spec
.
get_process_group
(),
[
-
1
],
[
weight
.
spec
.
get_process_group_size
()]),
distspec
.
shard
(
weight
.
spec
.
get_process_group
(),
[
-
1
],
[
weight
.
spec
.
get_process_group_size
()]),
ParallelAction
(
ComputePattern
.
TP1D
))
ParallelAction
(
ComputePattern
.
TP1D
))
output
=
ColoTensor
.
from_torch_tensor
(
output_parallel
,
spec
=
output_spec
)
output
=
ColoTensor
.
from_torch_tensor
(
output_parallel
,
spec
=
output_spec
)
if
weight
.
spec
.
parallel_action
.
gather_out
:
return
output
.
to_replicate
()
output
=
output
.
convert_to_dist_spec
(
distspec
.
replicate
(
weight
.
spec
.
get_process_group
()))
return
output
def
colo_embedding_1Drow
(
input_tensor
:
ColoTensor
,
def
colo_embedding_1Drow
(
input_tensor
:
ColoTensor
,
...
...
colossalai/nn/_ops/embedding_bag.py
View file @
177c3744
...
@@ -36,9 +36,8 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
...
@@ -36,9 +36,8 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
distspec
.
shard
(
weight
.
spec
.
get_process_group
(),
[
-
1
],
[
weight
.
spec
.
get_process_group_size
()]),
distspec
.
shard
(
weight
.
spec
.
get_process_group
(),
[
-
1
],
[
weight
.
spec
.
get_process_group_size
()]),
ParallelAction
(
ComputePattern
.
TP1D
))
ParallelAction
(
ComputePattern
.
TP1D
))
output
=
ColoTensor
.
from_torch_tensor
(
output_parallel
,
spec
=
output_spec
)
output
=
ColoTensor
.
from_torch_tensor
(
output_parallel
,
spec
=
output_spec
)
if
weight
.
spec
.
parallel_action
.
gather_out
:
output
=
output
.
convert_to_dist_spec
(
distspec
.
replicate
(
weight
.
spec
.
get_process_group
()))
return
output
.
to_replicate
()
return
output
def
colo_embedding_bag_1d
(
tp_mode
:
str
,
def
colo_embedding_bag_1d
(
tp_mode
:
str
,
...
...
colossalai/nn/_ops/linear.py
View file @
177c3744
...
@@ -42,10 +42,7 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
...
@@ -42,10 +42,7 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
distspec
.
shard
(
weight
.
spec
.
get_process_group
(),
[
-
1
],
distspec
.
shard
(
weight
.
spec
.
get_process_group
(),
[
-
1
],
[
weight
.
spec
.
get_process_group_size
()]),
[
weight
.
spec
.
get_process_group_size
()]),
ParallelAction
(
ComputePattern
.
TP1D
)))
ParallelAction
(
ComputePattern
.
TP1D
)))
if
parallel_action
.
gather_out
:
return
output
.
to_replicate
()
# All-Gather(Output)
output
=
output
.
convert_to_dist_spec
(
distspec
.
replicate
(
weight
.
spec
.
get_process_group
()))
return
output
def
colo_linear_1d
(
mode
:
str
,
input_tensor
:
ColoTensor
,
weight
:
ColoTensor
,
bias
:
Optional
[
ColoTensor
])
->
'ColoTensor'
:
def
colo_linear_1d
(
mode
:
str
,
input_tensor
:
ColoTensor
,
weight
:
ColoTensor
,
bias
:
Optional
[
ColoTensor
])
->
'ColoTensor'
:
...
...
colossalai/tensor/colo_tensor.py
View file @
177c3744
...
@@ -92,10 +92,13 @@ class ColoTensor(torch.Tensor):
...
@@ -92,10 +92,13 @@ class ColoTensor(torch.Tensor):
def
__repr__
(
self
):
def
__repr__
(
self
):
return
f
'ColoTensor:
{
super
().
__repr__
()
}
'
return
f
'ColoTensor:
{
super
().
__repr__
()
}
'
def
is_model_data
(
self
)
->
bool
:
return
self
.
_type
==
TensorType
.
MODEL
def
_convert_to_dist_spec
(
self
,
dist_spec
:
_DistSpec
)
->
None
:
def
_convert_to_dist_spec
(
self
,
dist_spec
:
_DistSpec
)
->
None
:
"""_convert_to_dist_spec
Note the function will not handle the logic of backward propagation!
It is used during model tensor initializations as an internal function.
Args:
dist_spec (_DistSpec): the target dist. spec.
"""
with
DistSpecManager
.
no_grad
():
with
DistSpecManager
.
no_grad
():
self
.
data
=
DistSpecManager
.
handle_trans_spec
(
self
,
self
.
spec
.
dist_spec
,
dist_spec
)
self
.
data
=
DistSpecManager
.
handle_trans_spec
(
self
,
self
.
spec
.
dist_spec
,
dist_spec
)
self
.
_tensor_spec
.
dist_spec
=
dist_spec
self
.
_tensor_spec
.
dist_spec
=
dist_spec
...
@@ -106,6 +109,19 @@ class ColoTensor(torch.Tensor):
...
@@ -106,6 +109,19 @@ class ColoTensor(torch.Tensor):
ret
=
DistSpecManager
.
handle_trans_spec
(
self
,
self
.
spec
.
dist_spec
,
dist_spec
)
ret
=
DistSpecManager
.
handle_trans_spec
(
self
,
self
.
spec
.
dist_spec
,
dist_spec
)
return
ColoTensor
.
from_torch_tensor
(
ret
,
tensor_spec
)
return
ColoTensor
.
from_torch_tensor
(
ret
,
tensor_spec
)
def
to_replicate_
(
self
):
"""to_replicate_
an inline member function, converting dist spec of the tensor to REPLICATE
"""
self
.
data
=
DistSpecManager
.
handle_trans_spec
(
self
,
self
.
spec
.
dist_spec
,
distspec
.
replicate
())
self
.
_tensor_spec
.
dist_spec
=
distspec
.
replicate
()
def
to_replicate
(
self
)
->
'ColoTensor'
:
"""to_replicate
converting dist spec of the tensor to REPLICATE
"""
return
self
.
convert_to_dist_spec
(
distspec
.
replicate
(
self
.
spec
.
get_process_group
()))
@
staticmethod
@
staticmethod
def
from_torch_tensor
(
tensor
:
torch
.
Tensor
,
spec
:
TensorSpec
=
TensorSpec
(
distspec
.
replicate
()))
->
'ColoTensor'
:
def
from_torch_tensor
(
tensor
:
torch
.
Tensor
,
spec
:
TensorSpec
=
TensorSpec
(
distspec
.
replicate
()))
->
'ColoTensor'
:
tensor
=
tensor
.
as_subclass
(
ColoTensor
)
tensor
=
tensor
.
as_subclass
(
ColoTensor
)
...
@@ -121,3 +137,13 @@ class ColoTensor(torch.Tensor):
...
@@ -121,3 +137,13 @@ class ColoTensor(torch.Tensor):
tensor
=
ColoTensor
(
data
,
spec
=
copy
(
self
.
spec
))
tensor
=
ColoTensor
(
data
,
spec
=
copy
(
self
.
spec
))
memo
[
id
(
self
)]
=
tensor
memo
[
id
(
self
)]
=
tensor
return
tensor
return
tensor
# TODO(jiaruifang) a patch for gpt test.
# We need to override the member function must operate on a replicated tensor
# def view(self, *args, **kwargs):
# self.data = DistSpecManager.handle_trans_spec(self,
# self.spec.dist_spec,
# distspec.replicate(self.spec.get_process_group()))
# # self._tensor_spec.dist_spec = distspec.replicate(self.spec.get_process_group())
# self.data.view(*args, **kwargs)
# return ColoTensor.from_torch_tensor(self.data)
colossalai/tensor/spec.py
View file @
177c3744
...
@@ -13,13 +13,12 @@ class ComputePattern(Enum):
...
@@ -13,13 +13,12 @@ class ComputePattern(Enum):
class
ParallelAction
(
object
):
class
ParallelAction
(
object
):
def
__init__
(
self
,
compute_pattern
:
ComputePattern
,
gather_out
:
bool
=
True
)
->
None
:
def
__init__
(
self
,
compute_pattern
:
ComputePattern
)
->
None
:
assert
isinstance
(
compute_pattern
,
ComputePattern
)
assert
isinstance
(
compute_pattern
,
ComputePattern
)
self
.
compute_pattern
=
compute_pattern
self
.
compute_pattern
=
compute_pattern
self
.
gather_out
=
gather_out
def
__repr__
(
self
):
def
__repr__
(
self
):
return
f
'compute pattern:
{
self
.
compute_pattern
}
, gather out:
{
self
.
gather_out
}
'
return
f
'compute pattern:
{
self
.
compute_pattern
}
'
class
TensorSpec
(
object
):
class
TensorSpec
(
object
):
...
...
tests/test_tensor/test_linear_tp.py
View file @
177c3744
...
@@ -41,6 +41,7 @@ def run_with_spec(spec_init_func):
...
@@ -41,6 +41,7 @@ def run_with_spec(spec_init_func):
x
=
torch
.
rand
(
2
,
4
).
cuda
()
x
=
torch
.
rand
(
2
,
4
).
cuda
()
out
=
model
(
x
)
out
=
model
(
x
)
colo_out
=
F
.
linear
(
x
,
weight
,
bias
)
colo_out
=
F
.
linear
(
x
,
weight
,
bias
)
colo_out
=
colo_out
.
to_replicate
()
assert
tensor_equal
(
out
,
colo_out
)
assert
tensor_equal
(
out
,
colo_out
)
grad
=
torch
.
rand_like
(
out
)
grad
=
torch
.
rand_like
(
out
)
out
.
backward
(
grad
)
out
.
backward
(
grad
)
...
...
tests/test_tensor/test_model.py
View file @
177c3744
...
@@ -26,10 +26,10 @@ def init_1d_row_linear(weight):
...
@@ -26,10 +26,10 @@ def init_1d_row_linear(weight):
weight
.
set_spec
(
spec
)
weight
.
set_spec
(
spec
)
def
init_1d_col_linear
(
weight
,
gather_out
=
True
):
def
init_1d_col_linear
(
weight
):
spec
=
TensorSpec
(
spec
=
TensorSpec
(
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ParallelAction
(
ComputePattern
.
TP1D
,
gather_out
=
gather_out
))
ParallelAction
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
weight
.
set_spec
(
spec
)
...
@@ -98,7 +98,7 @@ def run_1d_hybrid_tp(model_name):
...
@@ -98,7 +98,7 @@ def run_1d_hybrid_tp(model_name):
if
'proj2'
in
name
and
'weight'
in
name
:
if
'proj2'
in
name
and
'weight'
in
name
:
init_1d_row_linear
(
p
)
init_1d_row_linear
(
p
)
if
'classifier'
in
name
and
(
'weight'
in
name
or
'bias'
in
name
):
if
'classifier'
in
name
and
(
'weight'
in
name
or
'bias'
in
name
):
init_1d_col_linear
(
p
,
gather_out
=
False
)
init_1d_col_linear
(
p
)
model
=
model
.
cuda
()
model
=
model
.
cuda
()
colo_optimizer
=
ColoOptimizer
(
dict
(
model
.
named_parameters
()),
torch
.
optim
.
SGD
,
lr
=
0.1
)
colo_optimizer
=
ColoOptimizer
(
dict
(
model
.
named_parameters
()),
torch
.
optim
.
SGD
,
lr
=
0.1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment