Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
bdef9dfd
Unverified
Commit
bdef9dfd
authored
Dec 20, 2022
by
Jiarui Fang
Committed by
GitHub
Dec 20, 2022
Browse files
[NFC] remove useless graph node code (#2150)
parent
b3f73ce1
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
0 additions
and
151 deletions
+0
-151
colossalai/nn/graph/__init__.py
colossalai/nn/graph/__init__.py
+0
-4
colossalai/nn/graph/graph_node.py
colossalai/nn/graph/graph_node.py
+0
-96
colossalai/nn/graph/utils.py
colossalai/nn/graph/utils.py
+0
-51
No files found.
colossalai/nn/graph/__init__.py
deleted
100644 → 0
View file @
b3f73ce1
from
.utils
import
register_colo_graph
from
.graph_node
import
GraphContext
,
GraphGlobalEnv
,
GraphOpNode
__all__
=
[
'register_colo_graph'
,
'GraphContext'
,
'GraphGlobalEnv'
,
'GraphOpNode'
]
\ No newline at end of file
colossalai/nn/graph/graph_node.py
deleted
100644 → 0
View file @
b3f73ce1
from
colossalai.tensor
import
ColoTensor
from
colossalai.context.singleton_meta
import
SingletonMeta
class
GraphGlobalEnv
(
metaclass
=
SingletonMeta
):
def
__init__
(
self
)
->
None
:
self
.
graph_building
=
False
self
.
graph_node_list
=
[]
self
.
node_id
=
-
1
def
get_node_id
(
self
):
self
.
node_id
+=
1
return
self
.
node_id
def
add_graph_node
(
self
,
node
):
self
.
graph_node_list
.
append
(
node
)
class
GraphContext
():
"""
Building the computing graph under the context
>>> with GraphContext():
>>> output = model(colo_input_tensor)
"""
graph_nodes
=
[]
def
__enter__
(
self
):
GraphGlobalEnv
().
graph_building
=
True
GraphGlobalEnv
().
graph_node_list
=
[]
def
__exit__
(
self
,
*
exc_info
):
GraphGlobalEnv
().
graph_building
=
False
GraphGlobalEnv
().
node_id
=
-
1
self
.
graph_nodes
=
GraphGlobalEnv
().
graph_node_list
class
GraphNode
(
object
):
def
__init__
(
self
)
->
None
:
self
.
prev_nodes
=
[]
self
.
post_nodes
=
[]
self
.
id
=
GraphGlobalEnv
().
get_node_id
()
def
add_prev_node
(
self
,
node
):
if
GraphGlobalEnv
().
graph_building
:
self
.
prev_nodes
.
append
(
node
)
def
add_post_node
(
self
,
node
):
if
GraphGlobalEnv
().
graph_building
:
self
.
post_nodes
.
append
(
node
)
def
post_node_empty
(
self
)
->
bool
:
return
len
(
self
.
post_nodes
)
==
0
class
GraphOpNode
(
GraphNode
):
def
__init__
(
self
,
op_type
,
param_list
)
->
None
:
super
().
__init__
()
self
.
_op_type
=
op_type
self
.
_param_list
=
param_list
GraphGlobalEnv
().
add_graph_node
(
self
)
def
add_prev_tensor
(
self
,
colo_tensor
:
ColoTensor
):
r
"""
Link the current graph op node to previous graph op.
Op1 <- Activation (colo_tensor) Op2
Op1 <- Op2
"""
if
GraphGlobalEnv
().
graph_building
:
assert
isinstance
(
colo_tensor
,
ColoTensor
)
if
colo_tensor
.
_graph_node
is
None
:
colo_tensor
.
_graph_node
=
GraphNode
()
prev_ops
=
colo_tensor
.
_graph_node
.
prev_nodes
for
op_node
in
prev_ops
:
self
.
add_prev_node
(
op_node
)
op_node
.
add_post_node
(
self
)
def
add_post_tensor
(
self
,
colo_tensor
:
ColoTensor
):
"""
Op <- Activation (colo_tensor)
"""
if
GraphGlobalEnv
().
graph_building
:
assert
isinstance
(
colo_tensor
,
ColoTensor
),
f
'type
{
type
(
colo_tensor
)
}
'
if
colo_tensor
.
_graph_node
is
None
:
colo_tensor
.
_graph_node
=
GraphNode
()
colo_tensor
.
_graph_node
.
add_prev_node
(
self
)
def
print
(
self
):
print
(
f
'GraphOpNode
{
self
.
_op_type
}
{
self
.
id
}
, post nodes
{
[
node
.
id
for
node
in
self
.
post_nodes
]
}
, prev node number
{
[
node
.
id
for
node
in
self
.
prev_nodes
]
}
'
)
colossalai/nn/graph/utils.py
deleted
100644 → 0
View file @
b3f73ce1
import
functools
import
torch
from
colossalai.tensor
import
ColoTensor
from
typing
import
Callable
,
List
from
colossalai.nn._ops._utils
import
convert_to_colo_tensor
def
register_colo_graph
(
input_pos
:
List
[
int
],
param_pos
:
List
[
int
])
->
Callable
:
"""register_colo_graph
Register a Op (Layer) to ColoGraph.
Recoders the input args in types of ColoTensor to the Graph.
Args:
func (Callable): a function implements the Op.
Returns:
Callable: wrapper function.
"""
def
register_colo_graph_decorator
(
func
):
from
colossalai.nn.graph
import
GraphOpNode
,
GraphGlobalEnv
@
functools
.
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
param_list
=
[]
input_list
=
[]
# TODO(jiaruifang) find the pg
for
idx
,
arg
in
enumerate
(
args
):
if
isinstance
(
arg
,
torch
.
Tensor
)
and
idx
in
input_pos
:
input_list
.
append
(
convert_to_colo_tensor
(
arg
))
if
isinstance
(
arg
,
torch
.
Tensor
)
and
idx
in
param_pos
:
param_list
.
append
(
convert_to_colo_tensor
(
arg
))
# building the computing graph, inputs -> op
if
GraphGlobalEnv
().
graph_building
:
cur_op_node
=
GraphOpNode
(
'linear'
,
param_list
)
# TODO supports a list of ColoTensor as args
if
len
(
input_list
)
>
0
:
cur_op_node
.
add_prev_tensor
(
input_list
[
0
])
outputs
=
func
(
*
args
,
**
kwargs
)
# building the computing graph, op -> output
if
GraphGlobalEnv
().
graph_building
:
# TODO supports a list of ColoTensor as args
if
isinstance
(
outputs
[
0
],
ColoTensor
):
cur_op_node
.
add_post_tensor
(
outputs
[
0
])
return
outputs
return
wrapper
return
register_colo_graph_decorator
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment