Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
07f9c781
Unverified
Commit
07f9c781
authored
Jun 22, 2022
by
Jiarui Fang
Committed by
GitHub
Jun 22, 2022
Browse files
[graph] improve the graph building. (#1157)
parent
22717a85
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
79 additions
and
103 deletions
+79
-103
colossalai/nn/_ops/linear.py
colossalai/nn/_ops/linear.py
+16
-14
colossalai/nn/graph/__init__.py
colossalai/nn/graph/__init__.py
+4
-0
colossalai/nn/graph/graph_node.py
colossalai/nn/graph/graph_node.py
+1
-2
colossalai/nn/graph/utils.py
colossalai/nn/graph/utils.py
+50
-0
colossalai/tensor/distspec.py
colossalai/tensor/distspec.py
+8
-0
colossalai/tensor/graph/__init__.py
colossalai/tensor/graph/__init__.py
+0
-3
tests/test_tensor/test_graph.py
tests/test_tensor/test_graph.py
+0
-84
No files found.
colossalai/nn/_ops/linear.py
View file @
07f9c781
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
typing
import
Optional
from
typing
import
Optional
from
._utils
import
GeneralTensor
,
convert_to_colo_tensor
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.nn.layer.parallel_1d._utils
import
reduce_input
,
reduce_grad
from
colossalai.nn.layer.parallel_1d._utils
import
reduce_input
,
reduce_grad
from
colossalai.tensor
import
ComputePattern
,
TensorSpec
,
ComputePattern
,
ParallelAction
,
ColoTensor
,
distspec
from
colossalai.tensor
import
ComputePattern
,
TensorSpec
,
ComputePattern
,
ParallelAction
,
ColoTensor
,
distspec
from
colossalai.tensor.graph
import
GraphOpNode
,
GraphGlobalEnv
from
colossalai.context
import
ParallelMode
from
colossalai.context
import
ParallelMode
from
._utils
import
GeneralTensor
,
convert_to_colo_tensor
from
colossalai.nn.graph
import
register_colo_graph
,
GraphOpNode
,
GraphGlobalEnv
def
colo_linear_1Drow
(
input_tensor
:
ColoTensor
,
weight
:
ColoTensor
,
bias
:
Optional
[
ColoTensor
])
->
ColoTensor
:
def
colo_linear_1Drow
(
input_tensor
:
ColoTensor
,
weight
:
ColoTensor
,
bias
:
Optional
[
ColoTensor
])
->
'
ColoTensor
'
:
# Input:S[1] x Weight:S[0] = Output:P
# Input:S[1] x Weight:S[0] = Output:P
# All-Reduce(Output) + bias = res
# All-Reduce(Output) + bias = res
# Input:S[1]
# Input:S[1]
...
@@ -28,7 +28,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
...
@@ -28,7 +28,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
return
output
return
output
def
colo_linear_1Dcol
(
input_tensor
:
ColoTensor
,
weight
:
ColoTensor
,
bias
:
Optional
[
ColoTensor
])
->
ColoTensor
:
def
colo_linear_1Dcol
(
input_tensor
:
ColoTensor
,
weight
:
ColoTensor
,
bias
:
Optional
[
ColoTensor
])
->
'
ColoTensor
'
:
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
# All-Gather(Output)
# All-Gather(Output)
# Input:B
# Input:B
...
@@ -48,23 +48,21 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
...
@@ -48,23 +48,21 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
return
output
return
output
def
colo_linear_1d
(
mode
:
str
,
input_tensor
:
ColoTensor
,
weight
:
ColoTensor
,
bias
:
Optional
[
ColoTensor
])
->
ColoTensor
:
def
colo_linear_1d
(
mode
:
str
,
input_tensor
:
ColoTensor
,
weight
:
ColoTensor
,
bias
:
Optional
[
ColoTensor
])
->
'
ColoTensor
'
:
assert
mode
in
(
'row'
,
'col'
)
assert
mode
in
(
'row'
,
'col'
)
funcs
=
{
'row'
:
colo_linear_1Drow
,
'col'
:
colo_linear_1Dcol
}
funcs
=
{
'row'
:
colo_linear_1Drow
,
'col'
:
colo_linear_1Dcol
}
return
funcs
[
mode
](
input_tensor
,
weight
,
bias
)
return
funcs
[
mode
](
input_tensor
,
weight
,
bias
)
@
colo_op_impl
(
F
.
linear
)
@
register_colo_graph
(
input_pos
=
[
1
],
param_pos
=
[
2
,
3
])
def
colo_linear
(
input_tensor
:
GeneralTensor
,
weight
:
GeneralTensor
,
bias
:
Optional
[
GeneralTensor
]
=
None
):
def
colo_linear_imp
(
input_tensor
:
GeneralTensor
,
weight
:
GeneralTensor
,
bias
:
Optional
[
GeneralTensor
]
=
None
)
->
'ColoTensor'
:
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
This method computes a linear.
This method computes a linear.
"""
"""
input_tensor
,
weight
,
bias
=
tuple
(
map
(
convert_to_colo_tensor
,
(
input_tensor
,
weight
,
bias
)))
input_tensor
,
weight
,
bias
=
tuple
(
map
(
convert_to_colo_tensor
,
(
input_tensor
,
weight
,
bias
)))
# building the computing graph, inputs -> op
if
GraphGlobalEnv
().
graph_building
:
cur_op_node
=
GraphOpNode
(
'linear'
,
[
weight
,
bias
])
cur_op_node
.
add_prev_tensor
(
input_tensor
)
# Add communication logic before and after linear call.
# Add communication logic before and after linear call.
ret_tensor
=
None
ret_tensor
=
None
if
not
weight
.
has_spec
():
# No Model Parallel Applied
if
not
weight
.
has_spec
():
# No Model Parallel Applied
...
@@ -82,7 +80,11 @@ def colo_linear(input_tensor: GeneralTensor, weight: GeneralTensor, bias: Option
...
@@ -82,7 +80,11 @@ def colo_linear(input_tensor: GeneralTensor, weight: GeneralTensor, bias: Option
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
# building the computing graph, op -> output
if
GraphGlobalEnv
().
graph_building
:
cur_op_node
.
add_post_tensor
(
ret_tensor
)
return
ret_tensor
return
ret_tensor
@
colo_op_impl
(
F
.
linear
)
def
colo_linear
(
input_tensor
:
GeneralTensor
,
weight
:
GeneralTensor
,
bias
:
Optional
[
GeneralTensor
]
=
None
)
->
'ColoTensor'
:
return
colo_linear_imp
(
input_tensor
,
weight
,
bias
)
colossalai/nn/graph/__init__.py
0 → 100644
View file @
07f9c781
from
.utils
import
register_colo_graph
from
.graph_node
import
GraphContext
,
GraphGlobalEnv
,
GraphOpNode
__all__
=
[
'register_colo_graph'
,
'GraphContext'
,
'GraphGlobalEnv'
,
'GraphOpNode'
]
\ No newline at end of file
colossalai/
tensor
/graph/graph_node.py
→
colossalai/
nn
/graph/graph_node.py
View file @
07f9c781
...
@@ -74,7 +74,6 @@ class GraphOpNode(GraphNode):
...
@@ -74,7 +74,6 @@ class GraphOpNode(GraphNode):
assert
isinstance
(
colo_tensor
,
ColoTensor
)
assert
isinstance
(
colo_tensor
,
ColoTensor
)
if
colo_tensor
.
_graph_node
is
None
:
if
colo_tensor
.
_graph_node
is
None
:
colo_tensor
.
_graph_node
=
GraphNode
()
colo_tensor
.
_graph_node
=
GraphNode
()
prev_ops
=
colo_tensor
.
_graph_node
.
prev_nodes
prev_ops
=
colo_tensor
.
_graph_node
.
prev_nodes
for
op_node
in
prev_ops
:
for
op_node
in
prev_ops
:
self
.
add_prev_node
(
op_node
)
self
.
add_prev_node
(
op_node
)
...
@@ -85,7 +84,7 @@ class GraphOpNode(GraphNode):
...
@@ -85,7 +84,7 @@ class GraphOpNode(GraphNode):
Op <- Activation (colo_tensor)
Op <- Activation (colo_tensor)
"""
"""
if
GraphGlobalEnv
().
graph_building
:
if
GraphGlobalEnv
().
graph_building
:
assert
isinstance
(
colo_tensor
,
ColoTensor
)
assert
isinstance
(
colo_tensor
,
ColoTensor
)
,
f
'type
{
type
(
colo_tensor
)
}
'
if
colo_tensor
.
_graph_node
is
None
:
if
colo_tensor
.
_graph_node
is
None
:
colo_tensor
.
_graph_node
=
GraphNode
()
colo_tensor
.
_graph_node
=
GraphNode
()
...
...
colossalai/nn/graph/utils.py
0 → 100644
View file @
07f9c781
import
functools
import
torch
from
colossalai.tensor
import
ColoTensor
from
typing
import
Callable
,
List
from
colossalai.nn._ops._utils
import
convert_to_colo_tensor
def
register_colo_graph
(
input_pos
:
List
[
int
],
param_pos
:
List
[
int
])
->
Callable
:
"""register_colo_graph
Register a Op (Layer) to ColoGraph.
Recoders the input args in types of ColoTensor to the Graph.
Args:
func (Callable): a function implements the Op.
Returns:
Callable: wrapper function.
"""
def
register_colo_graph_decorator
(
func
):
from
colossalai.nn.graph
import
GraphOpNode
,
GraphGlobalEnv
@
functools
.
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
param_list
=
[]
input_list
=
[]
for
idx
,
arg
in
enumerate
(
args
):
if
isinstance
(
arg
,
torch
.
Tensor
)
and
idx
in
input_pos
:
input_list
.
append
(
convert_to_colo_tensor
(
arg
))
if
isinstance
(
arg
,
torch
.
Tensor
)
and
idx
in
param_pos
:
param_list
.
append
(
convert_to_colo_tensor
(
arg
))
print
(
f
'Op
{
func
}
'
)
# building the computing graph, inputs -> op
if
GraphGlobalEnv
().
graph_building
:
cur_op_node
=
GraphOpNode
(
'linear'
,
param_list
)
# TODO supports a list of ColoTensor as args
if
len
(
input_list
)
>
0
:
cur_op_node
.
add_prev_tensor
(
input_list
[
0
])
outputs
=
func
(
*
args
,
**
kwargs
)
# building the computing graph, op -> output
if
GraphGlobalEnv
().
graph_building
:
# TODO supports a list of ColoTensor as args
if
isinstance
(
outputs
[
0
],
ColoTensor
):
cur_op_node
.
add_post_tensor
(
outputs
[
0
])
return
outputs
return
wrapper
return
register_colo_graph_decorator
colossalai/tensor/distspec.py
View file @
07f9c781
...
@@ -17,6 +17,13 @@ class _DistSpec:
...
@@ -17,6 +17,13 @@ class _DistSpec:
dist_placement_pattern
:
DistPlacementPattern
,
dist_placement_pattern
:
DistPlacementPattern
,
process_group
:
Optional
[
ProcessGroup
]
=
None
,
process_group
:
Optional
[
ProcessGroup
]
=
None
,
**
meta_info
):
**
meta_info
):
"""_DistSpec, Distributed Specification
Args:
dist_placement_pattern (DistPlacementPattern): the pattern describing how tensors are distributed among processes.
The dist_placement_pattern is picked from a limited set, now including two patterns: replicate and shard.
process_group (Optional[ProcessGroup], optional): the process group contains processes. Defaults to None.
"""
self
.
placement
=
dist_placement_pattern
self
.
placement
=
dist_placement_pattern
self
.
process_group
=
process_group
self
.
process_group
=
process_group
for
k
,
v
in
meta_info
.
items
():
for
k
,
v
in
meta_info
.
items
():
...
@@ -37,6 +44,7 @@ class _DistSpec:
...
@@ -37,6 +44,7 @@ class _DistSpec:
res
+=
f
'
{
attr
}
:
{
str
(
getattr
(
self
,
attr
))
}
\n\t
'
res
+=
f
'
{
attr
}
:
{
str
(
getattr
(
self
,
attr
))
}
\n\t
'
return
res
return
res
def
replicate
(
process_group
:
Optional
[
ProcessGroup
]
=
None
)
->
_DistSpec
:
def
replicate
(
process_group
:
Optional
[
ProcessGroup
]
=
None
)
->
_DistSpec
:
# process_group=None means global process group
# process_group=None means global process group
return
_DistSpec
(
DistPlacementPattern
.
REPLICATE
,
process_group
)
return
_DistSpec
(
DistPlacementPattern
.
REPLICATE
,
process_group
)
...
...
colossalai/tensor/graph/__init__.py
deleted
100644 → 0
View file @
22717a85
from
.graph_node
import
GraphNode
,
GraphOpNode
,
GraphContext
,
GraphGlobalEnv
__all__
=
[
'GraphNode'
,
'GraphOpNode'
,
'GraphContext'
,
'GraphGlobalEnv'
]
tests/test_tensor/test_graph.py
deleted
100644 → 0
View file @
22717a85
import
pytest
from
torch
import
nn
import
torch
from
colossalai.tensor
import
ColoTensor
from
colossalai.tensor.graph
import
GraphContext
import
gc
class
SimpleNet
(
nn
.
Module
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
proj1
=
nn
.
Linear
(
4
,
8
)
self
.
proj2
=
nn
.
Linear
(
8
,
4
)
self
.
proj3
=
nn
.
Linear
(
4
,
4
)
self
.
proj4
=
nn
.
Linear
(
4
,
4
)
def
forward
(
self
,
x
):
x
=
self
.
proj1
(
x
)
x
=
self
.
proj2
(
x
)
x
=
self
.
proj3
(
x
)
x
=
self
.
proj4
(
x
)
return
x
def
_visit_graph
(
start_node
):
if
start_node
is
None
:
return
start_node
.
print
()
post_node_list
=
start_node
.
post_nodes
for
node
in
post_node_list
:
_visit_graph
(
node
)
def
_get_tensors
():
for
obj
in
gc
.
get_objects
():
try
:
if
torch
.
is_tensor
(
obj
):
yield
obj
except
Exception
as
e
:
print
(
'A trivial exception occured: {}'
.
format
(
e
))
def
_count_tensors
():
cnt
=
0
for
t
in
_get_tensors
():
cnt
+=
1
return
cnt
def
count_tensors
(
use_colossal
):
model
=
SimpleNet
()
model
.
eval
()
with
torch
.
no_grad
():
if
use_colossal
:
colo_input
=
ColoTensor
.
from_torch_tensor
(
torch
.
randn
(
4
))
graph_ctx
=
GraphContext
()
with
graph_ctx
:
output
=
model
(
colo_input
)
output
=
model
(
colo_input
)
ret
=
_count_tensors
()
_visit_graph
(
graph_ctx
.
graph_nodes
[
0
])
del
graph_ctx
return
ret
else
:
input_t
=
torch
.
randn
(
4
)
output
=
model
(
input_t
)
output
=
model
(
input_t
)
return
_count_tensors
()
@
pytest
.
mark
.
skip
# FIXME(ver217)
def
test_check_activation_tensors
():
assert
count_tensors
(
False
)
==
count_tensors
(
True
)
if
__name__
==
"__main__"
:
count_tensors
(
True
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment