Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
ecccc91f
Unverified
Commit
ecccc91f
authored
Jan 19, 2023
by
oahzxl
Committed by
GitHub
Jan 19, 2023
Browse files
[autochunk] support autochunk on evoformer (#2497)
parent
304f1ba1
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
200 additions
and
188 deletions
+200
-188
colossalai/autochunk/autochunk_codegen.py
colossalai/autochunk/autochunk_codegen.py
+3
-3
colossalai/autochunk/estimate_memory.py
colossalai/autochunk/estimate_memory.py
+20
-47
colossalai/autochunk/search_chunk.py
colossalai/autochunk/search_chunk.py
+22
-61
colossalai/autochunk/trace_flow.py
colossalai/autochunk/trace_flow.py
+53
-10
colossalai/autochunk/trace_indice.py
colossalai/autochunk/trace_indice.py
+61
-15
colossalai/autochunk/utils.py
colossalai/autochunk/utils.py
+16
-11
tests/test_autochunk/test_evoformer_codegen.py
tests/test_autochunk/test_evoformer_codegen.py
+10
-11
tests/test_autochunk/test_simple_evoformer_codegen.py
tests/test_autochunk/test_simple_evoformer_codegen.py
+12
-27
tests/test_autochunk/test_simple_evoformer_search.py
tests/test_autochunk/test_simple_evoformer_search.py
+3
-3
No files found.
colossalai/autochunk/autochunk_codegen.py
View file @
ecccc91f
...
@@ -123,12 +123,13 @@ def _replace_name(context: str, name_from: str, name_to: str) -> str:
...
@@ -123,12 +123,13 @@ def _replace_name(context: str, name_from: str, name_to: str) -> str:
"""
"""
replace node name
replace node name
"""
"""
patterns
=
[(
" "
,
" "
),
(
" "
,
"."
),
(
" "
,
","
),
(
"("
,
")"
),
(
"("
,
","
),
(
" "
,
")"
)]
patterns
=
[(
" "
,
" "
),
(
" "
,
"."
),
(
" "
,
","
),
(
"("
,
")"
),
(
"("
,
","
),
(
" "
,
")
"
),
(
" "
,
""
),
(
""
,
"
"
)]
for
p
in
patterns
:
for
p
in
patterns
:
source
=
p
[
0
]
+
name_from
+
p
[
1
]
source
=
p
[
0
]
+
name_from
+
p
[
1
]
target
=
p
[
0
]
+
name_to
+
p
[
1
]
target
=
p
[
0
]
+
name_to
+
p
[
1
]
if
source
in
context
:
if
source
in
context
:
context
=
context
.
replace
(
source
,
target
)
context
=
context
.
replace
(
source
,
target
)
break
return
context
return
context
...
@@ -138,8 +139,7 @@ def _replace_reshape_size(context: str, node_name: str, reshape_size_dict: Dict)
...
@@ -138,8 +139,7 @@ def _replace_reshape_size(context: str, node_name: str, reshape_size_dict: Dict)
"""
"""
if
node_name
not
in
reshape_size_dict
:
if
node_name
not
in
reshape_size_dict
:
return
context
return
context
for
size_name
,
size_value
in
reshape_size_dict
[
node_name
].
items
():
context
=
context
.
replace
(
reshape_size_dict
[
node_name
][
0
],
reshape_size_dict
[
node_name
][
1
])
context
=
context
.
replace
(
size_name
,
size_value
)
return
context
return
context
...
...
colossalai/autochunk/estimate_memory.py
View file @
ecccc91f
...
@@ -37,10 +37,10 @@ class EstimateMemory(object):
...
@@ -37,10 +37,10 @@ class EstimateMemory(object):
def
_add_active_node
(
self
,
n
,
active_list
):
def
_add_active_node
(
self
,
n
,
active_list
):
new_active
=
self
.
_get_output_node
(
n
)[
1
]
new_active
=
self
.
_get_output_node
(
n
)[
1
]
if
n
.
op
==
"placeholder"
:
if
n
.
op
==
"placeholder"
and
get_node_shape
(
n
)
is
not
None
:
new_active
.
append
(
n
.
name
)
new_active
.
append
(
n
.
name
)
for
i
in
new_active
:
for
i
in
new_active
:
if
i
not
in
active_list
:
if
i
not
in
active_list
and
get_node_shape
(
n
)
is
not
None
:
active_list
.
append
(
i
)
active_list
.
append
(
i
)
def
_get_delete_node
(
self
,
user
,
user_to_last_uses
,
to_keep
=
None
):
def
_get_delete_node
(
self
,
user
,
user_to_last_uses
,
to_keep
=
None
):
...
@@ -77,15 +77,11 @@ class EstimateMemory(object):
...
@@ -77,15 +77,11 @@ class EstimateMemory(object):
if
i
in
active_list
:
if
i
in
active_list
:
active_list
.
remove
(
i
)
active_list
.
remove
(
i
)
def
_get_chunk_inputs_size
(
def
_get_chunk_inputs_size
(
self
,
chunk_inputs
,
chunk_inputs_non_chunk
,
node_list
,
chunk_end_idx
):
self
,
chunk_inputs
,
chunk_inputs_non_chunk
,
node_list
,
chunk_end_idx
):
nodes_to_delete
=
[]
nodes_to_delete
=
[]
for
chunk_input
in
chunk_inputs
+
chunk_inputs_non_chunk
:
for
chunk_input
in
chunk_inputs
+
chunk_inputs_non_chunk
:
chunk_input_users
=
chunk_input
.
users
.
keys
()
chunk_input_users
=
chunk_input
.
users
.
keys
()
chunk_input_users_idx
=
[
chunk_input_users_idx
=
[
find_idx_by_name
(
i
.
name
,
node_list
)
for
i
in
chunk_input_users
]
find_idx_by_name
(
i
.
name
,
node_list
)
for
i
in
chunk_input_users
]
if
all
(
i
<=
chunk_end_idx
for
i
in
chunk_input_users_idx
):
if
all
(
i
<=
chunk_end_idx
for
i
in
chunk_input_users_idx
):
if
chunk_input
not
in
nodes_to_delete
:
if
chunk_input
not
in
nodes_to_delete
:
nodes_to_delete
.
append
(
chunk_input
)
nodes_to_delete
.
append
(
chunk_input
)
...
@@ -112,9 +108,7 @@ class EstimateMemory(object):
...
@@ -112,9 +108,7 @@ class EstimateMemory(object):
not_contiguous_ops
=
[
"permute"
]
not_contiguous_ops
=
[
"permute"
]
inherit_contiguous_ops
=
[
"transpose"
,
"view"
]
inherit_contiguous_ops
=
[
"transpose"
,
"view"
]
if
node
.
op
==
"call_function"
and
any
(
if
node
.
op
==
"call_function"
and
any
(
n
in
node
.
name
for
n
in
[
"matmul"
,
"reshape"
]):
n
in
node
.
name
for
n
in
[
"matmul"
,
"reshape"
]
):
for
n
in
node
.
args
:
for
n
in
node
.
args
:
if
n
in
not_contiguous_list
:
if
n
in
not_contiguous_list
:
# matmul won't change origin tensor, but create a tmp copy
# matmul won't change origin tensor, but create a tmp copy
...
@@ -125,9 +119,7 @@ class EstimateMemory(object):
...
@@ -125,9 +119,7 @@ class EstimateMemory(object):
# module will just make origin tensor to contiguous
# module will just make origin tensor to contiguous
if
delete
:
if
delete
:
not_contiguous_list
.
remove
(
n
)
not_contiguous_list
.
remove
(
n
)
elif
node
.
op
==
"call_method"
and
any
(
elif
node
.
op
==
"call_method"
and
any
(
i
in
node
.
name
for
i
in
not_contiguous_ops
):
i
in
node
.
name
for
i
in
not_contiguous_ops
):
if
node
not
in
not_contiguous_list
:
if
node
not
in
not_contiguous_list
:
not_contiguous_list
.
append
(
node
)
not_contiguous_list
.
append
(
node
)
return
mem
return
mem
...
@@ -142,9 +134,7 @@ class EstimateMemory(object):
...
@@ -142,9 +134,7 @@ class EstimateMemory(object):
else
:
else
:
return
float
(
chunk_size
)
/
node_shape
[
chunk_dim
]
return
float
(
chunk_size
)
/
node_shape
[
chunk_dim
]
def
_get_chunk_delete_node_size
(
def
_get_chunk_delete_node_size
(
self
,
user
,
user_to_last_uses
,
chunk_ratio
,
chunk_inputs_names
):
self
,
user
,
user_to_last_uses
,
chunk_ratio
,
chunk_inputs_names
):
# if any(j in user.name for j in ['transpose', 'permute', 'view']):
# if any(j in user.name for j in ['transpose', 'permute', 'view']):
# return 0
# return 0
if
user
.
op
in
(
"placeholder"
,
"output"
):
if
user
.
op
in
(
"placeholder"
,
"output"
):
...
@@ -196,7 +186,7 @@ class EstimateMemory(object):
...
@@ -196,7 +186,7 @@ class EstimateMemory(object):
Returns:
Returns:
act_memory_peak_log (List): peak memory of every node
act_memory_peak_log (List): peak memory of every node
act_memory_after_node_log (List): memory after excuting every node
act_memory_after_node_log (List): memory after excuting every node
active_node_list_log (List): active nodes of every node. active nodes refer to
active_node_list_log (List): active nodes of every node. active nodes refer to
nodes generated but not deleted.
nodes generated but not deleted.
"""
"""
act_memory
=
0.0
act_memory
=
0.0
...
@@ -212,7 +202,7 @@ class EstimateMemory(object):
...
@@ -212,7 +202,7 @@ class EstimateMemory(object):
use_chunk
=
True
if
chunk_infos
is
not
None
else
False
use_chunk
=
True
if
chunk_infos
is
not
None
else
False
chunk_within
=
False
chunk_within
=
False
chunk_region_idx
=
None
chunk_region_idx
=
None
chunk_ratio
=
1
# use it to estimate chunk mem
chunk_ratio
=
1
# use it to estimate chunk mem
chunk_inputs_names
=
[]
chunk_inputs_names
=
[]
if
use_chunk
:
if
use_chunk
:
...
@@ -221,23 +211,18 @@ class EstimateMemory(object):
...
@@ -221,23 +211,18 @@ class EstimateMemory(object):
chunk_ends
=
[
i
[
1
]
for
i
in
chunk_regions
]
chunk_ends
=
[
i
[
1
]
for
i
in
chunk_regions
]
chunk_inputs
=
[
i
[
"inputs"
]
for
i
in
chunk_infos
]
chunk_inputs
=
[
i
[
"inputs"
]
for
i
in
chunk_infos
]
chunk_inputs_non_chunk
=
[
i
[
"inputs_non_chunk"
]
for
i
in
chunk_infos
]
chunk_inputs_non_chunk
=
[
i
[
"inputs_non_chunk"
]
for
i
in
chunk_infos
]
chunk_inputs_names
=
[
j
.
name
for
i
in
chunk_inputs
for
j
in
i
]
+
[
chunk_inputs_names
=
[
j
.
name
for
i
in
chunk_inputs
for
j
in
i
j
.
name
for
i
in
chunk_inputs_non_chunk
for
j
in
i
]
+
[
j
.
name
for
i
in
chunk_inputs_non_chunk
for
j
in
i
]
]
chunk_outputs
=
[
i
[
"outputs"
][
0
]
for
i
in
chunk_infos
]
chunk_outputs
=
[
i
[
"outputs"
][
0
]
for
i
in
chunk_infos
]
chunk_node_dim
=
[
i
[
"node_chunk_dim"
]
for
i
in
chunk_infos
]
chunk_node_dim
=
[
i
[
"node_chunk_dim"
]
for
i
in
chunk_infos
]
chunk_sizes
=
[
chunk_sizes
=
[
i
[
"chunk_size"
]
if
"chunk_size"
in
i
else
1
for
i
in
chunk_infos
]
i
[
"chunk_size"
]
if
"chunk_size"
in
i
else
1
for
i
in
chunk_infos
]
for
idx
,
node
in
enumerate
(
node_list
):
for
idx
,
node
in
enumerate
(
node_list
):
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
if
use_chunk
and
idx
in
chunk_starts
:
if
use_chunk
and
idx
in
chunk_starts
:
chunk_within
=
True
chunk_within
=
True
chunk_region_idx
=
chunk_starts
.
index
(
idx
)
chunk_region_idx
=
chunk_starts
.
index
(
idx
)
act_memory
+=
self
.
_get_output_node_size
(
act_memory
+=
self
.
_get_output_node_size
(
chunk_outputs
[
chunk_region_idx
])
/
(
1024
**
2
)
chunk_outputs
[
chunk_region_idx
]
)
/
(
1024
**
2
)
# determine chunk ratio for current node
# determine chunk ratio for current node
if
chunk_within
:
if
chunk_within
:
...
@@ -262,22 +247,13 @@ class EstimateMemory(object):
...
@@ -262,22 +247,13 @@ class EstimateMemory(object):
else
:
else
:
# forward memory
# forward memory
# TODO: contiguous_memory still not accurate for matmul, view, reshape and transpose
# TODO: contiguous_memory still not accurate for matmul, view, reshape and transpose
act_memory
+=
(
act_memory
+=
(
self
.
_get_contiguous_memory
(
node
,
not_contiguous_list
)
*
chunk_ratio
/
(
1024
**
2
))
self
.
_get_contiguous_memory
(
node
,
not_contiguous_list
)
act_memory
+=
(
self
.
_get_output_node_size
(
node
)
*
chunk_ratio
/
(
1024
**
2
))
*
chunk_ratio
/
(
1024
**
2
)
)
act_memory
+=
(
self
.
_get_output_node_size
(
node
)
*
chunk_ratio
/
(
1024
**
2
)
)
# record max act memory
# record max act memory
act_memory_peak_log
.
append
(
act_memory
)
act_memory_peak_log
.
append
(
act_memory
)
# delete useless memory
# delete useless memory
act_memory
-=
(
act_memory
-=
(
self
.
_get_contiguous_memory
(
node
,
not_contiguous_list
,
delete
=
True
)
*
chunk_ratio
/
self
.
_get_contiguous_memory
(
node
,
not_contiguous_list
,
delete
=
True
)
(
1024
**
2
))
*
chunk_ratio
/
(
1024
**
2
)
)
# delete unused vars not in chunk_input_list
# delete unused vars not in chunk_input_list
# we can't delete input nodes until chunk ends
# we can't delete input nodes until chunk ends
if
chunk_within
:
if
chunk_within
:
...
@@ -288,9 +264,8 @@ class EstimateMemory(object):
...
@@ -288,9 +264,8 @@ class EstimateMemory(object):
chunk_inputs_names
,
chunk_inputs_names
,
)
/
(
1024
**
2
)
)
/
(
1024
**
2
)
else
:
else
:
act_memory
-=
self
.
_get_delete_node_size
(
act_memory
-=
self
.
_get_delete_node_size
(
node
,
user_to_last_uses_no_free_var
,
node
,
user_to_last_uses_no_free_var
,
chunk_inputs_names
chunk_inputs_names
)
/
(
1024
**
2
)
)
/
(
1024
**
2
)
# log active node, only effective without chunk
# log active node, only effective without chunk
self
.
_add_active_node
(
node
,
active_node_list
)
self
.
_add_active_node
(
node
,
active_node_list
)
...
@@ -298,9 +273,7 @@ class EstimateMemory(object):
...
@@ -298,9 +273,7 @@ class EstimateMemory(object):
# if node in chunk end nodes, restore chunk settings
# if node in chunk end nodes, restore chunk settings
if
use_chunk
and
idx
in
chunk_ends
:
if
use_chunk
and
idx
in
chunk_ends
:
act_memory
-=
(
act_memory
-=
(
self
.
_get_output_node_size
(
node
)
*
chunk_ratio
/
(
1024
**
2
))
self
.
_get_output_node_size
(
node
)
*
chunk_ratio
/
(
1024
**
2
)
)
act_memory
-=
self
.
_get_chunk_inputs_size
(
act_memory
-=
self
.
_get_chunk_inputs_size
(
chunk_inputs
[
chunk_region_idx
],
chunk_inputs
[
chunk_region_idx
],
chunk_inputs_non_chunk
[
chunk_region_idx
],
chunk_inputs_non_chunk
[
chunk_region_idx
],
...
...
colossalai/autochunk/search_chunk.py
View file @
ecccc91f
...
@@ -8,11 +8,7 @@ from .reorder_graph import ReorderGraph
...
@@ -8,11 +8,7 @@ from .reorder_graph import ReorderGraph
from
.select_chunk
import
SelectChunk
from
.select_chunk
import
SelectChunk
from
.trace_flow
import
TraceFlow
from
.trace_flow
import
TraceFlow
from
.trace_indice
import
TraceIndice
from
.trace_indice
import
TraceIndice
from
.utils
import
(
from
.utils
import
get_node_shape
,
is_non_compute_node
,
is_non_compute_node_except_placeholder
get_node_shape
,
is_non_compute_node
,
is_non_compute_node_except_placeholder
,
)
class
SearchChunk
(
object
):
class
SearchChunk
(
object
):
...
@@ -73,13 +69,11 @@ class SearchChunk(object):
...
@@ -73,13 +69,11 @@ class SearchChunk(object):
"""
"""
free_var_idx
=
[]
free_var_idx
=
[]
for
idx
,
n
in
enumerate
(
self
.
trace_indice
.
node_list
):
for
idx
,
n
in
enumerate
(
self
.
trace_indice
.
node_list
):
if
n
.
op
==
"placeholder"
:
if
n
.
op
==
"placeholder"
and
get_node_shape
(
n
)
is
not
None
:
free_var_idx
.
append
(
idx
)
free_var_idx
.
append
(
idx
)
return
free_var_idx
return
free_var_idx
def
_search_max_chunk_region
(
def
_search_max_chunk_region
(
self
,
active_node
:
List
,
peak_node
:
Node
,
chunk_regions
:
List
)
->
Tuple
:
self
,
active_node
:
List
,
peak_node
:
Node
,
chunk_regions
:
List
)
->
Tuple
:
"""
"""
Search max chunk region according to peak memory node
Search max chunk region according to peak memory node
...
@@ -124,15 +118,9 @@ class SearchChunk(object):
...
@@ -124,15 +118,9 @@ class SearchChunk(object):
region
=
i
[
"region"
]
region
=
i
[
"region"
]
if
chunk_region_start
>=
region
[
0
]
and
chunk_region_end
<=
region
[
1
]:
if
chunk_region_start
>=
region
[
0
]
and
chunk_region_end
<=
region
[
1
]:
return
None
return
None
elif
(
elif
(
region
[
0
]
<=
chunk_region_start
<=
region
[
1
]
and
chunk_region_end
>
region
[
1
]):
region
[
0
]
<=
chunk_region_start
<=
region
[
1
]
and
chunk_region_end
>
region
[
1
]
):
chunk_region_start
=
region
[
1
]
+
1
chunk_region_start
=
region
[
1
]
+
1
elif
(
elif
(
region
[
0
]
<=
chunk_region_end
<=
region
[
1
]
and
chunk_region_start
<
region
[
0
]):
region
[
0
]
<=
chunk_region_end
<=
region
[
1
]
and
chunk_region_start
<
region
[
0
]
):
chunk_region_end
=
region
[
0
]
-
1
chunk_region_end
=
region
[
0
]
-
1
return
chunk_region_start
,
chunk_region_end
return
chunk_region_start
,
chunk_region_end
...
@@ -164,25 +152,16 @@ class SearchChunk(object):
...
@@ -164,25 +152,16 @@ class SearchChunk(object):
for
start_node
,
start_trace
in
start_traces
.
items
():
for
start_node
,
start_trace
in
start_traces
.
items
():
for
start_dim
,
_
in
enumerate
(
start_trace
[
"indice"
]):
for
start_dim
,
_
in
enumerate
(
start_trace
[
"indice"
]):
# dim size cannot be 1
# dim size cannot be 1
if
(
if
(
get_node_shape
(
end_node
)[
end_dim
]
==
1
or
get_node_shape
(
start_node
)[
start_dim
]
==
1
):
get_node_shape
(
end_node
)[
end_dim
]
==
1
or
get_node_shape
(
start_node
)[
start_dim
]
==
1
):
continue
continue
# check index source align
# check index source align
if
not
self
.
trace_flow
.
check_index_source
(
if
not
self
.
trace_flow
.
check_index_source
(
start_dim
,
start_node
,
start_idx
,
end_dim
,
end_node
):
start_dim
,
start_node
,
start_idx
,
end_dim
,
end_node
):
continue
continue
# check index copmute
# check index copmute
if
not
self
.
trace_flow
.
check_index_compute
(
if
not
self
.
trace_flow
.
check_index_compute
(
start_idx
,
end_dim
,
end_node
,
end_idx
):
start_idx
,
end_dim
,
end_node
,
end_idx
):
continue
continue
# flow search
# flow search
chunk_info
=
self
.
trace_flow
.
flow_search
(
chunk_info
=
self
.
trace_flow
.
flow_search
(
start_idx
,
start_dim
,
end_idx
,
end_dim
)
start_idx
,
start_dim
,
end_idx
,
end_dim
)
if
chunk_info
is
None
:
if
chunk_info
is
None
:
continue
continue
# check index copmute
# check index copmute
...
@@ -191,9 +170,7 @@ class SearchChunk(object):
...
@@ -191,9 +170,7 @@ class SearchChunk(object):
chunk_infos
.
append
(
chunk_info
)
chunk_infos
.
append
(
chunk_info
)
return
chunk_infos
return
chunk_infos
def
_search_possible_chunk_regions
(
def
_search_possible_chunk_regions
(
self
,
max_chunk_region
:
Tuple
,
peak_node
:
Node
)
->
List
:
self
,
max_chunk_region
:
Tuple
,
peak_node
:
Node
)
->
List
:
"""
"""
Search every possible region within the max chunk region.
Search every possible region within the max chunk region.
...
@@ -206,28 +183,23 @@ class SearchChunk(object):
...
@@ -206,28 +183,23 @@ class SearchChunk(object):
"""
"""
possible_chunk_region
=
[]
possible_chunk_region
=
[]
output_trace
=
copy
.
deepcopy
(
self
.
trace_indice
.
indice_trace_list
)
output_trace
=
copy
.
deepcopy
(
self
.
trace_indice
.
indice_trace_list
)
input_trace
=
[]
# trace of a node's input nodes
input_trace
=
[]
# trace of a node's input nodes
for
_
,
n
in
enumerate
(
self
.
trace_indice
.
node_list
):
for
_
,
n
in
enumerate
(
self
.
trace_indice
.
node_list
):
cur_trace
=
{}
cur_trace
=
{}
for
arg
in
n
.
args
:
for
arg
in
n
.
args
:
if
type
(
arg
)
==
type
(
n
)
and
not
is_non_compute_node_except_placeholder
(
if
type
(
arg
)
==
type
(
n
)
and
not
is_non_compute_node_except_placeholder
(
arg
):
arg
):
cur_trace
[
arg
]
=
self
.
trace_indice
.
_find_trace_from_node
(
arg
)
cur_trace
[
arg
]
=
self
.
trace_indice
.
_find_trace_from_node
(
arg
)
input_trace
.
append
(
cur_trace
)
input_trace
.
append
(
cur_trace
)
for
start_idx
in
range
(
max_chunk_region
[
0
],
peak_node
+
1
):
for
start_idx
in
range
(
max_chunk_region
[
0
],
peak_node
+
1
):
for
end_idx
in
range
(
peak_node
,
max_chunk_region
[
1
]
+
1
):
for
end_idx
in
range
(
peak_node
,
max_chunk_region
[
1
]
+
1
):
# skip non compute nodes
# skip non compute nodes
if
is_non_compute_node
(
if
is_non_compute_node
(
self
.
trace_indice
.
node_list
[
start_idx
])
or
is_non_compute_node
(
self
.
trace_indice
.
node_list
[
start_idx
]
self
.
trace_indice
.
node_list
[
end_idx
]):
)
or
is_non_compute_node
(
self
.
trace_indice
.
node_list
[
end_idx
]):
continue
continue
# select free dim
# select free dim
chunk_info
=
self
.
_find_chunk_info
(
chunk_info
=
self
.
_find_chunk_info
(
input_trace
,
output_trace
,
start_idx
,
end_idx
)
input_trace
,
output_trace
,
start_idx
,
end_idx
)
if
len
(
chunk_info
)
>
0
:
if
len
(
chunk_info
)
>
0
:
possible_chunk_region
.
extend
(
chunk_info
)
possible_chunk_region
.
extend
(
chunk_info
)
return
possible_chunk_region
return
possible_chunk_region
...
@@ -256,17 +228,12 @@ class SearchChunk(object):
...
@@ -256,17 +228,12 @@ class SearchChunk(object):
best_chunk_region (Dict)
best_chunk_region (Dict)
"""
"""
peak_node
=
self
.
_find_peak_node
(
mem_peak
)
peak_node
=
self
.
_find_peak_node
(
mem_peak
)
max_chunk_region
=
self
.
_search_max_chunk_region
(
max_chunk_region
=
self
.
_search_max_chunk_region
(
active_node
,
peak_node
,
chunk_infos
)
active_node
,
peak_node
,
chunk_infos
)
if
max_chunk_region
==
None
:
if
max_chunk_region
==
None
:
return
None
return
None
possible_chunk_regions
=
self
.
_search_possible_chunk_regions
(
possible_chunk_regions
=
self
.
_search_possible_chunk_regions
(
max_chunk_region
,
peak_node
)
max_chunk_region
,
peak_node
best_chunk_region
=
self
.
select_chunk
.
_select_best_chunk_region
(
possible_chunk_regions
,
chunk_infos
,
peak_node
,
)
max_chunk_region
,
mem_peak
)
best_chunk_region
=
self
.
select_chunk
.
_select_best_chunk_region
(
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
)
best_chunk_region
=
self
.
reorder_graph
.
reorder_all
(
best_chunk_region
)
best_chunk_region
=
self
.
reorder_graph
.
reorder_all
(
best_chunk_region
)
return
best_chunk_region
return
best_chunk_region
...
@@ -291,9 +258,7 @@ class SearchChunk(object):
...
@@ -291,9 +258,7 @@ class SearchChunk(object):
init_mem_peak
,
init_mem_peak
,
_
,
_
,
active_node
,
active_node
,
)
=
self
.
estimate_memory
.
estimate_chunk_inference_mem
(
)
=
self
.
estimate_memory
.
estimate_chunk_inference_mem
(
self
.
trace_indice
.
node_list
)
self
.
trace_indice
.
node_list
)
mem_peak
=
init_mem_peak
mem_peak
=
init_mem_peak
while
True
:
while
True
:
...
@@ -306,14 +271,10 @@ class SearchChunk(object):
...
@@ -306,14 +271,10 @@ class SearchChunk(object):
mem_peak
,
mem_peak
,
_
,
_
,
active_node
,
active_node
,
)
=
self
.
estimate_memory
.
estimate_chunk_inference_mem
(
)
=
self
.
estimate_memory
.
estimate_chunk_inference_mem
(
self
.
trace_indice
.
node_list
,
chunk_infos
)
self
.
trace_indice
.
node_list
,
chunk_infos
)
if
self
.
_stop_search
(
init_mem_peak
,
mem_peak
):
if
self
.
_stop_search
(
init_mem_peak
,
mem_peak
):
break
break
if
self
.
print_mem
:
if
self
.
print_mem
:
self
.
print_mem
=
False
self
.
print_mem
=
False
self
.
estimate_memory
.
estimate_chunk_inference_mem
(
self
.
estimate_memory
.
estimate_chunk_inference_mem
(
self
.
trace_indice
.
node_list
,
chunk_infos
,
print_mem
=
True
)
self
.
trace_indice
.
node_list
,
chunk_infos
,
print_mem
=
True
)
return
chunk_infos
return
chunk_infos
colossalai/autochunk/trace_flow.py
View file @
ecccc91f
from
typing
import
Dict
,
List
,
Tuple
from
torch.fx.node
import
Node
from
.trace_indice
import
TraceIndice
from
.trace_indice
import
TraceIndice
from
.utils
import
(
from
.utils
import
(
find_chunk_all_input_nodes
,
find_chunk_all_input_nodes
,
find_chunk_compute_input_and_output_nodes
,
find_chunk_compute_input_and_output_nodes
,
find_idx_by_name
,
find_idx_by_name
,
flat_list
,
get_node_shape
,
get_node_shape
,
is_non_compute_node
,
is_non_compute_node
,
is_non_compute_node_except_placeholder
,
is_non_compute_node_except_placeholder
,
...
@@ -171,7 +176,7 @@ class TraceFlow(object):
...
@@ -171,7 +176,7 @@ class TraceFlow(object):
# get cur node info
# get cur node info
cur_node_chunk_dim
=
all_node_info
[
cur_node
][
"chunk_dim"
]
cur_node_chunk_dim
=
all_node_info
[
cur_node
][
"chunk_dim"
]
cur_node_fix_dim
=
all_node_info
[
cur_node
][
"fix_dim"
]
cur_node_fix_dim
=
all_node_info
[
cur_node
][
"fix_dim"
]
if
cur_node_chunk_dim
:
if
cur_node_chunk_dim
is
not
None
:
cur_node_compute
=
self
.
trace_indice
.
_find_compute_trace_from_node
(
cur_node
)
cur_node_compute
=
self
.
trace_indice
.
_find_compute_trace_from_node
(
cur_node
)
cur_node_source
=
self
.
trace_indice
.
_find_source_trace_from_node
(
cur_node
)
cur_node_source
=
self
.
trace_indice
.
_find_source_trace_from_node
(
cur_node
)
else
:
else
:
...
@@ -223,15 +228,32 @@ class TraceFlow(object):
...
@@ -223,15 +228,32 @@ class TraceFlow(object):
cur_node_list
=
next_node_list
cur_node_list
=
next_node_list
return
all_node_info
return
all_node_info
def
_get_input_nodes_dim
(
self
,
inputs
,
start_idx
,
end_idx
,
all_node_info
):
def
_get_input_nodes_dim
(
self
,
inputs
:
List
[
Node
],
start_idx
:
int
,
end_idx
:
int
,
all_node_info
:
Dict
)
->
Tuple
:
"""
Get chunk dim for every input node for their every entry, remove unchunked nodes
Args:
inputs (List[Node]): input nodes
all_node_info (Dict): describe all node's chunk dim and fix dim
start_idx (int): chunk start idx
end_idx (int): chunk end idx
Returns:
inputs (List(Node)): new inputs
inputs_dim (List): chunk dim for inputs
"""
inputs_dim
=
[]
inputs_dim
=
[]
remove_inputs
=
[]
remove_inputs
=
[]
for
input_node
in
inputs
:
for
input_node
in
inputs
:
input_dict
=
{}
input_dict
=
{}
input_node_idx
=
find_idx_by_name
(
input_node
.
name
,
self
.
trace_indice
.
node_list
)
input_node_idx
=
find_idx_by_name
(
input_node
.
name
,
self
.
trace_indice
.
node_list
)
for
user
in
input_node
.
users
.
keys
():
for
user
in
input_node
.
users
.
keys
():
# skip non compute
if
is_non_compute_node
(
user
):
if
is_non_compute_node
(
user
):
continue
continue
# untraced node, mostly non compute
if
user
not
in
all_node_info
:
continue
user_idx
=
find_idx_by_name
(
user
.
name
,
self
.
trace_indice
.
node_list
)
user_idx
=
find_idx_by_name
(
user
.
name
,
self
.
trace_indice
.
node_list
)
if
start_idx
<=
user_idx
<=
end_idx
:
if
start_idx
<=
user_idx
<=
end_idx
:
chunk_dim
=
all_node_info
[
user
][
"chunk_dim"
]
chunk_dim
=
all_node_info
[
user
][
"chunk_dim"
]
...
@@ -245,12 +267,24 @@ class TraceFlow(object):
...
@@ -245,12 +267,24 @@ class TraceFlow(object):
remove_inputs
.
append
(
input_node
)
remove_inputs
.
append
(
input_node
)
else
:
else
:
inputs_dim
.
append
(
input_dict
)
inputs_dim
.
append
(
input_dict
)
# remove unchunked inputs
for
i
in
remove_inputs
:
for
i
in
remove_inputs
:
if
i
in
inputs
:
if
i
in
inputs
:
inputs
.
remove
(
i
)
inputs
.
remove
(
i
)
return
inputs
,
inputs_dim
return
inputs
,
inputs_dim
def
_get_prepose_nodes
(
self
,
all_node_info
,
start_idx
,
end_idx
):
def
_get_prepose_nodes
(
self
,
all_node_info
:
Dict
,
start_idx
:
int
,
end_idx
:
int
)
->
List
[
Node
]:
"""
get all useless nodes in chunk region and prepose them
Args:
all_node_info (Dict): describe all node's chunk dim and fix dim
start_idx (int): chunk start idx
end_idx (int): chunk end idx
Returns:
List[Node]: all nodes to be preposed
"""
# get all possible prepose nodes
# get all possible prepose nodes
maybe_prepose_nodes
=
[]
maybe_prepose_nodes
=
[]
for
node
,
node_info
in
all_node_info
.
items
():
for
node
,
node_info
in
all_node_info
.
items
():
...
@@ -276,7 +310,7 @@ class TraceFlow(object):
...
@@ -276,7 +310,7 @@ class TraceFlow(object):
for
cur_prepose_node
in
tmp_cur_prepose_nodes
:
for
cur_prepose_node
in
tmp_cur_prepose_nodes
:
if
prepose_flag
==
False
:
if
prepose_flag
==
False
:
break
break
for
cur_prepose_node_arg
in
cur_prepose_node
.
a
rg
s
:
for
cur_prepose_node_arg
in
cur_prepose_node
.
a
ll_input_node
s
:
if
type
(
cur_prepose_node_arg
)
!=
type
(
cur_prepose_node
):
if
type
(
cur_prepose_node_arg
)
!=
type
(
cur_prepose_node
):
continue
continue
# out of loop
# out of loop
...
@@ -360,19 +394,28 @@ class TraceFlow(object):
...
@@ -360,19 +394,28 @@ class TraceFlow(object):
return
chunk_info
return
chunk_info
def
_reassgin_reshape_size
(
self
,
chunk_info
):
def
_reassgin_reshape_size
(
self
,
chunk_info
):
"""
Some shape args in reshape may have changed due to chunk
reassgin those changed shape
"""
chunk_region
=
chunk_info
[
"region"
]
chunk_region
=
chunk_info
[
"region"
]
reshape_size
=
{}
reshape_size
=
{}
chunk_shape
=
get_node_shape
(
chunk_info
[
"outputs"
][
0
])[
chunk_info
[
"outputs_dim"
]]
chunk_shape
=
get_node_shape
(
chunk_info
[
"outputs"
][
0
])[
chunk_info
[
"outputs_dim"
]]
for
node
in
self
.
trace_indice
.
node_list
[
chunk_region
[
0
]:
chunk_region
[
1
]
+
1
]:
for
node
in
self
.
trace_indice
.
node_list
[
chunk_region
[
0
]:
chunk_region
[
1
]
+
1
]:
if
any
(
i
in
node
.
name
for
i
in
[
"reshape"
,
"view"
]):
if
any
(
i
in
node
.
name
for
i
in
[
"reshape"
,
"view"
]):
reshape_args
=
node
.
args
[
1
:]
reshape_args
=
flat_list
(
node
.
args
[
1
:])
reshape_log
=
self
.
trace_indice
.
indice_view_list
[
node
]
chunk_dim
=
chunk_info
[
"node_chunk_dim"
][
node
][
"chunk_dim"
]
chunk_dim
=
chunk_info
[
"node_chunk_dim"
][
node
][
"chunk_dim"
]
re
shape
_size
[
node
.
name
]
=
{}
new_
shape
=
""
for
reshape_arg_dim
,
reshape_arg
in
enumerate
(
reshape_args
):
for
reshape_arg_dim
,
reshape_arg
in
enumerate
(
reshape_args
):
if
reshape_arg_dim
in
reshape_log
[
"dim_to"
]:
continue
if
reshape_arg_dim
==
chunk_dim
:
if
reshape_arg_dim
==
chunk_dim
:
reshape_size
[
node
.
name
][
reshape_arg
.
name
]
=
(
"min(chunk_size, %d - chunk_idx)"
%
chunk_shape
)
new_shape
+=
"min(chunk_size, %d - chunk_idx), "
%
chunk_shape
else
:
if
isinstance
(
reshape_arg
,
int
):
new_shape
+=
"%s, "
%
str
(
reshape_arg
)
else
:
new_shape
+=
"%s, "
%
reshape_arg
.
name
new_shape
=
new_shape
[:
-
2
]
origin_shape
=
str
(
reshape_args
)[
1
:
-
1
]
reshape_size
[
node
.
name
]
=
[
origin_shape
,
new_shape
]
chunk_info
[
"reshape_size"
]
=
reshape_size
chunk_info
[
"reshape_size"
]
=
reshape_size
return
chunk_info
return
chunk_info
colossalai/autochunk/trace_indice.py
View file @
ecccc91f
...
@@ -3,7 +3,7 @@ from typing import Dict, List, Tuple
...
@@ -3,7 +3,7 @@ from typing import Dict, List, Tuple
from
torch.fx.node
import
Node
from
torch.fx.node
import
Node
from
.utils
import
find_first_tensor_arg
,
find_idx_by_name
,
get_node_shape
,
unflat_list
from
.utils
import
find_first_tensor_arg
,
find_idx_by_name
,
flat_list
,
get_node_shape
class
TraceIndice
(
object
):
class
TraceIndice
(
object
):
...
@@ -28,7 +28,7 @@ class TraceIndice(object):
...
@@ -28,7 +28,7 @@ class TraceIndice(object):
node_list (List)
node_list (List)
"""
"""
def
__init__
(
self
,
node_list
:
List
)
->
None
:
def
__init__
(
self
,
node_list
:
List
[
Node
]
)
->
None
:
self
.
node_list
=
node_list
self
.
node_list
=
node_list
self
.
indice_trace_list
=
self
.
_init_indice_trace_list
()
self
.
indice_trace_list
=
self
.
_init_indice_trace_list
()
self
.
indice_view_list
=
{}
self
.
indice_view_list
=
{}
...
@@ -198,7 +198,7 @@ class TraceIndice(object):
...
@@ -198,7 +198,7 @@ class TraceIndice(object):
node_idx
=
find_idx_by_name
(
node
.
name
,
self
.
node_list
)
node_idx
=
find_idx_by_name
(
node
.
name
,
self
.
node_list
)
return
self
.
indice_trace_list
[
node_idx
][
"compute"
]
return
self
.
indice_trace_list
[
node_idx
][
"compute"
]
def
_assign_indice_as_input
(
self
,
node
,
node_idx
,
input_node
=
None
):
def
_assign_indice_as_input
(
self
,
node
:
Node
,
node_idx
:
int
,
input_node
=
None
):
"""
"""
Assign node's trace as its input node.
Assign node's trace as its input node.
...
@@ -216,7 +216,7 @@ class TraceIndice(object):
...
@@ -216,7 +216,7 @@ class TraceIndice(object):
self
.
_inherit_all_computation
(
input_node
,
node
)
self
.
_inherit_all_computation
(
input_node
,
node
)
def
_assign_all_indice
(
self
,
node
,
node_idx
):
def
_assign_all_indice
(
self
,
node
:
Node
,
node_idx
:
int
):
"""
"""
Add new indice for all node's dims.
Add new indice for all node's dims.
...
@@ -232,7 +232,7 @@ class TraceIndice(object):
...
@@ -232,7 +232,7 @@ class TraceIndice(object):
new_trace
.
append
(
self
.
_add_indice
())
new_trace
.
append
(
self
.
_add_indice
())
self
.
indice_trace_list
[
node_idx
][
"indice"
]
=
new_trace
self
.
indice_trace_list
[
node_idx
][
"indice"
]
=
new_trace
def
_assign_transpose_indice
(
self
,
node
,
node_idx
):
def
_assign_transpose_indice
(
self
,
node
:
Node
,
node_idx
:
int
):
"""
"""
Assign indice for transpose op.
Assign indice for transpose op.
1. swap input's dim according to transpose args
1. swap input's dim according to transpose args
...
@@ -249,7 +249,7 @@ class TraceIndice(object):
...
@@ -249,7 +249,7 @@ class TraceIndice(object):
self
.
_inherit_indice
(
input_node
,
tranpose_dim
[
1
],
node
,
tranpose_dim
[
0
])
self
.
_inherit_indice
(
input_node
,
tranpose_dim
[
1
],
node
,
tranpose_dim
[
0
])
self
.
_inherit_indice
(
input_node
,
tranpose_dim
[
0
],
node
,
tranpose_dim
[
1
])
self
.
_inherit_indice
(
input_node
,
tranpose_dim
[
0
],
node
,
tranpose_dim
[
1
])
def
_assign_permute_indice
(
self
,
node
,
node_idx
):
def
_assign_permute_indice
(
self
,
node
:
Node
,
node_idx
:
int
):
"""
"""
Assign indice for permute op.
Assign indice for permute op.
1. swap input's dim according to permute args
1. swap input's dim according to permute args
...
@@ -259,14 +259,14 @@ class TraceIndice(object):
...
@@ -259,14 +259,14 @@ class TraceIndice(object):
node (node)
node (node)
node_idx (int)
node_idx (int)
"""
"""
permute_dim
=
un
flat_list
(
node
.
args
[
1
:])
permute_dim
=
flat_list
(
node
.
args
[
1
:])
input_node
=
node
.
args
[
0
]
input_node
=
node
.
args
[
0
]
self
.
_assign_indice_as_input
(
node
,
node_idx
,
input_node
)
self
.
_assign_indice_as_input
(
node
,
node_idx
,
input_node
)
for
idx
,
d
in
enumerate
(
permute_dim
):
for
idx
,
d
in
enumerate
(
permute_dim
):
self
.
_inherit_indice
(
input_node
,
d
,
node
,
idx
)
self
.
_inherit_indice
(
input_node
,
d
,
node
,
idx
)
def
_assign_linear_indice
(
self
,
node
,
node_idx
):
def
_assign_linear_indice
(
self
,
node
:
Node
,
node_idx
:
int
):
"""
"""
Assign indice for linear op.
Assign indice for linear op.
1. copy trace from input node and change last indice accroding to weight
1. copy trace from input node and change last indice accroding to weight
...
@@ -287,7 +287,7 @@ class TraceIndice(object):
...
@@ -287,7 +287,7 @@ class TraceIndice(object):
self
.
_mark_computation
(
node
,
node_idx
,
[
-
1
])
self
.
_mark_computation
(
node
,
node_idx
,
[
-
1
])
def
_assign_matmul_indice
(
self
,
node
,
node_idx
):
def
_assign_matmul_indice
(
self
,
node
:
Node
,
node_idx
:
int
):
"""
"""
Assign indice for matmul op.
Assign indice for matmul op.
1. copy trace from matmul_left and change last indice accroding to matmul_right. (assert they have same length)
1. copy trace from matmul_left and change last indice accroding to matmul_right. (assert they have same length)
...
@@ -393,7 +393,7 @@ class TraceIndice(object):
...
@@ -393,7 +393,7 @@ class TraceIndice(object):
self
.
_assign_indice_as_input
(
node
,
idx
)
self
.
_assign_indice_as_input
(
node
,
idx
)
self
.
_mark_computation
(
node
,
idx
,
[
node
.
kwargs
[
"dim"
]])
self
.
_mark_computation
(
node
,
idx
,
[
node
.
kwargs
[
"dim"
]])
def
_assign_unsqueeze_indice
(
self
,
node
,
node_idx
):
def
_assign_unsqueeze_indice
(
self
,
node
:
Node
,
node_idx
:
int
):
"""
"""
Assign indice for unsqueeze op.
Assign indice for unsqueeze op.
1. assign new indice for unsqueeze dim
1. assign new indice for unsqueeze dim
...
@@ -404,9 +404,13 @@ class TraceIndice(object):
...
@@ -404,9 +404,13 @@ class TraceIndice(object):
"""
"""
self
.
_del_dim
(
node_idx
,
-
1
)
self
.
_del_dim
(
node_idx
,
-
1
)
self
.
_assign_indice_as_input
(
node
,
node_idx
)
self
.
_assign_indice_as_input
(
node
,
node_idx
)
self
.
_add_dim
(
node_idx
,
node
.
args
[
1
])
dim_idx
=
node
.
args
[
1
]
# unsqueeze(-1) = unsqueeze(shape_num + 1)
if
dim_idx
<
0
:
dim_idx
=
list
(
range
(
len
(
get_node_shape
(
node
))))[
dim_idx
]
self
.
_add_dim
(
node_idx
,
dim_idx
)
def
_assign_dropout_indice
(
self
,
node
,
node_idx
):
def
_assign_dropout_indice
(
self
,
node
:
Node
,
node_idx
:
int
):
"""
"""
Assign indice for unsqueeze op.
Assign indice for unsqueeze op.
1. assign new indice for unsqueeze dim
1. assign new indice for unsqueeze dim
...
@@ -417,7 +421,7 @@ class TraceIndice(object):
...
@@ -417,7 +421,7 @@ class TraceIndice(object):
"""
"""
self
.
_assign_indice_as_input
(
node
,
node_idx
)
self
.
_assign_indice_as_input
(
node
,
node_idx
)
def
_assign_ones_like_indice
(
self
,
node
,
node_idx
):
def
_assign_ones_like_indice
(
self
,
node
:
Node
,
node_idx
:
int
):
"""
"""
Assign indice for oneslike op.
Assign indice for oneslike op.
1. assign new indice for all dim
1. assign new indice for all dim
...
@@ -428,7 +432,47 @@ class TraceIndice(object):
...
@@ -428,7 +432,47 @@ class TraceIndice(object):
"""
"""
self
.
_assign_all_indice
(
node
,
node_idx
)
self
.
_assign_all_indice
(
node
,
node_idx
)
def
_assign_view_reshape_indice
(
self
,
node
,
node_idx
):
def
_assign_getitem_indice
(
self
,
node
:
Node
,
node_idx
:
int
):
"""
Assign indice for getitem.
getitem can act like slice sometimes
Args:
node (node)
node_idx (int)
"""
node_args
=
flat_list
(
node
.
args
[
1
:])
if
not
any
(
i
==
str
(
node_arg
)
for
i
in
[
"None"
,
"Ellipsis"
]
for
node_arg
in
node_args
):
return
# node args should be like [Ellipsis, slice(start, step, end), None]
node_shape
=
get_node_shape
(
node
)
origin_idx_count
=
0
new_idx_count
=
0
new_dim_num
=
sum
([
1
if
str
(
i
)
==
"None"
else
0
for
i
in
node_args
])
for
_
in
range
(
new_dim_num
):
self
.
_del_dim
(
node_idx
,
0
)
self
.
_assign_indice_as_input
(
node
,
node_idx
)
for
_
,
node_arg
in
enumerate
(
node_args
):
node_arg_str
=
str
(
node_arg
)
# Ellipsis means [..., ]
if
"Ellipsis"
==
node_arg_str
:
shape_gap
=
len
(
node_shape
)
-
len
(
node_args
)
+
1
origin_idx_count
+=
shape_gap
new_idx_count
+=
shape_gap
# slice(None, None, None) means all indexes, doesn't support other slice
elif
"slice(None, None, None)"
==
node_arg_str
:
origin_idx_count
+=
1
new_idx_count
+=
1
# None means a new dim
elif
"None"
==
node_arg_str
:
self
.
_add_dim
(
node_idx
,
new_idx_count
)
new_idx_count
+=
1
else
:
raise
NotImplementedError
()
def
_assign_view_reshape_indice
(
self
,
node
:
Node
,
node_idx
:
int
):
"""
"""
Assign indice for view and reshape op.
Assign indice for view and reshape op.
1. get origin shape and target shape by meta info.
1. get origin shape and target shape by meta info.
...
@@ -447,7 +491,7 @@ class TraceIndice(object):
...
@@ -447,7 +491,7 @@ class TraceIndice(object):
origin_node
=
node
.
args
[
0
]
origin_node
=
node
.
args
[
0
]
origin_shape
=
origin_node
.
meta
[
"tensor_meta"
].
shape
origin_shape
=
origin_node
.
meta
[
"tensor_meta"
].
shape
target_shape
=
[]
target_shape
=
[]
unflated_args
=
un
flat_list
(
node
.
args
)
unflated_args
=
flat_list
(
node
.
args
)
for
i
in
range
(
1
,
len
(
unflated_args
)):
for
i
in
range
(
1
,
len
(
unflated_args
)):
if
isinstance
(
unflated_args
[
i
],
int
):
if
isinstance
(
unflated_args
[
i
],
int
):
target_shape
.
append
(
unflated_args
[
i
])
target_shape
.
append
(
unflated_args
[
i
])
...
@@ -544,6 +588,8 @@ class TraceIndice(object):
...
@@ -544,6 +588,8 @@ class TraceIndice(object):
self
.
_assign_einsum_indice
(
node
,
idx
)
self
.
_assign_einsum_indice
(
node
,
idx
)
elif
"layer_norm"
in
node
.
name
:
elif
"layer_norm"
in
node
.
name
:
self
.
_assign_layernorm_indice
(
node
,
idx
)
self
.
_assign_layernorm_indice
(
node
,
idx
)
elif
"getitem"
in
node
.
name
:
self
.
_assign_getitem_indice
(
node
,
idx
)
elif
any
(
i
in
node
.
name
for
i
in
[
"getattr"
,
"getitem"
,
"eq"
,
"_assert"
]):
elif
any
(
i
in
node
.
name
for
i
in
[
"getattr"
,
"getitem"
,
"eq"
,
"_assert"
]):
continue
continue
else
:
else
:
...
...
colossalai/autochunk/utils.py
View file @
ecccc91f
...
@@ -3,14 +3,14 @@ from typing import Any, Callable, Dict, Iterable, List, Tuple
...
@@ -3,14 +3,14 @@ from typing import Any, Callable, Dict, Iterable, List, Tuple
from
torch.fx.node
import
Node
from
torch.fx.node
import
Node
def
un
flat_list
(
inputs
):
def
flat_list
(
inputs
):
"""
"""
un
flat a list by recursion
flat a list by recursion
"""
"""
res
=
[]
res
=
[]
for
i
in
inputs
:
for
i
in
inputs
:
if
isinstance
(
i
,
list
)
or
isinstance
(
i
,
set
)
or
isinstance
(
i
,
tuple
):
if
isinstance
(
i
,
list
)
or
isinstance
(
i
,
set
)
or
isinstance
(
i
,
tuple
):
res
.
extend
(
un
flat_list
(
i
))
res
.
extend
(
flat_list
(
i
))
else
:
else
:
res
.
append
(
i
)
res
.
append
(
i
)
return
res
return
res
...
@@ -27,8 +27,13 @@ def find_first_tensor_arg(node):
...
@@ -27,8 +27,13 @@ def find_first_tensor_arg(node):
def
is_non_compute_node
(
node
):
def
is_non_compute_node
(
node
):
if
any
(
i
in
node
.
op
for
i
in
[
"placeholder"
,
"get_attr"
,
"output"
])
or
any
(
if
any
(
i
in
node
.
op
for
i
in
[
"placeholder"
,
"get_attr"
,
"output"
])
or
any
(
i
in
node
.
name
for
i
in
[
"getattr"
]):
i
in
node
.
name
for
i
in
[
"getitem"
,
"getattr"
]):
return
True
if
"getitem"
in
node
.
name
:
node_args
=
flat_list
(
node
.
args
[
1
:])
for
node_arg
in
node_args
:
if
any
(
i
==
str
(
node_arg
)
for
i
in
[
"None"
,
"Ellipsis"
]):
return
False
return
True
return
True
return
False
return
False
...
@@ -40,15 +45,15 @@ def get_node_shape(node):
...
@@ -40,15 +45,15 @@ def get_node_shape(node):
def
is_non_compute_node_except_placeholder
(
node
):
def
is_non_compute_node_except_placeholder
(
node
):
if
any
(
i
in
node
.
op
for
i
in
[
"get_attr"
,
"output"
])
or
any
(
i
in
node
.
name
for
i
in
[
"getitem"
,
"getattr"
])
:
if
"placeholder"
in
node
.
op
:
return
Tru
e
return
Fals
e
return
False
return
is_non_compute_node
(
node
)
def
is_non_compute_node_except_placeholder_output
(
node
):
def
is_non_compute_node_except_placeholder_output
(
node
):
if
any
(
i
in
node
.
op
for
i
in
[
"get_attr"
])
or
any
(
i
in
node
.
name
for
i
in
[
"getitem"
,
"getattr"
])
:
if
"output"
in
node
.
op
:
return
Tru
e
return
Fals
e
return
False
return
is_non_compute_node_except_placeholder
(
node
)
def
find_idx_by_name
(
name
,
nodes_list
):
def
find_idx_by_name
(
name
,
nodes_list
):
...
...
tests/test_autochunk/test_evoformer_codegen.py
View file @
ecccc91f
...
@@ -27,18 +27,17 @@ if CODEGEN_AVAILABLE and is_compatible_with_meta():
...
@@ -27,18 +27,17 @@ if CODEGEN_AVAILABLE and is_compatible_with_meta():
def
_test_fwd
(
model
:
torch
.
nn
.
Module
,
gm
:
ColoGraphModule
,
node
,
pair
,
node_mask
,
pair_mask
):
def
_test_fwd
(
model
:
torch
.
nn
.
Module
,
gm
:
ColoGraphModule
,
node
,
pair
,
node_mask
,
pair_mask
):
# for memory test
# for memory test
# model = model.cuda()
# torch.cuda.reset_peak_memory_stats()
# torch.cuda.reset_peak_memory_stats()
# now_mem = torch.cuda.memory_allocated() / 1024**2
# now_mem = torch.cuda.memory_allocated() / 1024**2
# with torch.no_grad():
# with torch.no_grad():
# node1 = node.clone()
# node1 = node.clone()
# pair1 = pair.clone()
# pair1 = pair.clone()
# gm(node1, pair1)
# node_mask1 = node_mask.clone()
# new_now_mem = torch.cuda.memory_allocated() / 1024**2
# pair_mask1 = pair_mask.clone()
# gm(node1, pair1, node_mask1, pair_mask1)
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
# print(
# print("autochunk max mem:%.2f"% (new_max_mem - now_mem))
# "autochunk now mem:%.2f max mem:%.2f"
# % (new_now_mem - now_mem, new_max_mem - now_mem)
# )
# test forward
# test forward
model
=
model
.
cuda
()
model
=
model
.
cuda
()
...
@@ -113,7 +112,7 @@ def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory):
...
@@ -113,7 +112,7 @@ def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory):
MetaTensor
(
node_mask
,
fake_device
=
"cuda:0"
),
MetaTensor
(
node_mask
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair_mask
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair_mask
,
fake_device
=
"cuda:0"
),
)
)
#
codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory)
codegen
=
AutoChunkCodeGen
(
meta_graph
,
max_memory
=
max_memory
,
print_mem
=
False
)
# trace and recompile
# trace and recompile
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
...
@@ -130,14 +129,14 @@ def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory):
...
@@ -130,14 +129,14 @@ def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory):
"_mask_trans"
:
True
,
"_mask_trans"
:
True
,
},
},
)
)
#
graph.set_codegen(codegen)
graph
.
set_codegen
(
codegen
)
gm
=
ColoGraphModule
(
model
,
graph
)
gm
=
ColoGraphModule
(
model
,
graph
)
gm
.
recompile
()
gm
.
recompile
()
# assert we have inserted chunk
# assert we have inserted chunk
code
=
graph
.
python_code
(
"self"
).
src
code
=
graph
.
python_code
(
"self"
).
src
assert
"chunk_size"
in
code
# print(code)
# print(code)
assert
"chunk_result = None; chunk_size = None;"
in
code
_test_fwd
(
model
,
gm
,
node
,
pair
,
node_mask
,
pair_mask
)
_test_fwd
(
model
,
gm
,
node
,
pair
,
node_mask
,
pair_mask
)
gpc
.
destroy
()
gpc
.
destroy
()
...
@@ -147,7 +146,7 @@ def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory):
...
@@ -147,7 +146,7 @@ def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory):
not
(
CODEGEN_AVAILABLE
and
is_compatible_with_meta
()
and
HAS_REPO
),
not
(
CODEGEN_AVAILABLE
and
is_compatible_with_meta
()
and
HAS_REPO
),
reason
=
"torch version is lower than 1.12.0"
,
reason
=
"torch version is lower than 1.12.0"
,
)
)
@
pytest
.
mark
.
parametrize
(
"max_memory"
,
[
None
,
2
0
,
2
5
,
3
0
])
@
pytest
.
mark
.
parametrize
(
"max_memory"
,
[
None
,
2
4
,
2
8
,
3
2
])
@
pytest
.
mark
.
parametrize
(
"msa_len"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"msa_len"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"pair_len"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"pair_len"
,
[
64
])
def
test_evoformer_codegen
(
msa_len
,
pair_len
,
max_memory
):
def
test_evoformer_codegen
(
msa_len
,
pair_len
,
max_memory
):
...
@@ -161,4 +160,4 @@ def test_evoformer_codegen(msa_len, pair_len, max_memory):
...
@@ -161,4 +160,4 @@ def test_evoformer_codegen(msa_len, pair_len, max_memory):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
_test_evoformer_codegen
(
0
,
32
,
64
,
2
5
)
_test_evoformer_codegen
(
0
,
32
,
64
,
2
4
)
tests/test_autochunk/test_simple_evoformer_codegen.py
View file @
ecccc91f
...
@@ -13,7 +13,7 @@ except:
...
@@ -13,7 +13,7 @@ except:
import
colossalai
import
colossalai
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.fx
import
ColoTracer
from
colossalai.fx
import
ColoTracer
,
symbolic_trace
from
colossalai.fx._compatibility
import
is_compatible_with_meta
from
colossalai.fx._compatibility
import
is_compatible_with_meta
from
colossalai.fx.codegen.activation_checkpoint_codegen
import
CODEGEN_AVAILABLE
from
colossalai.fx.codegen.activation_checkpoint_codegen
import
CODEGEN_AVAILABLE
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.graph_module
import
ColoGraphModule
...
@@ -26,21 +26,6 @@ if CODEGEN_AVAILABLE and is_compatible_with_meta():
...
@@ -26,21 +26,6 @@ if CODEGEN_AVAILABLE and is_compatible_with_meta():
def
_test_fwd
(
model
:
torch
.
nn
.
Module
,
gm
:
ColoGraphModule
,
node
,
pair
):
def
_test_fwd
(
model
:
torch
.
nn
.
Module
,
gm
:
ColoGraphModule
,
node
,
pair
):
# for memory test
# torch.cuda.reset_peak_memory_stats()
# now_mem = torch.cuda.memory_allocated() / 1024**2
# with torch.no_grad():
# node1 = node.clone()
# pair1 = pair.clone()
# gm(node1, pair1)
# new_now_mem = torch.cuda.memory_allocated() / 1024**2
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
# print(
# "autochunk now mem:%.2f max mem:%.2f"
# % (new_now_mem - now_mem, new_max_mem - now_mem)
# )
# test forward
with
torch
.
no_grad
():
with
torch
.
no_grad
():
non_fx_out
=
model
(
node
,
pair
)
non_fx_out
=
model
(
node
,
pair
)
fx_out
=
gm
(
node
,
pair
)
fx_out
=
gm
(
node
,
pair
)
...
@@ -69,6 +54,16 @@ def _test_simple_evoformer_codegen(rank, msa_len, pair_len, max_memory):
...
@@ -69,6 +54,16 @@ def _test_simple_evoformer_codegen(rank, msa_len, pair_len, max_memory):
node
=
torch
.
randn
(
1
,
msa_len
,
pair_len
,
256
).
cuda
()
node
=
torch
.
randn
(
1
,
msa_len
,
pair_len
,
256
).
cuda
()
pair
=
torch
.
randn
(
1
,
pair_len
,
pair_len
,
128
).
cuda
()
pair
=
torch
.
randn
(
1
,
pair_len
,
pair_len
,
128
).
cuda
()
# meta info prop
meta_graph
=
symbolic_trace
(
model
,
meta_args
=
{
"node"
:
node
.
to
(
torch
.
device
(
"meta"
)),
"pair"
:
pair
.
to
(
torch
.
device
(
"meta"
)),
})
# must use symbolic_trace
interp
=
MetaInfoProp
(
meta_graph
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
))
codegen
=
AutoChunkCodeGen
(
meta_graph
,
max_memory
=
max_memory
)
# trace the module and replace codegen
# trace the module and replace codegen
graph
=
ColoTracer
().
trace
(
graph
=
ColoTracer
().
trace
(
model
,
model
,
...
@@ -77,24 +72,14 @@ def _test_simple_evoformer_codegen(rank, msa_len, pair_len, max_memory):
...
@@ -77,24 +72,14 @@ def _test_simple_evoformer_codegen(rank, msa_len, pair_len, max_memory):
"pair"
:
pair
.
to
(
torch
.
device
(
"meta"
)),
"pair"
:
pair
.
to
(
torch
.
device
(
"meta"
)),
},
},
)
)
gm_prop
=
torch
.
fx
.
symbolic_trace
(
model
)
# must use symbolic_trace
interp
=
MetaInfoProp
(
gm_prop
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
))
# now run it twice to get meta info in graph module, not necessary
gm
=
torch
.
fx
.
GraphModule
(
model
,
graph
)
interp
=
MetaInfoProp
(
gm
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
))
codegen
=
AutoChunkCodeGen
(
gm_prop
,
max_memory
=
max_memory
)
graph
.
set_codegen
(
codegen
)
graph
.
set_codegen
(
codegen
)
gm
=
ColoGraphModule
(
model
,
graph
)
gm
=
ColoGraphModule
(
model
,
graph
)
gm
.
recompile
()
gm
.
recompile
()
# assert we have inserted chunk
# assert we have inserted chunk
code
=
graph
.
python_code
(
"self"
).
src
code
=
graph
.
python_code
(
"self"
).
src
assert
"chunk_size"
in
code
# print(code)
# print(code)
assert
"chunk_result = None; chunk_size = None;"
in
code
_test_fwd
(
model
,
gm
,
node
,
pair
)
_test_fwd
(
model
,
gm
,
node
,
pair
)
gpc
.
destroy
()
gpc
.
destroy
()
...
...
tests/test_autochunk/test_simple_evoformer_search.py
View file @
ecccc91f
...
@@ -47,18 +47,18 @@ def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
...
@@ -47,18 +47,18 @@ def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
str
(
target_regions
),
str
(
target_regions
),
)
)
for
region
in
target_regions
:
for
region
in
target_regions
:
assert
(
region
in
found_regions
),
"region:%s not in found regions for msa:%d, pair:%d, maxmem:%
d
"
%
(
assert
(
region
in
found_regions
),
"region:%s not in found regions for msa:%d, pair:%d, maxmem:%
s
"
%
(
str
(
region
),
str
(
region
),
msa_len
,
msa_len
,
pair_len
,
pair_len
,
max_memory
,
str
(
max_memory
)
,
)
)
for
region
in
found_regions
:
for
region
in
found_regions
:
assert
(
region
in
target_regions
),
"region:%s should not be found for msa:%d, pair:%d, maxmem:%d"
%
(
assert
(
region
in
target_regions
),
"region:%s should not be found for msa:%d, pair:%d, maxmem:%d"
%
(
str
(
region
),
str
(
region
),
msa_len
,
msa_len
,
pair_len
,
pair_len
,
max_memory
,
str
(
max_memory
)
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment