Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
fad3b6d1
"vscode:/vscode.git/clone" did not exist on "46add4a5c5d9296bece829354efe53a46642cba3"
Commit
fad3b6d1
authored
Nov 15, 2022
by
oahzxl
Browse files
polish code
parent
7e2bd1e4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
239 additions
and
239 deletions
+239
-239
chunk_codegen.py
chunk_codegen.py
+239
-239
No files found.
chunk_codegen.py
View file @
fad3b6d1
...
@@ -10,6 +10,13 @@ CODEGEN_AVAILABLE = True
...
@@ -10,6 +10,13 @@ CODEGEN_AVAILABLE = True
__all__
=
[
'ChunkCodeGen'
]
__all__
=
[
'ChunkCodeGen'
]
def
_delete_free_var_from_last_use
(
user_to_last_uses
):
for
key
,
value
in
user_to_last_uses
.
items
():
for
n
in
value
:
if
n
.
op
==
'placeholder'
:
user_to_last_uses
[
key
].
remove
(
n
)
class
NodeIndexTracer
(
object
):
class
NodeIndexTracer
(
object
):
def
__init__
(
self
,
gm
)
->
None
:
def
__init__
(
self
,
gm
)
->
None
:
self
.
gm
=
gm
self
.
gm
=
gm
...
@@ -19,7 +26,7 @@ class NodeIndexTracer(object):
...
@@ -19,7 +26,7 @@ class NodeIndexTracer(object):
self
.
idx_view_list
=
[]
self
.
idx_view_list
=
[]
self
.
idx_count
=
-
1
self
.
idx_count
=
-
1
def
add_index
(
self
):
def
_
add_index
(
self
):
"""
"""
Update the count and return it. To record the idx number.
Update the count and return it. To record the idx number.
...
@@ -29,7 +36,7 @@ class NodeIndexTracer(object):
...
@@ -29,7 +36,7 @@ class NodeIndexTracer(object):
self
.
idx_count
+=
1
self
.
idx_count
+=
1
return
self
.
idx_count
return
self
.
idx_count
def
inherit_computation
(
self
,
node_from
,
node_to
):
def
_
inherit_computation
(
self
,
node_from
,
node_to
):
"""
"""
Inherit computed dim from node_from to node_to.
Inherit computed dim from node_from to node_to.
If a dim in node_from is marked as computed and exists in node_to,
If a dim in node_from is marked as computed and exists in node_to,
...
@@ -39,13 +46,13 @@ class NodeIndexTracer(object):
...
@@ -39,13 +46,13 @@ class NodeIndexTracer(object):
node_from (node): node to be inherited
node_from (node): node to be inherited
node_to (node): new node to inherit
node_to (node): new node to inherit
"""
"""
_
,
compute_from
=
self
.
find_trace_from_node
(
node_from
)
_
,
compute_from
=
self
.
_
find_trace_from_node
(
node_from
)
idx_to
,
compute_to
=
self
.
find_trace_from_node
(
node_to
)
idx_to
,
compute_to
=
self
.
_
find_trace_from_node
(
node_to
)
for
i
in
compute_from
:
for
i
in
compute_from
:
if
i
in
idx_to
and
i
not
in
compute_to
:
if
i
in
idx_to
and
i
not
in
compute_to
:
compute_to
.
append
(
i
)
compute_to
.
append
(
i
)
def
mark_idx_equal
(
self
,
idx1
,
idx2
):
def
_
mark_idx_equal
(
self
,
idx1
,
idx2
):
"""
"""
Mark 2 index to be equal.
Mark 2 index to be equal.
...
@@ -55,7 +62,7 @@ class NodeIndexTracer(object):
...
@@ -55,7 +62,7 @@ class NodeIndexTracer(object):
"""
"""
self
.
idx_trace_equal
.
append
((
idx1
,
idx2
))
self
.
idx_trace_equal
.
append
((
idx1
,
idx2
))
def
mark_computation
(
self
,
node
,
idx
,
dim
):
def
_
mark_computation
(
self
,
node
,
idx
,
dim
):
"""
"""
Mark some dims of node as computed.
Mark some dims of node as computed.
...
@@ -64,7 +71,7 @@ class NodeIndexTracer(object):
...
@@ -64,7 +71,7 @@ class NodeIndexTracer(object):
idx (int): node index
idx (int): node index
dim (list or int): dims to be marked as computed
dim (list or int): dims to be marked as computed
"""
"""
input_node_idx_trace
=
self
.
find_idx_trace_from_node
(
node
)
input_node_idx_trace
=
self
.
_
find_idx_trace_from_node
(
node
)
if
isinstance
(
dim
,
int
):
if
isinstance
(
dim
,
int
):
dim
=
[
dim
]
dim
=
[
dim
]
for
d
in
dim
:
for
d
in
dim
:
...
@@ -72,7 +79,7 @@ class NodeIndexTracer(object):
...
@@ -72,7 +79,7 @@ class NodeIndexTracer(object):
if
cur_idx
not
in
self
.
idx_trace_list
[
idx
][
'compute'
]:
if
cur_idx
not
in
self
.
idx_trace_list
[
idx
][
'compute'
]:
self
.
idx_trace_list
[
idx
][
'compute'
].
append
(
cur_idx
)
self
.
idx_trace_list
[
idx
][
'compute'
].
append
(
cur_idx
)
def
find_trace_from_node
(
self
,
node
):
def
_
find_trace_from_node
(
self
,
node
):
"""
"""
Find node idx and compute trace by the node.
Find node idx and compute trace by the node.
...
@@ -86,7 +93,7 @@ class NodeIndexTracer(object):
...
@@ -86,7 +93,7 @@ class NodeIndexTracer(object):
node_dict
=
self
.
idx_trace_list
[
node_idx
]
node_dict
=
self
.
idx_trace_list
[
node_idx
]
return
node_dict
[
'idx'
],
node_dict
[
'compute'
]
return
node_dict
[
'idx'
],
node_dict
[
'compute'
]
def
find_idx_trace_from_node
(
self
,
node
):
def
_
find_idx_trace_from_node
(
self
,
node
):
"""
"""
Find node idx trace by the node.
Find node idx trace by the node.
...
@@ -98,7 +105,7 @@ class NodeIndexTracer(object):
...
@@ -98,7 +105,7 @@ class NodeIndexTracer(object):
node_idx
=
_find_idx_by_name
(
node
.
name
,
self
.
nodes_list
)
node_idx
=
_find_idx_by_name
(
node
.
name
,
self
.
nodes_list
)
return
self
.
idx_trace_list
[
node_idx
][
'idx'
]
return
self
.
idx_trace_list
[
node_idx
][
'idx'
]
def
find_compute_trace_from_node
(
self
,
node
):
def
_
find_compute_trace_from_node
(
self
,
node
):
"""
"""
Find node compute trace by the node.
Find node compute trace by the node.
...
@@ -110,7 +117,7 @@ class NodeIndexTracer(object):
...
@@ -110,7 +117,7 @@ class NodeIndexTracer(object):
node_idx
=
_find_idx_by_name
(
node
.
name
,
self
.
nodes_list
)
node_idx
=
_find_idx_by_name
(
node
.
name
,
self
.
nodes_list
)
return
self
.
idx_trace_list
[
node_idx
][
'compute'
]
return
self
.
idx_trace_list
[
node_idx
][
'compute'
]
def
assign_index_as_input
(
self
,
node
,
node_idx
):
def
_
assign_index_as_input
(
self
,
node
,
node_idx
):
"""
"""
Assign node's trace as its input node.
Assign node's trace as its input node.
...
@@ -124,7 +131,7 @@ class NodeIndexTracer(object):
...
@@ -124,7 +131,7 @@ class NodeIndexTracer(object):
new_idx_trace
=
copy
.
deepcopy
(
input_node_idx_trace
)
new_idx_trace
=
copy
.
deepcopy
(
input_node_idx_trace
)
self
.
idx_trace_list
[
node_idx
][
'idx'
]
=
new_idx_trace
self
.
idx_trace_list
[
node_idx
][
'idx'
]
=
new_idx_trace
def
assign_all_index
(
self
,
node
,
node_idx
):
def
_
assign_all_index
(
self
,
node
,
node_idx
):
"""
"""
Add new index for all node's dims.
Add new index for all node's dims.
...
@@ -135,10 +142,10 @@ class NodeIndexTracer(object):
...
@@ -135,10 +142,10 @@ class NodeIndexTracer(object):
shape
=
node
.
meta
[
'tensor_meta'
].
shape
shape
=
node
.
meta
[
'tensor_meta'
].
shape
new_trace
=
[]
new_trace
=
[]
for
_
in
shape
:
for
_
in
shape
:
new_trace
.
append
(
self
.
add_index
())
new_trace
.
append
(
self
.
_
add_index
())
self
.
idx_trace_list
[
node_idx
][
'idx'
]
=
new_trace
self
.
idx_trace_list
[
node_idx
][
'idx'
]
=
new_trace
def
assign_transpose_index
(
self
,
node
,
node_idx
):
def
_
assign_transpose_index
(
self
,
node
,
node_idx
):
"""
"""
Assign index for transpose op.
Assign index for transpose op.
1. swap input's dim according to transpose args
1. swap input's dim according to transpose args
...
@@ -149,16 +156,16 @@ class NodeIndexTracer(object):
...
@@ -149,16 +156,16 @@ class NodeIndexTracer(object):
node_idx (int)
node_idx (int)
"""
"""
tranpose_dim
=
node
.
args
[
1
:]
tranpose_dim
=
node
.
args
[
1
:]
input_node_idx_trace
=
self
.
find_idx_trace_from_node
(
node
.
args
[
0
])
input_node_idx_trace
=
self
.
_
find_idx_trace_from_node
(
node
.
args
[
0
])
new_idx_trace
=
copy
.
deepcopy
(
input_node_idx_trace
)
new_idx_trace
=
copy
.
deepcopy
(
input_node_idx_trace
)
new_idx_trace
[
tranpose_dim
[
0
]]
=
input_node_idx_trace
[
tranpose_dim
[
1
]]
new_idx_trace
[
tranpose_dim
[
0
]]
=
input_node_idx_trace
[
tranpose_dim
[
1
]]
new_idx_trace
[
tranpose_dim
[
1
]]
=
input_node_idx_trace
[
tranpose_dim
[
0
]]
new_idx_trace
[
tranpose_dim
[
1
]]
=
input_node_idx_trace
[
tranpose_dim
[
0
]]
self
.
idx_trace_list
[
node_idx
][
'idx'
]
=
new_idx_trace
self
.
idx_trace_list
[
node_idx
][
'idx'
]
=
new_idx_trace
self
.
inherit_computation
(
node
.
args
[
0
],
node
)
self
.
_
inherit_computation
(
node
.
args
[
0
],
node
)
def
assign_permute_index
(
self
,
node
,
node_idx
):
def
_
assign_permute_index
(
self
,
node
,
node_idx
):
"""
"""
Assign index for permute op.
Assign index for permute op.
1. swap input's dim according to permute args
1. swap input's dim according to permute args
...
@@ -169,16 +176,16 @@ class NodeIndexTracer(object):
...
@@ -169,16 +176,16 @@ class NodeIndexTracer(object):
node_idx (int)
node_idx (int)
"""
"""
permute_dim
=
node
.
args
[
1
:]
permute_dim
=
node
.
args
[
1
:]
input_node_idx_trace
=
self
.
find_idx_trace_from_node
(
node
.
args
[
0
])
input_node_idx_trace
=
self
.
_
find_idx_trace_from_node
(
node
.
args
[
0
])
new_idx_trace
=
copy
.
deepcopy
(
input_node_idx_trace
)
new_idx_trace
=
copy
.
deepcopy
(
input_node_idx_trace
)
for
idx
,
d
in
enumerate
(
permute_dim
):
for
idx
,
d
in
enumerate
(
permute_dim
):
new_idx_trace
[
idx
]
=
input_node_idx_trace
[
d
]
new_idx_trace
[
idx
]
=
input_node_idx_trace
[
d
]
self
.
idx_trace_list
[
node_idx
][
'idx'
]
=
new_idx_trace
self
.
idx_trace_list
[
node_idx
][
'idx'
]
=
new_idx_trace
self
.
inherit_computation
(
node
.
args
[
0
],
node
)
self
.
_
inherit_computation
(
node
.
args
[
0
],
node
)
def
assign_linear_index
(
self
,
node
,
node_idx
):
def
_
assign_linear_index
(
self
,
node
,
node_idx
):
"""
"""
Assign index for linear op.
Assign index for linear op.
1. copy trace from input node and change last index accroding to weight
1. copy trace from input node and change last index accroding to weight
...
@@ -190,22 +197,22 @@ class NodeIndexTracer(object):
...
@@ -190,22 +197,22 @@ class NodeIndexTracer(object):
node_idx (int)
node_idx (int)
"""
"""
input_node
,
weight
,
bias
=
node
.
args
input_node
,
weight
,
bias
=
node
.
args
input_node_idx_trace
=
self
.
find_idx_trace_from_node
(
input_node
)
input_node_idx_trace
=
self
.
_
find_idx_trace_from_node
(
input_node
)
weight_idx_trace
=
self
.
find_idx_trace_from_node
(
weight
)
weight_idx_trace
=
self
.
_
find_idx_trace_from_node
(
weight
)
new_idx_trace
=
copy
.
deepcopy
(
input_node_idx_trace
)
new_idx_trace
=
copy
.
deepcopy
(
input_node_idx_trace
)
new_idx_trace
[
-
1
]
=
weight_idx_trace
[
1
]
new_idx_trace
[
-
1
]
=
weight_idx_trace
[
1
]
self
.
idx_trace_list
[
node_idx
][
'idx'
]
=
new_idx_trace
self
.
idx_trace_list
[
node_idx
][
'idx'
]
=
new_idx_trace
self
.
inherit_computation
(
input_node
,
node
)
self
.
_
inherit_computation
(
input_node
,
node
)
self
.
mark_computation
(
node
,
node_idx
,
[
-
1
])
self
.
_
mark_computation
(
node
,
node_idx
,
[
-
1
])
self
.
mark_idx_equal
(
input_node_idx_trace
[
-
1
],
weight_idx_trace
[
0
])
self
.
_
mark_idx_equal
(
input_node_idx_trace
[
-
1
],
weight_idx_trace
[
0
])
if
bias
:
if
bias
:
bias_idx_trace
=
self
.
find_idx_trace_from_node
(
bias
)
bias_idx_trace
=
self
.
_
find_idx_trace_from_node
(
bias
)
self
.
mark_idx_equal
(
input_node_idx_trace
[
-
1
],
bias_idx_trace
[
0
])
self
.
_
mark_idx_equal
(
input_node_idx_trace
[
-
1
],
bias_idx_trace
[
0
])
def
assign_matmul_index
(
self
,
node
,
node_idx
):
def
_
assign_matmul_index
(
self
,
node
,
node_idx
):
"""
"""
Assign index for matmul op.
Assign index for matmul op.
1. copy trace from matmul_left and change last index accroding to matmul_right. (assert they have same length)
1. copy trace from matmul_left and change last index accroding to matmul_right. (assert they have same length)
...
@@ -217,20 +224,20 @@ class NodeIndexTracer(object):
...
@@ -217,20 +224,20 @@ class NodeIndexTracer(object):
node_idx (int)
node_idx (int)
"""
"""
matmul_left
,
matmul_right
=
node
.
args
matmul_left
,
matmul_right
=
node
.
args
matmul_left_idx_trace
=
self
.
find_idx_trace_from_node
(
matmul_left
)
matmul_left_idx_trace
=
self
.
_
find_idx_trace_from_node
(
matmul_left
)
matmul_right_idx_trace
=
self
.
find_idx_trace_from_node
(
matmul_right
)
matmul_right_idx_trace
=
self
.
_
find_idx_trace_from_node
(
matmul_right
)
assert
(
len
(
matmul_left_idx_trace
)
==
len
(
matmul_right_idx_trace
))
assert
(
len
(
matmul_left_idx_trace
)
==
len
(
matmul_right_idx_trace
))
new_idx_trace
=
copy
.
deepcopy
(
matmul_left_idx_trace
)
new_idx_trace
=
copy
.
deepcopy
(
matmul_left_idx_trace
)
new_idx_trace
[
-
1
]
=
matmul_right_idx_trace
[
-
1
]
new_idx_trace
[
-
1
]
=
matmul_right_idx_trace
[
-
1
]
self
.
idx_trace_list
[
node_idx
][
'idx'
]
=
new_idx_trace
self
.
idx_trace_list
[
node_idx
][
'idx'
]
=
new_idx_trace
self
.
inherit_computation
(
matmul_left
,
node
)
self
.
_
inherit_computation
(
matmul_left
,
node
)
self
.
inherit_computation
(
matmul_right
,
node
)
self
.
_
inherit_computation
(
matmul_right
,
node
)
self
.
mark_computation
(
node
,
node_idx
,
[
-
1
])
self
.
_
mark_computation
(
node
,
node_idx
,
[
-
1
])
self
.
mark_idx_equal
(
matmul_left_idx_trace
[
-
1
],
matmul_right_idx_trace
[
-
2
])
self
.
_
mark_idx_equal
(
matmul_left_idx_trace
[
-
1
],
matmul_right_idx_trace
[
-
2
])
def
assign_layernorm_index
(
self
,
node
,
idx
):
def
_
assign_layernorm_index
(
self
,
node
,
idx
):
"""
"""
Assign index for layernorm op.
Assign index for layernorm op.
1. assign index as input node
1. assign index as input node
...
@@ -240,11 +247,11 @@ class NodeIndexTracer(object):
...
@@ -240,11 +247,11 @@ class NodeIndexTracer(object):
node (node)
node (node)
node_idx (int)
node_idx (int)
"""
"""
self
.
assign_index_as_input
(
node
,
idx
)
self
.
_
assign_index_as_input
(
node
,
idx
)
self
.
inherit_computation
(
node
.
args
[
0
],
node
)
self
.
_
inherit_computation
(
node
.
args
[
0
],
node
)
self
.
mark_computation
(
node
,
idx
,
[
-
1
,
-
2
])
self
.
_
mark_computation
(
node
,
idx
,
[
-
1
,
-
2
])
def
assign_elementwise_index
(
self
,
node
,
idx
):
def
_
assign_elementwise_index
(
self
,
node
,
idx
):
"""
"""
Assign index for element-wise op (eg. relu sigmoid add mul).
Assign index for element-wise op (eg. relu sigmoid add mul).
1. assign index as input node
1. assign index as input node
...
@@ -254,12 +261,12 @@ class NodeIndexTracer(object):
...
@@ -254,12 +261,12 @@ class NodeIndexTracer(object):
node (node)
node (node)
node_idx (int)
node_idx (int)
"""
"""
self
.
assign_index_as_input
(
node
,
idx
)
self
.
_
assign_index_as_input
(
node
,
idx
)
for
node_in
in
node
.
args
:
for
node_in
in
node
.
args
:
if
type
(
node_in
)
not
in
(
int
,
float
):
if
type
(
node_in
)
not
in
(
int
,
float
):
self
.
inherit_computation
(
node_in
,
node
)
self
.
_
inherit_computation
(
node_in
,
node
)
def
assign_softmax_index
(
self
,
node
,
idx
):
def
_
assign_softmax_index
(
self
,
node
,
idx
):
"""
"""
Assign index for softmax op.
Assign index for softmax op.
1. assign index as input node
1. assign index as input node
...
@@ -269,11 +276,11 @@ class NodeIndexTracer(object):
...
@@ -269,11 +276,11 @@ class NodeIndexTracer(object):
node (node)
node (node)
node_idx (int)
node_idx (int)
"""
"""
self
.
assign_index_as_input
(
node
,
idx
)
self
.
_
assign_index_as_input
(
node
,
idx
)
self
.
inherit_computation
(
node
.
args
[
0
],
node
)
self
.
_
inherit_computation
(
node
.
args
[
0
],
node
)
self
.
mark_computation
(
node
,
idx
,
[
node
.
kwargs
[
'dim'
]])
self
.
_
mark_computation
(
node
,
idx
,
[
node
.
kwargs
[
'dim'
]])
def
assign_view_reshape_index
(
self
,
node
,
node_idx
):
def
_
assign_view_reshape_index
(
self
,
node
,
node_idx
):
"""
"""
Assign index for view and reshape op.
Assign index for view and reshape op.
1. get origin shape and target shape by meta info.
1. get origin shape and target shape by meta info.
...
@@ -325,22 +332,22 @@ class NodeIndexTracer(object):
...
@@ -325,22 +332,22 @@ class NodeIndexTracer(object):
raise
NotImplementedError
(
"shape"
+
str
(
origin_shape
)
+
'and'
+
str
(
target_shape
)
+
"view not implemented"
)
raise
NotImplementedError
(
"shape"
+
str
(
origin_shape
)
+
'and'
+
str
(
target_shape
)
+
"view not implemented"
)
# get new index
# get new index
origin_trace
=
self
.
find_idx_trace_from_node
(
origin_node
)
origin_trace
=
self
.
_
find_idx_trace_from_node
(
origin_node
)
new_trace
=
copy
.
deepcopy
(
origin_trace
)
new_trace
=
copy
.
deepcopy
(
origin_trace
)
dim_from
.
reverse
()
dim_from
.
reverse
()
for
i
in
dim_from
:
for
i
in
dim_from
:
new_trace
.
pop
(
i
)
new_trace
.
pop
(
i
)
for
i
in
dim_to
:
for
i
in
dim_to
:
new_trace
.
insert
(
i
,
self
.
add_index
())
new_trace
.
insert
(
i
,
self
.
_
add_index
())
self
.
idx_trace_list
[
node_idx
][
'idx'
]
=
new_trace
self
.
idx_trace_list
[
node_idx
][
'idx'
]
=
new_trace
# inherit computation
# inherit computation
self
.
inherit_computation
(
origin_node
,
node
)
self
.
_
inherit_computation
(
origin_node
,
node
)
compute_log
=
self
.
find_compute_trace_from_node
(
origin_node
)
compute_log
=
self
.
_
find_compute_trace_from_node
(
origin_node
)
for
i
in
dim_from
:
for
i
in
dim_from
:
if
origin_trace
[
i
]
in
compute_log
:
if
origin_trace
[
i
]
in
compute_log
:
for
j
in
dim_to
:
for
j
in
dim_to
:
self
.
mark_computation
(
node
,
node_idx
,
[
j
])
self
.
_
mark_computation
(
node
,
node_idx
,
[
j
])
break
break
# log view, not used now
# log view, not used now
...
@@ -353,25 +360,25 @@ class NodeIndexTracer(object):
...
@@ -353,25 +360,25 @@ class NodeIndexTracer(object):
def
trace_node_idx
(
self
):
def
trace_node_idx
(
self
):
for
idx
,
node
in
enumerate
(
self
.
nodes_list
):
for
idx
,
node
in
enumerate
(
self
.
nodes_list
):
if
node
.
op
==
'placeholder'
:
if
node
.
op
==
'placeholder'
:
self
.
assign_all_index
(
node
,
idx
)
self
.
_
assign_all_index
(
node
,
idx
)
elif
node
.
op
==
'call_method'
:
elif
node
.
op
==
'call_method'
:
if
'transpose'
in
node
.
name
:
if
'transpose'
in
node
.
name
:
self
.
assign_transpose_index
(
node
,
idx
)
self
.
_
assign_transpose_index
(
node
,
idx
)
elif
'permute'
in
node
.
name
:
elif
'permute'
in
node
.
name
:
self
.
assign_permute_index
(
node
,
idx
)
self
.
_
assign_permute_index
(
node
,
idx
)
elif
'view'
in
node
.
name
or
'reshape'
in
node
.
name
:
elif
'view'
in
node
.
name
or
'reshape'
in
node
.
name
:
self
.
assign_view_reshape_index
(
node
,
idx
)
self
.
_
assign_view_reshape_index
(
node
,
idx
)
else
:
else
:
raise
NotImplementedError
(
node
.
name
,
"method not implemented yet!"
)
raise
NotImplementedError
(
node
.
name
,
"method not implemented yet!"
)
elif
node
.
op
==
'call_function'
:
elif
node
.
op
==
'call_function'
:
if
'linear'
in
node
.
name
:
if
'linear'
in
node
.
name
:
self
.
assign_linear_index
(
node
,
idx
)
self
.
_
assign_linear_index
(
node
,
idx
)
elif
'matmul'
in
node
.
name
:
elif
'matmul'
in
node
.
name
:
self
.
assign_matmul_index
(
node
,
idx
)
self
.
_
assign_matmul_index
(
node
,
idx
)
elif
'softmax'
in
node
.
name
:
elif
'softmax'
in
node
.
name
:
self
.
assign_softmax_index
(
node
,
idx
)
self
.
_
assign_softmax_index
(
node
,
idx
)
elif
any
(
n
in
node
.
name
for
n
in
[
'mul'
,
'add'
,
'sigmoid'
,
'relu'
]):
elif
any
(
n
in
node
.
name
for
n
in
[
'mul'
,
'add'
,
'sigmoid'
,
'relu'
]):
self
.
assign_elementwise_index
(
node
,
idx
)
self
.
_
assign_elementwise_index
(
node
,
idx
)
elif
'getattr'
in
node
.
name
:
elif
'getattr'
in
node
.
name
:
continue
# get attr like shape
continue
# get attr like shape
elif
'getitem'
in
node
.
name
:
elif
'getitem'
in
node
.
name
:
...
@@ -380,206 +387,198 @@ class NodeIndexTracer(object):
...
@@ -380,206 +387,198 @@ class NodeIndexTracer(object):
raise
NotImplementedError
(
node
.
name
,
"function not implemented yet!"
)
raise
NotImplementedError
(
node
.
name
,
"function not implemented yet!"
)
elif
node
.
op
==
'call_module'
:
elif
node
.
op
==
'call_module'
:
if
any
(
n
in
node
.
name
for
n
in
[
'layernorm'
,
'norm'
]):
if
any
(
n
in
node
.
name
for
n
in
[
'layernorm'
,
'norm'
]):
self
.
assign_layernorm_index
(
node
,
idx
)
self
.
_
assign_layernorm_index
(
node
,
idx
)
else
:
else
:
raise
NotImplementedError
(
node
.
name
,
"module not implemented yet!"
)
raise
NotImplementedError
(
node
.
name
,
"module not implemented yet!"
)
elif
node
.
op
==
'get_attr'
:
elif
node
.
op
==
'get_attr'
:
self
.
assign_all_index
(
node
,
idx
)
# get param
self
.
_
assign_all_index
(
node
,
idx
)
# get param
elif
node
.
op
==
'output'
:
elif
node
.
op
==
'output'
:
continue
continue
else
:
else
:
raise
NotImplementedError
(
node
.
op
,
"op not implemented yet!"
)
raise
NotImplementedError
(
node
.
op
,
"op not implemented yet!"
)
def
_get_meta_node_size
(
x
):
class
MemoryEstimator
(
object
):
x
=
x
.
meta
[
'tensor_meta'
]
def
__init__
(
self
)
->
None
:
x
=
x
.
numel
*
torch
.
tensor
([],
dtype
=
x
.
dtype
).
element_size
()
pass
return
x
def
_get_output_node_size
(
n
):
def
_get_meta_node_size
(
self
,
x
):
fwd_out
=
{
x
.
uuid
:
x
for
x
in
n
.
meta
[
"fwd_out"
]
if
isinstance
(
x
,
torch
.
Tensor
)
and
hasattr
(
x
,
'uuid'
)}
x
=
x
.
meta
[
'tensor_meta'
]
return
activation_size
(
fwd_out
)
x
=
x
.
numel
*
torch
.
tensor
([],
dtype
=
x
.
dtype
).
element_size
()
return
x
def
_get_output_node_size
(
self
,
n
):
fwd_out
=
{
x
.
uuid
:
x
for
x
in
n
.
meta
[
"fwd_out"
]
if
isinstance
(
x
,
torch
.
Tensor
)
and
hasattr
(
x
,
'uuid'
)}
return
activation_size
(
fwd_out
)
def
_get_delete_node_size
(
user
,
user_to_last_uses
):
def
_get_delete_node_size
(
self
,
user
,
user_to_last_uses
):
if
user
.
op
in
(
'placeholder'
,
'output'
):
if
user
.
op
in
(
'placeholder'
,
'output'
):
return
0
nodes_to_delete
=
user_to_last_uses
.
get
(
user
,
[])
if
len
(
nodes_to_delete
):
delete_size
=
sum
([
self
.
_get_output_node_size
(
i
)
for
i
in
nodes_to_delete
])
return
delete_size
return
0
return
0
nodes_to_delete
=
user_to_last_uses
.
get
(
user
,
[])
if
len
(
nodes_to_delete
):
delete_size
=
sum
([
_get_output_node_size
(
i
)
for
i
in
nodes_to_delete
])
return
delete_size
return
0
def
_get_last_usr
(
nodes
):
node_to_last_use
:
Dict
[
Node
,
Node
]
=
{}
user_to_last_uses
:
Dict
[
Node
,
List
[
Node
]]
=
{}
def
register_last_uses
(
n
:
Node
,
user
:
Node
):
if
n
not
in
node_to_last_use
:
node_to_last_use
[
n
]
=
user
user_to_last_uses
.
setdefault
(
user
,
[]).
append
(
n
)
for
node
in
reversed
(
nodes
):
map_arg
(
node
.
args
,
lambda
n
:
register_last_uses
(
n
,
node
))
map_arg
(
node
.
kwargs
,
lambda
n
:
register_last_uses
(
n
,
node
))
return
user_to_last_uses
def
_delete_free_var_from_last_use
(
user_to_last_uses
):
for
key
,
value
in
user_to_last_uses
.
items
():
for
n
in
value
:
if
n
.
op
==
'placeholder'
:
user_to_last_uses
[
key
].
remove
(
n
)
def
_get_contiguous_memory
(
node
,
not_contiguous_list
,
delete
=
False
):
mem
=
0
not_contiguous_ops
=
[
'transpose'
,
'permute'
]
if
node
.
op
==
'call_function'
and
'matmul'
in
node
.
name
:
for
n
in
node
.
args
:
if
n
in
not_contiguous_list
:
# matmul won't change origin tensor, but create a tmp copy
mem
+=
_get_output_node_size
(
n
)
elif
node
.
op
==
'call_module'
:
for
n
in
node
.
args
:
if
n
in
not_contiguous_list
:
# module will just make origin tensor to contiguous
if
delete
:
not_contiguous_list
.
remove
(
n
)
elif
node
.
op
==
'call_method'
and
any
(
i
in
node
.
name
for
i
in
not_contiguous_ops
):
if
node
not
in
not_contiguous_list
:
not_contiguous_list
.
append
(
node
)
elif
any
(
i
in
node
.
args
for
i
in
not_contiguous_list
):
if
node
not
in
not_contiguous_list
:
not_contiguous_list
.
append
(
node
)
return
mem
def
_estimate_inference_mem
(
gm
:
torch
.
fx
.
GraphModule
):
act_memory
=
0.0
act_memory_peak_log
=
[]
act_memory_after_node_log
=
[]
not_contiguous_list
=
[]
user_to_last_uses
=
_get_last_usr
(
list
(
gm
.
graph
.
nodes
))
_delete_free_var_from_last_use
(
user_to_last_uses
)
for
node
in
gm
.
graph
.
nodes
:
# if node is placeholder, just add the size of the node
if
node
.
op
==
'placeholder'
:
act_memory
+=
_get_meta_node_size
(
node
)
/
(
1024
**
2
)
act_memory_peak_log
.
append
(
act_memory
)
act_memory_after_node_log
.
append
(
act_memory
)
# skip output
elif
node
.
op
==
'output'
:
continue
# node is an operation, calculate tmp, output node and delete node memory
else
:
# forward memory
act_memory
+=
_get_contiguous_memory
(
node
,
not_contiguous_list
)
/
(
1024
**
2
)
act_memory
+=
_get_output_node_size
(
node
)
/
(
1024
**
2
)
# record max act memory
act_memory_peak_log
.
append
(
act_memory
)
# delete useless memory
act_memory
-=
_get_delete_node_size
(
node
,
user_to_last_uses
)
/
(
1024
**
2
)
act_memory
-=
_get_contiguous_memory
(
node
,
not_contiguous_list
,
delete
=
True
)
/
(
1024
**
2
)
act_memory_after_node_log
.
append
(
act_memory
)
print
(
"no chunk"
)
def
_get_last_usr
(
self
,
nodes
):
_print_mem_log
(
act_memory_peak_log
,
list
(
gm
.
graph
.
nodes
),
"peak"
)
node_to_last_use
:
Dict
[
Node
,
Node
]
=
{}
_print_mem_log
(
act_memory_after_node_log
,
list
(
gm
.
graph
.
nodes
),
"after"
)
user_to_last_uses
:
Dict
[
Node
,
List
[
Node
]]
=
{}
param_memory
=
parameter_size
(
gm
)
def
register_last_uses
(
n
:
Node
,
user
:
Node
):
return
act_memory
+
param_memory
,
param_memory
if
n
not
in
node_to_last_use
:
node_to_last_use
[
n
]
=
user
user_to_last_uses
.
setdefault
(
user
,
[]).
append
(
n
)
for
node
in
reversed
(
nodes
):
map_arg
(
node
.
args
,
lambda
n
:
register_last_uses
(
n
,
node
))
map_arg
(
node
.
kwargs
,
lambda
n
:
register_last_uses
(
n
,
node
))
return
user_to_last_uses
def
_get_contiguous_memory
(
self
,
node
,
not_contiguous_list
,
delete
=
False
):
mem
=
0
not_contiguous_ops
=
[
'transpose'
,
'permute'
]
if
node
.
op
==
'call_function'
and
'matmul'
in
node
.
name
:
for
n
in
node
.
args
:
if
n
in
not_contiguous_list
:
# matmul won't change origin tensor, but create a tmp copy
mem
+=
self
.
_get_output_node_size
(
n
)
elif
node
.
op
==
'call_module'
:
for
n
in
node
.
args
:
if
n
in
not_contiguous_list
:
# module will just make origin tensor to contiguous
if
delete
:
not_contiguous_list
.
remove
(
n
)
elif
node
.
op
==
'call_method'
and
any
(
i
in
node
.
name
for
i
in
not_contiguous_ops
):
if
node
not
in
not_contiguous_list
:
not_contiguous_list
.
append
(
node
)
elif
any
(
i
in
node
.
args
for
i
in
not_contiguous_list
):
if
node
not
in
not_contiguous_list
:
not_contiguous_list
.
append
(
node
)
return
mem
def
estimate_inference_mem
(
self
,
gm
:
torch
.
fx
.
GraphModule
):
act_memory
=
0.0
act_memory_peak_log
=
[]
act_memory_after_node_log
=
[]
not_contiguous_list
=
[]
user_to_last_uses
=
self
.
_get_last_usr
(
list
(
gm
.
graph
.
nodes
))
_delete_free_var_from_last_use
(
user_to_last_uses
)
for
node
in
gm
.
graph
.
nodes
:
# if node is placeholder, just add the size of the node
if
node
.
op
==
'placeholder'
:
act_memory
+=
self
.
_get_meta_node_size
(
node
)
/
(
1024
**
2
)
act_memory_peak_log
.
append
(
act_memory
)
act_memory_after_node_log
.
append
(
act_memory
)
# skip output
elif
node
.
op
==
'output'
:
continue
# node is an operation, calculate tmp, output node and delete node memory
else
:
# forward memory
act_memory
+=
self
.
_get_contiguous_memory
(
node
,
not_contiguous_list
)
/
(
1024
**
2
)
act_memory
+=
self
.
_get_output_node_size
(
node
)
/
(
1024
**
2
)
# record max act memory
act_memory_peak_log
.
append
(
act_memory
)
# delete useless memory
act_memory
-=
self
.
_get_delete_node_size
(
node
,
user_to_last_uses
)
/
(
1024
**
2
)
act_memory
-=
self
.
_get_contiguous_memory
(
node
,
not_contiguous_list
,
delete
=
True
)
/
(
1024
**
2
)
act_memory_after_node_log
.
append
(
act_memory
)
print
(
"no chunk"
)
self
.
_print_mem_log
(
act_memory_peak_log
,
list
(
gm
.
graph
.
nodes
),
"peak"
)
self
.
_print_mem_log
(
act_memory_after_node_log
,
list
(
gm
.
graph
.
nodes
),
"after"
)
param_memory
=
parameter_size
(
gm
)
return
act_memory
+
param_memory
,
param_memory
def
_get_chunk_ratio
(
node
,
chunk_dim
,
chunk_size
):
def
_get_chunk_ratio
(
self
,
node
,
chunk_dim
,
chunk_size
):
shape
=
node
.
meta
[
'tensor_meta'
].
shape
shape
=
node
.
meta
[
'tensor_meta'
].
shape
chunk_ratio
=
float
(
chunk_size
)
/
shape
[
chunk_dim
]
chunk_ratio
=
float
(
chunk_size
)
/
shape
[
chunk_dim
]
return
chunk_ratio
return
chunk_ratio
def
_get_chunk_delete_node_size
(
self
,
user
,
user_to_last_uses
,
chunk_ratio
,
node_list
,
start_node
,
end_node
):
if
user
.
op
in
(
'placeholder'
,
'output'
):
return
0
nodes_to_delete
=
user_to_last_uses
.
get
(
user
,
[])
delete_size
=
0
for
n
in
nodes_to_delete
:
node_idx
=
_find_idx_by_name
(
n
.
name
,
node_list
)
if
start_node
<=
node_idx
<
end_node
:
delete_size
+=
self
.
_get_output_node_size
(
n
)
*
chunk_ratio
return
delete_size
def
_get_chunk_delete_node_size
(
user
,
user_to_last_uses
,
chunk_ratio
,
node_list
,
start_node
,
end_node
):
def
_print_mem_log
(
self
,
log
,
nodes
,
title
=
None
):
if
user
.
op
in
(
'placeholder'
,
'output'
):
if
title
:
return
0
print
(
title
)
nodes_to_delete
=
user_to_last_uses
.
get
(
user
,
[])
for
idx
,
(
l
,
n
)
in
enumerate
(
zip
(
log
,
nodes
)):
delete_size
=
0
print
(
"%s:%.2f
\t
"
%
(
n
.
name
,
l
),
end
=
''
)
for
n
in
nodes_to_delete
:
if
(
idx
+
1
)
%
3
==
0
:
node_idx
=
_find_idx_by_name
(
n
.
name
,
node_list
)
print
(
""
)
if
start_node
<=
node_idx
<
end_node
:
print
(
"
\n
"
)
delete_size
+=
_get_output_node_size
(
n
)
*
chunk_ratio
return
delete_size
def
estimate_chunk_inference_mem
(
self
,
gm
:
torch
.
fx
.
GraphModule
,
start_nodes
,
end_nodes
,
chunk_dims
,
chunk_sizes
):
act_memory
=
0.0
def
_print_mem_log
(
log
,
nodes
,
title
=
None
):
act_memory_peak_log
=
[]
if
title
:
act_memory_after_node_log
=
[]
print
(
title
)
not_contiguous_list
=
[]
for
idx
,
(
l
,
n
)
in
enumerate
(
zip
(
log
,
nodes
)):
user_to_last_uses
=
self
.
_get_last_usr
(
list
(
gm
.
graph
.
nodes
))
print
(
"%s:%.2f
\t
"
%
(
n
.
name
,
l
),
end
=
''
)
_delete_free_var_from_last_use
(
user_to_last_uses
)
if
(
idx
+
1
)
%
3
==
0
:
within_chunk
=
False
print
(
""
)
region_idx
=
0
print
(
"
\n
"
)
chunk_ratio
=
1
# use it to estimate chunk mem
node_list
=
list
(
gm
.
graph
.
nodes
)
def
_estimate_chunk_inference_mem
(
gm
:
torch
.
fx
.
GraphModule
,
start_nodes
,
end_nodes
,
chunk_dims
,
chunk_sizes
):
for
idx
,
node
in
enumerate
(
node_list
):
act_memory
=
0.0
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
act_memory_peak_log
=
[]
if
idx
in
start_nodes
:
act_memory_after_node_log
=
[]
within_chunk
=
True
not_contiguous_list
=
[]
chunk_ratio
=
self
.
_get_chunk_ratio
(
node
,
chunk_dims
[
region_idx
],
chunk_sizes
[
region_idx
])
user_to_last_uses
=
_get_last_usr
(
list
(
gm
.
graph
.
nodes
))
act_memory
+=
self
.
_get_output_node_size
(
node_list
[
end_nodes
[
region_idx
]])
/
(
1024
**
2
)
_delete_free_var_from_last_use
(
user_to_last_uses
)
within_chunk
=
False
# if node is placeholder, just add the size of the node
region_idx
=
0
if
node
.
op
==
'placeholder'
:
chunk_ratio
=
1
# use it to estimate chunk mem
act_memory
+=
self
.
_get_meta_node_size
(
node
)
*
chunk_ratio
/
(
1024
**
2
)
node_list
=
list
(
gm
.
graph
.
nodes
)
act_memory_peak_log
.
append
(
act_memory
)
# skip output
for
idx
,
node
in
enumerate
(
node_list
):
elif
node
.
op
==
'output'
:
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
continue
if
idx
in
start_nodes
:
# node is an operation, calculate tmp, output node and delete node memory
within_chunk
=
True
chunk_ratio
=
_get_chunk_ratio
(
node
,
chunk_dims
[
region_idx
],
chunk_sizes
[
region_idx
])
act_memory
+=
_get_output_node_size
(
node_list
[
end_nodes
[
region_idx
]])
/
(
1024
**
2
)
# if node is placeholder, just add the size of the node
if
node
.
op
==
'placeholder'
:
act_memory
+=
_get_meta_node_size
(
node
)
*
chunk_ratio
/
(
1024
**
2
)
act_memory_peak_log
.
append
(
act_memory
)
# skip output
elif
node
.
op
==
'output'
:
continue
# node is an operation, calculate tmp, output node and delete node memory
else
:
# forward memory
# TODO: permute will create a tmp copy if not contiguous
act_memory
+=
_get_contiguous_memory
(
node
,
not_contiguous_list
)
*
chunk_ratio
/
(
1024
**
2
)
act_memory
+=
_get_output_node_size
(
node
)
*
chunk_ratio
/
(
1024
**
2
)
# record max act memory
act_memory_peak_log
.
append
(
act_memory
)
# delete useless memory
act_memory
-=
_get_contiguous_memory
(
node
,
not_contiguous_list
,
delete
=
True
)
*
chunk_ratio
/
(
1024
**
2
)
if
within_chunk
:
act_memory
-=
_get_chunk_delete_node_size
(
node
,
user_to_last_uses
,
chunk_ratio
,
node_list
,
start_nodes
[
region_idx
],
end_nodes
[
region_idx
])
/
(
1024
**
2
)
else
:
else
:
act_memory
-=
_get_delete_node_size
(
node
,
user_to_last_uses
)
/
(
1024
**
2
)
# forward memory
# TODO: permute will create a tmp copy if not contiguous
act_memory
+=
self
.
_get_contiguous_memory
(
node
,
not_contiguous_list
)
*
chunk_ratio
/
(
1024
**
2
)
act_memory
+=
self
.
_get_output_node_size
(
node
)
*
chunk_ratio
/
(
1024
**
2
)
# record max act memory
act_memory_peak_log
.
append
(
act_memory
)
# delete useless memory
act_memory
-=
self
.
_get_contiguous_memory
(
node
,
not_contiguous_list
,
delete
=
True
)
*
chunk_ratio
/
(
1024
**
2
)
if
within_chunk
:
act_memory
-=
self
.
_get_chunk_delete_node_size
(
node
,
user_to_last_uses
,
chunk_ratio
,
node_list
,
start_nodes
[
region_idx
],
end_nodes
[
region_idx
])
/
(
1024
**
2
)
else
:
act_memory
-=
self
.
_get_delete_node_size
(
node
,
user_to_last_uses
)
/
(
1024
**
2
)
if
idx
in
end_nodes
:
act_memory
-=
self
.
_get_output_node_size
(
node
)
*
chunk_ratio
/
(
1024
**
2
)
within_chunk
=
False
chunk_ratio
=
1
region_idx
+=
1
if
idx
in
end_nodes
:
act_memory_after_node_log
.
append
(
act_memory
)
act_memory
-=
_get_output_node_size
(
node
)
*
chunk_ratio
/
(
1024
**
2
)
within_chunk
=
False
chunk_ratio
=
1
region_idx
+=
1
act_memory_after_node_log
.
append
(
act_memory
)
print
(
"chunk"
)
print
(
"chunk"
)
_print_mem_log
(
act_memory_peak_log
,
node_list
,
"peak"
)
self
.
_print_mem_log
(
act_memory_peak_log
,
node_list
,
"peak"
)
_print_mem_log
(
act_memory_after_node_log
,
node_list
,
"after"
)
self
.
_print_mem_log
(
act_memory_after_node_log
,
node_list
,
"after"
)
param_memory
=
parameter_size
(
gm
)
param_memory
=
parameter_size
(
gm
)
return
act_memory
+
param_memory
,
param_memory
return
act_memory
+
param_memory
,
param_memory
def
_gen_chunk_slice_dim
(
chunk_dim
,
chunk_idx_name
,
shape
):
def
_gen_chunk_slice_dim
(
chunk_dim
,
chunk_idx_name
,
shape
):
...
@@ -695,8 +694,9 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
...
@@ -695,8 +694,9 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
within_chunk_region
=
False
within_chunk_region
=
False
node_list
=
list
(
nodes
)
node_list
=
list
(
nodes
)
_estimate_chunk_inference_mem
(
meta_graph
,
chunk_starts
,
chunk_ends
,
[
1
],
[
2
])
memory_estimator
=
MemoryEstimator
()
_estimate_inference_mem
(
meta_graph
)
memory_estimator
.
estimate_chunk_inference_mem
(
meta_graph
,
chunk_starts
,
chunk_ends
,
[
1
],
[
2
])
memory_estimator
.
estimate_inference_mem
(
meta_graph
)
node_index_tracer
=
NodeIndexTracer
(
meta_graph
)
node_index_tracer
=
NodeIndexTracer
(
meta_graph
)
node_index_tracer
.
trace_node_idx
()
node_index_tracer
.
trace_node_idx
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment