Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c36dba07
Commit
c36dba07
authored
Nov 14, 2022
by
oahzxl
Browse files
finish basic index tracer
parent
1607d04e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
124 additions
and
9 deletions
+124
-9
chunk_codegen.py
chunk_codegen.py
+124
-9
No files found.
chunk_codegen.py
View file @
c36dba07
...
...
@@ -25,6 +25,7 @@ class NodeIndexTracer(object):
self
.
nodes_list
=
list
(
gm
.
graph
.
nodes
)
self
.
idx_trace_list
=
[{
'idx'
:
[],
'compute'
:
[]}
for
_
in
range
(
len
(
self
.
nodes_list
))]
self
.
idx_trace_equal
=
[]
self
.
idx_view_list
=
[]
self
.
idx_count
=
1
def
add_index
(
self
):
...
...
@@ -35,7 +36,7 @@ class NodeIndexTracer(object):
_
,
compute_from
=
self
.
find_trace_from_node
(
node_from
)
idx_to
,
compute_to
=
self
.
find_trace_from_node
(
node_to
)
for
i
in
compute_from
:
if
i
in
idx_to
:
if
i
in
idx_to
and
i
not
in
compute_to
:
compute_to
.
append
(
i
)
def
mark_idx_equal
(
self
,
idx1
,
idx2
):
...
...
@@ -47,7 +48,8 @@ class NodeIndexTracer(object):
dim
=
[
dim
]
for
d
in
dim
:
cur_idx
=
input_node_idx_trace
[
d
]
self
.
idx_trace_list
[
idx
][
'compute'
].
append
(
cur_idx
)
if
cur_idx
not
in
self
.
idx_trace_list
[
idx
][
'compute'
]:
self
.
idx_trace_list
[
idx
][
'compute'
].
append
(
cur_idx
)
def
find_trace_from_node
(
self
,
node
):
node_idx
=
_find_idx_by_name
(
node
.
name
,
self
.
nodes_list
)
...
...
@@ -56,8 +58,11 @@ class NodeIndexTracer(object):
def
find_idx_trace_from_node
(
self
,
node
):
node_idx
=
_find_idx_by_name
(
node
.
name
,
self
.
nodes_list
)
node_idx_trace
=
self
.
idx_trace_list
[
node_idx
][
'idx'
]
return
node_idx_trace
return
self
.
idx_trace_list
[
node_idx
][
'idx'
]
def
find_compute_trace_from_node
(
self
,
node
):
node_idx
=
_find_idx_by_name
(
node
.
name
,
self
.
nodes_list
)
return
self
.
idx_trace_list
[
node_idx
][
'compute'
]
def
assign_index_as_input
(
self
,
node
,
node_idx
):
input_node_idx
=
_find_idx_by_name
(
node
.
args
[
0
].
name
,
self
.
nodes_list
)
...
...
@@ -82,6 +87,18 @@ class NodeIndexTracer(object):
new_idx_trace
[
tranpose_dim
[
1
]]
=
input_node_idx_trace
[
tranpose_dim
[
0
]]
self
.
idx_trace_list
[
node_idx
][
'idx'
]
=
new_idx_trace
self
.
inherit_computation
(
node
.
args
[
0
],
node
)
def
assign_permute_index
(
self
,
node
,
node_idx
):
permute_dim
=
node
.
args
[
1
:]
input_node_idx_trace
=
self
.
find_idx_trace_from_node
(
node
.
args
[
0
])
new_idx_trace
=
copy
.
deepcopy
(
input_node_idx_trace
)
for
idx
,
d
in
enumerate
(
permute_dim
):
new_idx_trace
[
idx
]
=
input_node_idx_trace
[
d
]
self
.
idx_trace_list
[
node_idx
][
'idx'
]
=
new_idx_trace
self
.
inherit_computation
(
node
.
args
[
0
],
node
)
def
assign_linear_index
(
self
,
node
,
node_idx
):
input_node
,
weight
,
bias
=
node
.
args
...
...
@@ -100,10 +117,99 @@ class NodeIndexTracer(object):
bias_idx_trace
=
self
.
find_idx_trace_from_node
(
bias
)
self
.
mark_idx_equal
(
input_node_idx_trace
[
-
1
],
bias_idx_trace
[
0
])
def
assign_matmul_index
(
self
,
node
,
node_idx
):
matmul_left
,
matmul_right
=
node
.
args
matmul_left_idx_trace
=
self
.
find_idx_trace_from_node
(
matmul_left
)
matmul_right_idx_trace
=
self
.
find_idx_trace_from_node
(
matmul_right
)
assert
(
len
(
matmul_left_idx_trace
)
==
len
(
matmul_right_idx_trace
))
new_idx_trace
=
copy
.
deepcopy
(
matmul_left_idx_trace
)
new_idx_trace
[
-
1
]
=
matmul_right_idx_trace
[
-
1
]
self
.
idx_trace_list
[
node_idx
][
'idx'
]
=
new_idx_trace
self
.
inherit_computation
(
matmul_left
,
node
)
self
.
inherit_computation
(
matmul_right
,
node
)
self
.
mark_computation
(
node
,
node_idx
,
[
-
1
])
self
.
mark_idx_equal
(
matmul_left_idx_trace
[
-
1
],
matmul_right_idx_trace
[
-
2
])
def
assign_layernorm_index
(
self
,
node
,
idx
):
self
.
assign_index_as_input
(
node
,
idx
)
self
.
inherit_computation
(
node
.
args
[
0
],
node
)
self
.
mark_computation
(
node
,
idx
,
[
-
1
,
-
2
])
def
assign_elementwise_index
(
self
,
node
,
idx
):
self
.
assign_index_as_input
(
node
,
idx
)
for
node_in
in
node
.
args
:
if
type
(
node_in
)
not
in
(
int
,
float
):
self
.
inherit_computation
(
node_in
,
node
)
def
assign_softmax_index
(
self
,
node
,
idx
):
self
.
assign_index_as_input
(
node
,
idx
)
self
.
mark_computation
(
node
,
idx
,
[
node
.
kwargs
[
'dim'
]])
def
assign_view_reshape_index
(
self
,
node
,
node_idx
):
# get data, turn into number
origin_node
=
node
.
args
[
0
]
origin_shape
=
origin_node
.
meta
[
'tensor_meta'
].
shape
target_shape
=
[]
for
i
in
range
(
1
,
len
(
node
.
args
)):
if
isinstance
(
node
.
args
[
i
],
int
):
target_shape
.
append
(
node
.
args
[
i
])
else
:
target_shape
.
append
(
node
.
args
[
i
].
meta
[
'fwd_out'
][
0
])
# compute the value of -1
if
-
1
in
target_shape
:
origin_product
=
1
for
i
in
origin_shape
:
origin_product
*=
i
target_product
=
-
1
for
i
in
target_shape
:
target_product
*=
i
shape_idx
=
target_shape
.
index
(
-
1
)
target_shape
[
shape_idx
]
=
origin_product
//
target_product
# determine changed dim
len_diff
=
len
(
origin_shape
)
-
len
(
target_shape
)
if
len_diff
==
1
:
# dim merge
dim_equal
=
[
i
==
j
for
i
,
j
in
zip
(
origin_shape
[:
-
1
],
target_shape
)]
dim_to
=
[
dim_equal
.
index
(
False
)]
dim_from
=
[
dim_equal
.
index
(
False
),
dim_equal
.
index
(
False
)
+
1
]
elif
len_diff
==
-
1
:
# dim expand
dim_equal
=
[
i
==
j
for
i
,
j
in
zip
(
origin_shape
,
target_shape
[:
-
1
])]
dim_from
=
[
dim_equal
.
index
(
False
)]
dim_to
=
[
dim_equal
.
index
(
False
),
dim_equal
.
index
(
False
)
+
1
]
else
:
raise
NotImplementedError
(
"shape"
+
str
(
origin_shape
)
+
'and'
+
str
(
target_shape
)
+
"view not implemented"
)
# get new index
origin_trace
=
self
.
find_idx_trace_from_node
(
origin_node
)
new_trace
=
copy
.
deepcopy
(
origin_trace
)
dim_from
.
reverse
()
for
i
in
dim_from
:
new_trace
.
pop
(
i
)
for
i
in
dim_to
:
new_trace
.
insert
(
i
,
self
.
add_index
())
self
.
idx_trace_list
[
node_idx
][
'idx'
]
=
new_trace
# inherit computation
self
.
inherit_computation
(
origin_node
,
node
)
compute_log
=
self
.
find_compute_trace_from_node
(
origin_node
)
for
i
in
dim_from
:
if
origin_trace
[
i
]
in
compute_log
:
for
j
in
dim_to
:
self
.
mark_computation
(
node
,
node_idx
,
[
j
])
break
# log view
view_dict
=
{
"idx_from"
:
[
origin_trace
[
i
]
for
i
in
dim_from
],
"dim_from"
:
dim_from
,
"idx_to"
:
[
new_trace
[
i
]
for
i
in
dim_to
],
"dim_to"
:
dim_to
}
self
.
idx_view_list
.
append
(
view_dict
)
def
trace_node_idx
(
self
):
for
idx
,
node
in
enumerate
(
self
.
nodes_list
):
if
node
.
op
==
'placeholder'
:
...
...
@@ -111,15 +217,21 @@ class NodeIndexTracer(object):
elif
node
.
op
==
'call_method'
:
if
'transpose'
in
node
.
name
:
self
.
assign_transpose_index
(
node
,
idx
)
elif
'view'
in
node
.
name
:
pass
elif
'permute'
in
node
.
name
:
pass
self
.
assign_permute_index
(
node
,
idx
)
elif
'view'
in
node
.
name
or
'reshape'
in
node
.
name
:
self
.
assign_view_reshape_index
(
node
,
idx
)
else
:
raise
NotImplementedError
(
node
.
name
,
"method not implemented yet!"
)
elif
node
.
op
==
'call_function'
:
if
'linear'
in
node
.
name
:
self
.
assign_linear_index
(
node
,
idx
)
elif
'matmul'
in
node
.
name
:
self
.
assign_matmul_index
(
node
,
idx
)
elif
'softmax'
in
node
.
name
:
self
.
assign_softmax_index
(
node
,
idx
)
elif
any
(
n
in
node
.
name
for
n
in
[
'mul'
,
'add'
,
'sigmoid'
,
'relu'
]):
self
.
assign_elementwise_index
(
node
,
idx
)
elif
'getattr'
in
node
.
name
:
continue
# get attr like shape
elif
'getitem'
in
node
.
name
:
...
...
@@ -127,12 +239,14 @@ class NodeIndexTracer(object):
else
:
raise
NotImplementedError
(
node
.
name
,
"function not implemented yet!"
)
elif
node
.
op
==
'call_module'
:
if
'layernorm'
in
node
.
name
:
if
any
(
n
in
node
.
name
for
n
in
[
'layernorm'
,
'norm'
])
:
self
.
assign_layernorm_index
(
node
,
idx
)
else
:
raise
NotImplementedError
(
node
.
name
,
"module not implemented yet!"
)
elif
node
.
op
==
'get_attr'
:
self
.
assign_all_index
(
node
,
idx
)
# get param
elif
node
.
op
==
'output'
:
continue
else
:
raise
NotImplementedError
(
node
.
op
,
"op not implemented yet!"
)
...
...
@@ -297,6 +411,7 @@ def _estimate_chunk_inference_mem(gm: torch.fx.GraphModule, start_nodes, end_nod
# node is an operation, calculate tmp, output node and delete node memory
else
:
# forward memory
# TODO: permute will create a tmp copy if not contiguous
act_memory
+=
_get_contiguous_memory
(
node
,
not_contiguous_list
)
*
chunk_ratio
/
(
1024
**
2
)
act_memory
+=
_get_output_node_size
(
node
)
*
chunk_ratio
/
(
1024
**
2
)
# record max act memory
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment