Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
ecccc91f
Unverified
Commit
ecccc91f
authored
Jan 19, 2023
by
oahzxl
Committed by
GitHub
Jan 19, 2023
Browse files
[autochunk] support autochunk on evoformer (#2497)
parent
304f1ba1
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
200 additions
and
188 deletions
+200
-188
colossalai/autochunk/autochunk_codegen.py
colossalai/autochunk/autochunk_codegen.py
+3
-3
colossalai/autochunk/estimate_memory.py
colossalai/autochunk/estimate_memory.py
+20
-47
colossalai/autochunk/search_chunk.py
colossalai/autochunk/search_chunk.py
+22
-61
colossalai/autochunk/trace_flow.py
colossalai/autochunk/trace_flow.py
+53
-10
colossalai/autochunk/trace_indice.py
colossalai/autochunk/trace_indice.py
+61
-15
colossalai/autochunk/utils.py
colossalai/autochunk/utils.py
+16
-11
tests/test_autochunk/test_evoformer_codegen.py
tests/test_autochunk/test_evoformer_codegen.py
+10
-11
tests/test_autochunk/test_simple_evoformer_codegen.py
tests/test_autochunk/test_simple_evoformer_codegen.py
+12
-27
tests/test_autochunk/test_simple_evoformer_search.py
tests/test_autochunk/test_simple_evoformer_search.py
+3
-3
No files found.
colossalai/autochunk/autochunk_codegen.py
View file @
ecccc91f
...
...
@@ -123,12 +123,13 @@ def _replace_name(context: str, name_from: str, name_to: str) -> str:
"""
replace node name
"""
patterns
=
[(
" "
,
" "
),
(
" "
,
"."
),
(
" "
,
","
),
(
"("
,
")"
),
(
"("
,
","
),
(
" "
,
")"
)]
patterns
=
[(
" "
,
" "
),
(
" "
,
"."
),
(
" "
,
","
),
(
"("
,
")"
),
(
"("
,
","
),
(
" "
,
")
"
),
(
" "
,
""
),
(
""
,
"
"
)]
for
p
in
patterns
:
source
=
p
[
0
]
+
name_from
+
p
[
1
]
target
=
p
[
0
]
+
name_to
+
p
[
1
]
if
source
in
context
:
context
=
context
.
replace
(
source
,
target
)
break
return
context
...
...
@@ -138,8 +139,7 @@ def _replace_reshape_size(context: str, node_name: str, reshape_size_dict: Dict)
"""
if
node_name
not
in
reshape_size_dict
:
return
context
for
size_name
,
size_value
in
reshape_size_dict
[
node_name
].
items
():
context
=
context
.
replace
(
size_name
,
size_value
)
context
=
context
.
replace
(
reshape_size_dict
[
node_name
][
0
],
reshape_size_dict
[
node_name
][
1
])
return
context
...
...
colossalai/autochunk/estimate_memory.py
View file @
ecccc91f
...
...
@@ -37,10 +37,10 @@ class EstimateMemory(object):
def
_add_active_node
(
self
,
n
,
active_list
):
new_active
=
self
.
_get_output_node
(
n
)[
1
]
if
n
.
op
==
"placeholder"
:
if
n
.
op
==
"placeholder"
and
get_node_shape
(
n
)
is
not
None
:
new_active
.
append
(
n
.
name
)
for
i
in
new_active
:
if
i
not
in
active_list
:
if
i
not
in
active_list
and
get_node_shape
(
n
)
is
not
None
:
active_list
.
append
(
i
)
def
_get_delete_node
(
self
,
user
,
user_to_last_uses
,
to_keep
=
None
):
...
...
@@ -77,15 +77,11 @@ class EstimateMemory(object):
if
i
in
active_list
:
active_list
.
remove
(
i
)
def
_get_chunk_inputs_size
(
self
,
chunk_inputs
,
chunk_inputs_non_chunk
,
node_list
,
chunk_end_idx
):
def
_get_chunk_inputs_size
(
self
,
chunk_inputs
,
chunk_inputs_non_chunk
,
node_list
,
chunk_end_idx
):
nodes_to_delete
=
[]
for
chunk_input
in
chunk_inputs
+
chunk_inputs_non_chunk
:
chunk_input_users
=
chunk_input
.
users
.
keys
()
chunk_input_users_idx
=
[
find_idx_by_name
(
i
.
name
,
node_list
)
for
i
in
chunk_input_users
]
chunk_input_users_idx
=
[
find_idx_by_name
(
i
.
name
,
node_list
)
for
i
in
chunk_input_users
]
if
all
(
i
<=
chunk_end_idx
for
i
in
chunk_input_users_idx
):
if
chunk_input
not
in
nodes_to_delete
:
nodes_to_delete
.
append
(
chunk_input
)
...
...
@@ -112,9 +108,7 @@ class EstimateMemory(object):
not_contiguous_ops
=
[
"permute"
]
inherit_contiguous_ops
=
[
"transpose"
,
"view"
]
if
node
.
op
==
"call_function"
and
any
(
n
in
node
.
name
for
n
in
[
"matmul"
,
"reshape"
]
):
if
node
.
op
==
"call_function"
and
any
(
n
in
node
.
name
for
n
in
[
"matmul"
,
"reshape"
]):
for
n
in
node
.
args
:
if
n
in
not_contiguous_list
:
# matmul won't change origin tensor, but create a tmp copy
...
...
@@ -125,9 +119,7 @@ class EstimateMemory(object):
# module will just make origin tensor to contiguous
if
delete
:
not_contiguous_list
.
remove
(
n
)
elif
node
.
op
==
"call_method"
and
any
(
i
in
node
.
name
for
i
in
not_contiguous_ops
):
elif
node
.
op
==
"call_method"
and
any
(
i
in
node
.
name
for
i
in
not_contiguous_ops
):
if
node
not
in
not_contiguous_list
:
not_contiguous_list
.
append
(
node
)
return
mem
...
...
@@ -142,9 +134,7 @@ class EstimateMemory(object):
else
:
return
float
(
chunk_size
)
/
node_shape
[
chunk_dim
]
def
_get_chunk_delete_node_size
(
self
,
user
,
user_to_last_uses
,
chunk_ratio
,
chunk_inputs_names
):
def
_get_chunk_delete_node_size
(
self
,
user
,
user_to_last_uses
,
chunk_ratio
,
chunk_inputs_names
):
# if any(j in user.name for j in ['transpose', 'permute', 'view']):
# return 0
if
user
.
op
in
(
"placeholder"
,
"output"
):
...
...
@@ -196,7 +186,7 @@ class EstimateMemory(object):
Returns:
act_memory_peak_log (List): peak memory of every node
act_memory_after_node_log (List): memory after excuting every node
active_node_list_log (List): active nodes of every node. active nodes refer to
active_node_list_log (List): active nodes of every node. active nodes refer to
nodes generated but not deleted.
"""
act_memory
=
0.0
...
...
@@ -212,7 +202,7 @@ class EstimateMemory(object):
use_chunk
=
True
if
chunk_infos
is
not
None
else
False
chunk_within
=
False
chunk_region_idx
=
None
chunk_ratio
=
1
# use it to estimate chunk mem
chunk_ratio
=
1
# use it to estimate chunk mem
chunk_inputs_names
=
[]
if
use_chunk
:
...
...
@@ -221,23 +211,18 @@ class EstimateMemory(object):
chunk_ends
=
[
i
[
1
]
for
i
in
chunk_regions
]
chunk_inputs
=
[
i
[
"inputs"
]
for
i
in
chunk_infos
]
chunk_inputs_non_chunk
=
[
i
[
"inputs_non_chunk"
]
for
i
in
chunk_infos
]
chunk_inputs_names
=
[
j
.
name
for
i
in
chunk_inputs
for
j
in
i
]
+
[
j
.
name
for
i
in
chunk_inputs_non_chunk
for
j
in
i
]
chunk_inputs_names
=
[
j
.
name
for
i
in
chunk_inputs
for
j
in
i
]
+
[
j
.
name
for
i
in
chunk_inputs_non_chunk
for
j
in
i
]
chunk_outputs
=
[
i
[
"outputs"
][
0
]
for
i
in
chunk_infos
]
chunk_node_dim
=
[
i
[
"node_chunk_dim"
]
for
i
in
chunk_infos
]
chunk_sizes
=
[
i
[
"chunk_size"
]
if
"chunk_size"
in
i
else
1
for
i
in
chunk_infos
]
chunk_sizes
=
[
i
[
"chunk_size"
]
if
"chunk_size"
in
i
else
1
for
i
in
chunk_infos
]
for
idx
,
node
in
enumerate
(
node_list
):
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
if
use_chunk
and
idx
in
chunk_starts
:
chunk_within
=
True
chunk_region_idx
=
chunk_starts
.
index
(
idx
)
act_memory
+=
self
.
_get_output_node_size
(
chunk_outputs
[
chunk_region_idx
]
)
/
(
1024
**
2
)
act_memory
+=
self
.
_get_output_node_size
(
chunk_outputs
[
chunk_region_idx
])
/
(
1024
**
2
)
# determine chunk ratio for current node
if
chunk_within
:
...
...
@@ -262,22 +247,13 @@ class EstimateMemory(object):
else
:
# forward memory
# TODO: contiguous_memory still not accurate for matmul, view, reshape and transpose
act_memory
+=
(
self
.
_get_contiguous_memory
(
node
,
not_contiguous_list
)
*
chunk_ratio
/
(
1024
**
2
)
)
act_memory
+=
(
self
.
_get_output_node_size
(
node
)
*
chunk_ratio
/
(
1024
**
2
)
)
act_memory
+=
(
self
.
_get_contiguous_memory
(
node
,
not_contiguous_list
)
*
chunk_ratio
/
(
1024
**
2
))
act_memory
+=
(
self
.
_get_output_node_size
(
node
)
*
chunk_ratio
/
(
1024
**
2
))
# record max act memory
act_memory_peak_log
.
append
(
act_memory
)
# delete useless memory
act_memory
-=
(
self
.
_get_contiguous_memory
(
node
,
not_contiguous_list
,
delete
=
True
)
*
chunk_ratio
/
(
1024
**
2
)
)
act_memory
-=
(
self
.
_get_contiguous_memory
(
node
,
not_contiguous_list
,
delete
=
True
)
*
chunk_ratio
/
(
1024
**
2
))
# delete unused vars not in chunk_input_list
# we can't delete input nodes until chunk ends
if
chunk_within
:
...
...
@@ -288,9 +264,8 @@ class EstimateMemory(object):
chunk_inputs_names
,
)
/
(
1024
**
2
)
else
:
act_memory
-=
self
.
_get_delete_node_size
(
node
,
user_to_last_uses_no_free_var
,
chunk_inputs_names
)
/
(
1024
**
2
)
act_memory
-=
self
.
_get_delete_node_size
(
node
,
user_to_last_uses_no_free_var
,
chunk_inputs_names
)
/
(
1024
**
2
)
# log active node, only effective without chunk
self
.
_add_active_node
(
node
,
active_node_list
)
...
...
@@ -298,9 +273,7 @@ class EstimateMemory(object):
# if node in chunk end nodes, restore chunk settings
if
use_chunk
and
idx
in
chunk_ends
:
act_memory
-=
(
self
.
_get_output_node_size
(
node
)
*
chunk_ratio
/
(
1024
**
2
)
)
act_memory
-=
(
self
.
_get_output_node_size
(
node
)
*
chunk_ratio
/
(
1024
**
2
))
act_memory
-=
self
.
_get_chunk_inputs_size
(
chunk_inputs
[
chunk_region_idx
],
chunk_inputs_non_chunk
[
chunk_region_idx
],
...
...
colossalai/autochunk/search_chunk.py
View file @
ecccc91f
...
...
@@ -8,11 +8,7 @@ from .reorder_graph import ReorderGraph
from
.select_chunk
import
SelectChunk
from
.trace_flow
import
TraceFlow
from
.trace_indice
import
TraceIndice
from
.utils
import
(
get_node_shape
,
is_non_compute_node
,
is_non_compute_node_except_placeholder
,
)
from
.utils
import
get_node_shape
,
is_non_compute_node
,
is_non_compute_node_except_placeholder
class
SearchChunk
(
object
):
...
...
@@ -73,13 +69,11 @@ class SearchChunk(object):
"""
free_var_idx
=
[]
for
idx
,
n
in
enumerate
(
self
.
trace_indice
.
node_list
):
if
n
.
op
==
"placeholder"
:
if
n
.
op
==
"placeholder"
and
get_node_shape
(
n
)
is
not
None
:
free_var_idx
.
append
(
idx
)
return
free_var_idx
def
_search_max_chunk_region
(
self
,
active_node
:
List
,
peak_node
:
Node
,
chunk_regions
:
List
)
->
Tuple
:
def
_search_max_chunk_region
(
self
,
active_node
:
List
,
peak_node
:
Node
,
chunk_regions
:
List
)
->
Tuple
:
"""
Search max chunk region according to peak memory node
...
...
@@ -124,15 +118,9 @@ class SearchChunk(object):
region
=
i
[
"region"
]
if
chunk_region_start
>=
region
[
0
]
and
chunk_region_end
<=
region
[
1
]:
return
None
elif
(
region
[
0
]
<=
chunk_region_start
<=
region
[
1
]
and
chunk_region_end
>
region
[
1
]
):
elif
(
region
[
0
]
<=
chunk_region_start
<=
region
[
1
]
and
chunk_region_end
>
region
[
1
]):
chunk_region_start
=
region
[
1
]
+
1
elif
(
region
[
0
]
<=
chunk_region_end
<=
region
[
1
]
and
chunk_region_start
<
region
[
0
]
):
elif
(
region
[
0
]
<=
chunk_region_end
<=
region
[
1
]
and
chunk_region_start
<
region
[
0
]):
chunk_region_end
=
region
[
0
]
-
1
return
chunk_region_start
,
chunk_region_end
...
...
@@ -164,25 +152,16 @@ class SearchChunk(object):
for
start_node
,
start_trace
in
start_traces
.
items
():
for
start_dim
,
_
in
enumerate
(
start_trace
[
"indice"
]):
# dim size cannot be 1
if
(
get_node_shape
(
end_node
)[
end_dim
]
==
1
or
get_node_shape
(
start_node
)[
start_dim
]
==
1
):
if
(
get_node_shape
(
end_node
)[
end_dim
]
==
1
or
get_node_shape
(
start_node
)[
start_dim
]
==
1
):
continue
# check index source align
if
not
self
.
trace_flow
.
check_index_source
(
start_dim
,
start_node
,
start_idx
,
end_dim
,
end_node
):
if
not
self
.
trace_flow
.
check_index_source
(
start_dim
,
start_node
,
start_idx
,
end_dim
,
end_node
):
continue
# check index copmute
if
not
self
.
trace_flow
.
check_index_compute
(
start_idx
,
end_dim
,
end_node
,
end_idx
):
if
not
self
.
trace_flow
.
check_index_compute
(
start_idx
,
end_dim
,
end_node
,
end_idx
):
continue
# flow search
chunk_info
=
self
.
trace_flow
.
flow_search
(
start_idx
,
start_dim
,
end_idx
,
end_dim
)
chunk_info
=
self
.
trace_flow
.
flow_search
(
start_idx
,
start_dim
,
end_idx
,
end_dim
)
if
chunk_info
is
None
:
continue
# check index copmute
...
...
@@ -191,9 +170,7 @@ class SearchChunk(object):
chunk_infos
.
append
(
chunk_info
)
return
chunk_infos
def
_search_possible_chunk_regions
(
self
,
max_chunk_region
:
Tuple
,
peak_node
:
Node
)
->
List
:
def
_search_possible_chunk_regions
(
self
,
max_chunk_region
:
Tuple
,
peak_node
:
Node
)
->
List
:
"""
Search every possible region within the max chunk region.
...
...
@@ -206,28 +183,23 @@ class SearchChunk(object):
"""
possible_chunk_region
=
[]
output_trace
=
copy
.
deepcopy
(
self
.
trace_indice
.
indice_trace_list
)
input_trace
=
[]
# trace of a node's input nodes
input_trace
=
[]
# trace of a node's input nodes
for
_
,
n
in
enumerate
(
self
.
trace_indice
.
node_list
):
cur_trace
=
{}
for
arg
in
n
.
args
:
if
type
(
arg
)
==
type
(
n
)
and
not
is_non_compute_node_except_placeholder
(
arg
):
if
type
(
arg
)
==
type
(
n
)
and
not
is_non_compute_node_except_placeholder
(
arg
):
cur_trace
[
arg
]
=
self
.
trace_indice
.
_find_trace_from_node
(
arg
)
input_trace
.
append
(
cur_trace
)
for
start_idx
in
range
(
max_chunk_region
[
0
],
peak_node
+
1
):
for
end_idx
in
range
(
peak_node
,
max_chunk_region
[
1
]
+
1
):
# skip non compute nodes
if
is_non_compute_node
(
self
.
trace_indice
.
node_list
[
start_idx
]
)
or
is_non_compute_node
(
self
.
trace_indice
.
node_list
[
end_idx
]):
if
is_non_compute_node
(
self
.
trace_indice
.
node_list
[
start_idx
])
or
is_non_compute_node
(
self
.
trace_indice
.
node_list
[
end_idx
]):
continue
# select free dim
chunk_info
=
self
.
_find_chunk_info
(
input_trace
,
output_trace
,
start_idx
,
end_idx
)
chunk_info
=
self
.
_find_chunk_info
(
input_trace
,
output_trace
,
start_idx
,
end_idx
)
if
len
(
chunk_info
)
>
0
:
possible_chunk_region
.
extend
(
chunk_info
)
return
possible_chunk_region
...
...
@@ -256,17 +228,12 @@ class SearchChunk(object):
best_chunk_region (Dict)
"""
peak_node
=
self
.
_find_peak_node
(
mem_peak
)
max_chunk_region
=
self
.
_search_max_chunk_region
(
active_node
,
peak_node
,
chunk_infos
)
max_chunk_region
=
self
.
_search_max_chunk_region
(
active_node
,
peak_node
,
chunk_infos
)
if
max_chunk_region
==
None
:
return
None
possible_chunk_regions
=
self
.
_search_possible_chunk_regions
(
max_chunk_region
,
peak_node
)
best_chunk_region
=
self
.
select_chunk
.
_select_best_chunk_region
(
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
)
possible_chunk_regions
=
self
.
_search_possible_chunk_regions
(
max_chunk_region
,
peak_node
)
best_chunk_region
=
self
.
select_chunk
.
_select_best_chunk_region
(
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
)
best_chunk_region
=
self
.
reorder_graph
.
reorder_all
(
best_chunk_region
)
return
best_chunk_region
...
...
@@ -291,9 +258,7 @@ class SearchChunk(object):
init_mem_peak
,
_
,
active_node
,
)
=
self
.
estimate_memory
.
estimate_chunk_inference_mem
(
self
.
trace_indice
.
node_list
)
)
=
self
.
estimate_memory
.
estimate_chunk_inference_mem
(
self
.
trace_indice
.
node_list
)
mem_peak
=
init_mem_peak
while
True
:
...
...
@@ -306,14 +271,10 @@ class SearchChunk(object):
mem_peak
,
_
,
active_node
,
)
=
self
.
estimate_memory
.
estimate_chunk_inference_mem
(
self
.
trace_indice
.
node_list
,
chunk_infos
)
)
=
self
.
estimate_memory
.
estimate_chunk_inference_mem
(
self
.
trace_indice
.
node_list
,
chunk_infos
)
if
self
.
_stop_search
(
init_mem_peak
,
mem_peak
):
break
if
self
.
print_mem
:
self
.
print_mem
=
False
self
.
estimate_memory
.
estimate_chunk_inference_mem
(
self
.
trace_indice
.
node_list
,
chunk_infos
,
print_mem
=
True
)
self
.
estimate_memory
.
estimate_chunk_inference_mem
(
self
.
trace_indice
.
node_list
,
chunk_infos
,
print_mem
=
True
)
return
chunk_infos
colossalai/autochunk/trace_flow.py
View file @
ecccc91f
from
typing
import
Dict
,
List
,
Tuple
from
torch.fx.node
import
Node
from
.trace_indice
import
TraceIndice
from
.utils
import
(
find_chunk_all_input_nodes
,
find_chunk_compute_input_and_output_nodes
,
find_idx_by_name
,
flat_list
,
get_node_shape
,
is_non_compute_node
,
is_non_compute_node_except_placeholder
,
...
...
@@ -171,7 +176,7 @@ class TraceFlow(object):
# get cur node info
cur_node_chunk_dim
=
all_node_info
[
cur_node
][
"chunk_dim"
]
cur_node_fix_dim
=
all_node_info
[
cur_node
][
"fix_dim"
]
if
cur_node_chunk_dim
:
if
cur_node_chunk_dim
is
not
None
:
cur_node_compute
=
self
.
trace_indice
.
_find_compute_trace_from_node
(
cur_node
)
cur_node_source
=
self
.
trace_indice
.
_find_source_trace_from_node
(
cur_node
)
else
:
...
...
@@ -223,15 +228,32 @@ class TraceFlow(object):
cur_node_list
=
next_node_list
return
all_node_info
def
_get_input_nodes_dim
(
self
,
inputs
,
start_idx
,
end_idx
,
all_node_info
):
def
_get_input_nodes_dim
(
self
,
inputs
:
List
[
Node
],
start_idx
:
int
,
end_idx
:
int
,
all_node_info
:
Dict
)
->
Tuple
:
"""
Get chunk dim for every input node for their every entry, remove unchunked nodes
Args:
inputs (List[Node]): input nodes
all_node_info (Dict): describe all node's chunk dim and fix dim
start_idx (int): chunk start idx
end_idx (int): chunk end idx
Returns:
inputs (List(Node)): new inputs
inputs_dim (List): chunk dim for inputs
"""
inputs_dim
=
[]
remove_inputs
=
[]
for
input_node
in
inputs
:
input_dict
=
{}
input_node_idx
=
find_idx_by_name
(
input_node
.
name
,
self
.
trace_indice
.
node_list
)
for
user
in
input_node
.
users
.
keys
():
# skip non compute
if
is_non_compute_node
(
user
):
continue
# untraced node, mostly non compute
if
user
not
in
all_node_info
:
continue
user_idx
=
find_idx_by_name
(
user
.
name
,
self
.
trace_indice
.
node_list
)
if
start_idx
<=
user_idx
<=
end_idx
:
chunk_dim
=
all_node_info
[
user
][
"chunk_dim"
]
...
...
@@ -245,12 +267,24 @@ class TraceFlow(object):
remove_inputs
.
append
(
input_node
)
else
:
inputs_dim
.
append
(
input_dict
)
# remove unchunked inputs
for
i
in
remove_inputs
:
if
i
in
inputs
:
inputs
.
remove
(
i
)
return
inputs
,
inputs_dim
def
_get_prepose_nodes
(
self
,
all_node_info
,
start_idx
,
end_idx
):
def
_get_prepose_nodes
(
self
,
all_node_info
:
Dict
,
start_idx
:
int
,
end_idx
:
int
)
->
List
[
Node
]:
"""
get all useless nodes in chunk region and prepose them
Args:
all_node_info (Dict): describe all node's chunk dim and fix dim
start_idx (int): chunk start idx
end_idx (int): chunk end idx
Returns:
List[Node]: all nodes to be preposed
"""
# get all possible prepose nodes
maybe_prepose_nodes
=
[]
for
node
,
node_info
in
all_node_info
.
items
():
...
...
@@ -276,7 +310,7 @@ class TraceFlow(object):
for
cur_prepose_node
in
tmp_cur_prepose_nodes
:
if
prepose_flag
==
False
:
break
for
cur_prepose_node_arg
in
cur_prepose_node
.
a
rg
s
:
for
cur_prepose_node_arg
in
cur_prepose_node
.
a
ll_input_node
s
:
if
type
(
cur_prepose_node_arg
)
!=
type
(
cur_prepose_node
):
continue
# out of loop
...
...
@@ -360,19 +394,28 @@ class TraceFlow(object):
return
chunk_info
def
_reassgin_reshape_size
(
self
,
chunk_info
):
"""
Some shape args in reshape may have changed due to chunk
reassgin those changed shape
"""
chunk_region
=
chunk_info
[
"region"
]
reshape_size
=
{}
chunk_shape
=
get_node_shape
(
chunk_info
[
"outputs"
][
0
])[
chunk_info
[
"outputs_dim"
]]
for
node
in
self
.
trace_indice
.
node_list
[
chunk_region
[
0
]:
chunk_region
[
1
]
+
1
]:
if
any
(
i
in
node
.
name
for
i
in
[
"reshape"
,
"view"
]):
reshape_args
=
node
.
args
[
1
:]
reshape_log
=
self
.
trace_indice
.
indice_view_list
[
node
]
reshape_args
=
flat_list
(
node
.
args
[
1
:])
chunk_dim
=
chunk_info
[
"node_chunk_dim"
][
node
][
"chunk_dim"
]
re
shape
_size
[
node
.
name
]
=
{}
new_
shape
=
""
for
reshape_arg_dim
,
reshape_arg
in
enumerate
(
reshape_args
):
if
reshape_arg_dim
in
reshape_log
[
"dim_to"
]:
continue
if
reshape_arg_dim
==
chunk_dim
:
reshape_size
[
node
.
name
][
reshape_arg
.
name
]
=
(
"min(chunk_size, %d - chunk_idx)"
%
chunk_shape
)
new_shape
+=
"min(chunk_size, %d - chunk_idx), "
%
chunk_shape
else
:
if
isinstance
(
reshape_arg
,
int
):
new_shape
+=
"%s, "
%
str
(
reshape_arg
)
else
:
new_shape
+=
"%s, "
%
reshape_arg
.
name
new_shape
=
new_shape
[:
-
2
]
origin_shape
=
str
(
reshape_args
)[
1
:
-
1
]
reshape_size
[
node
.
name
]
=
[
origin_shape
,
new_shape
]
chunk_info
[
"reshape_size"
]
=
reshape_size
return
chunk_info
colossalai/autochunk/trace_indice.py
View file @
ecccc91f
...
...
@@ -3,7 +3,7 @@ from typing import Dict, List, Tuple
from
torch.fx.node
import
Node
from
.utils
import
find_first_tensor_arg
,
find_idx_by_name
,
get_node_shape
,
unflat_list
from
.utils
import
find_first_tensor_arg
,
find_idx_by_name
,
flat_list
,
get_node_shape
class
TraceIndice
(
object
):
...
...
@@ -28,7 +28,7 @@ class TraceIndice(object):
node_list (List)
"""
def
__init__
(
self
,
node_list
:
List
)
->
None
:
def
__init__
(
self
,
node_list
:
List
[
Node
]
)
->
None
:
self
.
node_list
=
node_list
self
.
indice_trace_list
=
self
.
_init_indice_trace_list
()
self
.
indice_view_list
=
{}
...
...
@@ -198,7 +198,7 @@ class TraceIndice(object):
node_idx
=
find_idx_by_name
(
node
.
name
,
self
.
node_list
)
return
self
.
indice_trace_list
[
node_idx
][
"compute"
]
def
_assign_indice_as_input
(
self
,
node
,
node_idx
,
input_node
=
None
):
def
_assign_indice_as_input
(
self
,
node
:
Node
,
node_idx
:
int
,
input_node
=
None
):
"""
Assign node's trace as its input node.
...
...
@@ -216,7 +216,7 @@ class TraceIndice(object):
self
.
_inherit_all_computation
(
input_node
,
node
)
def
_assign_all_indice
(
self
,
node
,
node_idx
):
def
_assign_all_indice
(
self
,
node
:
Node
,
node_idx
:
int
):
"""
Add new indice for all node's dims.
...
...
@@ -232,7 +232,7 @@ class TraceIndice(object):
new_trace
.
append
(
self
.
_add_indice
())
self
.
indice_trace_list
[
node_idx
][
"indice"
]
=
new_trace
def
_assign_transpose_indice
(
self
,
node
,
node_idx
):
def
_assign_transpose_indice
(
self
,
node
:
Node
,
node_idx
:
int
):
"""
Assign indice for transpose op.
1. swap input's dim according to transpose args
...
...
@@ -249,7 +249,7 @@ class TraceIndice(object):
self
.
_inherit_indice
(
input_node
,
tranpose_dim
[
1
],
node
,
tranpose_dim
[
0
])
self
.
_inherit_indice
(
input_node
,
tranpose_dim
[
0
],
node
,
tranpose_dim
[
1
])
def
_assign_permute_indice
(
self
,
node
,
node_idx
):
def
_assign_permute_indice
(
self
,
node
:
Node
,
node_idx
:
int
):
"""
Assign indice for permute op.
1. swap input's dim according to permute args
...
...
@@ -259,14 +259,14 @@ class TraceIndice(object):
node (node)
node_idx (int)
"""
permute_dim
=
un
flat_list
(
node
.
args
[
1
:])
permute_dim
=
flat_list
(
node
.
args
[
1
:])
input_node
=
node
.
args
[
0
]
self
.
_assign_indice_as_input
(
node
,
node_idx
,
input_node
)
for
idx
,
d
in
enumerate
(
permute_dim
):
self
.
_inherit_indice
(
input_node
,
d
,
node
,
idx
)
def
_assign_linear_indice
(
self
,
node
,
node_idx
):
def
_assign_linear_indice
(
self
,
node
:
Node
,
node_idx
:
int
):
"""
Assign indice for linear op.
1. copy trace from input node and change last indice accroding to weight
...
...
@@ -287,7 +287,7 @@ class TraceIndice(object):
self
.
_mark_computation
(
node
,
node_idx
,
[
-
1
])
def
_assign_matmul_indice
(
self
,
node
,
node_idx
):
def
_assign_matmul_indice
(
self
,
node
:
Node
,
node_idx
:
int
):
"""
Assign indice for matmul op.
1. copy trace from matmul_left and change last indice accroding to matmul_right. (assert they have same length)
...
...
@@ -393,7 +393,7 @@ class TraceIndice(object):
self
.
_assign_indice_as_input
(
node
,
idx
)
self
.
_mark_computation
(
node
,
idx
,
[
node
.
kwargs
[
"dim"
]])
def
_assign_unsqueeze_indice
(
self
,
node
,
node_idx
):
def
_assign_unsqueeze_indice
(
self
,
node
:
Node
,
node_idx
:
int
):
"""
Assign indice for unsqueeze op.
1. assign new indice for unsqueeze dim
...
...
@@ -404,9 +404,13 @@ class TraceIndice(object):
"""
self
.
_del_dim
(
node_idx
,
-
1
)
self
.
_assign_indice_as_input
(
node
,
node_idx
)
self
.
_add_dim
(
node_idx
,
node
.
args
[
1
])
dim_idx
=
node
.
args
[
1
]
# unsqueeze(-1) = unsqueeze(shape_num + 1)
if
dim_idx
<
0
:
dim_idx
=
list
(
range
(
len
(
get_node_shape
(
node
))))[
dim_idx
]
self
.
_add_dim
(
node_idx
,
dim_idx
)
def
_assign_dropout_indice
(
self
,
node
,
node_idx
):
def
_assign_dropout_indice
(
self
,
node
:
Node
,
node_idx
:
int
):
"""
Assign indice for unsqueeze op.
1. assign new indice for unsqueeze dim
...
...
@@ -417,7 +421,7 @@ class TraceIndice(object):
"""
self
.
_assign_indice_as_input
(
node
,
node_idx
)
def
_assign_ones_like_indice
(
self
,
node
,
node_idx
):
def
_assign_ones_like_indice
(
self
,
node
:
Node
,
node_idx
:
int
):
"""
Assign indice for oneslike op.
1. assign new indice for all dim
...
...
@@ -428,7 +432,47 @@ class TraceIndice(object):
"""
self
.
_assign_all_indice
(
node
,
node_idx
)
def
_assign_view_reshape_indice
(
self
,
node
,
node_idx
):
def
_assign_getitem_indice
(
self
,
node
:
Node
,
node_idx
:
int
):
"""
Assign indice for getitem.
getitem can act like slice sometimes
Args:
node (node)
node_idx (int)
"""
node_args
=
flat_list
(
node
.
args
[
1
:])
if
not
any
(
i
==
str
(
node_arg
)
for
i
in
[
"None"
,
"Ellipsis"
]
for
node_arg
in
node_args
):
return
# node args should be like [Ellipsis, slice(start, step, end), None]
node_shape
=
get_node_shape
(
node
)
origin_idx_count
=
0
new_idx_count
=
0
new_dim_num
=
sum
([
1
if
str
(
i
)
==
"None"
else
0
for
i
in
node_args
])
for
_
in
range
(
new_dim_num
):
self
.
_del_dim
(
node_idx
,
0
)
self
.
_assign_indice_as_input
(
node
,
node_idx
)
for
_
,
node_arg
in
enumerate
(
node_args
):
node_arg_str
=
str
(
node_arg
)
# Ellipsis means [..., ]
if
"Ellipsis"
==
node_arg_str
:
shape_gap
=
len
(
node_shape
)
-
len
(
node_args
)
+
1
origin_idx_count
+=
shape_gap
new_idx_count
+=
shape_gap
# slice(None, None, None) means all indexes, doesn't support other slice
elif
"slice(None, None, None)"
==
node_arg_str
:
origin_idx_count
+=
1
new_idx_count
+=
1
# None means a new dim
elif
"None"
==
node_arg_str
:
self
.
_add_dim
(
node_idx
,
new_idx_count
)
new_idx_count
+=
1
else
:
raise
NotImplementedError
()
def
_assign_view_reshape_indice
(
self
,
node
:
Node
,
node_idx
:
int
):
"""
Assign indice for view and reshape op.
1. get origin shape and target shape by meta info.
...
...
@@ -447,7 +491,7 @@ class TraceIndice(object):
origin_node
=
node
.
args
[
0
]
origin_shape
=
origin_node
.
meta
[
"tensor_meta"
].
shape
target_shape
=
[]
unflated_args
=
un
flat_list
(
node
.
args
)
unflated_args
=
flat_list
(
node
.
args
)
for
i
in
range
(
1
,
len
(
unflated_args
)):
if
isinstance
(
unflated_args
[
i
],
int
):
target_shape
.
append
(
unflated_args
[
i
])
...
...
@@ -544,6 +588,8 @@ class TraceIndice(object):
self
.
_assign_einsum_indice
(
node
,
idx
)
elif
"layer_norm"
in
node
.
name
:
self
.
_assign_layernorm_indice
(
node
,
idx
)
elif
"getitem"
in
node
.
name
:
self
.
_assign_getitem_indice
(
node
,
idx
)
elif
any
(
i
in
node
.
name
for
i
in
[
"getattr"
,
"getitem"
,
"eq"
,
"_assert"
]):
continue
else
:
...
...
colossalai/autochunk/utils.py
View file @
ecccc91f
...
...
@@ -3,14 +3,14 @@ from typing import Any, Callable, Dict, Iterable, List, Tuple
from
torch.fx.node
import
Node
def
un
flat_list
(
inputs
):
def
flat_list
(
inputs
):
"""
un
flat a list by recursion
flat a list by recursion
"""
res
=
[]
for
i
in
inputs
:
if
isinstance
(
i
,
list
)
or
isinstance
(
i
,
set
)
or
isinstance
(
i
,
tuple
):
res
.
extend
(
un
flat_list
(
i
))
res
.
extend
(
flat_list
(
i
))
else
:
res
.
append
(
i
)
return
res
...
...
@@ -27,8 +27,13 @@ def find_first_tensor_arg(node):
def
is_non_compute_node
(
node
):
if
any
(
i
in
node
.
op
for
i
in
[
"placeholder"
,
"get_attr"
,
"output"
])
or
any
(
i
in
node
.
name
for
i
in
[
"getitem"
,
"getattr"
]):
if
any
(
i
in
node
.
op
for
i
in
[
"placeholder"
,
"get_attr"
,
"output"
])
or
any
(
i
in
node
.
name
for
i
in
[
"getattr"
]):
return
True
if
"getitem"
in
node
.
name
:
node_args
=
flat_list
(
node
.
args
[
1
:])
for
node_arg
in
node_args
:
if
any
(
i
==
str
(
node_arg
)
for
i
in
[
"None"
,
"Ellipsis"
]):
return
False
return
True
return
False
...
...
@@ -40,15 +45,15 @@ def get_node_shape(node):
def
is_non_compute_node_except_placeholder
(
node
):
if
any
(
i
in
node
.
op
for
i
in
[
"get_attr"
,
"output"
])
or
any
(
i
in
node
.
name
for
i
in
[
"getitem"
,
"getattr"
])
:
return
Tru
e
return
False
if
"placeholder"
in
node
.
op
:
return
Fals
e
return
is_non_compute_node
(
node
)
def
is_non_compute_node_except_placeholder_output
(
node
):
if
any
(
i
in
node
.
op
for
i
in
[
"get_attr"
])
or
any
(
i
in
node
.
name
for
i
in
[
"getitem"
,
"getattr"
])
:
return
Tru
e
return
False
if
"output"
in
node
.
op
:
return
Fals
e
return
is_non_compute_node_except_placeholder
(
node
)
def
find_idx_by_name
(
name
,
nodes_list
):
...
...
tests/test_autochunk/test_evoformer_codegen.py
View file @
ecccc91f
...
...
@@ -27,18 +27,17 @@ if CODEGEN_AVAILABLE and is_compatible_with_meta():
def
_test_fwd
(
model
:
torch
.
nn
.
Module
,
gm
:
ColoGraphModule
,
node
,
pair
,
node_mask
,
pair_mask
):
# for memory test
# model = model.cuda()
# torch.cuda.reset_peak_memory_stats()
# now_mem = torch.cuda.memory_allocated() / 1024**2
# with torch.no_grad():
# node1 = node.clone()
# pair1 = pair.clone()
# gm(node1, pair1)
# new_now_mem = torch.cuda.memory_allocated() / 1024**2
# node_mask1 = node_mask.clone()
# pair_mask1 = pair_mask.clone()
# gm(node1, pair1, node_mask1, pair_mask1)
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
# print(
# "autochunk now mem:%.2f max mem:%.2f"
# % (new_now_mem - now_mem, new_max_mem - now_mem)
# )
# print("autochunk max mem:%.2f"% (new_max_mem - now_mem))
# test forward
model
=
model
.
cuda
()
...
...
@@ -113,7 +112,7 @@ def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory):
MetaTensor
(
node_mask
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair_mask
,
fake_device
=
"cuda:0"
),
)
#
codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory)
codegen
=
AutoChunkCodeGen
(
meta_graph
,
max_memory
=
max_memory
,
print_mem
=
False
)
# trace and recompile
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
...
...
@@ -130,14 +129,14 @@ def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory):
"_mask_trans"
:
True
,
},
)
#
graph.set_codegen(codegen)
graph
.
set_codegen
(
codegen
)
gm
=
ColoGraphModule
(
model
,
graph
)
gm
.
recompile
()
# assert we have inserted chunk
code
=
graph
.
python_code
(
"self"
).
src
assert
"chunk_size"
in
code
# print(code)
assert
"chunk_result = None; chunk_size = None;"
in
code
_test_fwd
(
model
,
gm
,
node
,
pair
,
node_mask
,
pair_mask
)
gpc
.
destroy
()
...
...
@@ -147,7 +146,7 @@ def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory):
not
(
CODEGEN_AVAILABLE
and
is_compatible_with_meta
()
and
HAS_REPO
),
reason
=
"torch version is lower than 1.12.0"
,
)
@
pytest
.
mark
.
parametrize
(
"max_memory"
,
[
None
,
2
0
,
2
5
,
3
0
])
@
pytest
.
mark
.
parametrize
(
"max_memory"
,
[
None
,
2
4
,
2
8
,
3
2
])
@
pytest
.
mark
.
parametrize
(
"msa_len"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"pair_len"
,
[
64
])
def
test_evoformer_codegen
(
msa_len
,
pair_len
,
max_memory
):
...
...
@@ -161,4 +160,4 @@ def test_evoformer_codegen(msa_len, pair_len, max_memory):
if
__name__
==
"__main__"
:
_test_evoformer_codegen
(
0
,
32
,
64
,
2
5
)
_test_evoformer_codegen
(
0
,
32
,
64
,
2
4
)
tests/test_autochunk/test_simple_evoformer_codegen.py
View file @
ecccc91f
...
...
@@ -13,7 +13,7 @@ except:
import
colossalai
from
colossalai.core
import
global_context
as
gpc
from
colossalai.fx
import
ColoTracer
from
colossalai.fx
import
ColoTracer
,
symbolic_trace
from
colossalai.fx._compatibility
import
is_compatible_with_meta
from
colossalai.fx.codegen.activation_checkpoint_codegen
import
CODEGEN_AVAILABLE
from
colossalai.fx.graph_module
import
ColoGraphModule
...
...
@@ -26,21 +26,6 @@ if CODEGEN_AVAILABLE and is_compatible_with_meta():
def
_test_fwd
(
model
:
torch
.
nn
.
Module
,
gm
:
ColoGraphModule
,
node
,
pair
):
# for memory test
# torch.cuda.reset_peak_memory_stats()
# now_mem = torch.cuda.memory_allocated() / 1024**2
# with torch.no_grad():
# node1 = node.clone()
# pair1 = pair.clone()
# gm(node1, pair1)
# new_now_mem = torch.cuda.memory_allocated() / 1024**2
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
# print(
# "autochunk now mem:%.2f max mem:%.2f"
# % (new_now_mem - now_mem, new_max_mem - now_mem)
# )
# test forward
with
torch
.
no_grad
():
non_fx_out
=
model
(
node
,
pair
)
fx_out
=
gm
(
node
,
pair
)
...
...
@@ -69,6 +54,16 @@ def _test_simple_evoformer_codegen(rank, msa_len, pair_len, max_memory):
node
=
torch
.
randn
(
1
,
msa_len
,
pair_len
,
256
).
cuda
()
pair
=
torch
.
randn
(
1
,
pair_len
,
pair_len
,
128
).
cuda
()
# meta info prop
meta_graph
=
symbolic_trace
(
model
,
meta_args
=
{
"node"
:
node
.
to
(
torch
.
device
(
"meta"
)),
"pair"
:
pair
.
to
(
torch
.
device
(
"meta"
)),
})
# must use symbolic_trace
interp
=
MetaInfoProp
(
meta_graph
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
))
codegen
=
AutoChunkCodeGen
(
meta_graph
,
max_memory
=
max_memory
)
# trace the module and replace codegen
graph
=
ColoTracer
().
trace
(
model
,
...
...
@@ -77,24 +72,14 @@ def _test_simple_evoformer_codegen(rank, msa_len, pair_len, max_memory):
"pair"
:
pair
.
to
(
torch
.
device
(
"meta"
)),
},
)
gm_prop
=
torch
.
fx
.
symbolic_trace
(
model
)
# must use symbolic_trace
interp
=
MetaInfoProp
(
gm_prop
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
))
# now run it twice to get meta info in graph module, not necessary
gm
=
torch
.
fx
.
GraphModule
(
model
,
graph
)
interp
=
MetaInfoProp
(
gm
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
))
codegen
=
AutoChunkCodeGen
(
gm_prop
,
max_memory
=
max_memory
)
graph
.
set_codegen
(
codegen
)
gm
=
ColoGraphModule
(
model
,
graph
)
gm
.
recompile
()
# assert we have inserted chunk
code
=
graph
.
python_code
(
"self"
).
src
assert
"chunk_size"
in
code
# print(code)
assert
"chunk_result = None; chunk_size = None;"
in
code
_test_fwd
(
model
,
gm
,
node
,
pair
)
gpc
.
destroy
()
...
...
tests/test_autochunk/test_simple_evoformer_search.py
View file @
ecccc91f
...
...
@@ -47,18 +47,18 @@ def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
str
(
target_regions
),
)
for
region
in
target_regions
:
assert
(
region
in
found_regions
),
"region:%s not in found regions for msa:%d, pair:%d, maxmem:%
d
"
%
(
assert
(
region
in
found_regions
),
"region:%s not in found regions for msa:%d, pair:%d, maxmem:%
s
"
%
(
str
(
region
),
msa_len
,
pair_len
,
max_memory
,
str
(
max_memory
)
,
)
for
region
in
found_regions
:
assert
(
region
in
target_regions
),
"region:%s should not be found for msa:%d, pair:%d, maxmem:%d"
%
(
str
(
region
),
msa_len
,
pair_len
,
max_memory
,
str
(
max_memory
)
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment