Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
63199c66
Unverified
Commit
63199c66
authored
Jan 31, 2023
by
oahzxl
Committed by
GitHub
Jan 31, 2023
Browse files
[autochunk] support transformer (#2526)
parent
6e0faa70
Changes
19
Show whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
1198 additions
and
964 deletions
+1198
-964
colossalai/autochunk/autochunk_codegen.py
colossalai/autochunk/autochunk_codegen.py
+5
-2
colossalai/autochunk/search_chunk.py
colossalai/autochunk/search_chunk.py
+14
-52
colossalai/autochunk/select_chunk.py
colossalai/autochunk/select_chunk.py
+48
-76
colossalai/autochunk/trace_flow.py
colossalai/autochunk/trace_flow.py
+68
-44
colossalai/autochunk/trace_indice.py
colossalai/autochunk/trace_indice.py
+241
-100
colossalai/autochunk/utils.py
colossalai/autochunk/utils.py
+41
-9
tests/test_autochunk/benchmark_simple_evoformer.py
tests/test_autochunk/benchmark_simple_evoformer.py
+0
-94
tests/test_autochunk/test_alphafold/test_alphafold_utils.py
tests/test_autochunk/test_alphafold/test_alphafold_utils.py
+122
-0
tests/test_autochunk/test_alphafold/test_evoformer_block.py
tests/test_autochunk/test_alphafold/test_evoformer_block.py
+95
-0
tests/test_autochunk/test_alphafold/test_evoformer_stack.py
tests/test_autochunk/test_alphafold/test_evoformer_stack.py
+90
-0
tests/test_autochunk/test_alphafold/test_extramsa_block.py
tests/test_autochunk/test_alphafold/test_extramsa_block.py
+96
-0
tests/test_autochunk/test_diffuser/test_diffuser_utils.py
tests/test_autochunk/test_diffuser/test_diffuser_utils.py
+120
-0
tests/test_autochunk/test_diffuser/test_unet.py
tests/test_autochunk/test_diffuser/test_unet.py
+70
-0
tests/test_autochunk/test_evoformer_codegen.py
tests/test_autochunk/test_evoformer_codegen.py
+0
-163
tests/test_autochunk/test_evoformer_stack_codegen.py
tests/test_autochunk/test_evoformer_stack_codegen.py
+0
-163
tests/test_autochunk/test_extramsa_codegen.py
tests/test_autochunk/test_extramsa_codegen.py
+0
-164
tests/test_autochunk/test_simple_evoformer_search.py
tests/test_autochunk/test_simple_evoformer_search.py
+0
-97
tests/test_autochunk/test_transformer/test_autochunk_gpt.py
tests/test_autochunk/test_transformer/test_autochunk_gpt.py
+65
-0
tests/test_autochunk/test_transformer/test_transformer_utils.py
...test_autochunk/test_transformer/test_transformer_utils.py
+123
-0
No files found.
colossalai/autochunk/autochunk_codegen.py
View file @
63199c66
...
@@ -3,9 +3,12 @@ from typing import Any, Dict, Iterable, List, Tuple
...
@@ -3,9 +3,12 @@ from typing import Any, Dict, Iterable, List, Tuple
import
torch
import
torch
import
colossalai
import
colossalai
from
colossalai.fx._compatibility
import
is_compatible_with_meta
from
colossalai.fx.codegen.activation_checkpoint_codegen
import
CODEGEN_AVAILABLE
from
colossalai.fx.codegen.activation_checkpoint_codegen
import
CODEGEN_AVAILABLE
if
CODEGEN_AVAILABLE
:
AUTOCHUNK_AVAILABLE
=
CODEGEN_AVAILABLE
and
is_compatible_with_meta
()
if
AUTOCHUNK_AVAILABLE
:
from
torch.fx.graph
import
(
from
torch.fx.graph
import
(
CodeGen
,
CodeGen
,
PythonCode
,
PythonCode
,
...
@@ -272,7 +275,7 @@ def emit_code_with_chunk(
...
@@ -272,7 +275,7 @@ def emit_code_with_chunk(
node_idx
+=
1
node_idx
+=
1
if
CODEGEN
_AVAILABLE
:
if
AUTOCHUNK
_AVAILABLE
:
class
AutoChunkCodeGen
(
CodeGen
):
class
AutoChunkCodeGen
(
CodeGen
):
...
...
colossalai/autochunk/search_chunk.py
View file @
63199c66
...
@@ -8,7 +8,13 @@ from .reorder_graph import ReorderGraph
...
@@ -8,7 +8,13 @@ from .reorder_graph import ReorderGraph
from
.select_chunk
import
SelectChunk
from
.select_chunk
import
SelectChunk
from
.trace_flow
import
TraceFlow
from
.trace_flow
import
TraceFlow
from
.trace_indice
import
TraceIndice
from
.trace_indice
import
TraceIndice
from
.utils
import
get_logger
,
get_node_shape
,
is_non_compute_node
,
is_non_compute_node_except_placeholder
from
.utils
import
(
find_chunk_compute_input_and_output_nodes
,
get_logger
,
get_node_shape
,
is_non_compute_node
,
is_non_compute_node_except_placeholder
,
)
class
SearchChunk
(
object
):
class
SearchChunk
(
object
):
...
@@ -114,6 +120,12 @@ class SearchChunk(object):
...
@@ -114,6 +120,12 @@ class SearchChunk(object):
chunk_region_start (int)
chunk_region_start (int)
chunk_region_end (int)
chunk_region_end (int)
"""
"""
# check if peak node already in chunkinfo
if
chunk_regions
is
not
None
:
for
i
in
chunk_regions
:
if
i
[
"region"
][
0
]
<
peak_node_idx
<=
i
[
"region"
][
1
]:
return
None
free_vars
=
self
.
_get_free_var_idx
()
free_vars
=
self
.
_get_free_var_idx
()
free_var_num
=
len
(
free_vars
)
free_var_num
=
len
(
free_vars
)
active_node_num
=
[
len
(
i
)
for
i
in
active_node
]
active_node_num
=
[
len
(
i
)
for
i
in
active_node
]
...
@@ -152,55 +164,6 @@ class SearchChunk(object):
...
@@ -152,55 +164,6 @@ class SearchChunk(object):
chunk_region_end
=
region
[
0
]
-
1
chunk_region_end
=
region
[
0
]
-
1
return
chunk_region_start
,
chunk_region_end
return
chunk_region_start
,
chunk_region_end
def
_find_chunk_info
(
self
,
input_trace
,
output_trace
,
start_idx
,
end_idx
)
->
List
:
"""
Find chunk info for a region.
We are given the region start and region end, and need to find out all chunk info for it.
We first loop every dim of start node and end node, to see if we can find dim pair,
which is linked in a flow and not computed.
If found, we then search flow in the whole region to find out all chunk infos.
Args:
input_trace (List): node's input trace in region
output_trace (List): node's output trace in region
start_idx (int): region start node index
end_idx (int): region end node index
Returns:
chunk_infos: possible regions found
"""
start_traces
=
input_trace
[
start_idx
]
end_trace
=
output_trace
[
end_idx
]
end_node
=
self
.
trace_indice
.
node_list
[
end_idx
]
chunk_infos
=
[]
for
end_dim
,
_
in
enumerate
(
end_trace
[
"indice"
]):
if
len
(
start_traces
)
>
1
:
continue
for
start_node
,
start_trace
in
start_traces
.
items
():
for
start_dim
,
_
in
enumerate
(
start_trace
[
"indice"
]):
# dim size cannot be 1
if
(
get_node_shape
(
end_node
)[
end_dim
]
==
1
or
get_node_shape
(
start_node
)[
start_dim
]
==
1
):
continue
# must have users
if
len
(
end_node
.
users
)
==
0
:
continue
# check index source align
if
not
self
.
trace_flow
.
check_index_source
(
start_dim
,
start_node
,
start_idx
,
end_dim
,
end_node
):
continue
# check index copmute
if
not
self
.
trace_flow
.
check_index_compute
(
start_idx
,
end_dim
,
end_node
,
end_idx
):
continue
# flow search
chunk_info
=
self
.
trace_flow
.
flow_search
(
start_idx
,
start_dim
,
end_idx
,
end_dim
)
if
chunk_info
is
None
:
continue
# check index copmute
if
not
self
.
trace_flow
.
check_index_duplicate
(
chunk_info
):
continue
chunk_infos
.
append
(
chunk_info
)
return
chunk_infos
def
_search_possible_chunk_regions
(
self
,
max_chunk_region
:
Tuple
,
peak_node
:
Node
)
->
List
:
def
_search_possible_chunk_regions
(
self
,
max_chunk_region
:
Tuple
,
peak_node
:
Node
)
->
List
:
"""
"""
Search every possible region within the max chunk region.
Search every possible region within the max chunk region.
...
@@ -228,9 +191,8 @@ class SearchChunk(object):
...
@@ -228,9 +191,8 @@ class SearchChunk(object):
if
is_non_compute_node
(
self
.
trace_indice
.
node_list
[
start_idx
])
or
is_non_compute_node
(
if
is_non_compute_node
(
self
.
trace_indice
.
node_list
[
start_idx
])
or
is_non_compute_node
(
self
.
trace_indice
.
node_list
[
end_idx
]):
self
.
trace_indice
.
node_list
[
end_idx
]):
continue
continue
# select free dim
# select free dim
chunk_info
=
self
.
_
find_chunk_info
(
input_trace
,
output_trace
,
start_idx
,
end_idx
)
chunk_info
=
self
.
trace_flow
.
find_chunk_info
(
input_trace
,
output_trace
,
start_idx
,
end_idx
)
if
len
(
chunk_info
)
>
0
:
if
len
(
chunk_info
)
>
0
:
possible_chunk_region
.
extend
(
chunk_info
)
possible_chunk_region
.
extend
(
chunk_info
)
return
possible_chunk_region
return
possible_chunk_region
...
...
colossalai/autochunk/select_chunk.py
View file @
63199c66
...
@@ -5,6 +5,7 @@ from .utils import is_non_compute_node
...
@@ -5,6 +5,7 @@ from .utils import is_non_compute_node
class
SelectChunk
(
object
):
class
SelectChunk
(
object
):
def
__init__
(
def
__init__
(
self
,
self
,
trace_indice
:
TraceIndice
,
trace_indice
:
TraceIndice
,
...
@@ -21,9 +22,7 @@ class SelectChunk(object):
...
@@ -21,9 +22,7 @@ class SelectChunk(object):
else
:
else
:
self
.
stratge
=
"min_memory"
self
.
stratge
=
"min_memory"
def
_select_best_chunk_region
(
def
_select_best_chunk_region
(
self
,
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
):
self
,
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
):
if
self
.
stratge
==
"min_memory"
:
if
self
.
stratge
==
"min_memory"
:
best_region
=
self
.
_select_min_memory_chunk_region
(
best_region
=
self
.
_select_min_memory_chunk_region
(
possible_chunk_regions
,
possible_chunk_regions
,
...
@@ -44,9 +43,8 @@ class SelectChunk(object):
...
@@ -44,9 +43,8 @@ class SelectChunk(object):
raise
RuntimeError
()
raise
RuntimeError
()
return
best_region
return
best_region
def
_select_fit_memory_chunk_region
(
def
_select_fit_memory_chunk_region
(
self
,
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
self
,
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
mem_peak
):
):
# stop chunk if max memory satisfy memory limit
# stop chunk if max memory satisfy memory limit
if
max
(
mem_peak
)
<
self
.
max_memory
:
if
max
(
mem_peak
)
<
self
.
max_memory
:
return
None
return
None
...
@@ -63,33 +61,26 @@ class SelectChunk(object):
...
@@ -63,33 +61,26 @@ class SelectChunk(object):
if
len
(
possible_chunk_regions
)
==
0
:
if
len
(
possible_chunk_regions
)
==
0
:
return
None
return
None
max_possible_chunk_region
=
(
min
([
i
[
"region"
][
0
]
for
i
in
possible_chunk_regions
]),
max
([
i
[
"region"
][
1
]
for
i
in
possible_chunk_regions
]))
# get mem for chunk region
# get mem for chunk region
regions_dict
=
[]
regions_dict
=
[]
for
region
in
possible_chunk_regions
:
for
region
in
possible_chunk_regions
:
cur_region
=
region
.
copy
()
cur_region
=
region
.
copy
()
cur_node_list
,
cur_region
=
self
.
reorder_graph
.
tmp_reorder
(
cur_node_list
,
cur_region
=
self
.
reorder_graph
.
tmp_reorder
(
self
.
trace_indice
.
node_list
,
cur_region
)
self
.
trace_indice
.
node_list
,
cur_region
)
cur_chunk_infos
=
chunk_infos
+
[
cur_region
]
cur_chunk_infos
=
chunk_infos
+
[
cur_region
]
cur_mem_peak
=
self
.
estimate_memory
.
estimate_chunk_inference_mem
(
cur_mem_peak
=
self
.
estimate_memory
.
estimate_chunk_inference_mem
(
cur_node_list
,
cur_chunk_infos
)[
0
]
cur_node_list
,
cur_chunk_infos
cur_chunk_region_peak
=
cur_mem_peak
[
max_possible_chunk_region
[
0
]:
max_possible_chunk_region
[
1
]
+
1
]
)[
0
]
cur_chunk_region_peak
=
cur_mem_peak
[
max_chunk_region
[
0
]
:
max_chunk_region
[
1
]
+
1
]
cur_chunk_region_max_peak
=
max
(
cur_chunk_region_peak
)
cur_chunk_region_max_peak
=
max
(
cur_chunk_region_peak
)
if
cur_chunk_region_max_peak
<
self
.
max_memory
:
if
cur_chunk_region_max_peak
<
self
.
max_memory
:
regions_dict
.
append
(
regions_dict
.
append
({
{
"chunk_info"
:
region
,
"chunk_info"
:
region
,
"chunk_max_mem"
:
cur_chunk_region_max_peak
,
"chunk_max_mem"
:
cur_chunk_region_max_peak
,
"chunk_len"
:
self
.
_get_compute_node_num
(
"chunk_len"
:
self
.
_get_compute_node_num
(
region
[
"region"
][
0
],
region
[
"region"
][
1
]),
region
[
"region"
][
0
],
region
[
"region"
][
1
]
),
"reorder_chunk_info"
:
cur_region
,
"reorder_chunk_info"
:
cur_region
,
"reorder_node_list"
:
cur_node_list
,
"reorder_node_list"
:
cur_node_list
,
}
})
)
# no region found
# no region found
if
len
(
regions_dict
)
==
0
:
if
len
(
regions_dict
)
==
0
:
raise
RuntimeError
(
"Search failed. Try a larger memory threshold."
)
raise
RuntimeError
(
"Search failed. Try a larger memory threshold."
)
...
@@ -113,20 +104,13 @@ class SelectChunk(object):
...
@@ -113,20 +104,13 @@ class SelectChunk(object):
chunk_size
*=
2
chunk_size
*=
2
reorder_chunk_info
[
"chunk_size"
]
=
chunk_size
reorder_chunk_info
[
"chunk_size"
]
=
chunk_size
cur_chunk_infos
=
chunk_infos
+
[
reorder_chunk_info
]
cur_chunk_infos
=
chunk_infos
+
[
reorder_chunk_info
]
cur_mem_peak
=
self
.
estimate_memory
.
estimate_chunk_inference_mem
(
cur_mem_peak
=
self
.
estimate_memory
.
estimate_chunk_inference_mem
(
chunk_region_dict
[
"reorder_node_list"
],
chunk_region_dict
[
"reorder_node_list"
],
cur_chunk_infos
cur_chunk_infos
)[
0
]
)[
0
]
cur_chunk_max_mem
=
max
(
cur_mem_peak
[
reorder_chunk_info
[
"region"
][
0
]:
reorder_chunk_info
[
"region"
][
1
]
+
1
])
cur_chunk_max_mem
=
max
(
cur_mem_peak
[
reorder_chunk_info
[
"region"
][
0
]
:
reorder_chunk_info
[
"region"
][
1
]
+
1
]
)
# search exact size
# search exact size
chunk_info
=
chunk_region_dict
[
"chunk_info"
]
chunk_info
=
chunk_region_dict
[
"chunk_info"
]
chunk_info
[
"chunk_size"
]
=
self
.
_chunk_size_binary_search
(
chunk_info
[
"chunk_size"
]
=
self
.
_chunk_size_binary_search
(
chunk_size
//
2
,
chunk_size
,
chunk_region_dict
,
chunk_size
//
2
,
chunk_size
,
chunk_region_dict
,
chunk_infos
chunk_infos
)
)
return
chunk_info
return
chunk_info
def
_chunk_size_binary_search
(
self
,
left
,
right
,
chunk_region_dict
,
chunk_infos
):
def
_chunk_size_binary_search
(
self
,
left
,
right
,
chunk_region_dict
,
chunk_infos
):
...
@@ -139,12 +123,9 @@ class SelectChunk(object):
...
@@ -139,12 +123,9 @@ class SelectChunk(object):
mid
=
int
((
left
+
right
)
/
2
+
0.5
)
mid
=
int
((
left
+
right
)
/
2
+
0.5
)
chunk_info
[
"chunk_size"
]
=
mid
chunk_info
[
"chunk_size"
]
=
mid
cur_chunk_infos
=
chunk_infos
+
[
chunk_info
]
cur_chunk_infos
=
chunk_infos
+
[
chunk_info
]
cur_mem_peak
=
self
.
estimate_memory
.
estimate_chunk_inference_mem
(
cur_mem_peak
=
self
.
estimate_memory
.
estimate_chunk_inference_mem
(
chunk_region_dict
[
"reorder_node_list"
],
chunk_region_dict
[
"reorder_node_list"
],
cur_chunk_infos
cur_chunk_infos
)[
0
]
)[
0
]
cur_chunk_max_mem
=
max
(
cur_mem_peak
[
chunk_info
[
"region"
][
0
]:
chunk_info
[
"region"
][
1
]
+
1
])
cur_chunk_max_mem
=
max
(
cur_mem_peak
[
chunk_info
[
"region"
][
0
]
:
chunk_info
[
"region"
][
1
]
+
1
]
)
if
cur_chunk_max_mem
>=
self
.
max_memory
:
if
cur_chunk_max_mem
>=
self
.
max_memory
:
right
=
mid
-
gap
right
=
mid
-
gap
else
:
else
:
...
@@ -153,14 +134,13 @@ class SelectChunk(object):
...
@@ -153,14 +134,13 @@ class SelectChunk(object):
def
_get_compute_node_num
(
self
,
start
,
end
):
def
_get_compute_node_num
(
self
,
start
,
end
):
count
=
0
count
=
0
for
i
in
self
.
trace_indice
.
node_list
[
start
:
end
+
1
]:
for
i
in
self
.
trace_indice
.
node_list
[
start
:
end
+
1
]:
if
not
is_non_compute_node
(
i
):
if
not
is_non_compute_node
(
i
):
count
+=
1
count
+=
1
return
count
return
count
def
_select_min_memory_chunk_region
(
def
_select_min_memory_chunk_region
(
self
,
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
self
,
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
mem_peak
):
):
# remove illegal regions
# remove illegal regions
illegal_regions
=
[]
illegal_regions
=
[]
for
i
in
possible_chunk_regions
:
for
i
in
possible_chunk_regions
:
...
@@ -173,37 +153,31 @@ class SelectChunk(object):
...
@@ -173,37 +153,31 @@ class SelectChunk(object):
if
len
(
possible_chunk_regions
)
==
0
:
if
len
(
possible_chunk_regions
)
==
0
:
return
None
return
None
# get max possible chunk region
max_possible_chunk_region
=
(
min
([
i
[
"region"
][
0
]
for
i
in
possible_chunk_regions
]),
max
([
i
[
"region"
][
1
]
for
i
in
possible_chunk_regions
]))
# get mem for chunk region
# get mem for chunk region
regions_dict
=
[]
regions_dict
_list
=
[]
for
region
in
possible_chunk_regions
:
for
region
in
possible_chunk_regions
:
cur_region
=
region
.
copy
()
cur_region
=
region
.
copy
()
cur_node_list
,
cur_region
=
self
.
reorder_graph
.
tmp_reorder
(
cur_node_list
,
cur_region
=
self
.
reorder_graph
.
tmp_reorder
(
self
.
trace_indice
.
node_list
,
cur_region
)
self
.
trace_indice
.
node_list
,
cur_region
)
cur_chunk_infos
=
chunk_infos
+
[
cur_region
]
cur_chunk_infos
=
chunk_infos
+
[
cur_region
]
cur_mem_peak
=
self
.
estimate_memory
.
estimate_chunk_inference_mem
(
cur_mem_peak
=
self
.
estimate_memory
.
estimate_chunk_inference_mem
(
cur_node_list
,
cur_chunk_infos
)[
0
]
cur_node_list
,
cur_chunk_infos
cur_chunk_region_peak
=
cur_mem_peak
[
max_possible_chunk_region
[
0
]:
max_possible_chunk_region
[
1
]
+
1
]
)[
0
]
cur_chunk_region_peak
=
cur_mem_peak
[
max_chunk_region
[
0
]
:
max_chunk_region
[
1
]
+
1
]
cur_chunk_region_max_peak
=
max
(
cur_chunk_region_peak
)
cur_chunk_region_max_peak
=
max
(
cur_chunk_region_peak
)
regions_dict
.
append
(
regions_dict_list
.
append
({
{
"chunk_info"
:
region
,
"chunk_info"
:
region
,
"chunk_max_mem"
:
cur_chunk_region_max_peak
,
"chunk_max_mem"
:
cur_chunk_region_max_peak
,
"chunk_len"
:
self
.
_get_compute_node_num
(
"chunk_len"
:
self
.
_get_compute_node_num
(
region
[
"region"
][
0
],
region
[
"region"
][
1
]),
region
[
"region"
][
0
],
region
[
"region"
][
1
]
),
"reorder_chunk_info"
:
cur_region
,
"reorder_chunk_info"
:
cur_region
,
"reorder_node_list"
:
cur_node_list
,
"reorder_node_list"
:
cur_node_list
,
}
})
)
# select the min mem
# select the min mem
chunk_max_mem
=
[
i
[
"chunk_max_mem"
]
for
i
in
regions_dict
]
chunk_max_mem
=
[
i
[
"chunk_max_mem"
]
for
i
in
regions_dict
_list
]
best_region_idx
=
chunk_max_mem
.
index
(
min
(
chunk_max_mem
))
best_region_idx
=
chunk_max_mem
.
index
(
min
(
chunk_max_mem
))
best_region
=
regions_dict
[
best_region_idx
][
"chunk_info"
]
best_region
=
regions_dict
_list
[
best_region_idx
][
"chunk_info"
]
if
best_region
is
not
None
:
if
best_region
is
not
None
:
best_region
[
"chunk_size"
]
=
1
best_region
[
"chunk_size"
]
=
1
return
best_region
return
best_region
...
@@ -216,9 +190,7 @@ class SelectChunk(object):
...
@@ -216,9 +190,7 @@ class SelectChunk(object):
return
False
return
False
for
i
in
chunk_infos
:
for
i
in
chunk_infos
:
region
=
i
[
"region"
]
region
=
i
[
"region"
]
if
not
(
if
not
((
chunk_region_start
>
region
[
1
]
and
chunk_region_end
>
region
[
1
])
or
(
chunk_region_start
>
region
[
1
]
and
chunk_region_end
>
region
[
1
])
(
chunk_region_start
<
region
[
0
]
and
chunk_region_end
<
region
[
0
])):
or
(
chunk_region_start
<
region
[
0
]
and
chunk_region_end
<
region
[
0
])
):
return
False
return
False
return
True
return
True
colossalai/autochunk/trace_flow.py
View file @
63199c66
...
@@ -8,9 +8,9 @@ from .utils import (
...
@@ -8,9 +8,9 @@ from .utils import (
find_chunk_compute_input_and_output_nodes
,
find_chunk_compute_input_and_output_nodes
,
find_idx_by_name
,
find_idx_by_name
,
flat_list
,
flat_list
,
get_node_name
,
get_node_shape
,
get_node_shape
,
is_non_compute_node
,
is_non_compute_node
,
is_non_compute_node_except_placeholder
,
)
)
...
@@ -79,43 +79,6 @@ class TraceFlow(object):
...
@@ -79,43 +79,6 @@ class TraceFlow(object):
return
node_dim
return
node_dim
return
None
return
None
def
check_index_duplicate
(
self
,
chunk_infos
,
return_dim
=
False
):
input_dim_after_node
=
{}
for
input_node_idx
,
input_node
in
enumerate
(
chunk_infos
[
"inputs"
]):
for
k
,
v
in
chunk_infos
[
"inputs_dim"
][
input_node_idx
].
items
():
inherit_dim
=
self
.
_find_inherit_dim
(
input_node
,
v
,
self
.
trace_indice
.
node_list
[
k
])
if
inherit_dim
:
input_dim_after_node
[
k
]
=
inherit_dim
for
node
in
self
.
trace_indice
.
node_list
[
chunk_infos
[
"region"
][
0
]:
chunk_infos
[
"region"
][
1
]
+
1
]:
if
is_non_compute_node_except_placeholder
(
node
):
continue
count
=
0
duplicate_dims
=
[]
node_trace_source
=
self
.
trace_indice
.
_find_source_trace_from_node
(
node
)
for
node_dim
in
range
(
len
(
get_node_shape
(
node
))):
duplicate_dim
=
[]
duplicate_flag
=
False
dim_source
=
node_trace_source
[
node_dim
]
for
k
,
v
in
dim_source
.
items
():
if
chunk_infos
[
"region"
][
0
]
<=
k
<=
chunk_infos
[
"region"
][
1
]:
if
k
in
input_dim_after_node
and
input_dim_after_node
[
k
]
in
v
:
duplicate_flag
=
True
duplicate_dim
.
append
((
k
,
v
))
duplicate_dims
.
append
(
duplicate_dim
)
if
duplicate_flag
:
count
+=
1
if
count
>
1
:
if
return_dim
:
return
False
,
duplicate_dims
else
:
return
False
if
return_dim
:
return
True
,
None
else
:
return
True
def
_assgin_single_node_flow
(
def
_assgin_single_node_flow
(
self
,
self
,
arg_node
:
Node
,
arg_node
:
Node
,
...
@@ -225,9 +188,12 @@ class TraceFlow(object):
...
@@ -225,9 +188,12 @@ class TraceFlow(object):
if
flow_flag
==
False
:
if
flow_flag
==
False
:
return
None
return
None
if
len
(
arg_list
)
==
2
:
if
len
(
arg_list
)
>=
2
:
if
any
(
i
in
cur_node
.
name
for
i
in
[
"add"
,
"mul"
,
"truediv"
]):
# need to mark fix dim
if
any
(
i
==
get_node_name
(
cur_node
)
for
i
in
[
"add"
,
"mul"
,
"truediv"
,
"sub"
,
"where"
]):
for
arg
in
arg_list
:
for
arg
in
arg_list
:
if
get_node_shape
(
arg
)
is
None
:
continue
if
not
(
start_idx
<=
find_idx_by_name
(
arg
.
name
,
self
.
trace_indice
.
node_list
)
<
end_idx
):
if
not
(
start_idx
<=
find_idx_by_name
(
arg
.
name
,
self
.
trace_indice
.
node_list
)
<
end_idx
):
continue
continue
arg_chunk_dim
=
all_node_info
[
arg
][
"chunk_dim"
]
arg_chunk_dim
=
all_node_info
[
arg
][
"chunk_dim"
]
...
@@ -240,9 +206,8 @@ class TraceFlow(object):
...
@@ -240,9 +206,8 @@ class TraceFlow(object):
return
None
return
None
if
i
not
in
arg_fix_dim
:
if
i
not
in
arg_fix_dim
:
arg_fix_dim
.
append
(
i
)
arg_fix_dim
.
append
(
i
)
elif
"einsum"
in
cur_node
.
name
:
elif
any
(
i
==
get_node_name
(
cur_node
)
pass
for
i
in
[
"einsum"
,
"matmul"
,
"view"
,
"to"
,
"getitem"
,
"tensor"
,
"type"
]):
elif
"matmul"
in
cur_node
.
name
:
pass
pass
else
:
else
:
raise
NotImplementedError
()
raise
NotImplementedError
()
...
@@ -426,7 +391,7 @@ class TraceFlow(object):
...
@@ -426,7 +391,7 @@ class TraceFlow(object):
reshape_size
=
{}
reshape_size
=
{}
chunk_shape
=
get_node_shape
(
chunk_info
[
"outputs"
][
0
])[
chunk_info
[
"outputs_dim"
]]
chunk_shape
=
get_node_shape
(
chunk_info
[
"outputs"
][
0
])[
chunk_info
[
"outputs_dim"
]]
for
node
in
self
.
trace_indice
.
node_list
[
chunk_region
[
0
]:
chunk_region
[
1
]
+
1
]:
for
node
in
self
.
trace_indice
.
node_list
[
chunk_region
[
0
]:
chunk_region
[
1
]
+
1
]:
if
any
(
i
in
node
.
name
for
i
in
[
"reshape"
,
"view"
]):
if
any
(
i
==
get_
node
_
name
(
node
)
for
i
in
[
"reshape"
,
"view"
]):
reshape_args
=
flat_list
(
node
.
args
[
1
:])
reshape_args
=
flat_list
(
node
.
args
[
1
:])
chunk_dim
=
chunk_info
[
"node_chunk_dim"
][
node
][
"chunk_dim"
]
chunk_dim
=
chunk_info
[
"node_chunk_dim"
][
node
][
"chunk_dim"
]
new_shape
=
""
new_shape
=
""
...
@@ -443,3 +408,62 @@ class TraceFlow(object):
...
@@ -443,3 +408,62 @@ class TraceFlow(object):
reshape_size
[
node
.
name
]
=
[
origin_shape
,
new_shape
]
reshape_size
[
node
.
name
]
=
[
origin_shape
,
new_shape
]
chunk_info
[
"reshape_size"
]
=
reshape_size
chunk_info
[
"reshape_size"
]
=
reshape_size
return
chunk_info
return
chunk_info
def
find_chunk_info
(
self
,
input_trace
,
output_trace
,
start_idx
,
end_idx
)
->
List
:
"""
Find chunk info for a region.
We are given the region start and region end, and need to find out all chunk info for it.
We first loop every dim of start node and end node, to see if we can find dim pair,
which is linked in a flow and not computed.
If found, we then search flow in the whole region to find out all chunk infos.
Args:
input_trace (List): node's input trace in region
output_trace (List): node's output trace in region
start_idx (int): region start node index
end_idx (int): region end node index
Returns:
chunk_infos: possible regions found
"""
start_traces
=
input_trace
[
start_idx
]
if
len
(
start_traces
)
>
1
:
# TODO need to be removed
return
[]
end_trace
=
output_trace
[
end_idx
]
end_node
=
self
.
trace_indice
.
node_list
[
end_idx
]
chunk_infos
=
[]
for
end_dim
,
_
in
enumerate
(
end_trace
[
"indice"
]):
for
start_node
,
start_trace
in
start_traces
.
items
():
for
start_dim
,
_
in
enumerate
(
start_trace
[
"indice"
]):
if
not
self
.
_check_region_start_end
(
start_node
,
start_dim
,
start_idx
,
end_node
,
end_dim
,
end_idx
):
continue
# flow search
chunk_info
=
self
.
flow_search
(
start_idx
,
start_dim
,
end_idx
,
end_dim
)
if
chunk_info
is
None
:
continue
chunk_infos
.
append
(
chunk_info
)
return
chunk_infos
def
_check_region_start_end
(
self
,
start_node
:
Node
,
start_dim
:
int
,
start_idx
:
int
,
end_node
:
Node
,
end_dim
:
int
,
end_idx
:
int
)
->
bool
:
"""
check if region start and end is legal
"""
# dim cannot be None
if
(
get_node_shape
(
end_node
)
is
None
or
get_node_shape
(
start_node
)
is
None
):
return
False
# dim size cannot be 1
if
(
get_node_shape
(
end_node
)[
end_dim
]
==
1
or
get_node_shape
(
start_node
)[
start_dim
]
==
1
):
return
False
# must have users
if
len
(
end_node
.
users
)
==
0
:
return
False
# check index source align
if
not
self
.
check_index_source
(
start_dim
,
start_node
,
start_idx
,
end_dim
,
end_node
):
return
False
# check index copmute
if
not
self
.
check_index_compute
(
start_idx
,
end_dim
,
end_node
,
end_idx
):
return
False
return
True
colossalai/autochunk/trace_indice.py
View file @
63199c66
...
@@ -3,7 +3,14 @@ from typing import Dict, List, Tuple
...
@@ -3,7 +3,14 @@ from typing import Dict, List, Tuple
from
torch.fx.node
import
Node
from
torch.fx.node
import
Node
from
.utils
import
find_first_tensor_arg
,
find_idx_by_name
,
flat_list
,
get_node_shape
from
.utils
import
(
find_first_tensor_arg
,
find_idx_by_name
,
flat_list
,
get_module_node_name
,
get_node_name
,
get_node_shape
,
)
class
TraceIndice
(
object
):
class
TraceIndice
(
object
):
...
@@ -36,7 +43,7 @@ class TraceIndice(object):
...
@@ -36,7 +43,7 @@ class TraceIndice(object):
self
.
trace_range
=
[]
self
.
trace_range
=
[]
self
.
active_node_list
=
[]
self
.
active_node_list
=
[]
def
_init_indice_trace_list
(
self
):
def
_init_indice_trace_list
(
self
)
->
List
:
indice_trace_list
=
[]
indice_trace_list
=
[]
for
n
in
self
.
node_list
:
for
n
in
self
.
node_list
:
if
get_node_shape
(
n
)
!=
None
:
if
get_node_shape
(
n
)
!=
None
:
...
@@ -54,7 +61,7 @@ class TraceIndice(object):
...
@@ -54,7 +61,7 @@ class TraceIndice(object):
self
.
trace_range
=
trace_range
self
.
trace_range
=
trace_range
self
.
active_node_list
=
active_node_list
self
.
active_node_list
=
active_node_list
def
_add_indice
(
self
):
def
_add_indice
(
self
)
->
int
:
"""
"""
Update the count and return it. To record the idx number.
Update the count and return it. To record the idx number.
...
@@ -64,39 +71,30 @@ class TraceIndice(object):
...
@@ -64,39 +71,30 @@ class TraceIndice(object):
self
.
indice_count
+=
1
self
.
indice_count
+=
1
return
self
.
indice_count
return
self
.
indice_count
def
_del_dim
(
self
,
idx
,
dim_idx
):
def
_del_dim
(
self
,
idx
:
int
,
dim_idx
:
int
)
->
None
:
"""
delete a dim for indice, compute and source
"""
self
.
indice_trace_list
[
idx
][
"indice"
].
pop
(
dim_idx
)
self
.
indice_trace_list
[
idx
][
"indice"
].
pop
(
dim_idx
)
self
.
indice_trace_list
[
idx
][
"compute"
].
pop
(
dim_idx
)
self
.
indice_trace_list
[
idx
][
"compute"
].
pop
(
dim_idx
)
self
.
indice_trace_list
[
idx
][
"source"
].
pop
(
dim_idx
)
self
.
indice_trace_list
[
idx
][
"source"
].
pop
(
dim_idx
)
def
_add_dim
(
self
,
node_idx
,
dim_idx
):
def
_add_dim
(
self
,
node_idx
:
int
,
dim_idx
:
int
)
->
None
:
"""
add a dim for indice, compute and source
"""
self
.
indice_trace_list
[
node_idx
][
"indice"
].
insert
(
dim_idx
,
self
.
_add_indice
())
self
.
indice_trace_list
[
node_idx
][
"indice"
].
insert
(
dim_idx
,
self
.
_add_indice
())
self
.
indice_trace_list
[
node_idx
][
"compute"
].
insert
(
dim_idx
,
[])
self
.
indice_trace_list
[
node_idx
][
"compute"
].
insert
(
dim_idx
,
[])
self
.
indice_trace_list
[
node_idx
][
"source"
].
insert
(
dim_idx
,
{})
self
.
indice_trace_list
[
node_idx
][
"source"
].
insert
(
dim_idx
,
{})
def
_transform_indice
(
self
,
node
,
node_dim
):
def
_add_source
(
node_idx
=
self
.
_find_indice_trace_from_node
(
node
)
self
,
dims
=
list
(
range
(
len
(
node_idx
)))
node_from
:
Node
,
return
dims
[
node_dim
]
node_from_dim
:
int
,
node_to
:
Node
,
def
_inherit_indice
(
self
,
node_from
,
node_from_dim
,
node_to
,
node_to_dim
):
node_to_dim
:
int
,
node_from_dim
=
self
.
_transform_indice
(
node_from
,
node_from_dim
)
init
=
False
,
node_to_dim
=
self
.
_transform_indice
(
node_to
,
node_to_dim
)
)
->
None
:
node_from_trace
=
self
.
_find_trace_from_node
(
node_from
)
node_to_trace
=
self
.
_find_trace_from_node
(
node_to
)
node_to_trace
[
"indice"
][
node_to_dim
]
=
node_from_trace
[
"indice"
][
node_from_dim
]
node_to_trace
[
"compute"
][
node_to_dim
]
=
copy
.
deepcopy
(
node_from_trace
[
"compute"
][
node_from_dim
])
self
.
_add_source
(
node_from
,
node_from_dim
,
node_to
,
node_to_dim
,
init
=
True
)
def
_inherit_all_computation
(
self
,
node_from
,
node_to
):
node_from_compute
=
self
.
_find_compute_trace_from_node
(
node_from
)
node_to_compute
=
self
.
_find_compute_trace_from_node
(
node_to
)
assert
len
(
node_from_compute
)
==
len
(
node_to_compute
)
for
i
in
range
(
len
(
node_from_compute
)):
self
.
_add_source
(
node_from
,
i
,
node_to
,
i
)
node_to_compute
[
i
]
=
copy
.
deepcopy
(
node_from_compute
[
i
])
def
_add_source
(
self
,
node_from
,
node_from_dim
,
node_to
,
node_to_dim
,
init
=
False
):
node_from_dim
=
self
.
_transform_indice
(
node_from
,
node_from_dim
)
node_from_dim
=
self
.
_transform_indice
(
node_from
,
node_from_dim
)
node_from_trace_source
=
self
.
_find_source_trace_from_node
(
node_from
)
node_from_trace_source
=
self
.
_find_source_trace_from_node
(
node_from
)
node_to_dim
=
self
.
_transform_indice
(
node_to
,
node_to_dim
)
node_to_dim
=
self
.
_transform_indice
(
node_to
,
node_to_dim
)
...
@@ -119,7 +117,50 @@ class TraceIndice(object):
...
@@ -119,7 +117,50 @@ class TraceIndice(object):
if
d
not
in
node_to_trace_source
[
node_to_dim
][
node_idx
]:
if
d
not
in
node_to_trace_source
[
node_to_dim
][
node_idx
]:
node_to_trace_source
[
node_to_dim
][
node_idx
].
append
(
d
)
node_to_trace_source
[
node_to_dim
][
node_idx
].
append
(
d
)
def
_mark_computation_from_node
(
self
,
node_from
,
node_to
,
exclude
=
None
):
def
_transform_indice
(
self
,
node
:
Node
,
node_dim
:
int
)
->
int
:
node_idx
=
self
.
_find_indice_trace_from_node
(
node
)
dims
=
list
(
range
(
len
(
node_idx
)))
return
dims
[
node_dim
]
def
_inherit_indice
(
self
,
node_from
:
Node
,
node_from_dim
:
int
,
node_to
:
Node
,
node_to_dim
:
int
,
init
:
bool
=
True
,
)
->
None
:
"""
node_to's node_to_dim inherit node_from's node_from_dim by indice, compute and source
"""
node_from_dim
=
self
.
_transform_indice
(
node_from
,
node_from_dim
)
node_to_dim
=
self
.
_transform_indice
(
node_to
,
node_to_dim
)
node_from_trace
=
self
.
_find_trace_from_node
(
node_from
)
node_to_trace
=
self
.
_find_trace_from_node
(
node_to
)
if
init
:
node_to_trace
[
"indice"
][
node_to_dim
]
=
node_from_trace
[
"indice"
][
node_from_dim
]
node_to_trace
[
"compute"
][
node_to_dim
]
=
copy
.
deepcopy
(
node_from_trace
[
"compute"
][
node_from_dim
])
else
:
for
j
in
node_from_trace
[
"compute"
][
node_from_dim
]:
if
j
not
in
node_to_trace
[
"compute"
][
node_to_dim
]:
node_to_trace
[
"compute"
][
node_to_dim
].
append
(
j
)
self
.
_add_source
(
node_from
,
node_from_dim
,
node_to
,
node_to_dim
,
init
)
def
_inherit_all_indice
(
self
,
node_from
:
Node
,
node_to
:
Node
)
->
None
:
"""
inherit all dims with init
"""
# find indice just for assert length
node_from_indice
=
self
.
_find_indice_trace_from_node
(
node_from
)
node_to_indice
=
self
.
_find_indice_trace_from_node
(
node_to
)
assert
len
(
node_from_indice
)
==
len
(
node_to_indice
)
for
i
in
range
(
len
(
node_from_indice
)):
self
.
_inherit_indice
(
node_from
,
i
,
node_to
,
i
,
init
=
True
)
def
_inherit_more_indice_from_node
(
self
,
node_from
:
Node
,
node_to
:
Node
,
exclude
:
List
=
None
)
->
None
:
"""
inheirt indice from node without init
"""
if
exclude
==
None
:
if
exclude
==
None
:
exclude
=
[]
exclude
=
[]
else
:
else
:
...
@@ -130,12 +171,9 @@ class TraceIndice(object):
...
@@ -130,12 +171,9 @@ class TraceIndice(object):
for
i
in
range
(
-
1
,
-
min
(
len
(
node_from_compute
),
len
(
node_to_compute
))
-
1
,
-
1
):
for
i
in
range
(
-
1
,
-
min
(
len
(
node_from_compute
),
len
(
node_to_compute
))
-
1
,
-
1
):
if
self
.
_transform_indice
(
node_to
,
i
)
in
exclude
:
if
self
.
_transform_indice
(
node_to
,
i
)
in
exclude
:
continue
continue
self
.
_add_source
(
node_from
,
i
,
node_to
,
i
)
self
.
_inherit_indice
(
node_from
,
i
,
node_to
,
i
,
init
=
False
)
for
j
in
node_from_compute
[
i
]:
if
j
not
in
node_to_compute
[
i
]:
node_to_compute
[
i
].
append
(
j
)
def
_mark_computation
(
self
,
node
,
idx
,
dim
)
:
def
_mark_computation
(
self
,
node
:
Node
,
idx
:
int
,
dim
:
int
)
->
None
:
"""
"""
Mark some dims of node as computed.
Mark some dims of node as computed.
...
@@ -152,7 +190,7 @@ class TraceIndice(object):
...
@@ -152,7 +190,7 @@ class TraceIndice(object):
if
idx
not
in
self
.
indice_trace_list
[
idx
][
"compute"
][
cur_dim
]:
if
idx
not
in
self
.
indice_trace_list
[
idx
][
"compute"
][
cur_dim
]:
self
.
indice_trace_list
[
idx
][
"compute"
][
cur_dim
].
append
(
idx
)
self
.
indice_trace_list
[
idx
][
"compute"
][
cur_dim
].
append
(
idx
)
def
_find_trace_from_node
(
self
,
node
)
:
def
_find_trace_from_node
(
self
,
node
:
Node
)
->
Dict
:
"""
"""
Find node idx and compute trace by the node.
Find node idx and compute trace by the node.
...
@@ -166,7 +204,7 @@ class TraceIndice(object):
...
@@ -166,7 +204,7 @@ class TraceIndice(object):
node_dict
=
self
.
indice_trace_list
[
node_idx
]
node_dict
=
self
.
indice_trace_list
[
node_idx
]
return
node_dict
return
node_dict
def
_find_source_trace_from_node
(
self
,
node
)
:
def
_find_source_trace_from_node
(
self
,
node
:
Node
)
->
List
:
"""
"""
Find node source trace by the node.
Find node source trace by the node.
...
@@ -180,7 +218,7 @@ class TraceIndice(object):
...
@@ -180,7 +218,7 @@ class TraceIndice(object):
node_dict
=
self
.
indice_trace_list
[
node_idx
]
node_dict
=
self
.
indice_trace_list
[
node_idx
]
return
node_dict
[
"source"
]
return
node_dict
[
"source"
]
def
_find_indice_trace_from_node
(
self
,
node
):
def
_find_indice_trace_from_node
(
self
,
node
)
->
List
:
"""
"""
Find node idx trace by the node.
Find node idx trace by the node.
...
@@ -192,7 +230,7 @@ class TraceIndice(object):
...
@@ -192,7 +230,7 @@ class TraceIndice(object):
node_idx
=
find_idx_by_name
(
node
.
name
,
self
.
node_list
)
node_idx
=
find_idx_by_name
(
node
.
name
,
self
.
node_list
)
return
self
.
indice_trace_list
[
node_idx
][
"indice"
]
return
self
.
indice_trace_list
[
node_idx
][
"indice"
]
def
_find_compute_trace_from_node
(
self
,
node
)
:
def
_find_compute_trace_from_node
(
self
,
node
:
Node
)
->
List
:
"""
"""
Find node compute trace by the node.
Find node compute trace by the node.
...
@@ -204,7 +242,7 @@ class TraceIndice(object):
...
@@ -204,7 +242,7 @@ class TraceIndice(object):
node_idx
=
find_idx_by_name
(
node
.
name
,
self
.
node_list
)
node_idx
=
find_idx_by_name
(
node
.
name
,
self
.
node_list
)
return
self
.
indice_trace_list
[
node_idx
][
"compute"
]
return
self
.
indice_trace_list
[
node_idx
][
"compute"
]
def
_assign_indice_as_input
(
self
,
node
:
Node
,
node_idx
:
int
,
input_node
=
None
):
def
_assign_indice_as_input
(
self
,
node
:
Node
,
node_idx
:
int
,
input_node
=
None
)
->
None
:
"""
"""
Assign node's trace as its input node.
Assign node's trace as its input node.
...
@@ -214,15 +252,9 @@ class TraceIndice(object):
...
@@ -214,15 +252,9 @@ class TraceIndice(object):
"""
"""
if
input_node
==
None
:
if
input_node
==
None
:
input_node
=
find_first_tensor_arg
(
node
)
input_node
=
find_first_tensor_arg
(
node
)
input_node_idx
=
find_idx_by_name
(
input_node
.
name
,
self
.
node_list
)
self
.
_inherit_all_indice
(
input_node
,
node
)
input_node_idx_trace
=
self
.
indice_trace_list
[
input_node_idx
][
"indice"
]
new_idx_trace
=
copy
.
deepcopy
(
input_node_idx_trace
)
self
.
indice_trace_list
[
node_idx
][
"indice"
]
=
new_idx_trace
self
.
_inherit_all_computation
(
input_node
,
node
)
def
_assign_all_indice
(
self
,
node
:
Node
,
node_idx
:
int
)
->
None
:
def
_assign_all_indice
(
self
,
node
:
Node
,
node_idx
:
int
):
"""
"""
Add new indice for all node's dims.
Add new indice for all node's dims.
...
@@ -238,7 +270,7 @@ class TraceIndice(object):
...
@@ -238,7 +270,7 @@ class TraceIndice(object):
new_trace
.
append
(
self
.
_add_indice
())
new_trace
.
append
(
self
.
_add_indice
())
self
.
indice_trace_list
[
node_idx
][
"indice"
]
=
new_trace
self
.
indice_trace_list
[
node_idx
][
"indice"
]
=
new_trace
def
_assign_transpose_indice
(
self
,
node
:
Node
,
node_idx
:
int
):
def
_assign_transpose_indice
(
self
,
node
:
Node
,
node_idx
:
int
)
->
None
:
"""
"""
Assign indice for transpose op.
Assign indice for transpose op.
1. swap input's dim according to transpose args
1. swap input's dim according to transpose args
...
@@ -255,7 +287,7 @@ class TraceIndice(object):
...
@@ -255,7 +287,7 @@ class TraceIndice(object):
self
.
_inherit_indice
(
input_node
,
tranpose_dim
[
1
],
node
,
tranpose_dim
[
0
])
self
.
_inherit_indice
(
input_node
,
tranpose_dim
[
1
],
node
,
tranpose_dim
[
0
])
self
.
_inherit_indice
(
input_node
,
tranpose_dim
[
0
],
node
,
tranpose_dim
[
1
])
self
.
_inherit_indice
(
input_node
,
tranpose_dim
[
0
],
node
,
tranpose_dim
[
1
])
def
_assign_permute_indice
(
self
,
node
:
Node
,
node_idx
:
int
):
def
_assign_permute_indice
(
self
,
node
:
Node
,
node_idx
:
int
)
->
None
:
"""
"""
Assign indice for permute op.
Assign indice for permute op.
1. swap input's dim according to permute args
1. swap input's dim according to permute args
...
@@ -272,7 +304,7 @@ class TraceIndice(object):
...
@@ -272,7 +304,7 @@ class TraceIndice(object):
for
idx
,
d
in
enumerate
(
permute_dim
):
for
idx
,
d
in
enumerate
(
permute_dim
):
self
.
_inherit_indice
(
input_node
,
d
,
node
,
idx
)
self
.
_inherit_indice
(
input_node
,
d
,
node
,
idx
)
def
_assign_linear_indice
(
self
,
node
:
Node
,
node_idx
:
int
):
def
_assign_linear_indice
(
self
,
node
:
Node
,
node_idx
:
int
)
->
None
:
"""
"""
Assign indice for linear op.
Assign indice for linear op.
1. copy trace from input node and change last indice accroding to weight
1. copy trace from input node and change last indice accroding to weight
...
@@ -293,7 +325,23 @@ class TraceIndice(object):
...
@@ -293,7 +325,23 @@ class TraceIndice(object):
self
.
_mark_computation
(
node
,
node_idx
,
[
-
1
])
self
.
_mark_computation
(
node
,
node_idx
,
[
-
1
])
def
_assign_matmul_indice
(
self
,
node
:
Node
,
node_idx
:
int
):
def
_assign_addmm_indice
(
self
,
node
:
Node
,
node_idx
:
int
)
->
None
:
"""
Assign indice for addmm op.
Args:
node (node)
node_idx (int)
"""
bias
,
input_node
,
weight
=
node
.
args
self
.
_assign_indice_as_input
(
node
,
node_idx
,
input_node
)
self
.
_inherit_indice
(
weight
,
1
,
node
,
-
1
)
self
.
_inherit_indice
(
bias
,
-
1
,
node
,
-
1
)
self
.
_mark_computation
(
node
,
node_idx
,
[
-
1
])
def
_assign_matmul_indice
(
self
,
node
:
Node
,
node_idx
:
int
)
->
None
:
"""
"""
Assign indice for matmul op.
Assign indice for matmul op.
1. copy trace from matmul_left and change last indice accroding to matmul_right. (assert they have same length)
1. copy trace from matmul_left and change last indice accroding to matmul_right. (assert they have same length)
...
@@ -310,7 +358,7 @@ class TraceIndice(object):
...
@@ -310,7 +358,7 @@ class TraceIndice(object):
self
.
_assign_indice_as_input
(
node
,
node_idx
,
matmul_left
)
self
.
_assign_indice_as_input
(
node
,
node_idx
,
matmul_left
)
self
.
_inherit_indice
(
matmul_right
,
-
1
,
node
,
-
1
)
self
.
_inherit_indice
(
matmul_right
,
-
1
,
node
,
-
1
)
self
.
_
mark_computation
_from_node
(
matmul_right
,
node
,
[
-
1
,
-
2
])
self
.
_
inherit_more_indice
_from_node
(
matmul_right
,
node
,
[
-
1
,
-
2
])
self
.
_mark_computation
(
node
,
node_idx
,
[
-
1
])
self
.
_mark_computation
(
node
,
node_idx
,
[
-
1
])
def
_assign_layernorm_indice
(
self
,
node
,
idx
):
def
_assign_layernorm_indice
(
self
,
node
,
idx
):
...
@@ -341,14 +389,13 @@ class TraceIndice(object):
...
@@ -341,14 +389,13 @@ class TraceIndice(object):
for
node_in
in
node
.
args
:
for
node_in
in
node
.
args
:
if
type
(
node_in
)
==
type
(
node
):
if
type
(
node_in
)
==
type
(
node
):
nodes_in
.
append
(
node_in
)
nodes_in
.
append
(
node_in
)
self
.
_mark_computation_from_node
(
node_in
,
node
)
self
.
_inherit_more_indice_from_node
(
node_in
,
node
)
assert
len
(
nodes_in
)
<=
2
def
_assgin_no_change_indice
(
self
,
node
,
idx
):
def
_assgin_no_change_indice
(
self
,
node
,
idx
):
self
.
_assign_indice_as_input
(
node
,
idx
)
self
.
_assign_indice_as_input
(
node
,
idx
)
for
node_in
in
node
.
args
:
for
node_in
in
node
.
args
:
if
type
(
node_in
)
==
type
(
node
):
if
type
(
node_in
)
==
type
(
node
):
self
.
_
mark_computation
_from_node
(
node_in
,
node
)
self
.
_
inherit_more_indice
_from_node
(
node_in
,
node
)
def
_assign_einsum_indice
(
self
,
node
,
idx
):
def
_assign_einsum_indice
(
self
,
node
,
idx
):
"""
"""
...
@@ -365,7 +412,7 @@ class TraceIndice(object):
...
@@ -365,7 +412,7 @@ class TraceIndice(object):
left
,
right
=
patterns
.
split
(
"->"
)
left
,
right
=
patterns
.
split
(
"->"
)
left
=
left
.
split
(
","
)
left
=
left
.
split
(
","
)
if
'
...
'
in
right
:
if
"
...
"
in
right
:
replace_list
=
"!@#$%^&*"
replace_list
=
"!@#$%^&*"
target_len
=
len
(
get_node_shape
(
node
))
target_len
=
len
(
get_node_shape
(
node
))
add_len
=
target_len
-
len
(
right
)
+
3
add_len
=
target_len
-
len
(
right
)
+
3
...
@@ -399,24 +446,22 @@ class TraceIndice(object):
...
@@ -399,24 +446,22 @@ class TraceIndice(object):
self
.
_assign_indice_as_input
(
node
,
idx
)
self
.
_assign_indice_as_input
(
node
,
idx
)
self
.
_mark_computation
(
node
,
idx
,
[
node
.
kwargs
[
"dim"
]])
self
.
_mark_computation
(
node
,
idx
,
[
node
.
kwargs
[
"dim"
]])
def
_assign_
unsqueeze
_indice
(
self
,
node
:
Node
,
node_idx
:
int
):
def
_assign_
split
_indice
(
self
,
node
:
Node
,
node_idx
:
int
)
->
None
:
"""
"""
Assign indice for unsqueeze op.
Assign indice for split op.
1. assign new indice for unsqueeze dim
Args:
Args:
node (node)
node (node)
node_idx (int)
node_idx (int)
"""
"""
self
.
_del_dim
(
node_idx
,
-
1
)
for
_
in
range
(
len
(
get_node_shape
(
node
.
args
[
0
]))):
self
.
_add_dim
(
node_idx
,
0
)
self
.
_assign_indice_as_input
(
node
,
node_idx
)
self
.
_assign_indice_as_input
(
node
,
node_idx
)
dim_idx
=
node
.
args
[
1
]
dim_idx
=
node
.
kwargs
[
"dim"
]
# unsqueeze(-1) = unsqueeze(shape_num + 1)
self
.
_del_dim
(
node_idx
,
dim_idx
)
if
dim_idx
<
0
:
dim_idx
=
list
(
range
(
len
(
get_node_shape
(
node
))))[
dim_idx
]
self
.
_add_dim
(
node_idx
,
dim_idx
)
self
.
_add_dim
(
node_idx
,
dim_idx
)
def
_assign_
dropout
_indice
(
self
,
node
:
Node
,
node_idx
:
int
):
def
_assign_
unsqueeze
_indice
(
self
,
node
:
Node
,
node_idx
:
int
)
->
None
:
"""
"""
Assign indice for unsqueeze op.
Assign indice for unsqueeze op.
1. assign new indice for unsqueeze dim
1. assign new indice for unsqueeze dim
...
@@ -425,9 +470,15 @@ class TraceIndice(object):
...
@@ -425,9 +470,15 @@ class TraceIndice(object):
node (node)
node (node)
node_idx (int)
node_idx (int)
"""
"""
self
.
_del_dim
(
node_idx
,
-
1
)
self
.
_assign_indice_as_input
(
node
,
node_idx
)
self
.
_assign_indice_as_input
(
node
,
node_idx
)
dim_idx
=
node
.
args
[
1
]
# unsqueeze(-1) = unsqueeze(shape_num + 1)
if
dim_idx
<
0
:
dim_idx
=
list
(
range
(
len
(
get_node_shape
(
node
))))[
dim_idx
]
self
.
_add_dim
(
node_idx
,
dim_idx
)
def
_assign_ones_like_indice
(
self
,
node
:
Node
,
node_idx
:
int
):
def
_assign_ones_like_indice
(
self
,
node
:
Node
,
node_idx
:
int
)
->
None
:
"""
"""
Assign indice for oneslike op.
Assign indice for oneslike op.
1. assign new indice for all dim
1. assign new indice for all dim
...
@@ -438,7 +489,7 @@ class TraceIndice(object):
...
@@ -438,7 +489,7 @@ class TraceIndice(object):
"""
"""
self
.
_assign_all_indice
(
node
,
node_idx
)
self
.
_assign_all_indice
(
node
,
node_idx
)
def
_assign_cat_indice
(
self
,
node
:
Node
,
node_idx
:
int
):
def
_assign_cat_indice
(
self
,
node
:
Node
,
node_idx
:
int
)
->
None
:
"""
"""
Assign indice for cat op.
Assign indice for cat op.
...
@@ -449,12 +500,12 @@ class TraceIndice(object):
...
@@ -449,12 +500,12 @@ class TraceIndice(object):
nodes_in
=
flat_list
(
node
.
args
[
0
])
nodes_in
=
flat_list
(
node
.
args
[
0
])
self
.
_assign_indice_as_input
(
node
,
node_idx
,
input_node
=
nodes_in
[
0
])
self
.
_assign_indice_as_input
(
node
,
node_idx
,
input_node
=
nodes_in
[
0
])
for
n
in
nodes_in
[
1
:]:
for
n
in
nodes_in
[
1
:]:
self
.
_
mark_computation
_from_node
(
n
,
node
)
self
.
_
inherit_more_indice
_from_node
(
n
,
node
)
cat_dim
=
node
.
kwargs
[
"dim"
]
cat_dim
=
node
.
kwargs
[
"dim"
]
self
.
_del_dim
(
node_idx
,
cat_dim
)
self
.
_del_dim
(
node_idx
,
cat_dim
)
self
.
_add_dim
(
node_idx
,
cat_dim
)
self
.
_add_dim
(
node_idx
,
cat_dim
)
def
_assign_sum_indice
(
self
,
node
:
Node
,
node_idx
:
int
):
def
_assign_sum_indice
(
self
,
node
:
Node
,
node_idx
:
int
)
->
None
:
"""
"""
Assign indice for sum op.
Assign indice for sum op.
...
@@ -466,11 +517,46 @@ class TraceIndice(object):
...
@@ -466,11 +517,46 @@ class TraceIndice(object):
self
.
_add_dim
(
node_idx
,
0
)
self
.
_add_dim
(
node_idx
,
0
)
self
.
_assign_indice_as_input
(
node
,
node_idx
,
input_node
=
nodes_in
[
0
])
self
.
_assign_indice_as_input
(
node
,
node_idx
,
input_node
=
nodes_in
[
0
])
for
n
in
nodes_in
[
1
:]:
for
n
in
nodes_in
[
1
:]:
self
.
_
mark_computation
_from_node
(
n
,
node
)
self
.
_
inherit_more_indice
_from_node
(
n
,
node
)
cat_dim
=
node
.
kwargs
[
"dim"
]
cat_dim
=
node
.
kwargs
[
"dim"
]
self
.
_del_dim
(
node_idx
,
cat_dim
)
self
.
_del_dim
(
node_idx
,
cat_dim
)
def
_assign_getitem_indice
(
self
,
node
:
Node
,
node_idx
:
int
):
def
_assign_arange_indice
(
self
,
node
:
Node
,
node_idx
:
int
)
->
None
:
"""
Assign indice for arange op.
Args:
node (node)
node_idx (int)
"""
self
.
_assign_all_indice
(
node
,
node_idx
)
def
_assign_tensor_indice
(
self
,
node
:
Node
,
node_idx
:
int
)
->
None
:
"""
Assign indice for tensor op.
Args:
node (node)
node_idx (int)
"""
if
len
(
get_node_shape
(
node
))
==
0
:
return
else
:
raise
NotImplementedError
()
def
_assign_embedding_indice
(
self
,
node
:
Node
,
node_idx
:
int
)
->
None
:
"""
Assign indice for embedding op.
Args:
node (node)
node_idx (int)
"""
self
.
_del_dim
(
node_idx
,
-
1
)
self
.
_assign_indice_as_input
(
node
,
node_idx
)
self
.
_add_dim
(
node_idx
,
-
1
)
def
_assign_getitem_indice
(
self
,
node
:
Node
,
node_idx
:
int
)
->
None
:
"""
"""
Assign indice for getitem.
Assign indice for getitem.
getitem can act like slice sometimes
getitem can act like slice sometimes
...
@@ -480,6 +566,19 @@ class TraceIndice(object):
...
@@ -480,6 +566,19 @@ class TraceIndice(object):
node_idx (int)
node_idx (int)
"""
"""
node_args
=
flat_list
(
node
.
args
[
1
:])
node_args
=
flat_list
(
node
.
args
[
1
:])
# deal with split
if
get_node_name
(
node
.
args
[
0
])
==
"split"
:
self
.
_assign_indice_as_input
(
node
,
node_idx
)
self
.
_del_dim
(
node_idx
,
node
.
args
[
0
].
kwargs
[
"dim"
])
self
.
_add_dim
(
node_idx
,
node
.
args
[
0
].
kwargs
[
"dim"
])
return
# skip non tensor
if
get_node_shape
(
node
)
is
None
:
return
# find if slice
flag
=
False
flag
=
False
for
node_arg
in
node_args
:
for
node_arg
in
node_args
:
node_arg_str
=
str
(
node_arg
)
node_arg_str
=
str
(
node_arg
)
...
@@ -528,7 +627,7 @@ class TraceIndice(object):
...
@@ -528,7 +627,7 @@ class TraceIndice(object):
else
:
else
:
raise
NotImplementedError
()
raise
NotImplementedError
()
def
_assign_view_reshape_indice
(
self
,
node
:
Node
,
node_idx
:
int
):
def
_assign_view_reshape_indice
(
self
,
node
:
Node
,
node_idx
:
int
)
->
None
:
"""
"""
Assign indice for view and reshape op.
Assign indice for view and reshape op.
1. get origin shape and target shape by meta info.
1. get origin shape and target shape by meta info.
...
@@ -536,7 +635,7 @@ class TraceIndice(object):
...
@@ -536,7 +635,7 @@ class TraceIndice(object):
3. determine changed dim, and assgin indice for generated dim.
3. determine changed dim, and assgin indice for generated dim.
4. log changed dim and generated dim for restore
4. log changed dim and generated dim for restore
5. inherit computation.
5. inherit computation.
6.
TODO:
look into view list to see whether the view is associated with other,
6. look into view list to see whether the view is associated with other,
if so assgin equal dim according to previous view.
if so assgin equal dim according to previous view.
Args:
Args:
...
@@ -552,7 +651,7 @@ class TraceIndice(object):
...
@@ -552,7 +651,7 @@ class TraceIndice(object):
if
isinstance
(
unflated_args
[
i
],
int
):
if
isinstance
(
unflated_args
[
i
],
int
):
target_shape
.
append
(
unflated_args
[
i
])
target_shape
.
append
(
unflated_args
[
i
])
else
:
else
:
target_shape
.
app
end
(
unflated_args
[
i
].
meta
[
"fwd_out"
]
[
0
]
)
target_shape
.
ext
end
(
unflated_args
[
i
].
meta
[
"fwd_out"
])
# compute the value of -1
# compute the value of -1
if
-
1
in
target_shape
:
if
-
1
in
target_shape
:
...
@@ -579,17 +678,36 @@ class TraceIndice(object):
...
@@ -579,17 +678,36 @@ class TraceIndice(object):
dim_from
=
[
dim_equal
.
index
(
False
)]
dim_from
=
[
dim_equal
.
index
(
False
)]
dim_to
=
[
dim_equal
.
index
(
False
),
dim_equal
.
index
(
False
)
+
1
]
dim_to
=
[
dim_equal
.
index
(
False
),
dim_equal
.
index
(
False
)
+
1
]
self
.
_del_dim
(
node_idx
,
-
1
)
self
.
_del_dim
(
node_idx
,
-
1
)
elif
len_diff
==
0
:
# dim equal
dim_equal
=
[
i
==
j
for
i
,
j
in
zip
(
origin_shape
,
target_shape
[:
-
1
])]
dim_from
=
[]
dim_to
=
[]
else
:
else
:
raise
NotImplementedError
(
"shape"
+
str
(
origin_shape
)
+
"and"
+
str
(
target_shape
)
+
"view not implemented"
)
raise
NotImplementedError
(
"shape"
+
str
(
origin_shape
)
+
"and"
+
str
(
target_shape
)
+
"view not implemented"
)
# get new indice
# get new indice
origin_trace
=
self
.
_find_indice_trace_from_node
(
origin_node
)
origin_trace
=
self
.
_find_indice_trace_from_node
(
origin_node
)
self
.
_assign_indice_as_input
(
node
,
node_idx
,
origin_node
)
self
.
_assign_indice_as_input
(
node
,
node_idx
,
origin_node
)
idx_from
=
[
origin_trace
[
i
]
for
i
in
dim_from
]
dim_from
.
reverse
()
dim_from
.
reverse
()
for
i
in
dim_from
:
for
i
in
dim_from
:
self
.
_del_dim
(
node_idx
,
i
)
self
.
_del_dim
(
node_idx
,
i
)
for
i
in
dim_to
:
for
i
in
dim_to
:
self
.
_add_dim
(
node_idx
,
i
)
self
.
_add_dim
(
node_idx
,
i
)
dim_from
.
reverse
()
# search view list
for
view_node
,
view_dict
in
self
.
indice_view_list
.
items
():
if
(
view_dict
[
"idx_to"
]
==
idx_from
and
view_dict
[
"dim_to"
]
==
dim_from
and
view_dict
[
"dim_from"
]
==
dim_to
):
# inheirt indice from current node
for
dim_to_i
in
dim_to
:
for
dim_from_i
in
dim_from
:
self
.
_inherit_indice
(
origin_node
,
dim_from_i
,
node
,
dim_to_i
,
init
=
False
)
# inherid indice from input node of last view
for
dim_to_i
in
dim_to
:
self
.
_inherit_indice
(
view_node
.
args
[
0
],
dim_to_i
,
node
,
dim_to_i
,
init
=
False
)
# inherit computation
# inherit computation
compute_log
=
self
.
_find_compute_trace_from_node
(
origin_node
)
compute_log
=
self
.
_find_compute_trace_from_node
(
origin_node
)
...
@@ -630,7 +748,7 @@ class TraceIndice(object):
...
@@ -630,7 +748,7 @@ class TraceIndice(object):
# clear compute
# clear compute
for
dim_compute
in
trace
[
"compute"
]:
for
dim_compute
in
trace
[
"compute"
]:
for
i
in
range
(
len
(
dim_compute
)
-
1
,
-
1
,
-
1
):
for
i
in
range
(
len
(
dim_compute
)
-
1
,
-
1
,
-
1
):
if
dim_compute
[
i
]
<
trace_range
[
0
]
and
dim_compute
[
i
]
not
in
active_nodes
:
if
(
dim_compute
[
i
]
<
trace_range
[
0
]
and
dim_compute
[
i
]
not
in
active_nodes
)
:
dim_compute
.
pop
(
i
)
dim_compute
.
pop
(
i
)
continue
continue
# clear source
# clear source
...
@@ -639,59 +757,82 @@ class TraceIndice(object):
...
@@ -639,59 +757,82 @@ class TraceIndice(object):
if
k
<
trace_range
[
0
]
and
k
not
in
active_nodes
:
if
k
<
trace_range
[
0
]
and
k
not
in
active_nodes
:
dim_source
.
pop
(
k
)
dim_source
.
pop
(
k
)
def
trace_indice
(
self
):
def
trace_indice
(
self
)
->
None
:
for
idx
,
node
in
enumerate
(
self
.
node_list
):
for
idx
,
node
in
enumerate
(
self
.
node_list
):
node_name
=
get_node_name
(
node
)
if
node
.
op
==
"placeholder"
:
if
node
.
op
==
"placeholder"
:
self
.
_assign_all_indice
(
node
,
idx
)
self
.
_assign_all_indice
(
node
,
idx
)
elif
node
.
op
==
"call_method"
:
elif
node
.
op
==
"call_method"
:
if
"transpose"
in
node
.
name
:
if
"transpose"
==
node
_
name
:
self
.
_assign_transpose_indice
(
node
,
idx
)
self
.
_assign_transpose_indice
(
node
,
idx
)
elif
"permute"
in
node
.
name
:
elif
"permute"
==
node
_
name
:
self
.
_assign_permute_indice
(
node
,
idx
)
self
.
_assign_permute_indice
(
node
,
idx
)
elif
"view"
in
node
.
name
or
"reshape"
in
node
.
name
:
elif
"view"
==
node
_
name
or
"reshape"
==
node
_
name
:
self
.
_assign_view_reshape_indice
(
node
,
idx
)
self
.
_assign_view_reshape_indice
(
node
,
idx
)
elif
"unsqueeze"
in
node
.
name
:
elif
"unsqueeze"
==
node
_
name
:
self
.
_assign_unsqueeze_indice
(
node
,
idx
)
self
.
_assign_unsqueeze_indice
(
node
,
idx
)
elif
any
(
i
in
node
.
name
for
i
in
[
"to"
,
"contiguous"
,
"clone"
]):
elif
"split"
==
node_name
:
self
.
_assign_split_indice
(
node
,
idx
)
elif
any
(
i
==
node_name
for
i
in
[
"to"
,
"contiguous"
,
"clone"
,
"type"
]):
self
.
_assgin_no_change_indice
(
node
,
idx
)
self
.
_assgin_no_change_indice
(
node
,
idx
)
elif
"new_ones"
in
node
.
name
:
elif
"new_ones"
==
node
_
name
:
self
.
_assign_ones_like_indice
(
node
,
idx
)
self
.
_assign_ones_like_indice
(
node
,
idx
)
elif
any
(
i
==
node_name
for
i
in
[
"size"
]):
continue
else
:
else
:
raise
NotImplementedError
(
node
.
name
,
"method not implemented yet!"
)
raise
NotImplementedError
(
node
_
name
,
"method not implemented yet!"
)
elif
node
.
op
==
"call_function"
:
elif
node
.
op
==
"call_function"
:
if
"linear"
in
node
.
name
:
if
"linear"
==
node
_
name
:
self
.
_assign_linear_indice
(
node
,
idx
)
self
.
_assign_linear_indice
(
node
,
idx
)
elif
"cat"
in
node
.
name
:
elif
"cat"
==
node
_
name
:
self
.
_assign_cat_indice
(
node
,
idx
)
self
.
_assign_cat_indice
(
node
,
idx
)
elif
"matmul"
in
node
.
name
:
elif
"matmul"
==
node
_
name
:
self
.
_assign_matmul_indice
(
node
,
idx
)
self
.
_assign_matmul_indice
(
node
,
idx
)
elif
"softmax"
in
node
.
name
:
elif
"softmax"
==
node
_
name
:
self
.
_assign_softmax_indice
(
node
,
idx
)
self
.
_assign_softmax_indice
(
node
,
idx
)
elif
any
(
n
in
node
.
name
for
n
in
[
"mul"
,
"add"
,
"sigmoid"
,
"relu"
,
"sub"
,
"truediv"
]):
elif
any
(
n
==
node_name
for
n
in
[
"mul"
,
"add"
,
"sigmoid"
,
"relu"
,
"sub"
,
"truediv"
,
"pow"
,
"dropout"
,
"where"
,
"tanh"
,
]):
self
.
_assign_elementwise_indice
(
node
,
idx
)
self
.
_assign_elementwise_indice
(
node
,
idx
)
elif
"ones_like"
in
node
.
name
:
elif
"ones_like"
==
node
_
name
:
self
.
_assign_ones_like_indice
(
node
,
idx
)
self
.
_assign_ones_like_indice
(
node
,
idx
)
elif
"dropout"
in
node
.
name
:
elif
"einsum"
==
node_name
:
self
.
_assign_dropout_indice
(
node
,
idx
)
elif
"einsum"
in
node
.
name
:
self
.
_assign_einsum_indice
(
node
,
idx
)
self
.
_assign_einsum_indice
(
node
,
idx
)
elif
"sum"
in
node
.
name
:
elif
"sum"
==
node
_
name
:
self
.
_assign_sum_indice
(
node
,
idx
)
self
.
_assign_sum_indice
(
node
,
idx
)
elif
"layer_norm"
in
node
.
name
:
elif
"layer_norm"
==
node
_
name
:
self
.
_assign_layernorm_indice
(
node
,
idx
)
self
.
_assign_layernorm_indice
(
node
,
idx
)
elif
"getitem"
in
node
.
name
:
elif
"getitem"
==
node
_
name
:
self
.
_assign_getitem_indice
(
node
,
idx
)
self
.
_assign_getitem_indice
(
node
,
idx
)
elif
any
(
i
in
node
.
name
for
i
in
[
"getattr"
,
"getitem"
,
"eq"
,
"_assert"
]):
elif
"addmm"
==
node_name
:
self
.
_assign_addmm_indice
(
node
,
idx
)
elif
"arange"
==
node_name
:
self
.
_assign_arange_indice
(
node
,
idx
)
elif
"tensor"
==
node_name
:
self
.
_assign_arange_indice
(
node
,
idx
)
elif
any
(
i
==
node_name
for
i
in
[
"getattr"
,
"eq"
,
"_assert_is_none"
,
"_assert"
,
"finfo"
]):
continue
continue
else
:
else
:
raise
NotImplementedError
(
node
.
name
,
"function not implemented yet!"
)
raise
NotImplementedError
(
node
_
name
,
"function not implemented yet!"
)
elif
node
.
op
==
"call_module"
:
elif
node
.
op
==
"call_module"
:
if
any
(
n
in
node
.
name
for
n
in
[
"layernorm"
,
"norm"
]):
node_name
=
get_module_node_name
(
node
)
if
"layernorm"
==
node_name
:
self
.
_assign_layernorm_indice
(
node
,
idx
)
self
.
_assign_layernorm_indice
(
node
,
idx
)
elif
any
(
n
in
node
.
name
for
n
in
[
"sigmoid"
,
"dropout"
,
"relu"
]):
elif
"embedding"
==
node_name
:
self
.
_assign_embedding_indice
(
node
,
idx
)
elif
any
(
n
==
node_name
for
n
in
[
"sigmoid"
,
"dropout"
,
"relu"
]):
self
.
_assign_elementwise_indice
(
node
,
idx
)
self
.
_assign_elementwise_indice
(
node
,
idx
)
else
:
else
:
raise
NotImplementedError
(
node
.
name
,
"module not implemented yet!"
)
raise
NotImplementedError
(
node
_
name
,
"module not implemented yet!"
)
elif
node
.
op
==
"get_attr"
:
elif
node
.
op
==
"get_attr"
:
self
.
_assign_all_indice
(
node
,
idx
)
# get param
self
.
_assign_all_indice
(
node
,
idx
)
# get param
elif
node
.
op
==
"output"
:
elif
node
.
op
==
"output"
:
...
...
colossalai/autochunk/utils.py
View file @
63199c66
from
typing
import
Any
,
Callable
,
Dict
,
Iterable
,
List
,
Tuple
from
typing
import
Any
,
Callable
,
Dict
,
Iterable
,
List
,
Tuple
,
Union
from
torch.fx.node
import
Node
from
torch.fx.node
import
Node
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
NON_COMPUTE_OP
=
[
"placeholder"
,
"get_attr"
,
"output"
]
NON_COMPUTE_NAME
=
[
"getattr"
,
"eq"
,
"_assert_is_none"
,
"_assert"
,
"finfo"
,
"size"
]
logger
=
get_dist_logger
()
logger
=
get_dist_logger
()
def
get_logger
():
def
get_logger
()
->
Any
:
return
logger
return
logger
...
@@ -37,7 +39,7 @@ def find_first_tensor_arg(node: Node) -> Node:
...
@@ -37,7 +39,7 @@ def find_first_tensor_arg(node: Node) -> Node:
def
is_non_compute_node
(
node
:
Node
)
->
bool
:
def
is_non_compute_node
(
node
:
Node
)
->
bool
:
if
any
(
i
in
node
.
op
for
i
in
[
"placeholder"
,
"get_attr"
,
"output"
]
)
or
any
(
i
in
node
.
name
for
i
in
[
"getattr"
]
):
if
any
(
i
==
node
.
op
for
i
in
NON_COMPUTE_OP
)
or
any
(
i
==
get_
node
_
name
(
node
)
for
i
in
NON_COMPUTE_NAME
):
return
True
return
True
if
"getitem"
in
node
.
name
:
if
"getitem"
in
node
.
name
:
node_args
=
flat_list
(
node
.
args
[
1
:])
node_args
=
flat_list
(
node
.
args
[
1
:])
...
@@ -64,33 +66,33 @@ def is_non_memory_node(node: Node) -> bool:
...
@@ -64,33 +66,33 @@ def is_non_memory_node(node: Node) -> bool:
return
is_non_compute_node
(
node
)
return
is_non_compute_node
(
node
)
def
is_non_compute_node_except_placeholder
(
node
)
:
def
is_non_compute_node_except_placeholder
(
node
:
Node
)
->
bool
:
if
"placeholder"
in
node
.
op
:
if
"placeholder"
in
node
.
op
:
return
False
return
False
return
is_non_compute_node
(
node
)
return
is_non_compute_node
(
node
)
def
is_non_compute_node_except_placeholder_output
(
node
)
:
def
is_non_compute_node_except_placeholder_output
(
node
:
Node
)
->
bool
:
if
"output"
in
node
.
op
:
if
"output"
in
node
.
op
:
return
False
return
False
return
is_non_compute_node_except_placeholder
(
node
)
return
is_non_compute_node_except_placeholder
(
node
)
def
find_idx_by_name
(
name
,
nodes_list
)
:
def
find_idx_by_name
(
name
:
str
,
nodes_list
:
List
)
->
int
:
for
idx
,
node
in
enumerate
(
nodes_list
):
for
idx
,
node
in
enumerate
(
nodes_list
):
if
node
.
name
==
name
:
if
node
.
name
==
name
:
return
idx
return
idx
raise
RuntimeError
(
"name %s not found in node list"
%
name
)
raise
RuntimeError
(
"name %s not found in node list"
%
name
)
def
delete_free_var_from_last_use
(
user_to_last_uses
)
:
def
delete_free_var_from_last_use
(
user_to_last_uses
:
Dict
)
->
None
:
for
key
,
value
in
user_to_last_uses
.
items
():
for
key
,
value
in
user_to_last_uses
.
items
():
for
n
in
value
:
for
n
in
value
:
if
n
.
op
==
"placeholder"
:
if
n
.
op
==
"placeholder"
:
user_to_last_uses
[
key
].
remove
(
n
)
user_to_last_uses
[
key
].
remove
(
n
)
def
find_chunk_all_input_nodes
(
nodes
:
List
[
Node
]):
def
find_chunk_all_input_nodes
(
nodes
:
List
[
Node
])
->
List
:
"""
"""
Find non-compute input and output node names.
Find non-compute input and output node names.
input nodes are nodes used in the list
input nodes are nodes used in the list
...
@@ -104,7 +106,7 @@ def find_chunk_all_input_nodes(nodes: List[Node]):
...
@@ -104,7 +106,7 @@ def find_chunk_all_input_nodes(nodes: List[Node]):
return
input_nodes
return
input_nodes
def
find_chunk_compute_input_and_output_nodes
(
nodes
:
List
[
Node
]):
def
find_chunk_compute_input_and_output_nodes
(
nodes
:
List
[
Node
])
->
Union
[
List
,
List
]
:
"""
"""
Find non-compute input and output node names.
Find non-compute input and output node names.
input nodes are nodes used in the list
input nodes are nodes used in the list
...
@@ -130,3 +132,33 @@ def find_chunk_compute_input_and_output_nodes(nodes: List[Node]):
...
@@ -130,3 +132,33 @@ def find_chunk_compute_input_and_output_nodes(nodes: List[Node]):
output_nodes
.
append
(
node
)
output_nodes
.
append
(
node
)
return
input_nodes
,
output_nodes
return
input_nodes
,
output_nodes
def
get_module_node_name
(
node
:
Node
)
->
str
:
"""
get module class name
"""
node_targets
=
node
.
target
.
split
(
"."
)
module
=
node
.
graph
.
owning_module
for
i
in
node_targets
:
module
=
getattr
(
module
,
i
)
module_name
=
str
(
module
.
__class__
).
split
(
"."
)[
-
1
][:
-
2
]
module_name
=
module_name
.
lower
()
return
module_name
def
get_node_name
(
node
:
Node
)
->
str
:
"""
get node name
"""
node_name
=
node
.
name
if
"_"
in
node_name
:
for
i
in
range
(
len
(
node_name
)
-
1
,
-
1
,
-
1
):
if
node_name
[
i
]
==
"_"
:
node_name
=
node_name
[:
i
]
break
elif
node_name
[
i
]
in
[
"1"
,
"2"
,
"3"
,
"4"
,
"5"
,
"6"
,
"7"
,
"8"
,
"9"
,
"0"
]:
continue
else
:
break
return
node_name
tests/test_autochunk/benchmark_simple_evoformer.py
deleted
100644 → 0
View file @
6e0faa70
import
time
import
torch
import
torch.fx
from
simple_evoformer
import
base_evoformer
,
openfold_evoformer
from
colossalai.autochunk.autochunk_codegen
import
AutoChunkCodeGen
from
colossalai.fx
import
ColoTracer
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.fx.profiler
import
MetaTensor
def
_benchmark_evoformer
(
model
:
torch
.
nn
.
Module
,
node
,
pair
,
title
,
chunk_size
=
None
):
torch
.
cuda
.
reset_peak_memory_stats
()
now_mem
=
torch
.
cuda
.
memory_allocated
()
/
1024
**
2
loop
=
3
with
torch
.
no_grad
():
for
_
in
range
(
loop
//
2
+
1
):
if
chunk_size
:
model
(
node
,
pair
,
chunk_size
)
else
:
model
(
node
,
pair
)
torch
.
cuda
.
synchronize
()
time1
=
time
.
time
()
for
_
in
range
(
loop
):
if
chunk_size
:
model
(
node
,
pair
,
chunk_size
)
else
:
model
(
node
,
pair
)
torch
.
cuda
.
synchronize
()
time2
=
time
.
time
()
new_max_mem
=
torch
.
cuda
.
max_memory_allocated
()
/
1024
**
2
print
(
"%s: time %.4fs, mem %dMB"
%
(
title
,
(
time2
-
time1
)
/
loop
,
new_max_mem
-
now_mem
))
def
_build_autochunk
(
model
,
max_memory
,
node
,
pair
):
# trace the module and replace codegen
graph
=
ColoTracer
().
trace
(
model
,
meta_args
=
{
"node"
:
node
.
to
(
torch
.
device
(
"meta"
)),
"pair"
:
pair
.
to
(
torch
.
device
(
"meta"
)),
},
)
gm_prop
=
torch
.
fx
.
symbolic_trace
(
model
)
# must use symbolic_trace
interp
=
MetaInfoProp
(
gm_prop
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
))
# now run it twice to get meta info in graph module, not necessary
gm
=
torch
.
fx
.
GraphModule
(
model
,
graph
)
interp
=
MetaInfoProp
(
gm
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
))
# set code_gen
codegen
=
AutoChunkCodeGen
(
gm_prop
,
max_memory
,
print_mem
=
False
)
graph
.
set_codegen
(
codegen
)
gm
=
ColoGraphModule
(
model
,
graph
)
gm
.
recompile
()
# print
# code = graph.python_code("self").src
# print(code)
return
gm
def
benchmark_evoformer
():
# init data and model
msa_len
=
128
pair_len
=
256
node
=
torch
.
randn
(
1
,
msa_len
,
pair_len
,
256
).
cuda
()
pair
=
torch
.
randn
(
1
,
pair_len
,
pair_len
,
128
).
cuda
()
model
=
base_evoformer
().
cuda
()
# build autochunk model
# max_memory = 1000 # MB, fit memory mode
max_memory
=
None
# min memory mode
autochunk
=
_build_autochunk
(
base_evoformer
().
cuda
(),
max_memory
,
node
,
pair
)
# build openfold
chunk_size
=
64
openfold
=
openfold_evoformer
().
cuda
()
# benchmark
_benchmark_evoformer
(
model
,
node
,
pair
,
"base"
)
_benchmark_evoformer
(
openfold
,
node
,
pair
,
"openfold"
,
chunk_size
=
chunk_size
)
_benchmark_evoformer
(
autochunk
,
node
,
pair
,
"autochunk"
)
if
__name__
==
"__main__"
:
benchmark_evoformer
()
tests/test_autochunk/test_alphafold/test_alphafold_utils.py
0 → 100644
View file @
63199c66
from
typing
import
Any
,
Dict
,
List
import
torch
import
torch.fx
import
colossalai
from
colossalai.autochunk.autochunk_codegen
import
AUTOCHUNK_AVAILABLE
from
colossalai.autochunk.utils
import
flat_list
from
colossalai.core
import
global_context
as
gpc
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.utils
import
free_port
if
AUTOCHUNK_AVAILABLE
:
from
colossalai.autochunk.autochunk_codegen
import
AutoChunkCodeGen
from
colossalai.fx.profiler
import
MetaTensor
from
colossalai.fx.tracer.experimental
import
ColoTracer
,
symbolic_trace
def
assert_codegen_run
(
model
:
Any
,
meta_args
:
List
,
concrete_args
:
List
=
None
,
max_memory
:
int
=
None
,
print_mem
:
bool
=
False
,
print_progress
:
bool
=
False
,
print_code
:
bool
=
False
,
)
->
List
[
Dict
]:
if
concrete_args
is
None
:
concrete_args
=
[]
# trace the meta graph and setup codegen
meta_graph
=
symbolic_trace
(
model
,
meta_args
=
{
k
:
v
.
to
(
torch
.
device
(
"meta"
))
for
k
,
v
in
meta_args
},
concrete_args
=
{
k
:
v
for
k
,
v
in
concrete_args
},
)
interp
=
MetaInfoProp
(
meta_graph
)
meta_tensors
=
[
MetaTensor
(
i
[
1
],
fake_device
=
"cuda:0"
)
for
i
in
meta_args
]
+
[
i
[
1
]
for
i
in
concrete_args
]
interp
.
propagate
(
*
meta_tensors
)
codegen
=
AutoChunkCodeGen
(
meta_graph
,
max_memory
=
max_memory
,
print_mem
=
print_mem
,
print_progress
=
print_progress
,
)
chunks
=
codegen
.
chunk_infos
# trace and recompile
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
graph
=
ColoTracer
().
trace
(
model
,
meta_args
=
{
k
:
v
.
to
(
torch
.
device
(
"meta"
))
for
k
,
v
in
meta_args
},
concrete_args
=
{
k
:
v
for
k
,
v
in
concrete_args
},
)
graph
.
set_codegen
(
codegen
)
gm
=
ColoGraphModule
(
model
,
graph
,
ckpt_codegen
=
False
)
gm
.
recompile
()
# assert chunk in code
code
=
graph
.
python_code
(
"self"
).
src
if
print_code
:
print
(
code
)
assert
"chunk_result = None; chunk_size = None;"
in
code
# assert result
inputs
=
[
i
[
1
]
for
i
in
meta_args
]
+
[
i
[
1
]
for
i
in
concrete_args
]
model
.
cuda
()
with
torch
.
no_grad
():
out_gm
=
gm
(
*
inputs
)
out_model
=
model
(
*
inputs
)
out_gm
=
flat_list
(
out_gm
)
out_model
=
flat_list
(
out_model
)
for
out_gm_i
,
out_model_i
in
zip
(
out_gm
,
out_model
):
assert
torch
.
allclose
(
out_gm_i
,
out_model_i
,
atol
=
1e-4
),
"fx_out doesn't comply with original output, diff is %.2e"
%
torch
.
mean
(
torch
.
abs
(
out_gm_i
-
out_model_i
))
return
chunks
def
run_test
(
rank
:
int
,
data_args
:
tuple
,
max_memory
:
int
,
get_model
:
Any
,
get_data
:
Any
,
print_code
:
bool
,
print_mem
:
bool
,
print_progress
:
bool
,
get_chunk_target
:
Any
=
None
,
)
->
None
:
# launch colossalai
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
1
,
host
=
"localhost"
,
port
=
free_port
(),
backend
=
"nccl"
,
)
# build model and input
model
=
get_model
()
meta_args
,
concrete_args
=
get_data
(
*
data_args
)
chunks
=
assert_codegen_run
(
model
,
meta_args
=
meta_args
,
concrete_args
=
concrete_args
,
max_memory
=
max_memory
,
print_code
=
print_code
,
print_mem
=
print_mem
,
print_progress
=
print_progress
,
)
if
get_chunk_target
is
not
None
:
chunk_found
=
[
i
[
"region"
]
for
i
in
chunks
]
chunk_target
=
get_chunk_target
()[
max_memory
]
assert
chunk_found
==
chunk_target
,
"found regions %s doesn't equal target regions %s"
%
(
str
(
chunk_found
),
str
(
chunk_target
),
)
tests/test_autochunk/test_alphafold/test_evoformer_block.py
0 → 100644
View file @
63199c66
from
functools
import
partial
from
typing
import
Dict
,
List
,
Tuple
import
pytest
import
torch
import
torch.fx
import
torch.multiprocessing
as
mp
try
:
from
fastfold.model.nn.evoformer
import
EvoformerBlock
HAS_REPO
=
True
except
:
HAS_REPO
=
False
from
test_alphafold_utils
import
run_test
from
colossalai.autochunk.autochunk_codegen
import
AUTOCHUNK_AVAILABLE
def
get_model
():
model
=
EvoformerBlock
(
c_m
=
256
,
c_z
=
128
,
c_hidden_msa_att
=
32
,
c_hidden_opm
=
32
,
c_hidden_mul
=
128
,
c_hidden_pair_att
=
32
,
no_heads_msa
=
8
,
no_heads_pair
=
4
,
transition_n
=
4
,
msa_dropout
=
0.15
,
pair_dropout
=
0.15
,
inf
=
1e4
,
eps
=
1e-4
,
is_multimer
=
False
,
).
eval
().
cuda
()
return
model
def
get_data
(
msa_len
:
int
,
pair_len
:
int
)
->
Tuple
[
List
,
List
]:
node
=
torch
.
randn
(
1
,
msa_len
,
pair_len
,
256
).
cuda
()
node_mask
=
torch
.
randn
(
1
,
msa_len
,
pair_len
).
cuda
()
pair
=
torch
.
randn
(
1
,
pair_len
,
pair_len
,
128
).
cuda
()
pair_mask
=
torch
.
randn
(
1
,
pair_len
,
pair_len
).
cuda
()
meta_args
=
[
(
"m"
,
node
),
(
"z"
,
pair
),
(
"msa_mask"
,
node_mask
),
(
"pair_mask"
,
pair_mask
),
]
concrete_args
=
[(
"chunk_size"
,
None
),
(
"_mask_trans"
,
True
)]
return
meta_args
,
concrete_args
def
get_chunk_target
()
->
Dict
:
return
{
None
:
[(
118
,
123
),
(
219
,
237
),
(
264
,
289
),
(
302
,
309
),
(
97
,
104
),
(
144
,
152
),
(
185
,
193
),
(
241
,
242
),
(
21
,
46
)],
20
:
[(
118
,
123
),
(
230
,
237
),
(
275
,
282
),
(
305
,
306
),
(
100
,
101
),
(
32
,
39
),
(
73
,
79
)],
24
:
[(
118
,
123
)],
}
@
pytest
.
mark
.
skipif
(
not
(
AUTOCHUNK_AVAILABLE
and
HAS_REPO
),
reason
=
"torch version is lower than 1.12.0"
,
)
@
pytest
.
mark
.
parametrize
(
"max_memory"
,
[
None
,
20
,
24
])
@
pytest
.
mark
.
parametrize
(
"data_args"
,
[(
32
,
64
)])
# (msa_len, pair_len)
def
test_evoformer_block
(
data_args
,
max_memory
):
run_func
=
partial
(
run_test
,
data_args
=
data_args
,
max_memory
=
max_memory
,
get_model
=
get_model
,
get_data
=
get_data
,
get_chunk_target
=
get_chunk_target
,
print_code
=
False
,
print_mem
=
False
,
print_progress
=
False
,
)
mp
.
spawn
(
run_func
,
nprocs
=
1
)
if
__name__
==
"__main__"
:
run_test
(
rank
=
0
,
data_args
=
(
32
,
64
),
max_memory
=
20
,
get_model
=
get_model
,
get_data
=
get_data
,
print_code
=
False
,
print_mem
=
False
,
print_progress
=
False
,
)
tests/test_autochunk/test_alphafold/test_evoformer_stack.py
0 → 100644
View file @
63199c66
from
functools
import
partial
from
typing
import
List
,
Tuple
import
pytest
import
torch
import
torch.fx
import
torch.multiprocessing
as
mp
try
:
from
fastfold.model.nn.evoformer
import
EvoformerStack
HAS_REPO
=
True
except
:
HAS_REPO
=
False
from
test_alphafold_utils
import
run_test
from
colossalai.autochunk.autochunk_codegen
import
AUTOCHUNK_AVAILABLE
def
get_model
():
model
=
EvoformerStack
(
c_m
=
256
,
c_z
=
128
,
c_hidden_msa_att
=
32
,
c_hidden_opm
=
32
,
c_hidden_mul
=
128
,
c_hidden_pair_att
=
32
,
c_s
=
384
,
no_heads_msa
=
8
,
no_heads_pair
=
4
,
no_blocks
=
2
,
# 48
transition_n
=
4
,
msa_dropout
=
0.15
,
pair_dropout
=
0.25
,
blocks_per_ckpt
=
None
,
inf
=
1000000000.0
,
eps
=
1e-08
,
clear_cache_between_blocks
=
False
,
is_multimer
=
False
,
).
eval
().
cuda
()
return
model
def
get_data
(
msa_len
:
int
,
pair_len
:
int
)
->
Tuple
[
List
,
List
]:
node
=
torch
.
randn
(
1
,
msa_len
,
pair_len
,
256
).
cuda
()
node_mask
=
torch
.
randn
(
1
,
msa_len
,
pair_len
).
cuda
()
pair
=
torch
.
randn
(
1
,
pair_len
,
pair_len
,
128
).
cuda
()
pair_mask
=
torch
.
randn
(
1
,
pair_len
,
pair_len
).
cuda
()
meta_args
=
[
(
"m"
,
node
),
(
"z"
,
pair
),
(
"msa_mask"
,
node_mask
),
(
"pair_mask"
,
pair_mask
),
]
concrete_args
=
[(
"chunk_size"
,
None
),
(
"_mask_trans"
,
True
)]
return
meta_args
,
concrete_args
@
pytest
.
mark
.
skipif
(
not
(
AUTOCHUNK_AVAILABLE
and
HAS_REPO
),
reason
=
"torch version is lower than 1.12.0"
,
)
@
pytest
.
mark
.
parametrize
(
"max_memory"
,
[
None
,
20
,
24
])
@
pytest
.
mark
.
parametrize
(
"data_args"
,
[(
32
,
64
)])
# (msa_len, pair_len)
def
test_evoformer_stack
(
data_args
,
max_memory
):
run_func
=
partial
(
run_test
,
data_args
=
data_args
,
max_memory
=
max_memory
,
get_model
=
get_model
,
get_data
=
get_data
,
print_code
=
False
,
print_mem
=
False
,
print_progress
=
False
,
)
mp
.
spawn
(
run_func
,
nprocs
=
1
)
if
__name__
==
"__main__"
:
run_test
(
rank
=
0
,
data_args
=
(
32
,
64
),
max_memory
=
20
,
get_model
=
get_model
,
get_data
=
get_data
,
print_code
=
False
,
print_mem
=
False
,
print_progress
=
False
,
)
tests/test_autochunk/test_alphafold/test_extramsa_block.py
0 → 100644
View file @
63199c66
from
functools
import
partial
from
typing
import
Dict
,
List
,
Tuple
import
pytest
import
torch
import
torch.fx
import
torch.multiprocessing
as
mp
try
:
from
fastfold.model.nn.evoformer
import
ExtraMSABlock
HAS_REPO
=
True
except
:
HAS_REPO
=
False
from
test_alphafold_utils
import
run_test
from
colossalai.autochunk.autochunk_codegen
import
AUTOCHUNK_AVAILABLE
def
get_model
():
model
=
ExtraMSABlock
(
c_m
=
256
,
c_z
=
128
,
c_hidden_msa_att
=
32
,
c_hidden_opm
=
32
,
c_hidden_mul
=
128
,
c_hidden_pair_att
=
32
,
no_heads_msa
=
8
,
no_heads_pair
=
4
,
transition_n
=
4
,
msa_dropout
=
0.15
,
pair_dropout
=
0.15
,
inf
=
1e4
,
eps
=
1e-4
,
ckpt
=
False
,
is_multimer
=
False
,
).
eval
().
cuda
()
return
model
def
get_data
(
msa_len
:
int
,
pair_len
:
int
)
->
Tuple
[
List
,
List
]:
node
=
torch
.
randn
(
1
,
msa_len
,
pair_len
,
256
).
cuda
()
node_mask
=
torch
.
randn
(
1
,
msa_len
,
pair_len
).
cuda
()
pair
=
torch
.
randn
(
1
,
pair_len
,
pair_len
,
128
).
cuda
()
pair_mask
=
torch
.
randn
(
1
,
pair_len
,
pair_len
).
cuda
()
meta_args
=
[
(
"m"
,
node
),
(
"z"
,
pair
),
(
"msa_mask"
,
node_mask
),
(
"pair_mask"
,
pair_mask
),
]
concrete_args
=
[(
"chunk_size"
,
None
),
(
"_chunk_logits"
,
1024
)]
return
meta_args
,
concrete_args
def
get_chunk_target
()
->
Dict
:
return
{
None
:
[(
126
,
131
),
(
227
,
245
),
(
272
,
297
),
(
310
,
317
),
(
105
,
112
),
(
152
,
160
),
(
193
,
201
),
(
249
,
250
),
(
33
,
46
)],
20
:
[(
126
,
131
),
(
238
,
245
),
(
283
,
290
),
(
313
,
314
),
(
108
,
109
),
(
35
,
46
)],
24
:
[(
126
,
131
)],
}
@
pytest
.
mark
.
skipif
(
not
(
AUTOCHUNK_AVAILABLE
and
HAS_REPO
),
reason
=
"torch version is lower than 1.12.0"
,
)
@
pytest
.
mark
.
parametrize
(
"max_memory"
,
[
None
,
20
,
24
])
@
pytest
.
mark
.
parametrize
(
"data_args"
,
[(
32
,
64
)])
# (msa_len, pair_len)
def
test_extramsa_block
(
data_args
,
max_memory
):
run_func
=
partial
(
run_test
,
data_args
=
data_args
,
max_memory
=
max_memory
,
get_model
=
get_model
,
get_data
=
get_data
,
print_code
=
False
,
print_mem
=
False
,
print_progress
=
False
,
)
mp
.
spawn
(
run_func
,
nprocs
=
1
)
if
__name__
==
"__main__"
:
run_test
(
rank
=
0
,
data_args
=
(
32
,
64
),
max_memory
=
20
,
get_model
=
get_model
,
get_data
=
get_data
,
get_chunk_target
=
get_chunk_target
,
print_code
=
False
,
print_mem
=
False
,
print_progress
=
False
,
)
tests/test_autochunk/test_
simple_evoformer_codegen
.py
→
tests/test_autochunk/test_
diffuser/test_diffuser_utils
.py
View file @
63199c66
from
functools
import
partial
from
typing
import
Any
,
Dict
,
List
import
pytest
import
torch
import
torch
import
torch.fx
import
torch.fx
import
torch.multiprocessing
as
mp
try
:
from
simple_evoformer
import
base_evoformer
HAS_REPO
=
True
except
:
HAS_REPO
=
False
import
colossalai
import
colossalai
from
colossalai.autochunk.autochunk_codegen
import
AUTOCHUNK_AVAILABLE
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.fx
import
ColoTracer
,
symbolic_trace
from
colossalai.fx._compatibility
import
is_compatible_with_meta
from
colossalai.fx.codegen.activation_checkpoint_codegen
import
CODEGEN_AVAILABLE
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
if
CODEGEN
_AVAILABLE
and
is_compatible_with_meta
()
:
if
AUTOCHUNK
_AVAILABLE
:
from
colossalai.autochunk.autochunk_codegen
import
AutoChunkCodeGen
from
colossalai.autochunk.autochunk_codegen
import
AutoChunkCodeGen
from
colossalai.fx.profiler
import
MetaTensor
from
colossalai.fx.profiler
import
MetaTensor
from
colossalai.fx.tracer.experimental
import
ColoTracer
,
symbolic_trace
def
assert_codegen_run
(
model
:
Any
,
meta_args
:
List
,
concrete_args
:
List
=
None
,
max_memory
:
int
=
None
,
print_mem
:
bool
=
False
,
print_progress
:
bool
=
False
,
print_code
:
bool
=
False
,
)
->
List
[
Dict
]:
if
concrete_args
is
None
:
concrete_args
=
[]
model
=
model
()
# trace the meta graph and setup codegen
meta_graph
=
symbolic_trace
(
model
,
meta_args
=
{
k
:
v
.
to
(
torch
.
device
(
"meta"
))
for
k
,
v
in
meta_args
},
concrete_args
=
{
k
:
v
for
k
,
v
in
concrete_args
},
)
interp
=
MetaInfoProp
(
meta_graph
)
meta_tensors
=
[
MetaTensor
(
i
[
1
],
fake_device
=
"cuda:0"
)
for
i
in
meta_args
]
+
[
i
[
1
]
for
i
in
concrete_args
]
interp
.
propagate
(
*
meta_tensors
)
codegen
=
AutoChunkCodeGen
(
meta_graph
,
max_memory
=
max_memory
,
print_mem
=
print_mem
,
print_progress
=
print_progress
,
)
chunks
=
codegen
.
chunk_infos
# trace and recompile
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
graph
=
ColoTracer
().
trace
(
model
.
cuda
(),
meta_args
=
{
k
:
v
.
to
(
torch
.
device
(
"meta"
))
for
k
,
v
in
meta_args
},
concrete_args
=
{
k
:
v
for
k
,
v
in
concrete_args
},
)
graph
.
set_codegen
(
codegen
)
gm
=
ColoGraphModule
(
model
,
graph
,
ckpt_codegen
=
False
)
gm
.
recompile
()
def
_test_fwd
(
model
:
torch
.
nn
.
Module
,
gm
:
ColoGraphModule
,
node
,
pair
):
# assert chunk in code
with
torch
.
no_grad
():
code
=
graph
.
python_code
(
"self"
).
src
non_fx_out
=
model
(
node
,
pair
)
if
print_code
:
fx_out
=
gm
(
node
,
pair
)
print
(
code
)
assert
"chunk_result = None; chunk_size = None;"
in
code
assert
torch
.
allclose
(
non_fx_out
[
0
],
fx_out
[
0
],
# assert result
atol
=
1e-4
),
"fx_out doesn't comply with original output, diff is %.2e"
%
torch
.
mean
(
inputs
=
[
i
[
1
]
for
i
in
meta_args
]
+
[
i
[
1
]
for
i
in
concrete_args
]
torch
.
abs
(
non_fx_out
[
0
]
-
fx_out
[
0
]))
model
.
cuda
().
eval
()
assert
torch
.
allclose
(
non_fx_out
[
1
],
fx_out
[
1
],
gm
.
eval
()
with
torch
.
no_grad
():
out_gm
=
gm
(
*
inputs
)
out_model
=
model
(
*
inputs
)
assert
torch
.
allclose
(
out_gm
[
"sample"
],
out_model
[
"sample"
],
atol
=
1e-4
),
"fx_out doesn't comply with original output, diff is %.2e"
%
torch
.
mean
(
atol
=
1e-4
),
"fx_out doesn't comply with original output, diff is %.2e"
%
torch
.
mean
(
torch
.
abs
(
non_fx_out
[
1
]
-
fx_out
[
1
]))
torch
.
abs
(
out_gm
[
"sample"
]
-
out_model
[
"sample"
]))
return
chunks
def
_test_simple_evoformer_codegen
(
rank
,
msa_len
,
pair_len
,
max_memory
):
def
run_test
(
rank
:
int
,
model
:
Any
,
data
:
tuple
,
max_memory
:
int
,
print_code
:
bool
,
print_mem
:
bool
,
print_progress
:
bool
,
get_chunk_target
:
Any
=
None
,
)
->
None
:
# launch colossalai
# launch colossalai
colossalai
.
launch
(
colossalai
.
launch
(
config
=
{},
config
=
{},
...
@@ -50,55 +98,23 @@ def _test_simple_evoformer_codegen(rank, msa_len, pair_len, max_memory):
...
@@ -50,55 +98,23 @@ def _test_simple_evoformer_codegen(rank, msa_len, pair_len, max_memory):
)
)
# build model and input
# build model and input
model
=
base_evoformer
().
cuda
()
meta_args
,
concrete_args
=
data
node
=
torch
.
randn
(
1
,
msa_len
,
pair_len
,
256
).
cuda
()
chunks
=
assert_codegen_run
(
pair
=
torch
.
randn
(
1
,
pair_len
,
pair_len
,
128
).
cuda
()
# meta info prop
meta_graph
=
symbolic_trace
(
model
,
meta_args
=
{
"node"
:
node
.
to
(
torch
.
device
(
"meta"
)),
"pair"
:
pair
.
to
(
torch
.
device
(
"meta"
)),
})
# must use symbolic_trace
interp
=
MetaInfoProp
(
meta_graph
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
))
codegen
=
AutoChunkCodeGen
(
meta_graph
,
max_memory
=
max_memory
)
# trace the module and replace codegen
graph
=
ColoTracer
().
trace
(
model
,
model
,
meta_args
=
{
meta_args
=
meta_args
,
"node"
:
node
.
to
(
torch
.
device
(
"meta"
)),
concrete_args
=
concrete_args
,
"pair"
:
pair
.
to
(
torch
.
device
(
"meta"
)),
},
)
graph
.
set_codegen
(
codegen
)
gm
=
ColoGraphModule
(
model
,
graph
,
ckpt_codegen
=
False
)
gm
.
recompile
()
# assert we have inserted chunk
code
=
graph
.
python_code
(
"self"
).
src
# print(code)
assert
"chunk_result = None; chunk_size = None;"
in
code
_test_fwd
(
model
,
gm
,
node
,
pair
)
gpc
.
destroy
()
@
pytest
.
mark
.
skipif
(
not
(
CODEGEN_AVAILABLE
and
is_compatible_with_meta
()
and
HAS_REPO
),
reason
=
'torch version is lower than 1.12.0'
)
@
pytest
.
mark
.
parametrize
(
"max_memory"
,
[
None
,
20
,
25
,
30
])
@
pytest
.
mark
.
parametrize
(
"msa_len"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"pair_len"
,
[
64
])
def
test_simple_evoformer_codegen
(
msa_len
,
pair_len
,
max_memory
):
run_func
=
partial
(
_test_simple_evoformer_codegen
,
msa_len
=
msa_len
,
pair_len
=
pair_len
,
max_memory
=
max_memory
,
max_memory
=
max_memory
,
print_code
=
print_code
,
print_mem
=
print_mem
,
print_progress
=
print_progress
,
)
)
mp
.
spawn
(
run_func
,
nprocs
=
1
)
if
get_chunk_target
is
not
None
:
chunk_found
=
[
i
[
"region"
]
for
i
in
chunks
]
chunk_target
=
get_chunk_target
()[
max_memory
]
assert
(
chunk_found
==
chunk_target
),
"found regions %s doesn't equal target regions %s"
%
(
str
(
chunk_found
),
str
(
chunk_target
),
)
if
__name__
==
"__main__"
:
gpc
.
destroy
()
_test_simple_evoformer_codegen
(
0
,
32
,
64
,
25
)
tests/test_autochunk/test_diffuser/test_unet.py
0 → 100644
View file @
63199c66
from
functools
import
partial
from
typing
import
List
,
Tuple
import
pytest
import
torch
import
torch.multiprocessing
as
mp
try
:
from
diffusers
import
UNet2DModel
MODELS
=
[
UNet2DModel
]
HAS_REPO
=
True
except
:
MODELS
=
[]
HAS_REPO
=
False
from
test_diffuser_utils
import
run_test
from
colossalai.autochunk.autochunk_codegen
import
AUTOCHUNK_AVAILABLE
BATCH_SIZE
=
2
SEQ_LENGTH
=
5
HEIGHT
=
224
WIDTH
=
224
IN_CHANNELS
=
3
LATENTS_SHAPE
=
(
BATCH_SIZE
,
IN_CHANNELS
,
HEIGHT
//
7
,
WIDTH
//
7
)
def
get_data
(
shape
:
tuple
)
->
Tuple
[
List
,
List
]:
sample
=
torch
.
randn
(
shape
)
meta_args
=
[
(
"sample"
,
sample
),
]
concrete_args
=
[(
"timestep"
,
50
)]
return
meta_args
,
concrete_args
@
pytest
.
mark
.
skipif
(
True
,
reason
=
"not implemented"
,
)
@
pytest
.
mark
.
skipif
(
not
(
AUTOCHUNK_AVAILABLE
and
HAS_REPO
),
reason
=
"torch version is lower than 1.12.0"
,
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"shape"
,
[
LATENTS_SHAPE
])
@
pytest
.
mark
.
parametrize
(
"max_memory"
,
[
64
])
def
test_evoformer_block
(
model
,
shape
,
max_memory
):
run_func
=
partial
(
run_test
,
max_memory
=
max_memory
,
model
=
model
,
data
=
get_data
(
shape
),
print_code
=
False
,
print_mem
=
False
,
print_progress
=
False
,
)
mp
.
spawn
(
run_func
,
nprocs
=
1
)
if
__name__
==
"__main__"
:
run_test
(
rank
=
0
,
data
=
get_data
(
LATENTS_SHAPE
),
max_memory
=
64
,
model
=
UNet2DModel
,
print_code
=
False
,
print_mem
=
False
,
print_progress
=
False
,
)
tests/test_autochunk/test_evoformer_codegen.py
deleted
100644 → 0
View file @
6e0faa70
from
functools
import
partial
import
pytest
import
torch
import
torch.fx
import
torch.multiprocessing
as
mp
try
:
from
fastfold.model.nn.evoformer
import
EvoformerBlock
HAS_REPO
=
True
except
:
HAS_REPO
=
False
import
colossalai
from
colossalai.core
import
global_context
as
gpc
from
colossalai.fx._compatibility
import
is_compatible_with_meta
from
colossalai.fx.codegen.activation_checkpoint_codegen
import
CODEGEN_AVAILABLE
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.utils
import
free_port
if
CODEGEN_AVAILABLE
and
is_compatible_with_meta
():
from
colossalai.autochunk.autochunk_codegen
import
AutoChunkCodeGen
from
colossalai.fx.profiler
import
MetaTensor
from
colossalai.fx.tracer.experimental
import
ColoTracer
,
symbolic_trace
def
_test_fwd
(
model
:
torch
.
nn
.
Module
,
gm
:
ColoGraphModule
,
node
,
pair
,
node_mask
,
pair_mask
):
# for memory test
# model = model.cuda()
# torch.cuda.reset_peak_memory_stats()
# now_mem = torch.cuda.memory_allocated() / 1024**2
# with torch.no_grad():
# node1 = node.clone()
# pair1 = pair.clone()
# node_mask1 = node_mask.clone()
# pair_mask1 = pair_mask.clone()
# gm(node1, pair1, node_mask1, pair_mask1)
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
# print("autochunk max mem:%.2f"% (new_max_mem - now_mem))
# test forward
model
=
model
.
cuda
()
with
torch
.
no_grad
():
non_fx_out
=
model
(
node
,
pair
,
node_mask
,
pair_mask
)
fx_out
=
gm
(
node
,
pair
,
node_mask
,
pair_mask
)
assert
torch
.
allclose
(
non_fx_out
[
0
],
fx_out
[
0
],
atol
=
1e-4
),
"fx_out doesn't comply with original output, diff is %.2e"
%
torch
.
mean
(
torch
.
abs
(
non_fx_out
[
0
]
-
fx_out
[
0
]))
assert
torch
.
allclose
(
non_fx_out
[
1
],
fx_out
[
1
],
atol
=
1e-4
),
"fx_out doesn't comply with original output, diff is %.2e"
%
torch
.
mean
(
torch
.
abs
(
non_fx_out
[
1
]
-
fx_out
[
1
]))
def
_build_openfold
():
model
=
EvoformerBlock
(
c_m
=
256
,
c_z
=
128
,
c_hidden_msa_att
=
32
,
c_hidden_opm
=
32
,
c_hidden_mul
=
128
,
c_hidden_pair_att
=
32
,
no_heads_msa
=
8
,
no_heads_pair
=
4
,
transition_n
=
4
,
msa_dropout
=
0.15
,
pair_dropout
=
0.15
,
inf
=
1e4
,
eps
=
1e-4
,
is_multimer
=
False
,
).
eval
().
cuda
()
return
model
def
_test_evoformer_codegen
(
rank
,
msa_len
,
pair_len
,
max_memory
):
# launch colossalai
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
1
,
host
=
"localhost"
,
port
=
free_port
(),
backend
=
"nccl"
,
)
# build model and input
model
=
_build_openfold
()
node
=
torch
.
randn
(
1
,
msa_len
,
pair_len
,
256
).
cuda
()
node_mask
=
torch
.
randn
(
1
,
msa_len
,
pair_len
).
cuda
()
pair
=
torch
.
randn
(
1
,
pair_len
,
pair_len
,
128
).
cuda
()
pair_mask
=
torch
.
randn
(
1
,
pair_len
,
pair_len
).
cuda
()
# trace the meta graph and setup codegen
meta_graph
=
symbolic_trace
(
model
,
meta_args
=
{
"m"
:
node
.
to
(
torch
.
device
(
"meta"
)),
"z"
:
pair
.
to
(
torch
.
device
(
"meta"
)),
"msa_mask"
:
node_mask
.
to
(
torch
.
device
(
"meta"
)),
"pair_mask"
:
pair_mask
.
to
(
torch
.
device
(
"meta"
)),
},
concrete_args
=
{
"chunk_size"
:
None
,
"_mask_trans"
:
True
,
},
)
interp
=
MetaInfoProp
(
meta_graph
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
),
MetaTensor
(
node_mask
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair_mask
,
fake_device
=
"cuda:0"
),
)
codegen
=
AutoChunkCodeGen
(
meta_graph
,
max_memory
=
max_memory
,
print_mem
=
False
)
# trace and recompile
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
graph
=
ColoTracer
().
trace
(
model
,
meta_args
=
{
"m"
:
node
.
to
(
torch
.
device
(
"meta"
)),
"z"
:
pair
.
to
(
torch
.
device
(
"meta"
)),
"msa_mask"
:
node_mask
.
to
(
torch
.
device
(
"meta"
)),
"pair_mask"
:
pair_mask
.
to
(
torch
.
device
(
"meta"
)),
},
concrete_args
=
{
"chunk_size"
:
None
,
"_mask_trans"
:
True
,
},
)
graph
.
set_codegen
(
codegen
)
gm
=
ColoGraphModule
(
model
,
graph
,
ckpt_codegen
=
False
)
gm
.
recompile
()
# assert we have inserted chunk
code
=
graph
.
python_code
(
"self"
).
src
# print(code)
assert
"chunk_result = None; chunk_size = None;"
in
code
_test_fwd
(
model
,
gm
,
node
,
pair
,
node_mask
,
pair_mask
)
gpc
.
destroy
()
@
pytest
.
mark
.
skipif
(
not
(
CODEGEN_AVAILABLE
and
is_compatible_with_meta
()
and
HAS_REPO
),
reason
=
"torch version is lower than 1.12.0"
,
)
@
pytest
.
mark
.
parametrize
(
"max_memory"
,
[
None
,
24
,
28
,
32
])
@
pytest
.
mark
.
parametrize
(
"msa_len"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"pair_len"
,
[
64
])
def
test_evoformer_codegen
(
msa_len
,
pair_len
,
max_memory
):
run_func
=
partial
(
_test_evoformer_codegen
,
msa_len
=
msa_len
,
pair_len
=
pair_len
,
max_memory
=
max_memory
,
)
mp
.
spawn
(
run_func
,
nprocs
=
1
)
if
__name__
==
"__main__"
:
_test_evoformer_codegen
(
0
,
32
,
64
,
24
)
tests/test_autochunk/test_evoformer_stack_codegen.py
deleted
100644 → 0
View file @
6e0faa70
from
functools
import
partial
import
pytest
import
torch
import
torch.fx
import
torch.multiprocessing
as
mp
try
:
from
fastfold.model.nn.evoformer
import
EvoformerStack
HAS_REPO
=
True
except
:
HAS_REPO
=
False
import
colossalai
from
colossalai.core
import
global_context
as
gpc
from
colossalai.fx._compatibility
import
is_compatible_with_meta
from
colossalai.fx.codegen.activation_checkpoint_codegen
import
CODEGEN_AVAILABLE
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.utils
import
free_port
if
CODEGEN_AVAILABLE
and
is_compatible_with_meta
():
from
colossalai.autochunk.autochunk_codegen
import
AutoChunkCodeGen
from
colossalai.fx.profiler
import
MetaTensor
from
colossalai.fx.tracer.experimental
import
ColoTracer
,
symbolic_trace
def
_test_fwd
(
model
:
torch
.
nn
.
Module
,
gm
:
ColoGraphModule
,
node
,
pair
,
node_mask
,
pair_mask
):
# for memory test
# model = model.cuda()
# torch.cuda.reset_peak_memory_stats()
# now_mem = torch.cuda.memory_allocated() / 1024**2
# with torch.no_grad():
# node1 = node.clone()
# pair1 = pair.clone()
# node_mask1 = node_mask.clone()
# pair_mask1 = pair_mask.clone()
# gm(node1, pair1, node_mask1, pair_mask1, None)
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
# print("autochunk max mem:%.2f"% (new_max_mem - now_mem))
# test forward
model
=
model
.
cuda
()
with
torch
.
no_grad
():
non_fx_out
=
model
(
node
,
pair
,
node_mask
,
pair_mask
,
None
)
fx_out
=
gm
(
node
,
pair
,
node_mask
,
pair_mask
,
None
)
assert
torch
.
allclose
(
non_fx_out
[
0
],
fx_out
[
0
],
atol
=
1e-4
),
"fx_out doesn't comply with original output, diff is %.2e"
%
torch
.
mean
(
torch
.
abs
(
non_fx_out
[
0
]
-
fx_out
[
0
]))
assert
torch
.
allclose
(
non_fx_out
[
1
],
fx_out
[
1
],
atol
=
1e-4
),
"fx_out doesn't comply with original output, diff is %.2e"
%
torch
.
mean
(
torch
.
abs
(
non_fx_out
[
1
]
-
fx_out
[
1
]))
def
_build_openfold
():
model
=
EvoformerStack
(
c_m
=
256
,
c_z
=
128
,
c_hidden_msa_att
=
32
,
c_hidden_opm
=
32
,
c_hidden_mul
=
128
,
c_hidden_pair_att
=
32
,
c_s
=
384
,
no_heads_msa
=
8
,
no_heads_pair
=
4
,
no_blocks
=
2
,
# 48
transition_n
=
4
,
msa_dropout
=
0.15
,
pair_dropout
=
0.25
,
blocks_per_ckpt
=
None
,
inf
=
1000000000.0
,
eps
=
1e-08
,
clear_cache_between_blocks
=
False
,
is_multimer
=
False
,
).
eval
().
cuda
()
return
model
def
_test_evoformer_stack_codegen
(
rank
,
msa_len
,
pair_len
,
max_memory
):
# launch colossalai
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
1
,
host
=
"localhost"
,
port
=
free_port
(),
backend
=
"nccl"
,
)
# build model and input
model
=
_build_openfold
()
node
=
torch
.
randn
(
1
,
msa_len
,
pair_len
,
256
).
cuda
()
node_mask
=
torch
.
randn
(
1
,
msa_len
,
pair_len
).
cuda
()
pair
=
torch
.
randn
(
1
,
pair_len
,
pair_len
,
128
).
cuda
()
pair_mask
=
torch
.
randn
(
1
,
pair_len
,
pair_len
).
cuda
()
# trace the meta graph and setup codegen
meta_graph
=
symbolic_trace
(
model
,
meta_args
=
{
"m"
:
node
.
to
(
torch
.
device
(
"meta"
)),
"z"
:
pair
.
to
(
torch
.
device
(
"meta"
)),
"msa_mask"
:
node_mask
.
to
(
torch
.
device
(
"meta"
)),
"pair_mask"
:
pair_mask
.
to
(
torch
.
device
(
"meta"
)),
},
concrete_args
=
{
"chunk_size"
:
None
,
"_mask_trans"
:
True
,
},
)
interp
=
MetaInfoProp
(
meta_graph
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
),
MetaTensor
(
node_mask
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair_mask
,
fake_device
=
"cuda:0"
),
None
)
codegen
=
AutoChunkCodeGen
(
meta_graph
,
max_memory
=
max_memory
,
print_mem
=
False
,
print_progress
=
False
)
# trace and recompile
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
graph
=
ColoTracer
().
trace
(
model
,
meta_args
=
{
"m"
:
node
.
to
(
torch
.
device
(
"meta"
)),
"z"
:
pair
.
to
(
torch
.
device
(
"meta"
)),
"msa_mask"
:
node_mask
.
to
(
torch
.
device
(
"meta"
)),
"pair_mask"
:
pair_mask
.
to
(
torch
.
device
(
"meta"
)),
},
concrete_args
=
{
"chunk_size"
:
None
,
"_mask_trans"
:
True
,
},
)
graph
.
set_codegen
(
codegen
)
gm
=
ColoGraphModule
(
model
,
graph
,
ckpt_codegen
=
False
)
gm
.
recompile
()
# assert we have inserted chunk
code
=
graph
.
python_code
(
"self"
).
src
# print(code)
assert
"chunk_result = None; chunk_size = None;"
in
code
_test_fwd
(
model
,
gm
,
node
,
pair
,
node_mask
,
pair_mask
)
gpc
.
destroy
()
@
pytest
.
mark
.
skipif
(
not
(
CODEGEN_AVAILABLE
and
is_compatible_with_meta
()
and
HAS_REPO
),
reason
=
"torch version is lower than 1.12.0"
,
)
@
pytest
.
mark
.
parametrize
(
"max_memory"
,
[
None
,
24
,
28
,
32
])
@
pytest
.
mark
.
parametrize
(
"msa_len"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"pair_len"
,
[
64
])
def
test_evoformer_stack_codegen
(
msa_len
,
pair_len
,
max_memory
):
run_func
=
partial
(
_test_evoformer_stack_codegen
,
msa_len
=
msa_len
,
pair_len
=
pair_len
,
max_memory
=
max_memory
,
)
mp
.
spawn
(
run_func
,
nprocs
=
1
)
if
__name__
==
"__main__"
:
_test_evoformer_stack_codegen
(
0
,
32
,
64
,
None
)
tests/test_autochunk/test_extramsa_codegen.py
deleted
100644 → 0
View file @
6e0faa70
from
functools
import
partial
import
pytest
import
torch
import
torch.fx
import
torch.multiprocessing
as
mp
try
:
from
fastfold.model.nn.evoformer
import
ExtraMSABlock
HAS_REPO
=
True
except
:
HAS_REPO
=
False
import
colossalai
from
colossalai.core
import
global_context
as
gpc
from
colossalai.fx._compatibility
import
is_compatible_with_meta
from
colossalai.fx.codegen.activation_checkpoint_codegen
import
CODEGEN_AVAILABLE
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.utils
import
free_port
if
CODEGEN_AVAILABLE
and
is_compatible_with_meta
():
from
colossalai.autochunk.autochunk_codegen
import
AutoChunkCodeGen
from
colossalai.fx.profiler
import
MetaTensor
from
colossalai.fx.tracer.experimental
import
ColoTracer
,
symbolic_trace
def
_test_fwd
(
model
:
torch
.
nn
.
Module
,
gm
:
ColoGraphModule
,
node
,
pair
,
node_mask
,
pair_mask
):
# for memory test
# model = model.cuda()
# torch.cuda.reset_peak_memory_stats()
# now_mem = torch.cuda.memory_allocated() / 1024**2
# with torch.no_grad():
# node1 = node.clone()
# pair1 = pair.clone()
# node_mask1 = node_mask.clone()
# pair_mask1 = pair_mask.clone()
# gm(node1, pair1, node_mask1, pair_mask1)
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
# print("autochunk max mem:%.2f"% (new_max_mem - now_mem))
# test forward
model
=
model
.
cuda
()
with
torch
.
no_grad
():
non_fx_out
=
model
(
node
,
pair
,
node_mask
,
pair_mask
)
fx_out
=
gm
(
node
,
pair
,
node_mask
,
pair_mask
)
assert
torch
.
allclose
(
non_fx_out
[
0
],
fx_out
[
0
],
atol
=
1e-4
),
"fx_out doesn't comply with original output, diff is %.2e"
%
torch
.
mean
(
torch
.
abs
(
non_fx_out
[
0
]
-
fx_out
[
0
]))
assert
torch
.
allclose
(
non_fx_out
[
1
],
fx_out
[
1
],
atol
=
1e-4
),
"fx_out doesn't comply with original output, diff is %.2e"
%
torch
.
mean
(
torch
.
abs
(
non_fx_out
[
1
]
-
fx_out
[
1
]))
def
_build_openfold
():
model
=
ExtraMSABlock
(
c_m
=
256
,
c_z
=
128
,
c_hidden_msa_att
=
32
,
c_hidden_opm
=
32
,
c_hidden_mul
=
128
,
c_hidden_pair_att
=
32
,
no_heads_msa
=
8
,
no_heads_pair
=
4
,
transition_n
=
4
,
msa_dropout
=
0.15
,
pair_dropout
=
0.15
,
inf
=
1e4
,
eps
=
1e-4
,
ckpt
=
False
,
is_multimer
=
False
,
).
eval
().
cuda
()
return
model
def
_test_extramsa_codegen
(
rank
,
msa_len
,
pair_len
,
max_memory
):
# launch colossalai
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
1
,
host
=
"localhost"
,
port
=
free_port
(),
backend
=
"nccl"
,
)
# build model and input
model
=
_build_openfold
()
node
=
torch
.
randn
(
1
,
msa_len
,
pair_len
,
256
).
cuda
()
node_mask
=
torch
.
randn
(
1
,
msa_len
,
pair_len
).
cuda
()
pair
=
torch
.
randn
(
1
,
pair_len
,
pair_len
,
128
).
cuda
()
pair_mask
=
torch
.
randn
(
1
,
pair_len
,
pair_len
).
cuda
()
# trace the meta graph and setup codegen
meta_graph
=
symbolic_trace
(
model
,
meta_args
=
{
"m"
:
node
.
to
(
torch
.
device
(
"meta"
)),
"z"
:
pair
.
to
(
torch
.
device
(
"meta"
)),
"msa_mask"
:
node_mask
.
to
(
torch
.
device
(
"meta"
)),
"pair_mask"
:
pair_mask
.
to
(
torch
.
device
(
"meta"
)),
},
concrete_args
=
{
"chunk_size"
:
None
,
"_chunk_logits"
:
1024
,
},
)
interp
=
MetaInfoProp
(
meta_graph
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
),
MetaTensor
(
node_mask
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair_mask
,
fake_device
=
"cuda:0"
),
)
codegen
=
AutoChunkCodeGen
(
meta_graph
,
max_memory
=
max_memory
,
print_mem
=
False
)
# trace and recompile
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
graph
=
ColoTracer
().
trace
(
model
,
meta_args
=
{
"m"
:
node
.
to
(
torch
.
device
(
"meta"
)),
"z"
:
pair
.
to
(
torch
.
device
(
"meta"
)),
"msa_mask"
:
node_mask
.
to
(
torch
.
device
(
"meta"
)),
"pair_mask"
:
pair_mask
.
to
(
torch
.
device
(
"meta"
)),
},
concrete_args
=
{
"chunk_size"
:
None
,
"_chunk_logits"
:
1024
,
},
)
graph
.
set_codegen
(
codegen
)
gm
=
ColoGraphModule
(
model
,
graph
,
ckpt_codegen
=
False
)
gm
.
recompile
()
# assert we have inserted chunk
code
=
graph
.
python_code
(
"self"
).
src
# print(code)
assert
"chunk_result = None; chunk_size = None;"
in
code
_test_fwd
(
model
,
gm
,
node
,
pair
,
node_mask
,
pair_mask
)
gpc
.
destroy
()
@
pytest
.
mark
.
skipif
(
not
(
CODEGEN_AVAILABLE
and
is_compatible_with_meta
()
and
HAS_REPO
),
reason
=
"torch version is lower than 1.12.0"
,
)
@
pytest
.
mark
.
parametrize
(
"max_memory"
,
[
None
,
24
,
28
,
32
])
@
pytest
.
mark
.
parametrize
(
"msa_len"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"pair_len"
,
[
64
])
def
test_extramsa_codegen
(
msa_len
,
pair_len
,
max_memory
):
run_func
=
partial
(
_test_extramsa_codegen
,
msa_len
=
msa_len
,
pair_len
=
pair_len
,
max_memory
=
max_memory
,
)
mp
.
spawn
(
run_func
,
nprocs
=
1
)
if
__name__
==
"__main__"
:
_test_extramsa_codegen
(
0
,
32
,
64
,
None
)
tests/test_autochunk/test_simple_evoformer_search.py
deleted
100644 → 0
View file @
6e0faa70
from
functools
import
partial
import
pytest
import
torch
import
torch.fx
import
torch.multiprocessing
as
mp
try
:
from
simple_evoformer
import
base_evoformer
HAS_REPO
=
True
except
:
HAS_REPO
=
False
import
colossalai
from
colossalai.core
import
global_context
as
gpc
from
colossalai.fx
import
symbolic_trace
from
colossalai.fx._compatibility
import
is_compatible_with_meta
from
colossalai.fx.codegen.activation_checkpoint_codegen
import
CODEGEN_AVAILABLE
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.utils
import
free_port
if
CODEGEN_AVAILABLE
and
is_compatible_with_meta
():
from
colossalai.autochunk.autochunk_codegen
import
AutoChunkCodeGen
from
colossalai.fx.profiler
import
MetaTensor
def
assert_chunk_infos
(
chunk_infos
,
max_memory
,
msa_len
,
pair_len
):
found_regions
=
[
i
[
"region"
]
for
i
in
chunk_infos
]
if
msa_len
==
32
and
pair_len
==
64
:
if
max_memory
is
None
:
target_regions
=
[(
142
,
154
),
(
366
,
373
),
(
234
,
283
),
(
302
,
351
),
(
127
,
134
),
(
211
,
228
),
(
174
,
191
),
(
161
,
166
),
(
198
,
203
),
(
7
,
57
)]
elif
max_memory
==
20
:
target_regions
=
[(
142
,
154
),
(
369
,
373
),
(
235
,
269
),
(
303
,
351
),
(
130
,
131
)]
elif
max_memory
==
25
:
target_regions
=
[(
144
,
154
),
(
369
,
370
)]
elif
max_memory
==
30
:
target_regions
=
[(
144
,
154
)]
else
:
raise
NotImplementedError
()
else
:
raise
NotImplementedError
()
assert
found_regions
==
target_regions
,
"found regions %s doesn't equal target regions %s"
%
(
str
(
found_regions
),
str
(
target_regions
),
)
def
_test_simple_evoformer_search
(
rank
,
msa_len
,
pair_len
,
max_memory
):
# launch colossalai
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
1
,
host
=
"localhost"
,
port
=
free_port
(),
backend
=
"nccl"
,
)
# build model and input
model
=
base_evoformer
().
cuda
()
node
=
torch
.
randn
(
1
,
msa_len
,
pair_len
,
256
).
cuda
()
pair
=
torch
.
randn
(
1
,
pair_len
,
pair_len
,
128
).
cuda
()
meta_graph
=
symbolic_trace
(
model
,
meta_args
=
{
"node"
:
node
.
to
(
torch
.
device
(
"meta"
)),
"pair"
:
pair
.
to
(
torch
.
device
(
"meta"
)),
})
# must use symbolic_trace
interp
=
MetaInfoProp
(
meta_graph
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
))
codegen
=
AutoChunkCodeGen
(
meta_graph
,
max_memory
=
max_memory
)
chunk_infos
=
codegen
.
chunk_infos
assert_chunk_infos
(
chunk_infos
,
max_memory
,
msa_len
,
pair_len
)
gpc
.
destroy
()
@
pytest
.
mark
.
skipif
(
not
(
CODEGEN_AVAILABLE
and
is_compatible_with_meta
()
and
HAS_REPO
),
reason
=
"torch version is lower than 1.12.0"
)
@
pytest
.
mark
.
parametrize
(
"max_memory"
,
[
None
,
20
,
25
,
30
])
@
pytest
.
mark
.
parametrize
(
"msa_len"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"pair_len"
,
[
64
])
def
test_simple_evoformer_search
(
msa_len
,
pair_len
,
max_memory
):
run_func
=
partial
(
_test_simple_evoformer_search
,
msa_len
=
msa_len
,
pair_len
=
pair_len
,
max_memory
=
max_memory
,
)
mp
.
spawn
(
run_func
,
nprocs
=
1
)
if
__name__
==
"__main__"
:
_test_simple_evoformer_search
(
0
,
32
,
64
,
20
)
tests/test_autochunk/test_transformer/test_autochunk_gpt.py
0 → 100644
View file @
63199c66
from
functools
import
partial
from
typing
import
List
,
Tuple
import
pytest
import
torch
import
torch.multiprocessing
as
mp
try
:
from
transformers
import
GPT2Config
,
GPT2Model
MODELS
=
[
GPT2Model
]
HAS_REPO
=
True
except
:
MODELS
=
[]
HAS_REPO
=
False
from
test_transformer_utils
import
run_test
from
colossalai.autochunk.autochunk_codegen
import
AUTOCHUNK_AVAILABLE
BATCH_SIZE
=
2
SEQ_LENGTH
=
256
def
get_data
(
shape
:
tuple
)
->
Tuple
[
List
,
List
]:
input_ids
=
torch
.
zeros
(
shape
,
dtype
=
torch
.
int64
)
token_type_ids
=
torch
.
zeros
(
shape
,
dtype
=
torch
.
int64
)
attention_mask
=
torch
.
ones
(
shape
,
dtype
=
torch
.
int64
)
meta_args
=
dict
(
input_ids
=
input_ids
,
token_type_ids
=
token_type_ids
,
attention_mask
=
attention_mask
)
concrete_args
=
{
"past_key_values"
:
None
}
sequence
=
[
"input_ids"
,
"past_key_values"
,
"attention_mask"
,
"token_type_ids"
]
return
meta_args
,
concrete_args
,
sequence
@
pytest
.
mark
.
skipif
(
not
(
AUTOCHUNK_AVAILABLE
and
HAS_REPO
),
reason
=
"torch version is lower than 1.12.0"
,
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"shape"
,
[(
BATCH_SIZE
,
SEQ_LENGTH
)])
@
pytest
.
mark
.
parametrize
(
"max_memory"
,
[
None
,
4.5
,
5
])
def
test_gpt
(
model
,
shape
,
max_memory
):
run_func
=
partial
(
run_test
,
data
=
get_data
(
shape
),
max_memory
=
max_memory
,
model
=
model
,
config
=
GPT2Config
(
n_embd
=
96
,
n_position
=
shape
[
1
],
n_layer
=
2
,
n_head
=
4
),
print_code
=
False
,
print_mem
=
False
,
print_progress
=
False
,
)
mp
.
spawn
(
run_func
,
nprocs
=
1
)
if
__name__
==
"__main__"
:
run_test
(
rank
=
0
,
data
=
get_data
((
BATCH_SIZE
,
SEQ_LENGTH
)),
max_memory
=
None
,
model
=
GPT2Model
,
config
=
GPT2Config
(
n_embd
=
96
,
n_position
=
SEQ_LENGTH
,
n_layer
=
2
,
n_head
=
4
),
print_code
=
True
,
print_mem
=
True
,
print_progress
=
False
,
)
tests/test_autochunk/test_transformer/test_transformer_utils.py
0 → 100644
View file @
63199c66
from
typing
import
Any
,
Dict
,
List
import
torch
import
torch.fx
import
colossalai
from
colossalai.autochunk.autochunk_codegen
import
AUTOCHUNK_AVAILABLE
from
colossalai.core
import
global_context
as
gpc
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.utils
import
free_port
if
AUTOCHUNK_AVAILABLE
:
from
colossalai.autochunk.autochunk_codegen
import
AutoChunkCodeGen
from
colossalai.fx.profiler
import
MetaTensor
from
colossalai.fx.tracer.experimental
import
ColoTracer
,
symbolic_trace
def
assert_codegen_run
(
model
:
Any
,
data
:
tuple
,
max_memory
:
int
=
None
,
print_mem
:
bool
=
False
,
print_progress
:
bool
=
False
,
print_code
:
bool
=
False
,
)
->
List
[
Dict
]:
meta_args
,
concrete_args
,
sequence
=
data
if
concrete_args
is
None
:
concrete_args
=
{}
# trace the meta graph and setup codegen
meta_graph
=
symbolic_trace
(
model
,
meta_args
=
{
k
:
v
.
to
(
torch
.
device
(
"meta"
))
for
k
,
v
in
meta_args
.
items
()},
concrete_args
=
{
k
:
v
for
k
,
v
in
concrete_args
.
items
()},
)
interp
=
MetaInfoProp
(
meta_graph
)
meta_tensors
=
[
meta_args
[
i
]
if
i
in
meta_args
else
concrete_args
[
i
]
for
i
in
sequence
]
meta_tensors
=
[
MetaTensor
(
i
,
fake_device
=
"cuda:0"
)
if
isinstance
(
i
,
torch
.
Tensor
)
else
i
for
i
in
meta_tensors
]
interp
.
propagate
(
*
meta_tensors
)
codegen
=
AutoChunkCodeGen
(
meta_graph
,
max_memory
=
max_memory
,
print_mem
=
print_mem
,
print_progress
=
print_progress
,
)
chunks
=
codegen
.
chunk_infos
# trace and recompile
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
graph
=
ColoTracer
().
trace
(
model
.
cuda
(),
meta_args
=
{
k
:
v
.
to
(
torch
.
device
(
"meta"
))
for
k
,
v
in
meta_args
.
items
()},
concrete_args
=
{
k
:
v
for
k
,
v
in
concrete_args
.
items
()},
)
graph
.
set_codegen
(
codegen
)
gm
=
ColoGraphModule
(
model
,
graph
,
ckpt_codegen
=
False
)
gm
.
recompile
()
# assert chunk in code
code
=
graph
.
python_code
(
"self"
).
src
if
print_code
:
print
(
code
)
assert
"chunk_result = None; chunk_size = None;"
in
code
# assert result
inputs
=
[
meta_args
[
i
]
if
i
in
meta_args
else
concrete_args
[
i
]
for
i
in
sequence
]
inputs
=
[
i
.
cuda
()
if
isinstance
(
i
,
torch
.
Tensor
)
else
i
for
i
in
inputs
]
model
.
cuda
().
eval
()
gm
.
eval
()
with
torch
.
no_grad
():
out_gm
=
gm
(
*
inputs
)
out_model
=
model
(
*
inputs
)
for
k
in
out_model
.
keys
():
if
torch
.
is_tensor
(
out_gm
[
k
]):
assert
torch
.
equal
(
out_model
[
k
],
out_gm
[
k
]
),
f
'
{
model
.
__class__
.
__name__
}
has incorrect output
{
k
}
, expect
{
out_model
[
k
]
}
, but got
{
out_gm
[
k
]
}
'
return
chunks
def
run_test
(
rank
:
int
,
model
:
Any
,
config
:
Any
,
data
:
tuple
,
max_memory
:
int
,
print_code
:
bool
,
print_mem
:
bool
,
print_progress
:
bool
,
get_chunk_target
:
Any
=
None
,
)
->
None
:
model
=
model
(
config
=
config
)
# launch colossalai
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
1
,
host
=
"localhost"
,
port
=
free_port
(),
backend
=
"nccl"
,
)
# build model and input
chunks
=
assert_codegen_run
(
model
,
data
=
data
,
max_memory
=
max_memory
,
print_code
=
print_code
,
print_mem
=
print_mem
,
print_progress
=
print_progress
,
)
if
get_chunk_target
is
not
None
:
chunk_found
=
[
i
[
"region"
]
for
i
in
chunks
]
chunk_target
=
get_chunk_target
()[
max_memory
]
assert
(
chunk_found
==
chunk_target
),
"found regions %s doesn't equal target regions %s"
%
(
str
(
chunk_found
),
str
(
chunk_target
),
)
gpc
.
destroy
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment