Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
72341e65
Unverified
Commit
72341e65
authored
Jan 20, 2023
by
oahzxl
Committed by
GitHub
Jan 20, 2023
Browse files
[auto-chunk] support extramsa (#3) (#2504)
parent
0f02b8c6
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
283 additions
and
54 deletions
+283
-54
colossalai/autochunk/estimate_memory.py
colossalai/autochunk/estimate_memory.py
+2
-7
colossalai/autochunk/trace_flow.py
colossalai/autochunk/trace_flow.py
+32
-11
colossalai/autochunk/trace_indice.py
colossalai/autochunk/trace_indice.py
+52
-4
colossalai/autochunk/utils.py
colossalai/autochunk/utils.py
+16
-4
tests/test_autochunk/test_evoformer_codegen.py
tests/test_autochunk/test_evoformer_codegen.py
+1
-1
tests/test_autochunk/test_extramsa_codegen.py
tests/test_autochunk/test_extramsa_codegen.py
+164
-0
tests/test_autochunk/test_simple_evoformer_codegen.py
tests/test_autochunk/test_simple_evoformer_codegen.py
+1
-1
tests/test_autochunk/test_simple_evoformer_search.py
tests/test_autochunk/test_simple_evoformer_search.py
+15
-26
No files found.
colossalai/autochunk/estimate_memory.py
View file @
72341e65
...
@@ -6,12 +6,7 @@ from torch.fx.node import Node, map_arg
...
@@ -6,12 +6,7 @@ from torch.fx.node import Node, map_arg
from
colossalai.fx.profiler
import
activation_size
,
parameter_size
from
colossalai.fx.profiler
import
activation_size
,
parameter_size
from
.utils
import
(
from
.utils
import
delete_free_var_from_last_use
,
find_idx_by_name
,
get_node_shape
,
is_non_memory_node
delete_free_var_from_last_use
,
find_idx_by_name
,
get_node_shape
,
is_non_compute_node_except_placeholder
,
)
class
EstimateMemory
(
object
):
class
EstimateMemory
(
object
):
...
@@ -240,7 +235,7 @@ class EstimateMemory(object):
...
@@ -240,7 +235,7 @@ class EstimateMemory(object):
elif
node
.
op
==
"output"
:
elif
node
.
op
==
"output"
:
continue
continue
# no change for non compute node
# no change for non compute node
elif
is_non_
compute_node_except_placehol
de
r
(
node
):
elif
is_non_
memory_no
de
(
node
):
act_memory_peak_log
.
append
(
act_memory
)
act_memory_peak_log
.
append
(
act_memory
)
# node is a compute op
# node is a compute op
# calculate tmp, output node and delete node memory
# calculate tmp, output node and delete node memory
...
...
colossalai/autochunk/trace_flow.py
View file @
72341e65
...
@@ -118,16 +118,34 @@ class TraceFlow(object):
...
@@ -118,16 +118,34 @@ class TraceFlow(object):
def
_assgin_single_node_flow
(
def
_assgin_single_node_flow
(
self
,
self
,
arg_node
,
arg_node
:
Node
,
start_idx
,
start_idx
:
int
,
end_idx
,
end_idx
:
int
,
cur_node_dim
,
cur_node_dim
:
int
,
cur_node_compute
,
cur_node_compute
:
Dict
,
cur_node_source
,
cur_node_source
:
Dict
,
cur_node_fix_dim
,
cur_node_fix_dim
:
List
,
all_node_info
,
all_node_info
:
Dict
,
next_node_list
,
next_node_list
:
List
,
):
)
->
bool
:
"""
Given the current node and one of its arg node,
this function finds out arg node's chunk dim and fix dim
Args:
arg_node (Node): input node
start_idx (int): chunk region start
end_idx (int): chunk region end
cur_node_dim (int): current node chunk dim
cur_node_compute (Dict): current node compute dict
cur_node_source (Dict): current node source dict
cur_node_fix_dim (List): current node fix dim
all_node_info (Dict): all node chunk info in the chunk region
next_node_list (List)
Returns:
bool: True if this node can be added to the flow, vice versa.
"""
arg_idx
=
find_idx_by_name
(
arg_node
.
name
,
self
.
trace_indice
.
node_list
)
arg_idx
=
find_idx_by_name
(
arg_node
.
name
,
self
.
trace_indice
.
node_list
)
# arg in chunk range or be inputs
# arg in chunk range or be inputs
if
not
(
start_idx
<=
arg_idx
<
end_idx
):
if
not
(
start_idx
<=
arg_idx
<
end_idx
):
...
@@ -142,6 +160,9 @@ class TraceFlow(object):
...
@@ -142,6 +160,9 @@ class TraceFlow(object):
arg_dim
=
None
arg_dim
=
None
else
:
else
:
arg_dim
=
cur_node_source
[
cur_node_dim
][
arg_idx
][
0
]
arg_dim
=
cur_node_source
[
cur_node_dim
][
arg_idx
][
0
]
# chunk dim should be None if shape size is 1
if
get_node_shape
(
arg_node
)[
arg_dim
]
==
1
:
arg_dim
=
None
else
:
else
:
arg_dim
=
None
arg_dim
=
None
...
@@ -184,7 +205,7 @@ class TraceFlow(object):
...
@@ -184,7 +205,7 @@ class TraceFlow(object):
# get all valid args
# get all valid args
arg_list
=
[]
arg_list
=
[]
for
arg
in
cur_node
.
a
rg
s
:
for
arg
in
cur_node
.
a
ll_input_node
s
:
if
type
(
arg
)
!=
type
(
cur_node
):
if
type
(
arg
)
!=
type
(
cur_node
):
continue
continue
if
is_non_compute_node
(
arg
):
if
is_non_compute_node
(
arg
):
...
...
colossalai/autochunk/trace_indice.py
View file @
72341e65
...
@@ -432,6 +432,38 @@ class TraceIndice(object):
...
@@ -432,6 +432,38 @@ class TraceIndice(object):
"""
"""
self
.
_assign_all_indice
(
node
,
node_idx
)
self
.
_assign_all_indice
(
node
,
node_idx
)
def
_assign_cat_indice
(
self
,
node
:
Node
,
node_idx
:
int
):
"""
Assign indice for cat op.
Args:
node (node)
node_idx (int)
"""
nodes_in
=
flat_list
(
node
.
args
[
0
])
self
.
_assign_indice_as_input
(
node
,
node_idx
,
input_node
=
nodes_in
[
0
])
for
n
in
nodes_in
[
1
:]:
self
.
_mark_computation_from_node
(
n
,
node
)
cat_dim
=
node
.
kwargs
[
"dim"
]
self
.
_del_dim
(
node_idx
,
cat_dim
)
self
.
_add_dim
(
node_idx
,
cat_dim
)
def
_assign_sum_indice
(
self
,
node
:
Node
,
node_idx
:
int
):
"""
Assign indice for sum op.
Args:
node (node)
node_idx (int)
"""
nodes_in
=
flat_list
(
node
.
args
[
0
])
self
.
_add_dim
(
node_idx
,
0
)
self
.
_assign_indice_as_input
(
node
,
node_idx
,
input_node
=
nodes_in
[
0
])
for
n
in
nodes_in
[
1
:]:
self
.
_mark_computation_from_node
(
n
,
node
)
cat_dim
=
node
.
kwargs
[
"dim"
]
self
.
_del_dim
(
node_idx
,
cat_dim
)
def
_assign_getitem_indice
(
self
,
node
:
Node
,
node_idx
:
int
):
def
_assign_getitem_indice
(
self
,
node
:
Node
,
node_idx
:
int
):
"""
"""
Assign indice for getitem.
Assign indice for getitem.
...
@@ -442,7 +474,16 @@ class TraceIndice(object):
...
@@ -442,7 +474,16 @@ class TraceIndice(object):
node_idx (int)
node_idx (int)
"""
"""
node_args
=
flat_list
(
node
.
args
[
1
:])
node_args
=
flat_list
(
node
.
args
[
1
:])
if
not
any
(
i
==
str
(
node_arg
)
for
i
in
[
"None"
,
"Ellipsis"
]
for
node_arg
in
node_args
):
flag
=
False
for
node_arg
in
node_args
:
node_arg_str
=
str
(
node_arg
)
if
any
(
i
==
node_arg_str
for
i
in
[
"None"
,
"Ellipsis"
]):
flag
=
True
break
if
"slice"
in
node_arg_str
:
flag
=
True
break
if
flag
==
False
:
return
return
# node args should be like [Ellipsis, slice(start, step, end), None]
# node args should be like [Ellipsis, slice(start, step, end), None]
...
@@ -461,8 +502,11 @@ class TraceIndice(object):
...
@@ -461,8 +502,11 @@ class TraceIndice(object):
shape_gap
=
len
(
node_shape
)
-
len
(
node_args
)
+
1
shape_gap
=
len
(
node_shape
)
-
len
(
node_args
)
+
1
origin_idx_count
+=
shape_gap
origin_idx_count
+=
shape_gap
new_idx_count
+=
shape_gap
new_idx_count
+=
shape_gap
# slice(None, None, None) means all indexes, doesn't support other slice
# slice(None, None, None) means all indexes
elif
"slice(None, None, None)"
==
node_arg_str
:
elif
"slice"
in
node_arg_str
:
if
"slice(None, None, None)"
!=
node_arg_str
:
self
.
_del_dim
(
node_idx
,
new_idx_count
)
self
.
_add_dim
(
node_idx
,
new_idx_count
)
origin_idx_count
+=
1
origin_idx_count
+=
1
new_idx_count
+=
1
new_idx_count
+=
1
# None means a new dim
# None means a new dim
...
@@ -565,7 +609,7 @@ class TraceIndice(object):
...
@@ -565,7 +609,7 @@ class TraceIndice(object):
self
.
_assign_view_reshape_indice
(
node
,
idx
)
self
.
_assign_view_reshape_indice
(
node
,
idx
)
elif
"unsqueeze"
in
node
.
name
:
elif
"unsqueeze"
in
node
.
name
:
self
.
_assign_unsqueeze_indice
(
node
,
idx
)
self
.
_assign_unsqueeze_indice
(
node
,
idx
)
elif
any
(
i
in
node
.
name
for
i
in
[
"to"
,
"contiguous"
]):
elif
any
(
i
in
node
.
name
for
i
in
[
"to"
,
"contiguous"
,
"clone"
]):
self
.
_assgin_no_change_indice
(
node
,
idx
)
self
.
_assgin_no_change_indice
(
node
,
idx
)
elif
"new_ones"
in
node
.
name
:
elif
"new_ones"
in
node
.
name
:
self
.
_assign_ones_like_indice
(
node
,
idx
)
self
.
_assign_ones_like_indice
(
node
,
idx
)
...
@@ -574,6 +618,8 @@ class TraceIndice(object):
...
@@ -574,6 +618,8 @@ class TraceIndice(object):
elif
node
.
op
==
"call_function"
:
elif
node
.
op
==
"call_function"
:
if
"linear"
in
node
.
name
:
if
"linear"
in
node
.
name
:
self
.
_assign_linear_indice
(
node
,
idx
)
self
.
_assign_linear_indice
(
node
,
idx
)
elif
"cat"
in
node
.
name
:
self
.
_assign_cat_indice
(
node
,
idx
)
elif
"matmul"
in
node
.
name
:
elif
"matmul"
in
node
.
name
:
self
.
_assign_matmul_indice
(
node
,
idx
)
self
.
_assign_matmul_indice
(
node
,
idx
)
elif
"softmax"
in
node
.
name
:
elif
"softmax"
in
node
.
name
:
...
@@ -586,6 +632,8 @@ class TraceIndice(object):
...
@@ -586,6 +632,8 @@ class TraceIndice(object):
self
.
_assign_dropout_indice
(
node
,
idx
)
self
.
_assign_dropout_indice
(
node
,
idx
)
elif
"einsum"
in
node
.
name
:
elif
"einsum"
in
node
.
name
:
self
.
_assign_einsum_indice
(
node
,
idx
)
self
.
_assign_einsum_indice
(
node
,
idx
)
elif
"sum"
in
node
.
name
:
self
.
_assign_sum_indice
(
node
,
idx
)
elif
"layer_norm"
in
node
.
name
:
elif
"layer_norm"
in
node
.
name
:
self
.
_assign_layernorm_indice
(
node
,
idx
)
self
.
_assign_layernorm_indice
(
node
,
idx
)
elif
"getitem"
in
node
.
name
:
elif
"getitem"
in
node
.
name
:
...
...
colossalai/autochunk/utils.py
View file @
72341e65
...
@@ -3,10 +3,12 @@ from typing import Any, Callable, Dict, Iterable, List, Tuple
...
@@ -3,10 +3,12 @@ from typing import Any, Callable, Dict, Iterable, List, Tuple
from
torch.fx.node
import
Node
from
torch.fx.node
import
Node
def
flat_list
(
inputs
)
:
def
flat_list
(
inputs
:
Any
)
->
List
:
"""
"""
flat a list by recursion
flat a list by recursion
"""
"""
if
not
(
isinstance
(
inputs
,
list
)
or
isinstance
(
inputs
,
set
)
or
isinstance
(
inputs
,
tuple
)):
return
[
inputs
]
res
=
[]
res
=
[]
for
i
in
inputs
:
for
i
in
inputs
:
if
isinstance
(
i
,
list
)
or
isinstance
(
i
,
set
)
or
isinstance
(
i
,
tuple
):
if
isinstance
(
i
,
list
)
or
isinstance
(
i
,
set
)
or
isinstance
(
i
,
tuple
):
...
@@ -16,7 +18,7 @@ def flat_list(inputs):
...
@@ -16,7 +18,7 @@ def flat_list(inputs):
return
res
return
res
def
find_first_tensor_arg
(
node
)
:
def
find_first_tensor_arg
(
node
:
Node
)
->
Node
:
"""
"""
Find the first input tensor arg for a node
Find the first input tensor arg for a node
"""
"""
...
@@ -26,7 +28,7 @@ def find_first_tensor_arg(node):
...
@@ -26,7 +28,7 @@ def find_first_tensor_arg(node):
raise
RuntimeError
()
raise
RuntimeError
()
def
is_non_compute_node
(
node
)
:
def
is_non_compute_node
(
node
:
Node
)
->
bool
:
if
any
(
i
in
node
.
op
for
i
in
[
"placeholder"
,
"get_attr"
,
"output"
])
or
any
(
i
in
node
.
name
for
i
in
[
"getattr"
]):
if
any
(
i
in
node
.
op
for
i
in
[
"placeholder"
,
"get_attr"
,
"output"
])
or
any
(
i
in
node
.
name
for
i
in
[
"getattr"
]):
return
True
return
True
if
"getitem"
in
node
.
name
:
if
"getitem"
in
node
.
name
:
...
@@ -34,16 +36,26 @@ def is_non_compute_node(node):
...
@@ -34,16 +36,26 @@ def is_non_compute_node(node):
for
node_arg
in
node_args
:
for
node_arg
in
node_args
:
if
any
(
i
==
str
(
node_arg
)
for
i
in
[
"None"
,
"Ellipsis"
]):
if
any
(
i
==
str
(
node_arg
)
for
i
in
[
"None"
,
"Ellipsis"
]):
return
False
return
False
if
"slice"
in
str
(
node_arg
):
return
False
return
True
return
True
return
False
return
False
def
get_node_shape
(
node
)
:
def
get_node_shape
(
node
:
Node
)
->
List
:
if
hasattr
(
node
.
meta
[
"tensor_meta"
],
"shape"
):
if
hasattr
(
node
.
meta
[
"tensor_meta"
],
"shape"
):
return
node
.
meta
[
"tensor_meta"
].
shape
return
node
.
meta
[
"tensor_meta"
].
shape
return
None
return
None
def
is_non_memory_node
(
node
:
Node
)
->
bool
:
if
"getitem"
in
node
.
name
:
return
True
if
"output"
in
node
.
op
:
return
True
return
is_non_compute_node
(
node
)
def
is_non_compute_node_except_placeholder
(
node
):
def
is_non_compute_node_except_placeholder
(
node
):
if
"placeholder"
in
node
.
op
:
if
"placeholder"
in
node
.
op
:
return
False
return
False
...
...
tests/test_autochunk/test_evoformer_codegen.py
View file @
72341e65
...
@@ -130,7 +130,7 @@ def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory):
...
@@ -130,7 +130,7 @@ def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory):
},
},
)
)
graph
.
set_codegen
(
codegen
)
graph
.
set_codegen
(
codegen
)
gm
=
ColoGraphModule
(
model
,
graph
)
gm
=
ColoGraphModule
(
model
,
graph
,
ckpt_codegen
=
False
)
gm
.
recompile
()
gm
.
recompile
()
# assert we have inserted chunk
# assert we have inserted chunk
...
...
tests/test_autochunk/test_extramsa_codegen.py
0 → 100644
View file @
72341e65
from
functools
import
partial
import
pytest
import
torch
import
torch.fx
import
torch.multiprocessing
as
mp
try
:
from
fastfold.model.nn.evoformer
import
ExtraMSABlock
HAS_REPO
=
True
except
:
HAS_REPO
=
False
import
colossalai
from
colossalai.core
import
global_context
as
gpc
from
colossalai.fx._compatibility
import
is_compatible_with_meta
from
colossalai.fx.codegen.activation_checkpoint_codegen
import
CODEGEN_AVAILABLE
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.utils
import
free_port
if
CODEGEN_AVAILABLE
and
is_compatible_with_meta
():
from
colossalai.autochunk.autochunk_codegen
import
AutoChunkCodeGen
from
colossalai.fx.profiler
import
MetaTensor
from
colossalai.fx.tracer.experimental
import
ColoTracer
,
symbolic_trace
def
_test_fwd
(
model
:
torch
.
nn
.
Module
,
gm
:
ColoGraphModule
,
node
,
pair
,
node_mask
,
pair_mask
):
# for memory test
# model = model.cuda()
# torch.cuda.reset_peak_memory_stats()
# now_mem = torch.cuda.memory_allocated() / 1024**2
# with torch.no_grad():
# node1 = node.clone()
# pair1 = pair.clone()
# node_mask1 = node_mask.clone()
# pair_mask1 = pair_mask.clone()
# gm(node1, pair1, node_mask1, pair_mask1)
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
# print("autochunk max mem:%.2f"% (new_max_mem - now_mem))
# test forward
model
=
model
.
cuda
()
with
torch
.
no_grad
():
non_fx_out
=
model
(
node
,
pair
,
node_mask
,
pair_mask
)
fx_out
=
gm
(
node
,
pair
,
node_mask
,
pair_mask
)
assert
torch
.
allclose
(
non_fx_out
[
0
],
fx_out
[
0
],
atol
=
1e-4
),
"fx_out doesn't comply with original output, diff is %.2e"
%
torch
.
mean
(
torch
.
abs
(
non_fx_out
[
0
]
-
fx_out
[
0
]))
assert
torch
.
allclose
(
non_fx_out
[
1
],
fx_out
[
1
],
atol
=
1e-4
),
"fx_out doesn't comply with original output, diff is %.2e"
%
torch
.
mean
(
torch
.
abs
(
non_fx_out
[
1
]
-
fx_out
[
1
]))
def
_build_openfold
():
model
=
ExtraMSABlock
(
c_m
=
256
,
c_z
=
128
,
c_hidden_msa_att
=
32
,
c_hidden_opm
=
32
,
c_hidden_mul
=
128
,
c_hidden_pair_att
=
32
,
no_heads_msa
=
8
,
no_heads_pair
=
4
,
transition_n
=
4
,
msa_dropout
=
0.15
,
pair_dropout
=
0.15
,
inf
=
1e4
,
eps
=
1e-4
,
ckpt
=
False
,
is_multimer
=
False
,
).
eval
().
cuda
()
return
model
def
_test_extramsa_codegen
(
rank
,
msa_len
,
pair_len
,
max_memory
):
# launch colossalai
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
1
,
host
=
"localhost"
,
port
=
free_port
(),
backend
=
"nccl"
,
)
# build model and input
model
=
_build_openfold
()
node
=
torch
.
randn
(
1
,
msa_len
,
pair_len
,
256
).
cuda
()
node_mask
=
torch
.
randn
(
1
,
msa_len
,
pair_len
).
cuda
()
pair
=
torch
.
randn
(
1
,
pair_len
,
pair_len
,
128
).
cuda
()
pair_mask
=
torch
.
randn
(
1
,
pair_len
,
pair_len
).
cuda
()
# trace the meta graph and setup codegen
meta_graph
=
symbolic_trace
(
model
,
meta_args
=
{
"m"
:
node
.
to
(
torch
.
device
(
"meta"
)),
"z"
:
pair
.
to
(
torch
.
device
(
"meta"
)),
"msa_mask"
:
node_mask
.
to
(
torch
.
device
(
"meta"
)),
"pair_mask"
:
pair_mask
.
to
(
torch
.
device
(
"meta"
)),
},
concrete_args
=
{
"chunk_size"
:
None
,
"_chunk_logits"
:
1024
,
},
)
interp
=
MetaInfoProp
(
meta_graph
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
),
MetaTensor
(
node_mask
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair_mask
,
fake_device
=
"cuda:0"
),
)
codegen
=
AutoChunkCodeGen
(
meta_graph
,
max_memory
=
max_memory
,
print_mem
=
False
)
# trace and recompile
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
graph
=
ColoTracer
().
trace
(
model
,
meta_args
=
{
"m"
:
node
.
to
(
torch
.
device
(
"meta"
)),
"z"
:
pair
.
to
(
torch
.
device
(
"meta"
)),
"msa_mask"
:
node_mask
.
to
(
torch
.
device
(
"meta"
)),
"pair_mask"
:
pair_mask
.
to
(
torch
.
device
(
"meta"
)),
},
concrete_args
=
{
"chunk_size"
:
None
,
"_chunk_logits"
:
1024
,
},
)
graph
.
set_codegen
(
codegen
)
gm
=
ColoGraphModule
(
model
,
graph
,
ckpt_codegen
=
False
)
gm
.
recompile
()
# assert we have inserted chunk
code
=
graph
.
python_code
(
"self"
).
src
# print(code)
assert
"chunk_result = None; chunk_size = None;"
in
code
_test_fwd
(
model
,
gm
,
node
,
pair
,
node_mask
,
pair_mask
)
gpc
.
destroy
()
@
pytest
.
mark
.
skipif
(
not
(
CODEGEN_AVAILABLE
and
is_compatible_with_meta
()
and
HAS_REPO
),
reason
=
"torch version is lower than 1.12.0"
,
)
@
pytest
.
mark
.
parametrize
(
"max_memory"
,
[
None
,
24
,
28
,
32
])
@
pytest
.
mark
.
parametrize
(
"msa_len"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"pair_len"
,
[
64
])
def
test_extramsa_codegen
(
msa_len
,
pair_len
,
max_memory
):
run_func
=
partial
(
_test_extramsa_codegen
,
msa_len
=
msa_len
,
pair_len
=
pair_len
,
max_memory
=
max_memory
,
)
mp
.
spawn
(
run_func
,
nprocs
=
1
)
if
__name__
==
"__main__"
:
_test_extramsa_codegen
(
0
,
32
,
64
,
None
)
tests/test_autochunk/test_simple_evoformer_codegen.py
View file @
72341e65
...
@@ -73,7 +73,7 @@ def _test_simple_evoformer_codegen(rank, msa_len, pair_len, max_memory):
...
@@ -73,7 +73,7 @@ def _test_simple_evoformer_codegen(rank, msa_len, pair_len, max_memory):
},
},
)
)
graph
.
set_codegen
(
codegen
)
graph
.
set_codegen
(
codegen
)
gm
=
ColoGraphModule
(
model
,
graph
)
gm
=
ColoGraphModule
(
model
,
graph
,
ckpt_codegen
=
False
)
gm
.
recompile
()
gm
.
recompile
()
# assert we have inserted chunk
# assert we have inserted chunk
...
...
tests/test_autochunk/test_simple_evoformer_search.py
View file @
72341e65
...
@@ -13,6 +13,7 @@ except:
...
@@ -13,6 +13,7 @@ except:
import
colossalai
import
colossalai
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.fx
import
symbolic_trace
from
colossalai.fx._compatibility
import
is_compatible_with_meta
from
colossalai.fx._compatibility
import
is_compatible_with_meta
from
colossalai.fx.codegen.activation_checkpoint_codegen
import
CODEGEN_AVAILABLE
from
colossalai.fx.codegen.activation_checkpoint_codegen
import
CODEGEN_AVAILABLE
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
...
@@ -28,10 +29,10 @@ def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
...
@@ -28,10 +29,10 @@ def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
if
msa_len
==
32
and
pair_len
==
64
:
if
msa_len
==
32
and
pair_len
==
64
:
if
max_memory
is
None
:
if
max_memory
is
None
:
target_regions
=
[(
142
,
154
),
(
366
,
373
),
(
23
3
,
283
),
(
30
1
,
351
),
(
127
,
134
),
(
2
04
,
228
),
(
1
6
7
,
191
),
target_regions
=
[(
142
,
154
),
(
366
,
373
),
(
23
4
,
283
),
(
30
2
,
351
),
(
127
,
134
),
(
2
11
,
228
),
(
17
4
,
191
),
(
161
,
166
),
(
198
,
203
),
(
6
,
69
)]
(
161
,
166
),
(
198
,
203
),
(
7
,
57
)]
elif
max_memory
==
20
:
elif
max_memory
==
20
:
target_regions
=
[(
142
,
154
),
(
369
,
373
),
(
23
3
,
269
),
(
30
1
,
351
)]
target_regions
=
[(
142
,
154
),
(
369
,
373
),
(
23
5
,
269
),
(
30
3
,
351
)
,
(
130
,
131
)
]
elif
max_memory
==
25
:
elif
max_memory
==
25
:
target_regions
=
[(
144
,
154
),
(
369
,
370
)]
target_regions
=
[(
144
,
154
),
(
369
,
370
)]
elif
max_memory
==
30
:
elif
max_memory
==
30
:
...
@@ -41,25 +42,10 @@ def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
...
@@ -41,25 +42,10 @@ def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
else
:
else
:
raise
NotImplementedError
()
raise
NotImplementedError
()
assert
len
(
found_regions
)
==
len
(
assert
found_regions
==
target_regions
,
"found regions %s doesn't equal target regions %s"
%
(
target_regions
),
"len of found regions %s doesn't equal len of target regions %s"
%
(
str
(
found_regions
),
str
(
found_regions
),
str
(
target_regions
),
str
(
target_regions
),
)
)
for
region
in
target_regions
:
assert
(
region
in
found_regions
),
"region:%s not in found regions for msa:%d, pair:%d, maxmem:%s"
%
(
str
(
region
),
msa_len
,
pair_len
,
str
(
max_memory
),
)
for
region
in
found_regions
:
assert
(
region
in
target_regions
),
"region:%s should not be found for msa:%d, pair:%d, maxmem:%d"
%
(
str
(
region
),
msa_len
,
pair_len
,
str
(
max_memory
),
)
def
_test_simple_evoformer_search
(
rank
,
msa_len
,
pair_len
,
max_memory
):
def
_test_simple_evoformer_search
(
rank
,
msa_len
,
pair_len
,
max_memory
):
...
@@ -78,11 +64,14 @@ def _test_simple_evoformer_search(rank, msa_len, pair_len, max_memory):
...
@@ -78,11 +64,14 @@ def _test_simple_evoformer_search(rank, msa_len, pair_len, max_memory):
node
=
torch
.
randn
(
1
,
msa_len
,
pair_len
,
256
).
cuda
()
node
=
torch
.
randn
(
1
,
msa_len
,
pair_len
,
256
).
cuda
()
pair
=
torch
.
randn
(
1
,
pair_len
,
pair_len
,
128
).
cuda
()
pair
=
torch
.
randn
(
1
,
pair_len
,
pair_len
,
128
).
cuda
()
gm_prop
=
torch
.
fx
.
symbolic_trace
(
model
)
# must use symbolic_trace
meta_graph
=
symbolic_trace
(
model
,
interp
=
MetaInfoProp
(
gm_prop
)
meta_args
=
{
"node"
:
node
.
to
(
torch
.
device
(
"meta"
)),
"pair"
:
pair
.
to
(
torch
.
device
(
"meta"
)),
})
# must use symbolic_trace
interp
=
MetaInfoProp
(
meta_graph
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
))
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
))
codegen
=
AutoChunkCodeGen
(
meta_graph
,
max_memory
=
max_memory
)
codegen
=
AutoChunkCodeGen
(
gm_prop
,
max_memory
=
max_memory
)
chunk_infos
=
codegen
.
chunk_infos
chunk_infos
=
codegen
.
chunk_infos
assert_chunk_infos
(
chunk_infos
,
max_memory
,
msa_len
,
pair_len
)
assert_chunk_infos
(
chunk_infos
,
max_memory
,
msa_len
,
pair_len
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment