Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
845856ea
Unverified
Commit
845856ea
authored
May 07, 2022
by
Jiarui Fang
Committed by
GitHub
May 07, 2022
Browse files
[Graph] building computing graph with ColoTensor, Linear only (#917)
parent
75d22191
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
202 additions
and
9 deletions
+202
-9
colossalai/tensor/_ops/linear.py
colossalai/tensor/_ops/linear.py
+18
-7
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+1
-0
colossalai/tensor/graph/__init__.py
colossalai/tensor/graph/__init__.py
+3
-0
colossalai/tensor/graph/graph_node.py
colossalai/tensor/graph/graph_node.py
+97
-0
tests/test_tensor/test_graph.py
tests/test_tensor/test_graph.py
+81
-0
tests/test_tensor/test_model.py
tests/test_tensor/test_model.py
+2
-2
No files found.
colossalai/tensor/_ops/linear.py
View file @
845856ea
import
torch
import
torch
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.context
import
ParallelMode
from
colossalai.nn.layer.parallel_1d._utils
import
split_forward_gather_backward
,
reduce_input
,
reduce_grad
from
colossalai.nn.layer.parallel_1d._utils
import
split_forward_gather_backward
,
reduce_input
,
\
gather_forward_split_backward
,
reduce_grad
from
colossalai.nn.layer.utils
import
divide
from
colossalai.nn.layer.utils
import
divide
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
packaging
import
version
from
packaging
import
version
from
colossalai.tensor
import
ComputePattern
,
TensorSpec
,
ComputePattern
,
ParallelAction
,
ColoTensor
,
ShardPattern
from
colossalai.tensor
import
ComputePattern
,
TensorSpec
,
ComputePattern
,
ParallelAction
,
ColoTensor
,
ShardPattern
from
colossalai.tensor.graph
import
GraphOpNode
,
GraphGlobalEnv
def
colo_linear_1Drow
(
input_tensor
:
ColoTensor
,
weight
:
ColoTensor
,
bias
:
ColoTensor
)
->
ColoTensor
:
def
colo_linear_1Drow
(
input_tensor
:
ColoTensor
,
weight
:
ColoTensor
,
bias
:
ColoTensor
)
->
ColoTensor
:
parallel_action
=
weight
.
shard_spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1DRow_Linear
)
parallel_action
=
weight
.
shard_spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1DRow_Linear
)
# Input:S[1] x Weight:S[0] = Output:P
# Input:S[1] x Weight:S[0] = Output:P
# All-Reduce(Output) + bias = res
# All-Reduce(Output) + bias = res
...
@@ -99,20 +98,32 @@ def colo_linear(types, args, kwargs, pg):
...
@@ -99,20 +98,32 @@ def colo_linear(types, args, kwargs, pg):
if
bias
is
not
None
and
not
isinstance
(
bias
,
ColoTensor
):
if
bias
is
not
None
and
not
isinstance
(
bias
,
ColoTensor
):
bias
=
ColoTensor
.
init_from_torch_tensor
(
bias
)
bias
=
ColoTensor
.
init_from_torch_tensor
(
bias
)
# building the computing graph, inputs -> op
if
GraphGlobalEnv
().
graph_building
:
cur_op_node
=
GraphOpNode
(
'linear'
,
[
weight
,
bias
])
cur_op_node
.
add_prev_tensor
(
input_tensor
)
# Add communication logic before and after linear call.
# Add communication logic before and after linear call.
ret_tensor
=
None
if
not
weight
.
has_spec
():
# No Model Parallel Applied
if
not
weight
.
has_spec
():
# No Model Parallel Applied
assert
not
bias
.
has_spec
(),
'Invalid bias spec for native Linear op'
assert
not
bias
.
has_spec
(),
'Invalid bias spec for native Linear op'
input_tensor
=
input_tensor
.
torch_tensor
()
input_tensor
=
input_tensor
.
torch_tensor
()
weight
=
weight
.
torch_tensor
()
weight
=
weight
.
torch_tensor
()
bias
=
bias
.
torch_tensor
()
bias
=
bias
.
torch_tensor
()
ret
urn
ColoTensor
.
init_from_torch_tensor
(
torch
.
nn
.
functional
.
linear
(
input_tensor
,
weight
,
bias
))
ret
_tensor
=
ColoTensor
.
init_from_torch_tensor
(
torch
.
nn
.
functional
.
linear
(
input_tensor
,
weight
,
bias
))
elif
weight
.
shard_spec
.
num_action
==
1
:
# Single Model Parallel Applied
elif
weight
.
shard_spec
.
num_action
==
1
:
# Single Model Parallel Applied
compute_patterns
=
weight
.
shard_spec
.
compute_patterns
compute_patterns
=
weight
.
shard_spec
.
compute_patterns
if
ComputePattern
.
TP1DRow_Linear
in
compute_patterns
:
if
ComputePattern
.
TP1DRow_Linear
in
compute_patterns
:
ret
urn
colo_linear_1Drow
(
input_tensor
,
weight
,
bias
)
ret
_tensor
=
colo_linear_1Drow
(
input_tensor
,
weight
,
bias
)
elif
ComputePattern
.
TP1DCol_Linear
in
compute_patterns
:
elif
ComputePattern
.
TP1DCol_Linear
in
compute_patterns
:
ret
urn
colo_linear_1Dcol
(
input_tensor
,
weight
,
bias
)
ret
_tensor
=
colo_linear_1Dcol
(
input_tensor
,
weight
,
bias
)
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
# building the computing graph, op -> output
if
GraphGlobalEnv
().
graph_building
:
cur_op_node
.
add_post_tensor
(
ret_tensor
)
return
ret_tensor
colossalai/tensor/colo_tensor.py
View file @
845856ea
...
@@ -38,6 +38,7 @@ class ColoTensor(object):
...
@@ -38,6 +38,7 @@ class ColoTensor(object):
self
.
_shard_spec
=
shard_spec
self
.
_shard_spec
=
shard_spec
self
.
_shard_pattern
=
ShardPattern
.
NA
self
.
_shard_pattern
=
ShardPattern
.
NA
self
.
_type
=
TensorType
.
NONMODEL
self
.
_type
=
TensorType
.
NONMODEL
self
.
_graph_node
=
None
def
__getitem__
(
self
,
key
):
def
__getitem__
(
self
,
key
):
return
ColoTensor
.
init_from_torch_tensor
(
self
.
torch_tensor
()[
key
])
return
ColoTensor
.
init_from_torch_tensor
(
self
.
torch_tensor
()[
key
])
...
...
colossalai/tensor/graph/__init__.py
0 → 100644
View file @
845856ea
from
.graph_node
import
GraphNode
,
GraphOpNode
,
GraphContext
,
GraphGlobalEnv
__all__
=
[
'GraphNode'
,
'GraphOpNode'
,
'GraphContext'
,
'GraphGlobalEnv'
]
colossalai/tensor/graph/graph_node.py
0 → 100644
View file @
845856ea
from
colossalai.tensor
import
ColoTensor
from
colossalai.context.singleton_meta
import
SingletonMeta
class
GraphGlobalEnv
(
metaclass
=
SingletonMeta
):
def
__init__
(
self
)
->
None
:
self
.
graph_building
=
False
self
.
graph_node_list
=
[]
self
.
node_id
=
-
1
def
get_node_id
(
self
):
self
.
node_id
+=
1
return
self
.
node_id
def
add_graph_node
(
self
,
node
):
self
.
graph_node_list
.
append
(
node
)
class
GraphContext
():
"""
Building the computing graph under the context
>>> with GraphContext():
>>> output = model(colo_input_tensor)
"""
graph_nodes
=
[]
def
__enter__
(
self
):
GraphGlobalEnv
().
graph_building
=
True
GraphGlobalEnv
().
graph_node_list
=
[]
def
__exit__
(
self
,
*
exc_info
):
GraphGlobalEnv
().
graph_building
=
False
GraphGlobalEnv
().
node_id
=
-
1
self
.
graph_nodes
=
GraphGlobalEnv
().
graph_node_list
class
GraphNode
(
object
):
def
__init__
(
self
)
->
None
:
self
.
prev_nodes
=
[]
self
.
post_nodes
=
[]
self
.
id
=
GraphGlobalEnv
().
get_node_id
()
def
add_prev_node
(
self
,
node
):
if
GraphGlobalEnv
().
graph_building
:
self
.
prev_nodes
.
append
(
node
)
def
add_post_node
(
self
,
node
):
if
GraphGlobalEnv
().
graph_building
:
self
.
post_nodes
.
append
(
node
)
def
post_node_empty
(
self
)
->
bool
:
return
len
(
self
.
post_nodes
)
==
0
class
GraphOpNode
(
GraphNode
):
def
__init__
(
self
,
op_type
,
param_list
)
->
None
:
super
().
__init__
()
self
.
_op_type
=
op_type
self
.
_param_list
=
param_list
GraphGlobalEnv
().
add_graph_node
(
self
)
def
add_prev_tensor
(
self
,
colo_tensor
:
ColoTensor
):
r
"""
Link the current graph op node to previous graph op.
Op1 <- Activation (colo_tensor) Op2
Op1 <- Op2
"""
if
GraphGlobalEnv
().
graph_building
:
assert
isinstance
(
colo_tensor
,
ColoTensor
)
if
colo_tensor
.
_graph_node
is
None
:
colo_tensor
.
_graph_node
=
GraphNode
()
prev_ops
=
colo_tensor
.
_graph_node
.
prev_nodes
for
op_node
in
prev_ops
:
self
.
add_prev_node
(
op_node
)
op_node
.
add_post_node
(
self
)
def
add_post_tensor
(
self
,
colo_tensor
:
ColoTensor
):
"""
Op <- Activation (colo_tensor)
"""
if
GraphGlobalEnv
().
graph_building
:
assert
isinstance
(
colo_tensor
,
ColoTensor
)
if
colo_tensor
.
_graph_node
is
None
:
colo_tensor
.
_graph_node
=
GraphNode
()
colo_tensor
.
_graph_node
.
add_prev_node
(
self
)
def
print
(
self
):
print
(
f
'GraphOpNode
{
self
.
_op_type
}
{
self
.
id
}
, post nodes
{
[
node
.
id
for
node
in
self
.
post_nodes
]
}
, prev node number
{
[
node
.
id
for
node
in
self
.
prev_nodes
]
}
'
)
tests/test_tensor/test_graph.py
0 → 100644
View file @
845856ea
from
torch
import
nn
import
torch
from
colossalai.tensor
import
ColoTensor
from
colossalai.tensor.graph
import
GraphContext
import
gc
class
SimpleNet
(
nn
.
Module
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
proj1
=
nn
.
Linear
(
4
,
8
)
self
.
proj2
=
nn
.
Linear
(
8
,
4
)
self
.
proj3
=
nn
.
Linear
(
4
,
4
)
self
.
proj4
=
nn
.
Linear
(
4
,
4
)
def
forward
(
self
,
x
):
x
=
self
.
proj1
(
x
)
x
=
self
.
proj2
(
x
)
x
=
self
.
proj3
(
x
)
x
=
self
.
proj4
(
x
)
return
x
def
_visit_graph
(
start_node
):
if
start_node
is
None
:
return
start_node
.
print
()
post_node_list
=
start_node
.
post_nodes
for
node
in
post_node_list
:
_visit_graph
(
node
)
def
_get_tensors
():
for
obj
in
gc
.
get_objects
():
try
:
if
torch
.
is_tensor
(
obj
):
yield
obj
except
Exception
as
e
:
print
(
'A trivial exception occured: {}'
.
format
(
e
))
def
_count_tensors
():
cnt
=
0
for
t
in
_get_tensors
():
cnt
+=
1
return
cnt
def
count_tensors
(
use_colossal
):
model
=
SimpleNet
()
model
.
eval
()
with
torch
.
no_grad
():
if
use_colossal
:
colo_input
=
ColoTensor
.
init_from_torch_tensor
(
torch
.
randn
(
4
))
graph_ctx
=
GraphContext
()
with
graph_ctx
:
output
=
model
(
colo_input
)
output
=
model
(
colo_input
)
ret
=
_count_tensors
()
_visit_graph
(
graph_ctx
.
graph_nodes
[
0
])
del
graph_ctx
return
ret
else
:
input_t
=
torch
.
randn
(
4
)
output
=
model
(
input_t
)
output
=
model
(
input_t
)
return
_count_tensors
()
def
test_check_activation_tensors
():
assert
count_tensors
(
False
)
==
count_tensors
(
True
)
if
__name__
==
"__main__"
:
count_tensors
(
True
)
tests/test_tensor/test_model.py
View file @
845856ea
...
@@ -26,7 +26,7 @@ from dataclasses import fields
...
@@ -26,7 +26,7 @@ from dataclasses import fields
def
_post_init_colo
(
self
):
def
_post_init_colo
(
self
):
class_fields
=
fields
(
self
)
class_fields
=
fields
(
self
)
# Safety and consistency checks
# Safety and consistency checks
if
not
len
(
class_fields
):
if
len
(
class_fields
)
==
0
:
raise
ValueError
(
f
"
{
self
.
__class__
.
__name__
}
has no fields."
)
raise
ValueError
(
f
"
{
self
.
__class__
.
__name__
}
has no fields."
)
if
not
all
(
field
.
default
is
None
for
field
in
class_fields
[
1
:]):
if
not
all
(
field
.
default
is
None
for
field
in
class_fields
[
1
:]):
raise
ValueError
(
f
"
{
self
.
__class__
.
__name__
}
should not have more than one required field."
)
raise
ValueError
(
f
"
{
self
.
__class__
.
__name__
}
should not have more than one required field."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment