Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
d361d533
Commit
d361d533
authored
Dec 21, 2022
by
oahzxl
Browse files
refactor flow tracer
parent
d734529a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
240 additions
and
54 deletions
+240
-54
chunk_codegen.py
chunk_codegen.py
+231
-52
evoformer/evoformer.py
evoformer/evoformer.py
+9
-2
No files found.
chunk_codegen.py
View file @
d361d533
...
@@ -139,7 +139,13 @@ class IndexTracer(object):
...
@@ -139,7 +139,13 @@ class IndexTracer(object):
node_from_idx
=
_find_idx_by_name
(
node_from
.
name
,
self
.
nodes_list
)
node_from_idx
=
_find_idx_by_name
(
node_from
.
name
,
self
.
nodes_list
)
if
init
:
if
init
:
node_to_trace
[
"source"
][
node_to_dim
]
=
{}
node_to_trace
[
"source"
][
node_to_dim
]
=
{}
node_to_trace
[
"source"
][
node_to_dim
][
node_from_idx
]
=
node_from_dim
# add dim to cur new source
if
node_from_idx
not
in
node_to_trace
[
"source"
][
node_to_dim
]:
node_to_trace
[
"source"
][
node_to_dim
][
node_from_idx
]
=
[
node_from_dim
]
else
:
if
node_from_dim
not
in
node_to_trace
[
"source"
][
node_to_dim
][
node_from_idx
]:
node_to_trace
[
"source"
][
node_to_dim
][
node_from_idx
].
append
(
node_from_dim
)
# update inputs source
node_to_trace
[
"source"
][
node_to_dim
].
update
(
node_to_trace
[
"source"
][
node_to_dim
].
update
(
node_from_trace
[
"source"
][
node_from_dim
]
node_from_trace
[
"source"
][
node_from_dim
]
)
)
...
@@ -654,7 +660,7 @@ class IndexTracer(object):
...
@@ -654,7 +660,7 @@ class IndexTracer(object):
end_node_trace_source
.
items
(),
key
=
lambda
d
:
d
[
0
],
reverse
=
True
end_node_trace_source
.
items
(),
key
=
lambda
d
:
d
[
0
],
reverse
=
True
)
)
for
node_idx
,
node_dim
in
sorted_source
:
for
node_idx
,
node_dim
in
sorted_source
:
if
node_idx
==
start_node_idx
and
node_dim
==
start
_dim
:
if
node_idx
==
start_node_idx
and
start_dim
in
node
_dim
:
return
True
return
True
# it means we meet a node outside the loop, and the node is not input node
# it means we meet a node outside the loop, and the node is not input node
if
node_idx
<
start_idx
:
if
node_idx
<
start_idx
:
...
@@ -694,12 +700,12 @@ class IndexTracer(object):
...
@@ -694,12 +700,12 @@ class IndexTracer(object):
for
node_dim
in
range
(
len
(
_get_node_shape
(
node
))):
for
node_dim
in
range
(
len
(
_get_node_shape
(
node
))):
if
(
if
(
input_node_idx
in
node_trace_source
[
node_dim
]
input_node_idx
in
node_trace_source
[
node_dim
]
and
node_trace_source
[
node_dim
][
input_node_idx
]
==
input_dim
and
input_dim
in
node_trace_source
[
node_dim
][
input_node_idx
]
):
):
return
node_dim
return
node_dim
return
None
return
None
def
check_index_duplicate
(
self
,
chunk_infos
):
def
check_index_duplicate
(
self
,
chunk_infos
,
return_dim
=
False
):
input_dim_after_node
=
{}
input_dim_after_node
=
{}
for
input_node_idx
,
input_node
in
enumerate
(
chunk_infos
[
"inputs"
]):
for
input_node_idx
,
input_node
in
enumerate
(
chunk_infos
[
"inputs"
]):
for
k
,
v
in
chunk_infos
[
"inputs_dim"
][
input_node_idx
].
items
():
for
k
,
v
in
chunk_infos
[
"inputs_dim"
][
input_node_idx
].
items
():
...
@@ -713,17 +719,30 @@ class IndexTracer(object):
...
@@ -713,17 +719,30 @@ class IndexTracer(object):
if
_is_non_compute_node_except_placeholder
(
node
):
if
_is_non_compute_node_except_placeholder
(
node
):
continue
continue
count
=
0
count
=
0
duplicate_dims
=
[]
node_trace_source
=
self
.
_find_source_trace_from_node
(
node
)
node_trace_source
=
self
.
_find_source_trace_from_node
(
node
)
for
node_dim
in
range
(
len
(
_get_node_shape
(
node
))):
for
node_dim
in
range
(
len
(
_get_node_shape
(
node
))):
duplicate_dim
=
[]
duplicate_flag
=
False
dim_source
=
node_trace_source
[
node_dim
]
dim_source
=
node_trace_source
[
node_dim
]
for
k
,
v
in
dim_source
.
items
():
for
k
,
v
in
dim_source
.
items
():
if
chunk_infos
[
"region"
][
0
]
<=
k
<=
chunk_infos
[
"region"
][
1
]:
if
chunk_infos
[
"region"
][
0
]
<=
k
<=
chunk_infos
[
"region"
][
1
]:
if
k
in
input_dim_after_node
and
input_dim_after_node
[
k
]
==
v
:
if
k
in
input_dim_after_node
and
input_dim_after_node
[
k
]
in
v
:
count
+=
1
duplicate_flag
=
True
break
duplicate_dim
.
append
((
k
,
v
))
duplicate_dims
.
append
(
duplicate_dim
)
if
duplicate_flag
:
count
+=
1
if
count
>
1
:
if
count
>
1
:
return
False
if
return_dim
:
return
True
return
False
,
duplicate_dims
else
:
return
False
if
return_dim
:
return
True
,
None
else
:
return
True
...
@@ -857,43 +876,45 @@ class FlowTracer(object):
...
@@ -857,43 +876,45 @@ class FlowTracer(object):
flow_block
=
True
flow_block
=
True
return
flow_block
,
chunk_info
return
flow_block
,
chunk_info
for
idx
in
range
(
start_idx
,
end_idx
+
1
):
# for idx in range(start_idx, end_idx + 1):
node
=
self
.
node_list
[
idx
]
# node = self.node_list[idx]
mix_flow_node
=
self
.
_get_flow_mix_node
(
node
)
# mix_flow_node = self._get_flow_mix_node(node)
if
mix_flow_node
is
None
:
# if mix_flow_node is None:
continue
# continue
# if there is a flow mix, op must be in [mul, add, matmul]
# # if there is a flow mix, op must be in [mul, add, matmul]
# element-wise op requires dim to be equal in every dim
# # element-wise op requires dim to be equal in every dim
if
any
(
n
in
node
.
name
for
n
in
[
"mul"
,
"add"
]):
# if any(n in node.name for n in ["mul", "add"]):
for
i
in
node
.
args
:
# for i in node.args:
if
type
(
i
)
==
type
(
mix_flow_node
)
and
i
!=
mix_flow_node
:
# if type(i) == type(mix_flow_node) and i != mix_flow_node:
main_flow_var
=
i
# main_flow_var = i
# if mix flow is a broadcast in chunk dim,
# # if mix flow is a broadcast in chunk dim,
# TODO: need to move that flow out of the chunk
# # TODO: need to move that flow out of the chunk
mix_flow_node_dim
=
index_tracer
.
get_node_chunk_dim
(
# mix_flow_node_dim = index_tracer.get_node_chunk_dim(
self
.
node_list
[
end_idx
],
end_dim
,
node
# self.node_list[end_idx], end_dim, node
)
# )
if
mix_flow_node_dim
is
None
:
# # TODO: we need to loop every dim
flow_block
=
True
# if isinstance(mix_flow_node_dim, list):
break
# mix_flow_node_dim = mix_flow_node_dim[0]
if
_get_node_shape
(
mix_flow_node
)[
mix_flow_node_dim
]
==
1
:
# if mix_flow_node_dim is None:
flow_block
=
False
# flow_block = True
for
i
in
self
.
_get_same_flow_node
(
# break
chunk_info
[
"inputs"
],
mix_flow_node
# if _get_node_shape(mix_flow_node)[mix_flow_node_dim] == 1:
):
# flow_block = False
chunk_info
[
"inputs"
].
remove
(
i
)
# for i in self._get_same_flow_node(
# else, we need to chunk mix var as well
# chunk_info["inputs"], mix_flow_node
else
:
# ):
# TODO chunk another value
# chunk_info["inputs"].remove(i)
flow_block
=
True
# # else, we need to chunk mix var as well
break
# else:
else
:
# # TODO chunk another value
raise
NotImplementedError
(
"%s not implemented"
%
node
.
name
)
# flow_block = True
# break
if
flow_block
:
# else:
flow_block
=
True
# raise NotImplementedError("%s not implemented" % node.name)
return
flow_block
,
chunk_info
# if flow_block:
# flow_block = True
# return flow_block, chunk_info
inputs_dim
=
[]
inputs_dim
=
[]
remove_inputs
=
[]
remove_inputs
=
[]
...
@@ -908,6 +929,9 @@ class FlowTracer(object):
...
@@ -908,6 +929,9 @@ class FlowTracer(object):
dim
=
index_tracer
.
get_node_chunk_dim
(
dim
=
index_tracer
.
get_node_chunk_dim
(
self
.
node_list
[
end_idx
],
end_dim
,
input_node
self
.
node_list
[
end_idx
],
end_dim
,
input_node
)
)
# TODO: we need to loop every dim
if
isinstance
(
dim
,
list
):
dim
=
dim
[
0
]
elif
user_idx
==
end_idx
:
elif
user_idx
==
end_idx
:
dim
=
end_dim
dim
=
end_dim
# n has relation with chunk dim
# n has relation with chunk dim
...
@@ -921,6 +945,8 @@ class FlowTracer(object):
...
@@ -921,6 +945,8 @@ class FlowTracer(object):
for
i
in
remove_inputs
:
for
i
in
remove_inputs
:
if
i
in
chunk_info
[
"inputs"
]:
if
i
in
chunk_info
[
"inputs"
]:
chunk_info
[
"inputs"
].
remove
(
i
)
chunk_info
[
"inputs"
].
remove
(
i
)
duplicate_result
,
duplicate_dim
=
index_tracer
.
check_index_duplicate
(
chunk_info
,
return_dim
=
True
)
# we need to log input nodes to avoid deleteing them in the loop
# we need to log input nodes to avoid deleteing them in the loop
non_chunk_inputs
=
_find_chunk_all_input_nodes
(
non_chunk_inputs
=
_find_chunk_all_input_nodes
(
...
@@ -932,6 +958,150 @@ class FlowTracer(object):
...
@@ -932,6 +958,150 @@ class FlowTracer(object):
return
flow_block
,
chunk_info
return
flow_block
,
chunk_info
def
_assgin_single_node_flow
(
self
,
arg_node
,
start_idx
,
end_idx
,
inputs
,
index_tracer
,
cur_node_dim
,
cur_node_compute
,
cur_node_source
,
cur_node_fix_dim
,
all_node_info
,
next_node_list
):
arg_idx
=
_find_idx_by_name
(
arg_node
.
name
,
index_tracer
.
nodes_list
)
# arg in chunk range or be inputs
if
not
(
start_idx
<=
arg_idx
<
end_idx
):
return
True
# find arg dim
if
cur_node_dim
is
not
None
:
# dim is computed
if
arg_idx
in
cur_node_compute
[
cur_node_dim
]:
return
False
if
arg_idx
not
in
cur_node_source
[
cur_node_dim
]:
arg_dim
=
None
else
:
arg_dim
=
cur_node_source
[
cur_node_dim
][
arg_idx
][
0
]
else
:
arg_dim
=
None
# get fix dim
arg_fix_dim
=
[]
if
cur_node_dim
is
not
None
:
for
i
in
cur_node_fix_dim
:
fix_dim_source
=
cur_node_source
[
i
]
if
arg_idx
in
fix_dim_source
:
arg_fix_dim
.
append
(
fix_dim_source
[
arg_idx
][
0
])
# if already in node_info, arg dim must be same
if
arg_node
in
all_node_info
:
if
all_node_info
[
arg_node
]
!=
arg_dim
:
return
False
all_node_info
[
arg_node
][
'fix_dim'
]
=
list
(
set
(
all_node_info
[
arg_node
][
'fix_dim'
]
+
arg_fix_dim
))
# else add it to list
else
:
all_node_info
[
arg_node
]
=
{
'chunk_dim'
:
arg_dim
,
'fix_dim'
:
arg_fix_dim
}
next_node_list
.
append
(
arg_node
)
return
True
def
flow_search
(
self
,
start_idx
,
start_dim
,
end_idx
,
end_dim
,
index_tracer
:
IndexTracer
):
inputs
,
outputs
=
_find_chunk_compute_input_and_output_nodes
(
self
.
node_list
[
start_idx
:
end_idx
+
1
]
)
# only single ouput
if
len
(
outputs
)
>
1
:
return
None
cur_node_list
=
[
index_tracer
.
nodes_list
[
end_idx
]]
# start from the last node
all_node_info
=
{
cur_node_list
[
0
]:
{
'chunk_dim'
:
end_dim
,
'fix_dim'
:
[]}}
while
len
(
cur_node_list
)
>
0
:
next_node_list
=
[]
for
cur_node
in
cur_node_list
:
# get cur node info
cur_node_chunk_dim
=
all_node_info
[
cur_node
][
'chunk_dim'
]
cur_node_fix_dim
=
all_node_info
[
cur_node
][
'fix_dim'
]
cur_node_idx
=
_find_idx_by_name
(
cur_node
.
name
,
index_tracer
.
nodes_list
)
if
cur_node_chunk_dim
:
cur_node_compute
=
index_tracer
.
_find_compute_trace_from_node
(
cur_node
)
cur_node_source
=
index_tracer
.
_find_source_trace_from_node
(
cur_node
)
else
:
cur_node_compute
=
cur_node_source
=
None
# get all valid args
arg_list
=
[]
for
arg
in
cur_node
.
args
:
if
type
(
arg
)
!=
type
(
cur_node
):
continue
if
_is_non_compute_node
(
arg
):
continue
arg_list
.
append
(
arg
)
flow_flag
=
self
.
_assgin_single_node_flow
(
arg
,
start_idx
,
end_idx
,
inputs
,
index_tracer
,
cur_node_chunk_dim
,
cur_node_compute
,
cur_node_source
,
cur_node_fix_dim
,
all_node_info
,
next_node_list
)
if
flow_flag
==
False
:
return
None
if
len
(
arg_list
)
==
2
:
if
any
(
i
in
cur_node
.
name
for
i
in
[
"add"
,
"mul"
]):
for
arg
in
arg_list
:
if
not
(
start_idx
<=
_find_idx_by_name
(
arg
.
name
,
index_tracer
.
nodes_list
)
<
end_idx
):
continue
arg_chunk_dim
=
all_node_info
[
arg
][
'chunk_dim'
]
arg_fix_dim
=
all_node_info
[
arg
][
'fix_dim'
]
arg_shape
=
_get_node_shape
(
arg
)
# add all dim as fix dim except chunk dim
for
i
,
shape
in
enumerate
(
arg_shape
):
if
shape
!=
1
and
i
!=
cur_node_chunk_dim
:
if
i
==
arg_chunk_dim
:
return
None
if
i
not
in
arg_fix_dim
:
arg_fix_dim
.
append
(
i
)
elif
"einsum"
in
cur_node
.
name
:
pass
elif
"matmul"
in
cur_node
.
name
:
pass
else
:
raise
NotImplementedError
()
cur_node_list
=
next_node_list
inputs_dim
=
[]
remove_inputs
=
[]
for
input_node
in
inputs
:
input_dict
=
{}
for
user
in
input_node
.
users
.
keys
():
if
_is_non_compute_node
(
user
):
continue
user_idx
=
_find_idx_by_name
(
user
.
name
,
self
.
node_list
)
if
start_idx
<=
user_idx
<=
end_idx
:
chunk_dim
=
all_node_info
[
user
][
'chunk_dim'
]
if
chunk_dim
is
not
None
:
input_dict
[
user_idx
]
=
chunk_dim
if
len
(
input_dict
)
==
0
:
remove_inputs
.
append
(
input_node
)
else
:
inputs_dim
.
append
(
input_dict
)
for
i
in
remove_inputs
:
if
i
in
inputs
:
inputs
.
remove
(
i
)
chunk_info
=
{
"region"
:
(
start_idx
,
end_idx
),
"inputs"
:
inputs
,
"inputs_non_chunk"
:
[],
"inputs_dim"
:
inputs_dim
,
"outputs"
:
outputs
,
"outputs_dim"
:
end_dim
,
"args"
:
{},
}
# we need to log input nodes to avoid deleteing them in the loop
non_chunk_inputs
=
_find_chunk_all_input_nodes
(
self
.
node_list
[
start_idx
:
end_idx
+
1
]
)
for
i
in
non_chunk_inputs
:
if
i
not
in
chunk_info
[
"inputs"
]:
chunk_info
[
"inputs_non_chunk"
].
append
(
i
)
return
chunk_info
class
MemoryEstimator
(
object
):
class
MemoryEstimator
(
object
):
def
__init__
(
self
,
index_tracer
:
IndexTracer
)
->
None
:
def
__init__
(
self
,
index_tracer
:
IndexTracer
)
->
None
:
...
@@ -1055,12 +1225,13 @@ class MemoryEstimator(object):
...
@@ -1055,12 +1225,13 @@ class MemoryEstimator(object):
node_source
=
self
.
index_tracer
.
_find_source_trace_from_node
(
node
)
node_source
=
self
.
index_tracer
.
_find_source_trace_from_node
(
node
)
for
(
input_node
,
input_node_dim
)
in
zip
(
chunk_inputs
,
chunk_inputs_dim
):
for
(
input_node
,
input_node_dim
)
in
zip
(
chunk_inputs
,
chunk_inputs_dim
):
for
k
,
v
in
input_node_dim
.
items
():
for
k
,
v
in
input_node_dim
.
items
():
# TODO: inherit dim should be list too, int now
inherit_dim
=
self
.
index_tracer
.
_find_inherit_dim
(
input_node
,
v
,
self
.
index_tracer
.
nodes_list
[
k
])
inherit_dim
=
self
.
index_tracer
.
_find_inherit_dim
(
input_node
,
v
,
self
.
index_tracer
.
nodes_list
[
k
])
if
k
==
_find_idx_by_name
(
node
.
name
,
self
.
index_tracer
.
nodes_list
):
if
k
==
_find_idx_by_name
(
node
.
name
,
self
.
index_tracer
.
nodes_list
):
chunk_ratio
=
float
(
chunk_size
)
/
node_shape
[
inherit_dim
]
chunk_ratio
=
float
(
chunk_size
)
/
node_shape
[
inherit_dim
]
return
chunk_ratio
return
chunk_ratio
for
dim
,
source
in
enumerate
(
node_source
):
for
dim
,
source
in
enumerate
(
node_source
):
if
k
in
source
and
source
[
k
]
==
inherit_dim
:
if
k
in
source
and
inherit_dim
in
source
[
k
]
:
chunk_ratio
=
float
(
chunk_size
)
/
node_shape
[
dim
]
chunk_ratio
=
float
(
chunk_size
)
/
node_shape
[
dim
]
return
chunk_ratio
return
chunk_ratio
return
1.
return
1.
...
@@ -1323,9 +1494,11 @@ class ChunkRegionSearch(object):
...
@@ -1323,9 +1494,11 @@ class ChunkRegionSearch(object):
continue
continue
for
start_node
,
start_trace
in
start_traces
.
items
():
for
start_node
,
start_trace
in
start_traces
.
items
():
for
start_dim
,
start_trace_idx
in
enumerate
(
start_trace
[
"idx"
]):
for
start_dim
,
start_trace_idx
in
enumerate
(
start_trace
[
"idx"
]):
# must be same trace idx
if
start_idx
==
199
and
end_idx
==
229
and
start_dim
==
2
and
end_dim
==
2
:
if
start_trace_idx
!=
end_trace_idx
:
print
(
1
)
continue
self
.
flow_tracer
.
flow_search
(
start_idx
,
start_dim
,
end_idx
,
end_dim
,
self
.
index_tracer
)
# dim size cannot be 1
# dim size cannot be 1
if
(
if
(
_get_node_shape
(
end_node
)[
end_dim
]
==
1
_get_node_shape
(
end_node
)[
end_dim
]
==
1
...
@@ -1343,10 +1516,16 @@ class ChunkRegionSearch(object):
...
@@ -1343,10 +1516,16 @@ class ChunkRegionSearch(object):
):
):
continue
continue
# detect flow meet
# detect flow meet
flow_block
,
chunk_info
=
self
.
flow_tracer
.
_detect_flow
(
# flow_block, chunk_info = self.flow_tracer._detect_flow(
# start_idx, start_dim, end_idx, end_dim, self.index_tracer
# )
# if flow_block:
# continue
# flow search
chunk_info
=
self
.
flow_tracer
.
flow_search
(
start_idx
,
start_dim
,
end_idx
,
end_dim
,
self
.
index_tracer
start_idx
,
start_dim
,
end_idx
,
end_dim
,
self
.
index_tracer
)
)
if
flow_block
:
if
chunk_info
is
None
:
continue
continue
# check index copmute
# check index copmute
if
not
self
.
index_tracer
.
check_index_duplicate
(
chunk_info
):
if
not
self
.
index_tracer
.
check_index_duplicate
(
chunk_info
):
...
...
evoformer/evoformer.py
View file @
d361d533
...
@@ -6,6 +6,13 @@ from .ops import OutProductMean
...
@@ -6,6 +6,13 @@ from .ops import OutProductMean
from
.triangle
import
PairStack
from
.triangle
import
PairStack
def
print_memory
(
init_mem
,
text
=
None
):
now_mem
=
torch
.
cuda
.
memory_allocated
()
/
1024
**
2
-
init_mem
max_mem
=
torch
.
cuda
.
max_memory_allocated
()
/
1024
**
2
-
init_mem
print
(
"%s now:%.2f max:%.2f"
%
(
""
if
text
is
None
else
text
,
now_mem
,
max_mem
))
torch
.
cuda
.
reset_peak_memory_stats
()
class
EvoformerBlock
(
nn
.
Module
):
class
EvoformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
d_node
,
d_pair
):
def
__init__
(
self
,
d_node
,
d_pair
):
...
@@ -16,9 +23,9 @@ class EvoformerBlock(nn.Module):
...
@@ -16,9 +23,9 @@ class EvoformerBlock(nn.Module):
self
.
pair_stack
=
PairStack
(
d_pair
=
d_pair
)
self
.
pair_stack
=
PairStack
(
d_pair
=
d_pair
)
def
forward
(
self
,
node
,
pair
):
def
forward
(
self
,
node
,
pair
):
node
=
node
+
self
.
msa_stack
(
node
,
pair
)
node
=
self
.
msa_stack
(
node
,
pair
)
pair
=
pair
+
self
.
communication
(
node
)
pair
=
pair
+
self
.
communication
(
node
)
pair
=
pair
+
self
.
pair_stack
(
pair
)
pair
=
self
.
pair_stack
(
pair
)
return
node
,
pair
return
node
,
pair
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment