Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
63199c66
Unverified
Commit
63199c66
authored
Jan 31, 2023
by
oahzxl
Committed by
GitHub
Jan 31, 2023
Browse files
[autochunk] support transformer (#2526)
parent
6e0faa70
Changes
19
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
1198 additions
and
964 deletions
+1198
-964
colossalai/autochunk/autochunk_codegen.py
colossalai/autochunk/autochunk_codegen.py
+5
-2
colossalai/autochunk/search_chunk.py
colossalai/autochunk/search_chunk.py
+14
-52
colossalai/autochunk/select_chunk.py
colossalai/autochunk/select_chunk.py
+48
-76
colossalai/autochunk/trace_flow.py
colossalai/autochunk/trace_flow.py
+68
-44
colossalai/autochunk/trace_indice.py
colossalai/autochunk/trace_indice.py
+241
-100
colossalai/autochunk/utils.py
colossalai/autochunk/utils.py
+41
-9
tests/test_autochunk/benchmark_simple_evoformer.py
tests/test_autochunk/benchmark_simple_evoformer.py
+0
-94
tests/test_autochunk/test_alphafold/test_alphafold_utils.py
tests/test_autochunk/test_alphafold/test_alphafold_utils.py
+122
-0
tests/test_autochunk/test_alphafold/test_evoformer_block.py
tests/test_autochunk/test_alphafold/test_evoformer_block.py
+95
-0
tests/test_autochunk/test_alphafold/test_evoformer_stack.py
tests/test_autochunk/test_alphafold/test_evoformer_stack.py
+90
-0
tests/test_autochunk/test_alphafold/test_extramsa_block.py
tests/test_autochunk/test_alphafold/test_extramsa_block.py
+96
-0
tests/test_autochunk/test_diffuser/test_diffuser_utils.py
tests/test_autochunk/test_diffuser/test_diffuser_utils.py
+120
-0
tests/test_autochunk/test_diffuser/test_unet.py
tests/test_autochunk/test_diffuser/test_unet.py
+70
-0
tests/test_autochunk/test_evoformer_codegen.py
tests/test_autochunk/test_evoformer_codegen.py
+0
-163
tests/test_autochunk/test_evoformer_stack_codegen.py
tests/test_autochunk/test_evoformer_stack_codegen.py
+0
-163
tests/test_autochunk/test_extramsa_codegen.py
tests/test_autochunk/test_extramsa_codegen.py
+0
-164
tests/test_autochunk/test_simple_evoformer_search.py
tests/test_autochunk/test_simple_evoformer_search.py
+0
-97
tests/test_autochunk/test_transformer/test_autochunk_gpt.py
tests/test_autochunk/test_transformer/test_autochunk_gpt.py
+65
-0
tests/test_autochunk/test_transformer/test_transformer_utils.py
...test_autochunk/test_transformer/test_transformer_utils.py
+123
-0
No files found.
colossalai/autochunk/autochunk_codegen.py
View file @
63199c66
...
...
@@ -3,9 +3,12 @@ from typing import Any, Dict, Iterable, List, Tuple
import
torch
import
colossalai
from
colossalai.fx._compatibility
import
is_compatible_with_meta
from
colossalai.fx.codegen.activation_checkpoint_codegen
import
CODEGEN_AVAILABLE
if
CODEGEN_AVAILABLE
:
AUTOCHUNK_AVAILABLE
=
CODEGEN_AVAILABLE
and
is_compatible_with_meta
()
if
AUTOCHUNK_AVAILABLE
:
from
torch.fx.graph
import
(
CodeGen
,
PythonCode
,
...
...
@@ -272,7 +275,7 @@ def emit_code_with_chunk(
node_idx
+=
1
if
CODEGEN
_AVAILABLE
:
if
AUTOCHUNK
_AVAILABLE
:
class
AutoChunkCodeGen
(
CodeGen
):
...
...
colossalai/autochunk/search_chunk.py
View file @
63199c66
...
...
@@ -8,7 +8,13 @@ from .reorder_graph import ReorderGraph
from
.select_chunk
import
SelectChunk
from
.trace_flow
import
TraceFlow
from
.trace_indice
import
TraceIndice
from
.utils
import
get_logger
,
get_node_shape
,
is_non_compute_node
,
is_non_compute_node_except_placeholder
from
.utils
import
(
find_chunk_compute_input_and_output_nodes
,
get_logger
,
get_node_shape
,
is_non_compute_node
,
is_non_compute_node_except_placeholder
,
)
class
SearchChunk
(
object
):
...
...
@@ -114,6 +120,12 @@ class SearchChunk(object):
chunk_region_start (int)
chunk_region_end (int)
"""
# check if peak node already in chunkinfo
if
chunk_regions
is
not
None
:
for
i
in
chunk_regions
:
if
i
[
"region"
][
0
]
<
peak_node_idx
<=
i
[
"region"
][
1
]:
return
None
free_vars
=
self
.
_get_free_var_idx
()
free_var_num
=
len
(
free_vars
)
active_node_num
=
[
len
(
i
)
for
i
in
active_node
]
...
...
@@ -152,55 +164,6 @@ class SearchChunk(object):
chunk_region_end
=
region
[
0
]
-
1
return
chunk_region_start
,
chunk_region_end
def
_find_chunk_info
(
self
,
input_trace
,
output_trace
,
start_idx
,
end_idx
)
->
List
:
"""
Find chunk info for a region.
We are given the region start and region end, and need to find out all chunk info for it.
We first loop every dim of start node and end node, to see if we can find dim pair,
which is linked in a flow and not computed.
If found, we then search flow in the whole region to find out all chunk infos.
Args:
input_trace (List): node's input trace in region
output_trace (List): node's output trace in region
start_idx (int): region start node index
end_idx (int): region end node index
Returns:
chunk_infos: possible regions found
"""
start_traces
=
input_trace
[
start_idx
]
end_trace
=
output_trace
[
end_idx
]
end_node
=
self
.
trace_indice
.
node_list
[
end_idx
]
chunk_infos
=
[]
for
end_dim
,
_
in
enumerate
(
end_trace
[
"indice"
]):
if
len
(
start_traces
)
>
1
:
continue
for
start_node
,
start_trace
in
start_traces
.
items
():
for
start_dim
,
_
in
enumerate
(
start_trace
[
"indice"
]):
# dim size cannot be 1
if
(
get_node_shape
(
end_node
)[
end_dim
]
==
1
or
get_node_shape
(
start_node
)[
start_dim
]
==
1
):
continue
# must have users
if
len
(
end_node
.
users
)
==
0
:
continue
# check index source align
if
not
self
.
trace_flow
.
check_index_source
(
start_dim
,
start_node
,
start_idx
,
end_dim
,
end_node
):
continue
# check index copmute
if
not
self
.
trace_flow
.
check_index_compute
(
start_idx
,
end_dim
,
end_node
,
end_idx
):
continue
# flow search
chunk_info
=
self
.
trace_flow
.
flow_search
(
start_idx
,
start_dim
,
end_idx
,
end_dim
)
if
chunk_info
is
None
:
continue
# check index copmute
if
not
self
.
trace_flow
.
check_index_duplicate
(
chunk_info
):
continue
chunk_infos
.
append
(
chunk_info
)
return
chunk_infos
def
_search_possible_chunk_regions
(
self
,
max_chunk_region
:
Tuple
,
peak_node
:
Node
)
->
List
:
"""
Search every possible region within the max chunk region.
...
...
@@ -228,9 +191,8 @@ class SearchChunk(object):
if
is_non_compute_node
(
self
.
trace_indice
.
node_list
[
start_idx
])
or
is_non_compute_node
(
self
.
trace_indice
.
node_list
[
end_idx
]):
continue
# select free dim
chunk_info
=
self
.
_
find_chunk_info
(
input_trace
,
output_trace
,
start_idx
,
end_idx
)
chunk_info
=
self
.
trace_flow
.
find_chunk_info
(
input_trace
,
output_trace
,
start_idx
,
end_idx
)
if
len
(
chunk_info
)
>
0
:
possible_chunk_region
.
extend
(
chunk_info
)
return
possible_chunk_region
...
...
colossalai/autochunk/select_chunk.py
View file @
63199c66
...
...
@@ -5,6 +5,7 @@ from .utils import is_non_compute_node
class
SelectChunk
(
object
):
def
__init__
(
self
,
trace_indice
:
TraceIndice
,
...
...
@@ -21,9 +22,7 @@ class SelectChunk(object):
else
:
self
.
stratge
=
"min_memory"
def
_select_best_chunk_region
(
self
,
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
):
def
_select_best_chunk_region
(
self
,
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
):
if
self
.
stratge
==
"min_memory"
:
best_region
=
self
.
_select_min_memory_chunk_region
(
possible_chunk_regions
,
...
...
@@ -44,9 +43,8 @@ class SelectChunk(object):
raise
RuntimeError
()
return
best_region
def
_select_fit_memory_chunk_region
(
self
,
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
):
def
_select_fit_memory_chunk_region
(
self
,
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
):
# stop chunk if max memory satisfy memory limit
if
max
(
mem_peak
)
<
self
.
max_memory
:
return
None
...
...
@@ -63,33 +61,26 @@ class SelectChunk(object):
if
len
(
possible_chunk_regions
)
==
0
:
return
None
max_possible_chunk_region
=
(
min
([
i
[
"region"
][
0
]
for
i
in
possible_chunk_regions
]),
max
([
i
[
"region"
][
1
]
for
i
in
possible_chunk_regions
]))
# get mem for chunk region
regions_dict
=
[]
for
region
in
possible_chunk_regions
:
cur_region
=
region
.
copy
()
cur_node_list
,
cur_region
=
self
.
reorder_graph
.
tmp_reorder
(
self
.
trace_indice
.
node_list
,
cur_region
)
cur_node_list
,
cur_region
=
self
.
reorder_graph
.
tmp_reorder
(
self
.
trace_indice
.
node_list
,
cur_region
)
cur_chunk_infos
=
chunk_infos
+
[
cur_region
]
cur_mem_peak
=
self
.
estimate_memory
.
estimate_chunk_inference_mem
(
cur_node_list
,
cur_chunk_infos
)[
0
]
cur_chunk_region_peak
=
cur_mem_peak
[
max_chunk_region
[
0
]
:
max_chunk_region
[
1
]
+
1
]
cur_mem_peak
=
self
.
estimate_memory
.
estimate_chunk_inference_mem
(
cur_node_list
,
cur_chunk_infos
)[
0
]
cur_chunk_region_peak
=
cur_mem_peak
[
max_possible_chunk_region
[
0
]:
max_possible_chunk_region
[
1
]
+
1
]
cur_chunk_region_max_peak
=
max
(
cur_chunk_region_peak
)
if
cur_chunk_region_max_peak
<
self
.
max_memory
:
regions_dict
.
append
(
{
regions_dict
.
append
({
"chunk_info"
:
region
,
"chunk_max_mem"
:
cur_chunk_region_max_peak
,
"chunk_len"
:
self
.
_get_compute_node_num
(
region
[
"region"
][
0
],
region
[
"region"
][
1
]
),
"chunk_len"
:
self
.
_get_compute_node_num
(
region
[
"region"
][
0
],
region
[
"region"
][
1
]),
"reorder_chunk_info"
:
cur_region
,
"reorder_node_list"
:
cur_node_list
,
}
)
})
# no region found
if
len
(
regions_dict
)
==
0
:
raise
RuntimeError
(
"Search failed. Try a larger memory threshold."
)
...
...
@@ -113,20 +104,13 @@ class SelectChunk(object):
chunk_size
*=
2
reorder_chunk_info
[
"chunk_size"
]
=
chunk_size
cur_chunk_infos
=
chunk_infos
+
[
reorder_chunk_info
]
cur_mem_peak
=
self
.
estimate_memory
.
estimate_chunk_inference_mem
(
chunk_region_dict
[
"reorder_node_list"
],
cur_chunk_infos
)[
0
]
cur_chunk_max_mem
=
max
(
cur_mem_peak
[
reorder_chunk_info
[
"region"
][
0
]
:
reorder_chunk_info
[
"region"
][
1
]
+
1
]
)
cur_mem_peak
=
self
.
estimate_memory
.
estimate_chunk_inference_mem
(
chunk_region_dict
[
"reorder_node_list"
],
cur_chunk_infos
)[
0
]
cur_chunk_max_mem
=
max
(
cur_mem_peak
[
reorder_chunk_info
[
"region"
][
0
]:
reorder_chunk_info
[
"region"
][
1
]
+
1
])
# search exact size
chunk_info
=
chunk_region_dict
[
"chunk_info"
]
chunk_info
[
"chunk_size"
]
=
self
.
_chunk_size_binary_search
(
chunk_size
//
2
,
chunk_size
,
chunk_region_dict
,
chunk_infos
)
chunk_info
[
"chunk_size"
]
=
self
.
_chunk_size_binary_search
(
chunk_size
//
2
,
chunk_size
,
chunk_region_dict
,
chunk_infos
)
return
chunk_info
def
_chunk_size_binary_search
(
self
,
left
,
right
,
chunk_region_dict
,
chunk_infos
):
...
...
@@ -139,12 +123,9 @@ class SelectChunk(object):
mid
=
int
((
left
+
right
)
/
2
+
0.5
)
chunk_info
[
"chunk_size"
]
=
mid
cur_chunk_infos
=
chunk_infos
+
[
chunk_info
]
cur_mem_peak
=
self
.
estimate_memory
.
estimate_chunk_inference_mem
(
chunk_region_dict
[
"reorder_node_list"
],
cur_chunk_infos
)[
0
]
cur_chunk_max_mem
=
max
(
cur_mem_peak
[
chunk_info
[
"region"
][
0
]
:
chunk_info
[
"region"
][
1
]
+
1
]
)
cur_mem_peak
=
self
.
estimate_memory
.
estimate_chunk_inference_mem
(
chunk_region_dict
[
"reorder_node_list"
],
cur_chunk_infos
)[
0
]
cur_chunk_max_mem
=
max
(
cur_mem_peak
[
chunk_info
[
"region"
][
0
]:
chunk_info
[
"region"
][
1
]
+
1
])
if
cur_chunk_max_mem
>=
self
.
max_memory
:
right
=
mid
-
gap
else
:
...
...
@@ -153,14 +134,13 @@ class SelectChunk(object):
def
_get_compute_node_num
(
self
,
start
,
end
):
count
=
0
for
i
in
self
.
trace_indice
.
node_list
[
start
:
end
+
1
]:
for
i
in
self
.
trace_indice
.
node_list
[
start
:
end
+
1
]:
if
not
is_non_compute_node
(
i
):
count
+=
1
return
count
def
_select_min_memory_chunk_region
(
self
,
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
):
def
_select_min_memory_chunk_region
(
self
,
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
):
# remove illegal regions
illegal_regions
=
[]
for
i
in
possible_chunk_regions
:
...
...
@@ -173,37 +153,31 @@ class SelectChunk(object):
if
len
(
possible_chunk_regions
)
==
0
:
return
None
# get max possible chunk region
max_possible_chunk_region
=
(
min
([
i
[
"region"
][
0
]
for
i
in
possible_chunk_regions
]),
max
([
i
[
"region"
][
1
]
for
i
in
possible_chunk_regions
]))
# get mem for chunk region
regions_dict
=
[]
regions_dict
_list
=
[]
for
region
in
possible_chunk_regions
:
cur_region
=
region
.
copy
()
cur_node_list
,
cur_region
=
self
.
reorder_graph
.
tmp_reorder
(
self
.
trace_indice
.
node_list
,
cur_region
)
cur_node_list
,
cur_region
=
self
.
reorder_graph
.
tmp_reorder
(
self
.
trace_indice
.
node_list
,
cur_region
)
cur_chunk_infos
=
chunk_infos
+
[
cur_region
]
cur_mem_peak
=
self
.
estimate_memory
.
estimate_chunk_inference_mem
(
cur_node_list
,
cur_chunk_infos
)[
0
]
cur_chunk_region_peak
=
cur_mem_peak
[
max_chunk_region
[
0
]
:
max_chunk_region
[
1
]
+
1
]
cur_mem_peak
=
self
.
estimate_memory
.
estimate_chunk_inference_mem
(
cur_node_list
,
cur_chunk_infos
)[
0
]
cur_chunk_region_peak
=
cur_mem_peak
[
max_possible_chunk_region
[
0
]:
max_possible_chunk_region
[
1
]
+
1
]
cur_chunk_region_max_peak
=
max
(
cur_chunk_region_peak
)
regions_dict
.
append
(
{
regions_dict_list
.
append
({
"chunk_info"
:
region
,
"chunk_max_mem"
:
cur_chunk_region_max_peak
,
"chunk_len"
:
self
.
_get_compute_node_num
(
region
[
"region"
][
0
],
region
[
"region"
][
1
]
),
"chunk_len"
:
self
.
_get_compute_node_num
(
region
[
"region"
][
0
],
region
[
"region"
][
1
]),
"reorder_chunk_info"
:
cur_region
,
"reorder_node_list"
:
cur_node_list
,
}
)
})
# select the min mem
chunk_max_mem
=
[
i
[
"chunk_max_mem"
]
for
i
in
regions_dict
]
chunk_max_mem
=
[
i
[
"chunk_max_mem"
]
for
i
in
regions_dict
_list
]
best_region_idx
=
chunk_max_mem
.
index
(
min
(
chunk_max_mem
))
best_region
=
regions_dict
[
best_region_idx
][
"chunk_info"
]
best_region
=
regions_dict
_list
[
best_region_idx
][
"chunk_info"
]
if
best_region
is
not
None
:
best_region
[
"chunk_size"
]
=
1
return
best_region
...
...
@@ -216,9 +190,7 @@ class SelectChunk(object):
return
False
for
i
in
chunk_infos
:
region
=
i
[
"region"
]
if
not
(
(
chunk_region_start
>
region
[
1
]
and
chunk_region_end
>
region
[
1
])
or
(
chunk_region_start
<
region
[
0
]
and
chunk_region_end
<
region
[
0
])
):
if
not
((
chunk_region_start
>
region
[
1
]
and
chunk_region_end
>
region
[
1
])
or
(
chunk_region_start
<
region
[
0
]
and
chunk_region_end
<
region
[
0
])):
return
False
return
True
colossalai/autochunk/trace_flow.py
View file @
63199c66
...
...
@@ -8,9 +8,9 @@ from .utils import (
find_chunk_compute_input_and_output_nodes
,
find_idx_by_name
,
flat_list
,
get_node_name
,
get_node_shape
,
is_non_compute_node
,
is_non_compute_node_except_placeholder
,
)
...
...
@@ -79,43 +79,6 @@ class TraceFlow(object):
return
node_dim
return
None
def
check_index_duplicate
(
self
,
chunk_infos
,
return_dim
=
False
):
input_dim_after_node
=
{}
for
input_node_idx
,
input_node
in
enumerate
(
chunk_infos
[
"inputs"
]):
for
k
,
v
in
chunk_infos
[
"inputs_dim"
][
input_node_idx
].
items
():
inherit_dim
=
self
.
_find_inherit_dim
(
input_node
,
v
,
self
.
trace_indice
.
node_list
[
k
])
if
inherit_dim
:
input_dim_after_node
[
k
]
=
inherit_dim
for
node
in
self
.
trace_indice
.
node_list
[
chunk_infos
[
"region"
][
0
]:
chunk_infos
[
"region"
][
1
]
+
1
]:
if
is_non_compute_node_except_placeholder
(
node
):
continue
count
=
0
duplicate_dims
=
[]
node_trace_source
=
self
.
trace_indice
.
_find_source_trace_from_node
(
node
)
for
node_dim
in
range
(
len
(
get_node_shape
(
node
))):
duplicate_dim
=
[]
duplicate_flag
=
False
dim_source
=
node_trace_source
[
node_dim
]
for
k
,
v
in
dim_source
.
items
():
if
chunk_infos
[
"region"
][
0
]
<=
k
<=
chunk_infos
[
"region"
][
1
]:
if
k
in
input_dim_after_node
and
input_dim_after_node
[
k
]
in
v
:
duplicate_flag
=
True
duplicate_dim
.
append
((
k
,
v
))
duplicate_dims
.
append
(
duplicate_dim
)
if
duplicate_flag
:
count
+=
1
if
count
>
1
:
if
return_dim
:
return
False
,
duplicate_dims
else
:
return
False
if
return_dim
:
return
True
,
None
else
:
return
True
def
_assgin_single_node_flow
(
self
,
arg_node
:
Node
,
...
...
@@ -225,9 +188,12 @@ class TraceFlow(object):
if
flow_flag
==
False
:
return
None
if
len
(
arg_list
)
==
2
:
if
any
(
i
in
cur_node
.
name
for
i
in
[
"add"
,
"mul"
,
"truediv"
]):
if
len
(
arg_list
)
>=
2
:
# need to mark fix dim
if
any
(
i
==
get_node_name
(
cur_node
)
for
i
in
[
"add"
,
"mul"
,
"truediv"
,
"sub"
,
"where"
]):
for
arg
in
arg_list
:
if
get_node_shape
(
arg
)
is
None
:
continue
if
not
(
start_idx
<=
find_idx_by_name
(
arg
.
name
,
self
.
trace_indice
.
node_list
)
<
end_idx
):
continue
arg_chunk_dim
=
all_node_info
[
arg
][
"chunk_dim"
]
...
...
@@ -240,9 +206,8 @@ class TraceFlow(object):
return
None
if
i
not
in
arg_fix_dim
:
arg_fix_dim
.
append
(
i
)
elif
"einsum"
in
cur_node
.
name
:
pass
elif
"matmul"
in
cur_node
.
name
:
elif
any
(
i
==
get_node_name
(
cur_node
)
for
i
in
[
"einsum"
,
"matmul"
,
"view"
,
"to"
,
"getitem"
,
"tensor"
,
"type"
]):
pass
else
:
raise
NotImplementedError
()
...
...
@@ -426,7 +391,7 @@ class TraceFlow(object):
reshape_size
=
{}
chunk_shape
=
get_node_shape
(
chunk_info
[
"outputs"
][
0
])[
chunk_info
[
"outputs_dim"
]]
for
node
in
self
.
trace_indice
.
node_list
[
chunk_region
[
0
]:
chunk_region
[
1
]
+
1
]:
if
any
(
i
in
node
.
name
for
i
in
[
"reshape"
,
"view"
]):
if
any
(
i
==
get_
node
_
name
(
node
)
for
i
in
[
"reshape"
,
"view"
]):
reshape_args
=
flat_list
(
node
.
args
[
1
:])
chunk_dim
=
chunk_info
[
"node_chunk_dim"
][
node
][
"chunk_dim"
]
new_shape
=
""
...
...
@@ -443,3 +408,62 @@ class TraceFlow(object):
reshape_size
[
node
.
name
]
=
[
origin_shape
,
new_shape
]
chunk_info
[
"reshape_size"
]
=
reshape_size
return
chunk_info
def
find_chunk_info
(
self
,
input_trace
,
output_trace
,
start_idx
,
end_idx
)
->
List
:
"""
Find chunk info for a region.
We are given the region start and region end, and need to find out all chunk info for it.
We first loop every dim of start node and end node, to see if we can find dim pair,
which is linked in a flow and not computed.
If found, we then search flow in the whole region to find out all chunk infos.
Args:
input_trace (List): node's input trace in region
output_trace (List): node's output trace in region
start_idx (int): region start node index
end_idx (int): region end node index
Returns:
chunk_infos: possible regions found
"""
start_traces
=
input_trace
[
start_idx
]
if
len
(
start_traces
)
>
1
:
# TODO need to be removed
return
[]
end_trace
=
output_trace
[
end_idx
]
end_node
=
self
.
trace_indice
.
node_list
[
end_idx
]
chunk_infos
=
[]
for
end_dim
,
_
in
enumerate
(
end_trace
[
"indice"
]):
for
start_node
,
start_trace
in
start_traces
.
items
():
for
start_dim
,
_
in
enumerate
(
start_trace
[
"indice"
]):
if
not
self
.
_check_region_start_end
(
start_node
,
start_dim
,
start_idx
,
end_node
,
end_dim
,
end_idx
):
continue
# flow search
chunk_info
=
self
.
flow_search
(
start_idx
,
start_dim
,
end_idx
,
end_dim
)
if
chunk_info
is
None
:
continue
chunk_infos
.
append
(
chunk_info
)
return
chunk_infos
def
_check_region_start_end
(
self
,
start_node
:
Node
,
start_dim
:
int
,
start_idx
:
int
,
end_node
:
Node
,
end_dim
:
int
,
end_idx
:
int
)
->
bool
:
"""
check if region start and end is legal
"""
# dim cannot be None
if
(
get_node_shape
(
end_node
)
is
None
or
get_node_shape
(
start_node
)
is
None
):
return
False
# dim size cannot be 1
if
(
get_node_shape
(
end_node
)[
end_dim
]
==
1
or
get_node_shape
(
start_node
)[
start_dim
]
==
1
):
return
False
# must have users
if
len
(
end_node
.
users
)
==
0
:
return
False
# check index source align
if
not
self
.
check_index_source
(
start_dim
,
start_node
,
start_idx
,
end_dim
,
end_node
):
return
False
# check index copmute
if
not
self
.
check_index_compute
(
start_idx
,
end_dim
,
end_node
,
end_idx
):
return
False
return
True
colossalai/autochunk/trace_indice.py
View file @
63199c66
This diff is collapsed.
Click to expand it.
colossalai/autochunk/utils.py
View file @
63199c66
from
typing
import
Any
,
Callable
,
Dict
,
Iterable
,
List
,
Tuple
from
typing
import
Any
,
Callable
,
Dict
,
Iterable
,
List
,
Tuple
,
Union
from
torch.fx.node
import
Node
from
colossalai.logging
import
get_dist_logger
NON_COMPUTE_OP
=
[
"placeholder"
,
"get_attr"
,
"output"
]
NON_COMPUTE_NAME
=
[
"getattr"
,
"eq"
,
"_assert_is_none"
,
"_assert"
,
"finfo"
,
"size"
]
logger
=
get_dist_logger
()
def
get_logger
():
def
get_logger
()
->
Any
:
return
logger
...
...
@@ -37,7 +39,7 @@ def find_first_tensor_arg(node: Node) -> Node:
def
is_non_compute_node
(
node
:
Node
)
->
bool
:
if
any
(
i
in
node
.
op
for
i
in
[
"placeholder"
,
"get_attr"
,
"output"
]
)
or
any
(
i
in
node
.
name
for
i
in
[
"getattr"
]
):
if
any
(
i
==
node
.
op
for
i
in
NON_COMPUTE_OP
)
or
any
(
i
==
get_
node
_
name
(
node
)
for
i
in
NON_COMPUTE_NAME
):
return
True
if
"getitem"
in
node
.
name
:
node_args
=
flat_list
(
node
.
args
[
1
:])
...
...
@@ -64,33 +66,33 @@ def is_non_memory_node(node: Node) -> bool:
return
is_non_compute_node
(
node
)
def
is_non_compute_node_except_placeholder
(
node
)
:
def
is_non_compute_node_except_placeholder
(
node
:
Node
)
->
bool
:
if
"placeholder"
in
node
.
op
:
return
False
return
is_non_compute_node
(
node
)
def
is_non_compute_node_except_placeholder_output
(
node
)
:
def
is_non_compute_node_except_placeholder_output
(
node
:
Node
)
->
bool
:
if
"output"
in
node
.
op
:
return
False
return
is_non_compute_node_except_placeholder
(
node
)
def
find_idx_by_name
(
name
,
nodes_list
)
:
def
find_idx_by_name
(
name
:
str
,
nodes_list
:
List
)
->
int
:
for
idx
,
node
in
enumerate
(
nodes_list
):
if
node
.
name
==
name
:
return
idx
raise
RuntimeError
(
"name %s not found in node list"
%
name
)
def
delete_free_var_from_last_use
(
user_to_last_uses
)
:
def
delete_free_var_from_last_use
(
user_to_last_uses
:
Dict
)
->
None
:
for
key
,
value
in
user_to_last_uses
.
items
():
for
n
in
value
:
if
n
.
op
==
"placeholder"
:
user_to_last_uses
[
key
].
remove
(
n
)
def
find_chunk_all_input_nodes
(
nodes
:
List
[
Node
]):
def
find_chunk_all_input_nodes
(
nodes
:
List
[
Node
])
->
List
:
"""
Find non-compute input and output node names.
input nodes are nodes used in the list
...
...
@@ -104,7 +106,7 @@ def find_chunk_all_input_nodes(nodes: List[Node]):
return
input_nodes
def
find_chunk_compute_input_and_output_nodes
(
nodes
:
List
[
Node
]):
def
find_chunk_compute_input_and_output_nodes
(
nodes
:
List
[
Node
])
->
Union
[
List
,
List
]
:
"""
Find non-compute input and output node names.
input nodes are nodes used in the list
...
...
@@ -130,3 +132,33 @@ def find_chunk_compute_input_and_output_nodes(nodes: List[Node]):
output_nodes
.
append
(
node
)
return
input_nodes
,
output_nodes
def
get_module_node_name
(
node
:
Node
)
->
str
:
"""
get module class name
"""
node_targets
=
node
.
target
.
split
(
"."
)
module
=
node
.
graph
.
owning_module
for
i
in
node_targets
:
module
=
getattr
(
module
,
i
)
module_name
=
str
(
module
.
__class__
).
split
(
"."
)[
-
1
][:
-
2
]
module_name
=
module_name
.
lower
()
return
module_name
def
get_node_name
(
node
:
Node
)
->
str
:
"""
get node name
"""
node_name
=
node
.
name
if
"_"
in
node_name
:
for
i
in
range
(
len
(
node_name
)
-
1
,
-
1
,
-
1
):
if
node_name
[
i
]
==
"_"
:
node_name
=
node_name
[:
i
]
break
elif
node_name
[
i
]
in
[
"1"
,
"2"
,
"3"
,
"4"
,
"5"
,
"6"
,
"7"
,
"8"
,
"9"
,
"0"
]:
continue
else
:
break
return
node_name
tests/test_autochunk/benchmark_simple_evoformer.py
deleted
100644 → 0
View file @
6e0faa70
import
time
import
torch
import
torch.fx
from
simple_evoformer
import
base_evoformer
,
openfold_evoformer
from
colossalai.autochunk.autochunk_codegen
import
AutoChunkCodeGen
from
colossalai.fx
import
ColoTracer
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.fx.profiler
import
MetaTensor
def
_benchmark_evoformer
(
model
:
torch
.
nn
.
Module
,
node
,
pair
,
title
,
chunk_size
=
None
):
torch
.
cuda
.
reset_peak_memory_stats
()
now_mem
=
torch
.
cuda
.
memory_allocated
()
/
1024
**
2
loop
=
3
with
torch
.
no_grad
():
for
_
in
range
(
loop
//
2
+
1
):
if
chunk_size
:
model
(
node
,
pair
,
chunk_size
)
else
:
model
(
node
,
pair
)
torch
.
cuda
.
synchronize
()
time1
=
time
.
time
()
for
_
in
range
(
loop
):
if
chunk_size
:
model
(
node
,
pair
,
chunk_size
)
else
:
model
(
node
,
pair
)
torch
.
cuda
.
synchronize
()
time2
=
time
.
time
()
new_max_mem
=
torch
.
cuda
.
max_memory_allocated
()
/
1024
**
2
print
(
"%s: time %.4fs, mem %dMB"
%
(
title
,
(
time2
-
time1
)
/
loop
,
new_max_mem
-
now_mem
))
def
_build_autochunk
(
model
,
max_memory
,
node
,
pair
):
# trace the module and replace codegen
graph
=
ColoTracer
().
trace
(
model
,
meta_args
=
{
"node"
:
node
.
to
(
torch
.
device
(
"meta"
)),
"pair"
:
pair
.
to
(
torch
.
device
(
"meta"
)),
},
)
gm_prop
=
torch
.
fx
.
symbolic_trace
(
model
)
# must use symbolic_trace
interp
=
MetaInfoProp
(
gm_prop
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
))
# now run it twice to get meta info in graph module, not necessary
gm
=
torch
.
fx
.
GraphModule
(
model
,
graph
)
interp
=
MetaInfoProp
(
gm
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
))
# set code_gen
codegen
=
AutoChunkCodeGen
(
gm_prop
,
max_memory
,
print_mem
=
False
)
graph
.
set_codegen
(
codegen
)
gm
=
ColoGraphModule
(
model
,
graph
)
gm
.
recompile
()
# print
# code = graph.python_code("self").src
# print(code)
return
gm
def
benchmark_evoformer
():
# init data and model
msa_len
=
128
pair_len
=
256
node
=
torch
.
randn
(
1
,
msa_len
,
pair_len
,
256
).
cuda
()
pair
=
torch
.
randn
(
1
,
pair_len
,
pair_len
,
128
).
cuda
()
model
=
base_evoformer
().
cuda
()
# build autochunk model
# max_memory = 1000 # MB, fit memory mode
max_memory
=
None
# min memory mode
autochunk
=
_build_autochunk
(
base_evoformer
().
cuda
(),
max_memory
,
node
,
pair
)
# build openfold
chunk_size
=
64
openfold
=
openfold_evoformer
().
cuda
()
# benchmark
_benchmark_evoformer
(
model
,
node
,
pair
,
"base"
)
_benchmark_evoformer
(
openfold
,
node
,
pair
,
"openfold"
,
chunk_size
=
chunk_size
)
_benchmark_evoformer
(
autochunk
,
node
,
pair
,
"autochunk"
)
if
__name__
==
"__main__"
:
benchmark_evoformer
()
tests/test_autochunk/test_alphafold/test_alphafold_utils.py
0 → 100644
View file @
63199c66
from
typing
import
Any
,
Dict
,
List
import
torch
import
torch.fx
import
colossalai
from
colossalai.autochunk.autochunk_codegen
import
AUTOCHUNK_AVAILABLE
from
colossalai.autochunk.utils
import
flat_list
from
colossalai.core
import
global_context
as
gpc
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.utils
import
free_port
if
AUTOCHUNK_AVAILABLE
:
from
colossalai.autochunk.autochunk_codegen
import
AutoChunkCodeGen
from
colossalai.fx.profiler
import
MetaTensor
from
colossalai.fx.tracer.experimental
import
ColoTracer
,
symbolic_trace
def
assert_codegen_run
(
model
:
Any
,
meta_args
:
List
,
concrete_args
:
List
=
None
,
max_memory
:
int
=
None
,
print_mem
:
bool
=
False
,
print_progress
:
bool
=
False
,
print_code
:
bool
=
False
,
)
->
List
[
Dict
]:
if
concrete_args
is
None
:
concrete_args
=
[]
# trace the meta graph and setup codegen
meta_graph
=
symbolic_trace
(
model
,
meta_args
=
{
k
:
v
.
to
(
torch
.
device
(
"meta"
))
for
k
,
v
in
meta_args
},
concrete_args
=
{
k
:
v
for
k
,
v
in
concrete_args
},
)
interp
=
MetaInfoProp
(
meta_graph
)
meta_tensors
=
[
MetaTensor
(
i
[
1
],
fake_device
=
"cuda:0"
)
for
i
in
meta_args
]
+
[
i
[
1
]
for
i
in
concrete_args
]
interp
.
propagate
(
*
meta_tensors
)
codegen
=
AutoChunkCodeGen
(
meta_graph
,
max_memory
=
max_memory
,
print_mem
=
print_mem
,
print_progress
=
print_progress
,
)
chunks
=
codegen
.
chunk_infos
# trace and recompile
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
graph
=
ColoTracer
().
trace
(
model
,
meta_args
=
{
k
:
v
.
to
(
torch
.
device
(
"meta"
))
for
k
,
v
in
meta_args
},
concrete_args
=
{
k
:
v
for
k
,
v
in
concrete_args
},
)
graph
.
set_codegen
(
codegen
)
gm
=
ColoGraphModule
(
model
,
graph
,
ckpt_codegen
=
False
)
gm
.
recompile
()
# assert chunk in code
code
=
graph
.
python_code
(
"self"
).
src
if
print_code
:
print
(
code
)
assert
"chunk_result = None; chunk_size = None;"
in
code
# assert result
inputs
=
[
i
[
1
]
for
i
in
meta_args
]
+
[
i
[
1
]
for
i
in
concrete_args
]
model
.
cuda
()
with
torch
.
no_grad
():
out_gm
=
gm
(
*
inputs
)
out_model
=
model
(
*
inputs
)
out_gm
=
flat_list
(
out_gm
)
out_model
=
flat_list
(
out_model
)
for
out_gm_i
,
out_model_i
in
zip
(
out_gm
,
out_model
):
assert
torch
.
allclose
(
out_gm_i
,
out_model_i
,
atol
=
1e-4
),
"fx_out doesn't comply with original output, diff is %.2e"
%
torch
.
mean
(
torch
.
abs
(
out_gm_i
-
out_model_i
))
return
chunks
def
run_test
(
rank
:
int
,
data_args
:
tuple
,
max_memory
:
int
,
get_model
:
Any
,
get_data
:
Any
,
print_code
:
bool
,
print_mem
:
bool
,
print_progress
:
bool
,
get_chunk_target
:
Any
=
None
,
)
->
None
:
# launch colossalai
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
1
,
host
=
"localhost"
,
port
=
free_port
(),
backend
=
"nccl"
,
)
# build model and input
model
=
get_model
()
meta_args
,
concrete_args
=
get_data
(
*
data_args
)
chunks
=
assert_codegen_run
(
model
,
meta_args
=
meta_args
,
concrete_args
=
concrete_args
,
max_memory
=
max_memory
,
print_code
=
print_code
,
print_mem
=
print_mem
,
print_progress
=
print_progress
,
)
if
get_chunk_target
is
not
None
:
chunk_found
=
[
i
[
"region"
]
for
i
in
chunks
]
chunk_target
=
get_chunk_target
()[
max_memory
]
assert
chunk_found
==
chunk_target
,
"found regions %s doesn't equal target regions %s"
%
(
str
(
chunk_found
),
str
(
chunk_target
),
)
tests/test_autochunk/test_alphafold/test_evoformer_block.py
0 → 100644
View file @
63199c66
from
functools
import
partial
from
typing
import
Dict
,
List
,
Tuple
import
pytest
import
torch
import
torch.fx
import
torch.multiprocessing
as
mp
try
:
from
fastfold.model.nn.evoformer
import
EvoformerBlock
HAS_REPO
=
True
except
:
HAS_REPO
=
False
from
test_alphafold_utils
import
run_test
from
colossalai.autochunk.autochunk_codegen
import
AUTOCHUNK_AVAILABLE
def
get_model
():
model
=
EvoformerBlock
(
c_m
=
256
,
c_z
=
128
,
c_hidden_msa_att
=
32
,
c_hidden_opm
=
32
,
c_hidden_mul
=
128
,
c_hidden_pair_att
=
32
,
no_heads_msa
=
8
,
no_heads_pair
=
4
,
transition_n
=
4
,
msa_dropout
=
0.15
,
pair_dropout
=
0.15
,
inf
=
1e4
,
eps
=
1e-4
,
is_multimer
=
False
,
).
eval
().
cuda
()
return
model
def
get_data
(
msa_len
:
int
,
pair_len
:
int
)
->
Tuple
[
List
,
List
]:
node
=
torch
.
randn
(
1
,
msa_len
,
pair_len
,
256
).
cuda
()
node_mask
=
torch
.
randn
(
1
,
msa_len
,
pair_len
).
cuda
()
pair
=
torch
.
randn
(
1
,
pair_len
,
pair_len
,
128
).
cuda
()
pair_mask
=
torch
.
randn
(
1
,
pair_len
,
pair_len
).
cuda
()
meta_args
=
[
(
"m"
,
node
),
(
"z"
,
pair
),
(
"msa_mask"
,
node_mask
),
(
"pair_mask"
,
pair_mask
),
]
concrete_args
=
[(
"chunk_size"
,
None
),
(
"_mask_trans"
,
True
)]
return
meta_args
,
concrete_args
def
get_chunk_target
()
->
Dict
:
return
{
None
:
[(
118
,
123
),
(
219
,
237
),
(
264
,
289
),
(
302
,
309
),
(
97
,
104
),
(
144
,
152
),
(
185
,
193
),
(
241
,
242
),
(
21
,
46
)],
20
:
[(
118
,
123
),
(
230
,
237
),
(
275
,
282
),
(
305
,
306
),
(
100
,
101
),
(
32
,
39
),
(
73
,
79
)],
24
:
[(
118
,
123
)],
}
@
pytest
.
mark
.
skipif
(
not
(
AUTOCHUNK_AVAILABLE
and
HAS_REPO
),
reason
=
"torch version is lower than 1.12.0"
,
)
@
pytest
.
mark
.
parametrize
(
"max_memory"
,
[
None
,
20
,
24
])
@
pytest
.
mark
.
parametrize
(
"data_args"
,
[(
32
,
64
)])
# (msa_len, pair_len)
def
test_evoformer_block
(
data_args
,
max_memory
):
run_func
=
partial
(
run_test
,
data_args
=
data_args
,
max_memory
=
max_memory
,
get_model
=
get_model
,
get_data
=
get_data
,
get_chunk_target
=
get_chunk_target
,
print_code
=
False
,
print_mem
=
False
,
print_progress
=
False
,
)
mp
.
spawn
(
run_func
,
nprocs
=
1
)
if
__name__
==
"__main__"
:
run_test
(
rank
=
0
,
data_args
=
(
32
,
64
),
max_memory
=
20
,
get_model
=
get_model
,
get_data
=
get_data
,
print_code
=
False
,
print_mem
=
False
,
print_progress
=
False
,
)
tests/test_autochunk/test_alphafold/test_evoformer_stack.py
0 → 100644
View file @
63199c66
from
functools
import
partial
from
typing
import
List
,
Tuple
import
pytest
import
torch
import
torch.fx
import
torch.multiprocessing
as
mp
try
:
from
fastfold.model.nn.evoformer
import
EvoformerStack
HAS_REPO
=
True
except
:
HAS_REPO
=
False
from
test_alphafold_utils
import
run_test
from
colossalai.autochunk.autochunk_codegen
import
AUTOCHUNK_AVAILABLE
def
get_model
():
model
=
EvoformerStack
(
c_m
=
256
,
c_z
=
128
,
c_hidden_msa_att
=
32
,
c_hidden_opm
=
32
,
c_hidden_mul
=
128
,
c_hidden_pair_att
=
32
,
c_s
=
384
,
no_heads_msa
=
8
,
no_heads_pair
=
4
,
no_blocks
=
2
,
# 48
transition_n
=
4
,
msa_dropout
=
0.15
,
pair_dropout
=
0.25
,
blocks_per_ckpt
=
None
,
inf
=
1000000000.0
,
eps
=
1e-08
,
clear_cache_between_blocks
=
False
,
is_multimer
=
False
,
).
eval
().
cuda
()
return
model
def
get_data
(
msa_len
:
int
,
pair_len
:
int
)
->
Tuple
[
List
,
List
]:
node
=
torch
.
randn
(
1
,
msa_len
,
pair_len
,
256
).
cuda
()
node_mask
=
torch
.
randn
(
1
,
msa_len
,
pair_len
).
cuda
()
pair
=
torch
.
randn
(
1
,
pair_len
,
pair_len
,
128
).
cuda
()
pair_mask
=
torch
.
randn
(
1
,
pair_len
,
pair_len
).
cuda
()
meta_args
=
[
(
"m"
,
node
),
(
"z"
,
pair
),
(
"msa_mask"
,
node_mask
),
(
"pair_mask"
,
pair_mask
),
]
concrete_args
=
[(
"chunk_size"
,
None
),
(
"_mask_trans"
,
True
)]
return
meta_args
,
concrete_args
@
pytest
.
mark
.
skipif
(
not
(
AUTOCHUNK_AVAILABLE
and
HAS_REPO
),
reason
=
"torch version is lower than 1.12.0"
,
)
@
pytest
.
mark
.
parametrize
(
"max_memory"
,
[
None
,
20
,
24
])
@
pytest
.
mark
.
parametrize
(
"data_args"
,
[(
32
,
64
)])
# (msa_len, pair_len)
def
test_evoformer_stack
(
data_args
,
max_memory
):
run_func
=
partial
(
run_test
,
data_args
=
data_args
,
max_memory
=
max_memory
,
get_model
=
get_model
,
get_data
=
get_data
,
print_code
=
False
,
print_mem
=
False
,
print_progress
=
False
,
)
mp
.
spawn
(
run_func
,
nprocs
=
1
)
if
__name__
==
"__main__"
:
run_test
(
rank
=
0
,
data_args
=
(
32
,
64
),
max_memory
=
20
,
get_model
=
get_model
,
get_data
=
get_data
,
print_code
=
False
,
print_mem
=
False
,
print_progress
=
False
,
)
tests/test_autochunk/test_alphafold/test_extramsa_block.py
0 → 100644
View file @
63199c66
from
functools
import
partial
from
typing
import
Dict
,
List
,
Tuple
import
pytest
import
torch
import
torch.fx
import
torch.multiprocessing
as
mp
try
:
from
fastfold.model.nn.evoformer
import
ExtraMSABlock
HAS_REPO
=
True
except
:
HAS_REPO
=
False
from
test_alphafold_utils
import
run_test
from
colossalai.autochunk.autochunk_codegen
import
AUTOCHUNK_AVAILABLE
def
get_model
():
model
=
ExtraMSABlock
(
c_m
=
256
,
c_z
=
128
,
c_hidden_msa_att
=
32
,
c_hidden_opm
=
32
,
c_hidden_mul
=
128
,
c_hidden_pair_att
=
32
,
no_heads_msa
=
8
,
no_heads_pair
=
4
,
transition_n
=
4
,
msa_dropout
=
0.15
,
pair_dropout
=
0.15
,
inf
=
1e4
,
eps
=
1e-4
,
ckpt
=
False
,
is_multimer
=
False
,
).
eval
().
cuda
()
return
model
def
get_data
(
msa_len
:
int
,
pair_len
:
int
)
->
Tuple
[
List
,
List
]:
node
=
torch
.
randn
(
1
,
msa_len
,
pair_len
,
256
).
cuda
()
node_mask
=
torch
.
randn
(
1
,
msa_len
,
pair_len
).
cuda
()
pair
=
torch
.
randn
(
1
,
pair_len
,
pair_len
,
128
).
cuda
()
pair_mask
=
torch
.
randn
(
1
,
pair_len
,
pair_len
).
cuda
()
meta_args
=
[
(
"m"
,
node
),
(
"z"
,
pair
),
(
"msa_mask"
,
node_mask
),
(
"pair_mask"
,
pair_mask
),
]
concrete_args
=
[(
"chunk_size"
,
None
),
(
"_chunk_logits"
,
1024
)]
return
meta_args
,
concrete_args
def
get_chunk_target
()
->
Dict
:
return
{
None
:
[(
126
,
131
),
(
227
,
245
),
(
272
,
297
),
(
310
,
317
),
(
105
,
112
),
(
152
,
160
),
(
193
,
201
),
(
249
,
250
),
(
33
,
46
)],
20
:
[(
126
,
131
),
(
238
,
245
),
(
283
,
290
),
(
313
,
314
),
(
108
,
109
),
(
35
,
46
)],
24
:
[(
126
,
131
)],
}
@
pytest
.
mark
.
skipif
(
not
(
AUTOCHUNK_AVAILABLE
and
HAS_REPO
),
reason
=
"torch version is lower than 1.12.0"
,
)
@
pytest
.
mark
.
parametrize
(
"max_memory"
,
[
None
,
20
,
24
])
@
pytest
.
mark
.
parametrize
(
"data_args"
,
[(
32
,
64
)])
# (msa_len, pair_len)
def
test_extramsa_block
(
data_args
,
max_memory
):
run_func
=
partial
(
run_test
,
data_args
=
data_args
,
max_memory
=
max_memory
,
get_model
=
get_model
,
get_data
=
get_data
,
print_code
=
False
,
print_mem
=
False
,
print_progress
=
False
,
)
mp
.
spawn
(
run_func
,
nprocs
=
1
)
if
__name__
==
"__main__"
:
run_test
(
rank
=
0
,
data_args
=
(
32
,
64
),
max_memory
=
20
,
get_model
=
get_model
,
get_data
=
get_data
,
get_chunk_target
=
get_chunk_target
,
print_code
=
False
,
print_mem
=
False
,
print_progress
=
False
,
)
tests/test_autochunk/test_
simple_evoformer_codegen
.py
→
tests/test_autochunk/test_
diffuser/test_diffuser_utils
.py
View file @
63199c66
from
functools
import
partial
from
typing
import
Any
,
Dict
,
List
import
pytest
import
torch
import
torch.fx
import
torch.multiprocessing
as
mp
try
:
from
simple_evoformer
import
base_evoformer
HAS_REPO
=
True
except
:
HAS_REPO
=
False
import
colossalai
from
colossalai.autochunk.autochunk_codegen
import
AUTOCHUNK_AVAILABLE
from
colossalai.core
import
global_context
as
gpc
from
colossalai.fx
import
ColoTracer
,
symbolic_trace
from
colossalai.fx._compatibility
import
is_compatible_with_meta
from
colossalai.fx.codegen.activation_checkpoint_codegen
import
CODEGEN_AVAILABLE
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.utils
import
free_port
if
CODEGEN
_AVAILABLE
and
is_compatible_with_meta
()
:
if
AUTOCHUNK
_AVAILABLE
:
from
colossalai.autochunk.autochunk_codegen
import
AutoChunkCodeGen
from
colossalai.fx.profiler
import
MetaTensor
from
colossalai.fx.tracer.experimental
import
ColoTracer
,
symbolic_trace
def
assert_codegen_run
(
model
:
Any
,
meta_args
:
List
,
concrete_args
:
List
=
None
,
max_memory
:
int
=
None
,
print_mem
:
bool
=
False
,
print_progress
:
bool
=
False
,
print_code
:
bool
=
False
,
)
->
List
[
Dict
]:
if
concrete_args
is
None
:
concrete_args
=
[]
model
=
model
()
# trace the meta graph and setup codegen
meta_graph
=
symbolic_trace
(
model
,
meta_args
=
{
k
:
v
.
to
(
torch
.
device
(
"meta"
))
for
k
,
v
in
meta_args
},
concrete_args
=
{
k
:
v
for
k
,
v
in
concrete_args
},
)
interp
=
MetaInfoProp
(
meta_graph
)
meta_tensors
=
[
MetaTensor
(
i
[
1
],
fake_device
=
"cuda:0"
)
for
i
in
meta_args
]
+
[
i
[
1
]
for
i
in
concrete_args
]
interp
.
propagate
(
*
meta_tensors
)
codegen
=
AutoChunkCodeGen
(
meta_graph
,
max_memory
=
max_memory
,
print_mem
=
print_mem
,
print_progress
=
print_progress
,
)
chunks
=
codegen
.
chunk_infos
# trace and recompile
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
graph
=
ColoTracer
().
trace
(
model
.
cuda
(),
meta_args
=
{
k
:
v
.
to
(
torch
.
device
(
"meta"
))
for
k
,
v
in
meta_args
},
concrete_args
=
{
k
:
v
for
k
,
v
in
concrete_args
},
)
graph
.
set_codegen
(
codegen
)
gm
=
ColoGraphModule
(
model
,
graph
,
ckpt_codegen
=
False
)
gm
.
recompile
()
def
_test_fwd
(
model
:
torch
.
nn
.
Module
,
gm
:
ColoGraphModule
,
node
,
pair
):
with
torch
.
no_grad
():
non_fx_out
=
model
(
node
,
pair
)
fx_out
=
gm
(
node
,
pair
)
# assert chunk in code
code
=
graph
.
python_code
(
"self"
).
src
if
print_code
:
print
(
code
)
assert
"chunk_result = None; chunk_size = None;"
in
code
assert
torch
.
allclose
(
non_fx_out
[
0
],
fx_out
[
0
],
atol
=
1e-4
),
"fx_out doesn't comply with original output, diff is %.2e"
%
torch
.
mean
(
torch
.
abs
(
non_fx_out
[
0
]
-
fx_out
[
0
]))
assert
torch
.
allclose
(
non_fx_out
[
1
],
fx_out
[
1
],
# assert result
inputs
=
[
i
[
1
]
for
i
in
meta_args
]
+
[
i
[
1
]
for
i
in
concrete_args
]
model
.
cuda
().
eval
()
gm
.
eval
()
with
torch
.
no_grad
():
out_gm
=
gm
(
*
inputs
)
out_model
=
model
(
*
inputs
)
assert
torch
.
allclose
(
out_gm
[
"sample"
],
out_model
[
"sample"
],
atol
=
1e-4
),
"fx_out doesn't comply with original output, diff is %.2e"
%
torch
.
mean
(
torch
.
abs
(
non_fx_out
[
1
]
-
fx_out
[
1
]))
torch
.
abs
(
out_gm
[
"sample"
]
-
out_model
[
"sample"
]))
return
chunks
def
_test_simple_evoformer_codegen
(
rank
,
msa_len
,
pair_len
,
max_memory
):
def
run_test
(
rank
:
int
,
model
:
Any
,
data
:
tuple
,
max_memory
:
int
,
print_code
:
bool
,
print_mem
:
bool
,
print_progress
:
bool
,
get_chunk_target
:
Any
=
None
,
)
->
None
:
# launch colossalai
colossalai
.
launch
(
config
=
{},
...
...
@@ -50,55 +98,23 @@ def _test_simple_evoformer_codegen(rank, msa_len, pair_len, max_memory):
)
# build model and input
model
=
base_evoformer
().
cuda
()
node
=
torch
.
randn
(
1
,
msa_len
,
pair_len
,
256
).
cuda
()
pair
=
torch
.
randn
(
1
,
pair_len
,
pair_len
,
128
).
cuda
()
# meta info prop
meta_graph
=
symbolic_trace
(
model
,
meta_args
=
{
"node"
:
node
.
to
(
torch
.
device
(
"meta"
)),
"pair"
:
pair
.
to
(
torch
.
device
(
"meta"
)),
})
# must use symbolic_trace
interp
=
MetaInfoProp
(
meta_graph
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
))
codegen
=
AutoChunkCodeGen
(
meta_graph
,
max_memory
=
max_memory
)
# trace the module and replace codegen
graph
=
ColoTracer
().
trace
(
meta_args
,
concrete_args
=
data
chunks
=
assert_codegen_run
(
model
,
meta_args
=
{
"node"
:
node
.
to
(
torch
.
device
(
"meta"
)),
"pair"
:
pair
.
to
(
torch
.
device
(
"meta"
)),
},
)
graph
.
set_codegen
(
codegen
)
gm
=
ColoGraphModule
(
model
,
graph
,
ckpt_codegen
=
False
)
gm
.
recompile
()
# assert we have inserted chunk
code
=
graph
.
python_code
(
"self"
).
src
# print(code)
assert
"chunk_result = None; chunk_size = None;"
in
code
_test_fwd
(
model
,
gm
,
node
,
pair
)
gpc
.
destroy
()
@
pytest
.
mark
.
skipif
(
not
(
CODEGEN_AVAILABLE
and
is_compatible_with_meta
()
and
HAS_REPO
),
reason
=
'torch version is lower than 1.12.0'
)
@
pytest
.
mark
.
parametrize
(
"max_memory"
,
[
None
,
20
,
25
,
30
])
@
pytest
.
mark
.
parametrize
(
"msa_len"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"pair_len"
,
[
64
])
def
test_simple_evoformer_codegen
(
msa_len
,
pair_len
,
max_memory
):
run_func
=
partial
(
_test_simple_evoformer_codegen
,
msa_len
=
msa_len
,
pair_len
=
pair_len
,
meta_args
=
meta_args
,
concrete_args
=
concrete_args
,
max_memory
=
max_memory
,
print_code
=
print_code
,
print_mem
=
print_mem
,
print_progress
=
print_progress
,
)
mp
.
spawn
(
run_func
,
nprocs
=
1
)
if
get_chunk_target
is
not
None
:
chunk_found
=
[
i
[
"region"
]
for
i
in
chunks
]
chunk_target
=
get_chunk_target
()[
max_memory
]
assert
(
chunk_found
==
chunk_target
),
"found regions %s doesn't equal target regions %s"
%
(
str
(
chunk_found
),
str
(
chunk_target
),
)
if
__name__
==
"__main__"
:
_test_simple_evoformer_codegen
(
0
,
32
,
64
,
25
)
gpc
.
destroy
()
tests/test_autochunk/test_diffuser/test_unet.py
0 → 100644
View file @
63199c66
from
functools
import
partial
from
typing
import
List
,
Tuple
import
pytest
import
torch
import
torch.multiprocessing
as
mp
try
:
from
diffusers
import
UNet2DModel
MODELS
=
[
UNet2DModel
]
HAS_REPO
=
True
except
:
MODELS
=
[]
HAS_REPO
=
False
from
test_diffuser_utils
import
run_test
from
colossalai.autochunk.autochunk_codegen
import
AUTOCHUNK_AVAILABLE
BATCH_SIZE
=
2
SEQ_LENGTH
=
5
HEIGHT
=
224
WIDTH
=
224
IN_CHANNELS
=
3
LATENTS_SHAPE
=
(
BATCH_SIZE
,
IN_CHANNELS
,
HEIGHT
//
7
,
WIDTH
//
7
)
def
get_data
(
shape
:
tuple
)
->
Tuple
[
List
,
List
]:
sample
=
torch
.
randn
(
shape
)
meta_args
=
[
(
"sample"
,
sample
),
]
concrete_args
=
[(
"timestep"
,
50
)]
return
meta_args
,
concrete_args
@
pytest
.
mark
.
skipif
(
True
,
reason
=
"not implemented"
,
)
@
pytest
.
mark
.
skipif
(
not
(
AUTOCHUNK_AVAILABLE
and
HAS_REPO
),
reason
=
"torch version is lower than 1.12.0"
,
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"shape"
,
[
LATENTS_SHAPE
])
@
pytest
.
mark
.
parametrize
(
"max_memory"
,
[
64
])
def
test_evoformer_block
(
model
,
shape
,
max_memory
):
run_func
=
partial
(
run_test
,
max_memory
=
max_memory
,
model
=
model
,
data
=
get_data
(
shape
),
print_code
=
False
,
print_mem
=
False
,
print_progress
=
False
,
)
mp
.
spawn
(
run_func
,
nprocs
=
1
)
if
__name__
==
"__main__"
:
run_test
(
rank
=
0
,
data
=
get_data
(
LATENTS_SHAPE
),
max_memory
=
64
,
model
=
UNet2DModel
,
print_code
=
False
,
print_mem
=
False
,
print_progress
=
False
,
)
tests/test_autochunk/test_evoformer_codegen.py
deleted
100644 → 0
View file @
6e0faa70
from
functools
import
partial
import
pytest
import
torch
import
torch.fx
import
torch.multiprocessing
as
mp
try
:
from
fastfold.model.nn.evoformer
import
EvoformerBlock
HAS_REPO
=
True
except
:
HAS_REPO
=
False
import
colossalai
from
colossalai.core
import
global_context
as
gpc
from
colossalai.fx._compatibility
import
is_compatible_with_meta
from
colossalai.fx.codegen.activation_checkpoint_codegen
import
CODEGEN_AVAILABLE
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.utils
import
free_port
if
CODEGEN_AVAILABLE
and
is_compatible_with_meta
():
from
colossalai.autochunk.autochunk_codegen
import
AutoChunkCodeGen
from
colossalai.fx.profiler
import
MetaTensor
from
colossalai.fx.tracer.experimental
import
ColoTracer
,
symbolic_trace
def
_test_fwd
(
model
:
torch
.
nn
.
Module
,
gm
:
ColoGraphModule
,
node
,
pair
,
node_mask
,
pair_mask
):
# for memory test
# model = model.cuda()
# torch.cuda.reset_peak_memory_stats()
# now_mem = torch.cuda.memory_allocated() / 1024**2
# with torch.no_grad():
# node1 = node.clone()
# pair1 = pair.clone()
# node_mask1 = node_mask.clone()
# pair_mask1 = pair_mask.clone()
# gm(node1, pair1, node_mask1, pair_mask1)
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
# print("autochunk max mem:%.2f"% (new_max_mem - now_mem))
# test forward
model
=
model
.
cuda
()
with
torch
.
no_grad
():
non_fx_out
=
model
(
node
,
pair
,
node_mask
,
pair_mask
)
fx_out
=
gm
(
node
,
pair
,
node_mask
,
pair_mask
)
assert
torch
.
allclose
(
non_fx_out
[
0
],
fx_out
[
0
],
atol
=
1e-4
),
"fx_out doesn't comply with original output, diff is %.2e"
%
torch
.
mean
(
torch
.
abs
(
non_fx_out
[
0
]
-
fx_out
[
0
]))
assert
torch
.
allclose
(
non_fx_out
[
1
],
fx_out
[
1
],
atol
=
1e-4
),
"fx_out doesn't comply with original output, diff is %.2e"
%
torch
.
mean
(
torch
.
abs
(
non_fx_out
[
1
]
-
fx_out
[
1
]))
def
_build_openfold
():
model
=
EvoformerBlock
(
c_m
=
256
,
c_z
=
128
,
c_hidden_msa_att
=
32
,
c_hidden_opm
=
32
,
c_hidden_mul
=
128
,
c_hidden_pair_att
=
32
,
no_heads_msa
=
8
,
no_heads_pair
=
4
,
transition_n
=
4
,
msa_dropout
=
0.15
,
pair_dropout
=
0.15
,
inf
=
1e4
,
eps
=
1e-4
,
is_multimer
=
False
,
).
eval
().
cuda
()
return
model
def
_test_evoformer_codegen
(
rank
,
msa_len
,
pair_len
,
max_memory
):
# launch colossalai
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
1
,
host
=
"localhost"
,
port
=
free_port
(),
backend
=
"nccl"
,
)
# build model and input
model
=
_build_openfold
()
node
=
torch
.
randn
(
1
,
msa_len
,
pair_len
,
256
).
cuda
()
node_mask
=
torch
.
randn
(
1
,
msa_len
,
pair_len
).
cuda
()
pair
=
torch
.
randn
(
1
,
pair_len
,
pair_len
,
128
).
cuda
()
pair_mask
=
torch
.
randn
(
1
,
pair_len
,
pair_len
).
cuda
()
# trace the meta graph and setup codegen
meta_graph
=
symbolic_trace
(
model
,
meta_args
=
{
"m"
:
node
.
to
(
torch
.
device
(
"meta"
)),
"z"
:
pair
.
to
(
torch
.
device
(
"meta"
)),
"msa_mask"
:
node_mask
.
to
(
torch
.
device
(
"meta"
)),
"pair_mask"
:
pair_mask
.
to
(
torch
.
device
(
"meta"
)),
},
concrete_args
=
{
"chunk_size"
:
None
,
"_mask_trans"
:
True
,
},
)
interp
=
MetaInfoProp
(
meta_graph
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
),
MetaTensor
(
node_mask
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair_mask
,
fake_device
=
"cuda:0"
),
)
codegen
=
AutoChunkCodeGen
(
meta_graph
,
max_memory
=
max_memory
,
print_mem
=
False
)
# trace and recompile
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
graph
=
ColoTracer
().
trace
(
model
,
meta_args
=
{
"m"
:
node
.
to
(
torch
.
device
(
"meta"
)),
"z"
:
pair
.
to
(
torch
.
device
(
"meta"
)),
"msa_mask"
:
node_mask
.
to
(
torch
.
device
(
"meta"
)),
"pair_mask"
:
pair_mask
.
to
(
torch
.
device
(
"meta"
)),
},
concrete_args
=
{
"chunk_size"
:
None
,
"_mask_trans"
:
True
,
},
)
graph
.
set_codegen
(
codegen
)
gm
=
ColoGraphModule
(
model
,
graph
,
ckpt_codegen
=
False
)
gm
.
recompile
()
# assert we have inserted chunk
code
=
graph
.
python_code
(
"self"
).
src
# print(code)
assert
"chunk_result = None; chunk_size = None;"
in
code
_test_fwd
(
model
,
gm
,
node
,
pair
,
node_mask
,
pair_mask
)
gpc
.
destroy
()
@
pytest
.
mark
.
skipif
(
not
(
CODEGEN_AVAILABLE
and
is_compatible_with_meta
()
and
HAS_REPO
),
reason
=
"torch version is lower than 1.12.0"
,
)
@
pytest
.
mark
.
parametrize
(
"max_memory"
,
[
None
,
24
,
28
,
32
])
@
pytest
.
mark
.
parametrize
(
"msa_len"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"pair_len"
,
[
64
])
def
test_evoformer_codegen
(
msa_len
,
pair_len
,
max_memory
):
run_func
=
partial
(
_test_evoformer_codegen
,
msa_len
=
msa_len
,
pair_len
=
pair_len
,
max_memory
=
max_memory
,
)
mp
.
spawn
(
run_func
,
nprocs
=
1
)
if
__name__
==
"__main__"
:
_test_evoformer_codegen
(
0
,
32
,
64
,
24
)
tests/test_autochunk/test_evoformer_stack_codegen.py
deleted
100644 → 0
View file @
6e0faa70
from
functools
import
partial
import
pytest
import
torch
import
torch.fx
import
torch.multiprocessing
as
mp
try
:
from
fastfold.model.nn.evoformer
import
EvoformerStack
HAS_REPO
=
True
except
:
HAS_REPO
=
False
import
colossalai
from
colossalai.core
import
global_context
as
gpc
from
colossalai.fx._compatibility
import
is_compatible_with_meta
from
colossalai.fx.codegen.activation_checkpoint_codegen
import
CODEGEN_AVAILABLE
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.utils
import
free_port
if
CODEGEN_AVAILABLE
and
is_compatible_with_meta
():
from
colossalai.autochunk.autochunk_codegen
import
AutoChunkCodeGen
from
colossalai.fx.profiler
import
MetaTensor
from
colossalai.fx.tracer.experimental
import
ColoTracer
,
symbolic_trace
def
_test_fwd
(
model
:
torch
.
nn
.
Module
,
gm
:
ColoGraphModule
,
node
,
pair
,
node_mask
,
pair_mask
):
# for memory test
# model = model.cuda()
# torch.cuda.reset_peak_memory_stats()
# now_mem = torch.cuda.memory_allocated() / 1024**2
# with torch.no_grad():
# node1 = node.clone()
# pair1 = pair.clone()
# node_mask1 = node_mask.clone()
# pair_mask1 = pair_mask.clone()
# gm(node1, pair1, node_mask1, pair_mask1, None)
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
# print("autochunk max mem:%.2f"% (new_max_mem - now_mem))
# test forward
model
=
model
.
cuda
()
with
torch
.
no_grad
():
non_fx_out
=
model
(
node
,
pair
,
node_mask
,
pair_mask
,
None
)
fx_out
=
gm
(
node
,
pair
,
node_mask
,
pair_mask
,
None
)
assert
torch
.
allclose
(
non_fx_out
[
0
],
fx_out
[
0
],
atol
=
1e-4
),
"fx_out doesn't comply with original output, diff is %.2e"
%
torch
.
mean
(
torch
.
abs
(
non_fx_out
[
0
]
-
fx_out
[
0
]))
assert
torch
.
allclose
(
non_fx_out
[
1
],
fx_out
[
1
],
atol
=
1e-4
),
"fx_out doesn't comply with original output, diff is %.2e"
%
torch
.
mean
(
torch
.
abs
(
non_fx_out
[
1
]
-
fx_out
[
1
]))
def
_build_openfold
():
model
=
EvoformerStack
(
c_m
=
256
,
c_z
=
128
,
c_hidden_msa_att
=
32
,
c_hidden_opm
=
32
,
c_hidden_mul
=
128
,
c_hidden_pair_att
=
32
,
c_s
=
384
,
no_heads_msa
=
8
,
no_heads_pair
=
4
,
no_blocks
=
2
,
# 48
transition_n
=
4
,
msa_dropout
=
0.15
,
pair_dropout
=
0.25
,
blocks_per_ckpt
=
None
,
inf
=
1000000000.0
,
eps
=
1e-08
,
clear_cache_between_blocks
=
False
,
is_multimer
=
False
,
).
eval
().
cuda
()
return
model
def
_test_evoformer_stack_codegen
(
rank
,
msa_len
,
pair_len
,
max_memory
):
# launch colossalai
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
1
,
host
=
"localhost"
,
port
=
free_port
(),
backend
=
"nccl"
,
)
# build model and input
model
=
_build_openfold
()
node
=
torch
.
randn
(
1
,
msa_len
,
pair_len
,
256
).
cuda
()
node_mask
=
torch
.
randn
(
1
,
msa_len
,
pair_len
).
cuda
()
pair
=
torch
.
randn
(
1
,
pair_len
,
pair_len
,
128
).
cuda
()
pair_mask
=
torch
.
randn
(
1
,
pair_len
,
pair_len
).
cuda
()
# trace the meta graph and setup codegen
meta_graph
=
symbolic_trace
(
model
,
meta_args
=
{
"m"
:
node
.
to
(
torch
.
device
(
"meta"
)),
"z"
:
pair
.
to
(
torch
.
device
(
"meta"
)),
"msa_mask"
:
node_mask
.
to
(
torch
.
device
(
"meta"
)),
"pair_mask"
:
pair_mask
.
to
(
torch
.
device
(
"meta"
)),
},
concrete_args
=
{
"chunk_size"
:
None
,
"_mask_trans"
:
True
,
},
)
interp
=
MetaInfoProp
(
meta_graph
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
),
MetaTensor
(
node_mask
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair_mask
,
fake_device
=
"cuda:0"
),
None
)
codegen
=
AutoChunkCodeGen
(
meta_graph
,
max_memory
=
max_memory
,
print_mem
=
False
,
print_progress
=
False
)
# trace and recompile
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
graph
=
ColoTracer
().
trace
(
model
,
meta_args
=
{
"m"
:
node
.
to
(
torch
.
device
(
"meta"
)),
"z"
:
pair
.
to
(
torch
.
device
(
"meta"
)),
"msa_mask"
:
node_mask
.
to
(
torch
.
device
(
"meta"
)),
"pair_mask"
:
pair_mask
.
to
(
torch
.
device
(
"meta"
)),
},
concrete_args
=
{
"chunk_size"
:
None
,
"_mask_trans"
:
True
,
},
)
graph
.
set_codegen
(
codegen
)
gm
=
ColoGraphModule
(
model
,
graph
,
ckpt_codegen
=
False
)
gm
.
recompile
()
# assert we have inserted chunk
code
=
graph
.
python_code
(
"self"
).
src
# print(code)
assert
"chunk_result = None; chunk_size = None;"
in
code
_test_fwd
(
model
,
gm
,
node
,
pair
,
node_mask
,
pair_mask
)
gpc
.
destroy
()
@
pytest
.
mark
.
skipif
(
not
(
CODEGEN_AVAILABLE
and
is_compatible_with_meta
()
and
HAS_REPO
),
reason
=
"torch version is lower than 1.12.0"
,
)
@
pytest
.
mark
.
parametrize
(
"max_memory"
,
[
None
,
24
,
28
,
32
])
@
pytest
.
mark
.
parametrize
(
"msa_len"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"pair_len"
,
[
64
])
def
test_evoformer_stack_codegen
(
msa_len
,
pair_len
,
max_memory
):
run_func
=
partial
(
_test_evoformer_stack_codegen
,
msa_len
=
msa_len
,
pair_len
=
pair_len
,
max_memory
=
max_memory
,
)
mp
.
spawn
(
run_func
,
nprocs
=
1
)
if
__name__
==
"__main__"
:
_test_evoformer_stack_codegen
(
0
,
32
,
64
,
None
)
tests/test_autochunk/test_extramsa_codegen.py
deleted
100644 → 0
View file @
6e0faa70
from
functools
import
partial
import
pytest
import
torch
import
torch.fx
import
torch.multiprocessing
as
mp
try
:
from
fastfold.model.nn.evoformer
import
ExtraMSABlock
HAS_REPO
=
True
except
:
HAS_REPO
=
False
import
colossalai
from
colossalai.core
import
global_context
as
gpc
from
colossalai.fx._compatibility
import
is_compatible_with_meta
from
colossalai.fx.codegen.activation_checkpoint_codegen
import
CODEGEN_AVAILABLE
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.utils
import
free_port
if
CODEGEN_AVAILABLE
and
is_compatible_with_meta
():
from
colossalai.autochunk.autochunk_codegen
import
AutoChunkCodeGen
from
colossalai.fx.profiler
import
MetaTensor
from
colossalai.fx.tracer.experimental
import
ColoTracer
,
symbolic_trace
def
_test_fwd
(
model
:
torch
.
nn
.
Module
,
gm
:
ColoGraphModule
,
node
,
pair
,
node_mask
,
pair_mask
):
# for memory test
# model = model.cuda()
# torch.cuda.reset_peak_memory_stats()
# now_mem = torch.cuda.memory_allocated() / 1024**2
# with torch.no_grad():
# node1 = node.clone()
# pair1 = pair.clone()
# node_mask1 = node_mask.clone()
# pair_mask1 = pair_mask.clone()
# gm(node1, pair1, node_mask1, pair_mask1)
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
# print("autochunk max mem:%.2f"% (new_max_mem - now_mem))
# test forward
model
=
model
.
cuda
()
with
torch
.
no_grad
():
non_fx_out
=
model
(
node
,
pair
,
node_mask
,
pair_mask
)
fx_out
=
gm
(
node
,
pair
,
node_mask
,
pair_mask
)
assert
torch
.
allclose
(
non_fx_out
[
0
],
fx_out
[
0
],
atol
=
1e-4
),
"fx_out doesn't comply with original output, diff is %.2e"
%
torch
.
mean
(
torch
.
abs
(
non_fx_out
[
0
]
-
fx_out
[
0
]))
assert
torch
.
allclose
(
non_fx_out
[
1
],
fx_out
[
1
],
atol
=
1e-4
),
"fx_out doesn't comply with original output, diff is %.2e"
%
torch
.
mean
(
torch
.
abs
(
non_fx_out
[
1
]
-
fx_out
[
1
]))
def
_build_openfold
():
model
=
ExtraMSABlock
(
c_m
=
256
,
c_z
=
128
,
c_hidden_msa_att
=
32
,
c_hidden_opm
=
32
,
c_hidden_mul
=
128
,
c_hidden_pair_att
=
32
,
no_heads_msa
=
8
,
no_heads_pair
=
4
,
transition_n
=
4
,
msa_dropout
=
0.15
,
pair_dropout
=
0.15
,
inf
=
1e4
,
eps
=
1e-4
,
ckpt
=
False
,
is_multimer
=
False
,
).
eval
().
cuda
()
return
model
def
_test_extramsa_codegen
(
rank
,
msa_len
,
pair_len
,
max_memory
):
# launch colossalai
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
1
,
host
=
"localhost"
,
port
=
free_port
(),
backend
=
"nccl"
,
)
# build model and input
model
=
_build_openfold
()
node
=
torch
.
randn
(
1
,
msa_len
,
pair_len
,
256
).
cuda
()
node_mask
=
torch
.
randn
(
1
,
msa_len
,
pair_len
).
cuda
()
pair
=
torch
.
randn
(
1
,
pair_len
,
pair_len
,
128
).
cuda
()
pair_mask
=
torch
.
randn
(
1
,
pair_len
,
pair_len
).
cuda
()
# trace the meta graph and setup codegen
meta_graph
=
symbolic_trace
(
model
,
meta_args
=
{
"m"
:
node
.
to
(
torch
.
device
(
"meta"
)),
"z"
:
pair
.
to
(
torch
.
device
(
"meta"
)),
"msa_mask"
:
node_mask
.
to
(
torch
.
device
(
"meta"
)),
"pair_mask"
:
pair_mask
.
to
(
torch
.
device
(
"meta"
)),
},
concrete_args
=
{
"chunk_size"
:
None
,
"_chunk_logits"
:
1024
,
},
)
interp
=
MetaInfoProp
(
meta_graph
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
),
MetaTensor
(
node_mask
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair_mask
,
fake_device
=
"cuda:0"
),
)
codegen
=
AutoChunkCodeGen
(
meta_graph
,
max_memory
=
max_memory
,
print_mem
=
False
)
# trace and recompile
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
graph
=
ColoTracer
().
trace
(
model
,
meta_args
=
{
"m"
:
node
.
to
(
torch
.
device
(
"meta"
)),
"z"
:
pair
.
to
(
torch
.
device
(
"meta"
)),
"msa_mask"
:
node_mask
.
to
(
torch
.
device
(
"meta"
)),
"pair_mask"
:
pair_mask
.
to
(
torch
.
device
(
"meta"
)),
},
concrete_args
=
{
"chunk_size"
:
None
,
"_chunk_logits"
:
1024
,
},
)
graph
.
set_codegen
(
codegen
)
gm
=
ColoGraphModule
(
model
,
graph
,
ckpt_codegen
=
False
)
gm
.
recompile
()
# assert we have inserted chunk
code
=
graph
.
python_code
(
"self"
).
src
# print(code)
assert
"chunk_result = None; chunk_size = None;"
in
code
_test_fwd
(
model
,
gm
,
node
,
pair
,
node_mask
,
pair_mask
)
gpc
.
destroy
()
@
pytest
.
mark
.
skipif
(
not
(
CODEGEN_AVAILABLE
and
is_compatible_with_meta
()
and
HAS_REPO
),
reason
=
"torch version is lower than 1.12.0"
,
)
@
pytest
.
mark
.
parametrize
(
"max_memory"
,
[
None
,
24
,
28
,
32
])
@
pytest
.
mark
.
parametrize
(
"msa_len"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"pair_len"
,
[
64
])
def
test_extramsa_codegen
(
msa_len
,
pair_len
,
max_memory
):
run_func
=
partial
(
_test_extramsa_codegen
,
msa_len
=
msa_len
,
pair_len
=
pair_len
,
max_memory
=
max_memory
,
)
mp
.
spawn
(
run_func
,
nprocs
=
1
)
if
__name__
==
"__main__"
:
_test_extramsa_codegen
(
0
,
32
,
64
,
None
)
tests/test_autochunk/test_simple_evoformer_search.py
deleted
100644 → 0
View file @
6e0faa70
from
functools
import
partial
import
pytest
import
torch
import
torch.fx
import
torch.multiprocessing
as
mp
try
:
from
simple_evoformer
import
base_evoformer
HAS_REPO
=
True
except
:
HAS_REPO
=
False
import
colossalai
from
colossalai.core
import
global_context
as
gpc
from
colossalai.fx
import
symbolic_trace
from
colossalai.fx._compatibility
import
is_compatible_with_meta
from
colossalai.fx.codegen.activation_checkpoint_codegen
import
CODEGEN_AVAILABLE
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.utils
import
free_port
if
CODEGEN_AVAILABLE
and
is_compatible_with_meta
():
from
colossalai.autochunk.autochunk_codegen
import
AutoChunkCodeGen
from
colossalai.fx.profiler
import
MetaTensor
def
assert_chunk_infos
(
chunk_infos
,
max_memory
,
msa_len
,
pair_len
):
found_regions
=
[
i
[
"region"
]
for
i
in
chunk_infos
]
if
msa_len
==
32
and
pair_len
==
64
:
if
max_memory
is
None
:
target_regions
=
[(
142
,
154
),
(
366
,
373
),
(
234
,
283
),
(
302
,
351
),
(
127
,
134
),
(
211
,
228
),
(
174
,
191
),
(
161
,
166
),
(
198
,
203
),
(
7
,
57
)]
elif
max_memory
==
20
:
target_regions
=
[(
142
,
154
),
(
369
,
373
),
(
235
,
269
),
(
303
,
351
),
(
130
,
131
)]
elif
max_memory
==
25
:
target_regions
=
[(
144
,
154
),
(
369
,
370
)]
elif
max_memory
==
30
:
target_regions
=
[(
144
,
154
)]
else
:
raise
NotImplementedError
()
else
:
raise
NotImplementedError
()
assert
found_regions
==
target_regions
,
"found regions %s doesn't equal target regions %s"
%
(
str
(
found_regions
),
str
(
target_regions
),
)
def
_test_simple_evoformer_search
(
rank
,
msa_len
,
pair_len
,
max_memory
):
# launch colossalai
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
1
,
host
=
"localhost"
,
port
=
free_port
(),
backend
=
"nccl"
,
)
# build model and input
model
=
base_evoformer
().
cuda
()
node
=
torch
.
randn
(
1
,
msa_len
,
pair_len
,
256
).
cuda
()
pair
=
torch
.
randn
(
1
,
pair_len
,
pair_len
,
128
).
cuda
()
meta_graph
=
symbolic_trace
(
model
,
meta_args
=
{
"node"
:
node
.
to
(
torch
.
device
(
"meta"
)),
"pair"
:
pair
.
to
(
torch
.
device
(
"meta"
)),
})
# must use symbolic_trace
interp
=
MetaInfoProp
(
meta_graph
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
))
codegen
=
AutoChunkCodeGen
(
meta_graph
,
max_memory
=
max_memory
)
chunk_infos
=
codegen
.
chunk_infos
assert_chunk_infos
(
chunk_infos
,
max_memory
,
msa_len
,
pair_len
)
gpc
.
destroy
()
@
pytest
.
mark
.
skipif
(
not
(
CODEGEN_AVAILABLE
and
is_compatible_with_meta
()
and
HAS_REPO
),
reason
=
"torch version is lower than 1.12.0"
)
@
pytest
.
mark
.
parametrize
(
"max_memory"
,
[
None
,
20
,
25
,
30
])
@
pytest
.
mark
.
parametrize
(
"msa_len"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"pair_len"
,
[
64
])
def
test_simple_evoformer_search
(
msa_len
,
pair_len
,
max_memory
):
run_func
=
partial
(
_test_simple_evoformer_search
,
msa_len
=
msa_len
,
pair_len
=
pair_len
,
max_memory
=
max_memory
,
)
mp
.
spawn
(
run_func
,
nprocs
=
1
)
if
__name__
==
"__main__"
:
_test_simple_evoformer_search
(
0
,
32
,
64
,
20
)
tests/test_autochunk/test_transformer/test_autochunk_gpt.py
0 → 100644
View file @
63199c66
from
functools
import
partial
from
typing
import
List
,
Tuple
import
pytest
import
torch
import
torch.multiprocessing
as
mp
try
:
from
transformers
import
GPT2Config
,
GPT2Model
MODELS
=
[
GPT2Model
]
HAS_REPO
=
True
except
:
MODELS
=
[]
HAS_REPO
=
False
from
test_transformer_utils
import
run_test
from
colossalai.autochunk.autochunk_codegen
import
AUTOCHUNK_AVAILABLE
BATCH_SIZE
=
2
SEQ_LENGTH
=
256
def
get_data
(
shape
:
tuple
)
->
Tuple
[
List
,
List
]:
input_ids
=
torch
.
zeros
(
shape
,
dtype
=
torch
.
int64
)
token_type_ids
=
torch
.
zeros
(
shape
,
dtype
=
torch
.
int64
)
attention_mask
=
torch
.
ones
(
shape
,
dtype
=
torch
.
int64
)
meta_args
=
dict
(
input_ids
=
input_ids
,
token_type_ids
=
token_type_ids
,
attention_mask
=
attention_mask
)
concrete_args
=
{
"past_key_values"
:
None
}
sequence
=
[
"input_ids"
,
"past_key_values"
,
"attention_mask"
,
"token_type_ids"
]
return
meta_args
,
concrete_args
,
sequence
@
pytest
.
mark
.
skipif
(
not
(
AUTOCHUNK_AVAILABLE
and
HAS_REPO
),
reason
=
"torch version is lower than 1.12.0"
,
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"shape"
,
[(
BATCH_SIZE
,
SEQ_LENGTH
)])
@
pytest
.
mark
.
parametrize
(
"max_memory"
,
[
None
,
4.5
,
5
])
def
test_gpt
(
model
,
shape
,
max_memory
):
run_func
=
partial
(
run_test
,
data
=
get_data
(
shape
),
max_memory
=
max_memory
,
model
=
model
,
config
=
GPT2Config
(
n_embd
=
96
,
n_position
=
shape
[
1
],
n_layer
=
2
,
n_head
=
4
),
print_code
=
False
,
print_mem
=
False
,
print_progress
=
False
,
)
mp
.
spawn
(
run_func
,
nprocs
=
1
)
if
__name__
==
"__main__"
:
run_test
(
rank
=
0
,
data
=
get_data
((
BATCH_SIZE
,
SEQ_LENGTH
)),
max_memory
=
None
,
model
=
GPT2Model
,
config
=
GPT2Config
(
n_embd
=
96
,
n_position
=
SEQ_LENGTH
,
n_layer
=
2
,
n_head
=
4
),
print_code
=
True
,
print_mem
=
True
,
print_progress
=
False
,
)
tests/test_autochunk/test_transformer/test_transformer_utils.py
0 → 100644
View file @
63199c66
from
typing
import
Any
,
Dict
,
List
import
torch
import
torch.fx
import
colossalai
from
colossalai.autochunk.autochunk_codegen
import
AUTOCHUNK_AVAILABLE
from
colossalai.core
import
global_context
as
gpc
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.utils
import
free_port
if
AUTOCHUNK_AVAILABLE
:
from
colossalai.autochunk.autochunk_codegen
import
AutoChunkCodeGen
from
colossalai.fx.profiler
import
MetaTensor
from
colossalai.fx.tracer.experimental
import
ColoTracer
,
symbolic_trace
def
assert_codegen_run
(
model
:
Any
,
data
:
tuple
,
max_memory
:
int
=
None
,
print_mem
:
bool
=
False
,
print_progress
:
bool
=
False
,
print_code
:
bool
=
False
,
)
->
List
[
Dict
]:
meta_args
,
concrete_args
,
sequence
=
data
if
concrete_args
is
None
:
concrete_args
=
{}
# trace the meta graph and setup codegen
meta_graph
=
symbolic_trace
(
model
,
meta_args
=
{
k
:
v
.
to
(
torch
.
device
(
"meta"
))
for
k
,
v
in
meta_args
.
items
()},
concrete_args
=
{
k
:
v
for
k
,
v
in
concrete_args
.
items
()},
)
interp
=
MetaInfoProp
(
meta_graph
)
meta_tensors
=
[
meta_args
[
i
]
if
i
in
meta_args
else
concrete_args
[
i
]
for
i
in
sequence
]
meta_tensors
=
[
MetaTensor
(
i
,
fake_device
=
"cuda:0"
)
if
isinstance
(
i
,
torch
.
Tensor
)
else
i
for
i
in
meta_tensors
]
interp
.
propagate
(
*
meta_tensors
)
codegen
=
AutoChunkCodeGen
(
meta_graph
,
max_memory
=
max_memory
,
print_mem
=
print_mem
,
print_progress
=
print_progress
,
)
chunks
=
codegen
.
chunk_infos
# trace and recompile
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
graph
=
ColoTracer
().
trace
(
model
.
cuda
(),
meta_args
=
{
k
:
v
.
to
(
torch
.
device
(
"meta"
))
for
k
,
v
in
meta_args
.
items
()},
concrete_args
=
{
k
:
v
for
k
,
v
in
concrete_args
.
items
()},
)
graph
.
set_codegen
(
codegen
)
gm
=
ColoGraphModule
(
model
,
graph
,
ckpt_codegen
=
False
)
gm
.
recompile
()
# assert chunk in code
code
=
graph
.
python_code
(
"self"
).
src
if
print_code
:
print
(
code
)
assert
"chunk_result = None; chunk_size = None;"
in
code
# assert result
inputs
=
[
meta_args
[
i
]
if
i
in
meta_args
else
concrete_args
[
i
]
for
i
in
sequence
]
inputs
=
[
i
.
cuda
()
if
isinstance
(
i
,
torch
.
Tensor
)
else
i
for
i
in
inputs
]
model
.
cuda
().
eval
()
gm
.
eval
()
with
torch
.
no_grad
():
out_gm
=
gm
(
*
inputs
)
out_model
=
model
(
*
inputs
)
for
k
in
out_model
.
keys
():
if
torch
.
is_tensor
(
out_gm
[
k
]):
assert
torch
.
equal
(
out_model
[
k
],
out_gm
[
k
]
),
f
'
{
model
.
__class__
.
__name__
}
has incorrect output
{
k
}
, expect
{
out_model
[
k
]
}
, but got
{
out_gm
[
k
]
}
'
return
chunks
def
run_test
(
rank
:
int
,
model
:
Any
,
config
:
Any
,
data
:
tuple
,
max_memory
:
int
,
print_code
:
bool
,
print_mem
:
bool
,
print_progress
:
bool
,
get_chunk_target
:
Any
=
None
,
)
->
None
:
model
=
model
(
config
=
config
)
# launch colossalai
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
1
,
host
=
"localhost"
,
port
=
free_port
(),
backend
=
"nccl"
,
)
# build model and input
chunks
=
assert_codegen_run
(
model
,
data
=
data
,
max_memory
=
max_memory
,
print_code
=
print_code
,
print_mem
=
print_mem
,
print_progress
=
print_progress
,
)
if
get_chunk_target
is
not
None
:
chunk_found
=
[
i
[
"region"
]
for
i
in
chunks
]
chunk_target
=
get_chunk_target
()[
max_memory
]
assert
(
chunk_found
==
chunk_target
),
"found regions %s doesn't equal target regions %s"
%
(
str
(
chunk_found
),
str
(
chunk_target
),
)
gpc
.
destroy
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment