Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c04f1832
Unverified
Commit
c04f1832
authored
Jan 20, 2023
by
oahzxl
Committed by
GitHub
Jan 20, 2023
Browse files
[autochunk] support parsing blocks (#2506)
parent
35c0c000
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
314 additions
and
22 deletions
+314
-22
colossalai/autochunk/autochunk_codegen.py
colossalai/autochunk/autochunk_codegen.py
+9
-3
colossalai/autochunk/estimate_memory.py
colossalai/autochunk/estimate_memory.py
+27
-0
colossalai/autochunk/search_chunk.py
colossalai/autochunk/search_chunk.py
+57
-18
colossalai/autochunk/trace_flow.py
colossalai/autochunk/trace_flow.py
+4
-1
colossalai/autochunk/trace_indice.py
colossalai/autochunk/trace_indice.py
+46
-0
colossalai/autochunk/utils.py
colossalai/autochunk/utils.py
+8
-0
tests/test_autochunk/test_evoformer_stack_codegen.py
tests/test_autochunk/test_evoformer_stack_codegen.py
+163
-0
No files found.
colossalai/autochunk/autochunk_codegen.py
View file @
c04f1832
...
...
@@ -22,7 +22,7 @@ if CODEGEN_AVAILABLE:
from
torch.fx.node
import
Argument
,
Node
,
_get_qualified_name
,
_type_repr
,
map_arg
from
.search_chunk
import
SearchChunk
from
.utils
import
delete_free_var_from_last_use
,
find_idx_by_name
,
get_node_shape
from
.utils
import
delete_free_var_from_last_use
,
find_idx_by_name
,
get_logger
,
get_node_shape
def
_gen_chunk_slice_dim
(
chunk_dim
:
int
,
chunk_indice_name
:
str
,
shape
:
List
)
->
str
:
...
...
@@ -276,11 +276,17 @@ if CODEGEN_AVAILABLE:
class
AutoChunkCodeGen
(
CodeGen
):
def
__init__
(
self
,
meta_graph
,
max_memory
=
None
,
print_mem
=
False
):
def
__init__
(
self
,
meta_graph
,
max_memory
:
int
=
None
,
print_mem
:
bool
=
False
,
print_progress
:
bool
=
False
)
->
None
:
super
().
__init__
()
# find the chunk regions
self
.
search_chunk
=
SearchChunk
(
meta_graph
,
max_memory
,
print_mem
)
self
.
search_chunk
=
SearchChunk
(
meta_graph
,
max_memory
,
print_mem
,
print_progress
)
self
.
chunk_infos
=
self
.
search_chunk
.
search_region
()
if
print_progress
:
get_logger
().
info
(
"AutoChunk start codegen"
)
def
_gen_python_code
(
self
,
nodes
,
root_module
:
str
,
namespace
:
_Namespace
)
->
PythonCode
:
free_vars
:
List
[
str
]
=
[]
...
...
colossalai/autochunk/estimate_memory.py
View file @
c04f1832
...
...
@@ -43,6 +43,8 @@ class EstimateMemory(object):
delete_node
=
[]
if
user
.
op
not
in
(
"output"
,):
nodes_to_delete
=
user_to_last_uses
.
get
(
user
,
[])
if
len
(
user
.
users
)
==
0
:
nodes_to_delete
.
append
(
user
)
if
to_keep
is
not
None
:
keep_list
=
[]
for
n
in
nodes_to_delete
:
...
...
@@ -135,6 +137,8 @@ class EstimateMemory(object):
if
user
.
op
in
(
"placeholder"
,
"output"
):
return
0
nodes_to_delete
=
user_to_last_uses
.
get
(
user
,
[])
if
len
(
user
.
users
)
==
0
:
nodes_to_delete
.
append
(
user
)
delete_size
=
0
for
n
in
nodes_to_delete
:
if
n
.
name
in
chunk_inputs_names
:
...
...
@@ -294,3 +298,26 @@ class EstimateMemory(object):
# param_memory = parameter_size(gm)
# all_memory = act_memory + param_memory
return
act_memory_peak_log
,
act_memory_after_node_log
,
active_node_list_log
def
get_active_nodes
(
self
,
node_list
:
List
)
->
List
:
"""
Get active nodes for every node
Args:
node_list (List): _description_
Returns:
active_node_list_log (List): active nodes of every node. active nodes refer to
nodes generated but not deleted.
"""
active_node_list
=
[]
active_node_list_log
=
[]
user_to_last_uses
=
self
.
_get_last_usr
(
node_list
)
user_to_last_uses_no_free_var
=
self
.
_get_last_usr
(
node_list
)
delete_free_var_from_last_use
(
user_to_last_uses_no_free_var
)
for
_
,
node
in
enumerate
(
node_list
):
# log active node, only effective without chunk
self
.
_add_active_node
(
node
,
active_node_list
)
self
.
_remove_deactive_node
(
node
,
user_to_last_uses
,
active_node_list
)
active_node_list_log
.
append
(
copy
.
deepcopy
(
active_node_list
))
return
active_node_list_log
colossalai/autochunk/search_chunk.py
View file @
c04f1832
...
...
@@ -8,7 +8,7 @@ from .reorder_graph import ReorderGraph
from
.select_chunk
import
SelectChunk
from
.trace_flow
import
TraceFlow
from
.trace_indice
import
TraceIndice
from
.utils
import
get_node_shape
,
is_non_compute_node
,
is_non_compute_node_except_placeholder
from
.utils
import
get_logger
,
get_node_shape
,
is_non_compute_node
,
is_non_compute_node_except_placeholder
class
SearchChunk
(
object
):
...
...
@@ -40,14 +40,14 @@ class SearchChunk(object):
print_mem (bool): print estimated memory
"""
def
__init__
(
self
,
gm
,
max_memory
=
None
,
print_mem
=
False
)
->
None
:
self
.
gm
=
gm
def
__init__
(
self
,
gm
,
max_memory
=
None
,
print_mem
=
False
,
print_progress
=
False
)
->
None
:
self
.
print_mem
=
print_mem
self
.
print_progress
=
print_progress
self
.
trace_indice
=
TraceIndice
(
list
(
gm
.
graph
.
nodes
))
self
.
trace_indice
.
trace_indice
()
self
.
estimate_memory
=
EstimateMemory
()
self
.
_init_trace
()
self
.
trace_flow
=
TraceFlow
(
self
.
trace_indice
)
self
.
reorder_graph
=
ReorderGraph
(
self
.
trace_indice
)
self
.
estimate_memory
=
EstimateMemory
()
self
.
select_chunk
=
SelectChunk
(
self
.
trace_indice
,
self
.
estimate_memory
,
...
...
@@ -55,7 +55,33 @@ class SearchChunk(object):
max_memory
=
max_memory
,
)
def
_find_peak_node
(
self
,
mem_peak
):
def
_init_trace
(
self
)
->
None
:
"""
find the max trace range for every node
reduce the computation complexity of trace_indice
"""
# find all max ranges
active_nodes
=
self
.
estimate_memory
.
get_active_nodes
(
self
.
trace_indice
.
node_list
)
cur_node_idx
=
len
(
self
.
_get_free_var_idx
())
max_chunk_region_list
=
[]
while
True
:
max_chunk_region
=
self
.
_search_max_chunk_region
(
active_nodes
,
cur_node_idx
)
cur_node_idx
=
max_chunk_region
[
1
]
if
cur_node_idx
==
len
(
active_nodes
)
-
1
:
break
max_chunk_region_list
.
append
(
max_chunk_region
)
# nothing to limit for the first range
max_chunk_region_list
=
max_chunk_region_list
[
1
:]
max_chunk_region_list
[
0
]
=
(
0
,
max_chunk_region_list
[
0
][
1
])
# set trace range and do the trace
if
self
.
print_progress
:
get_logger
().
info
(
"AutoChunk start tracing indice"
)
self
.
trace_indice
.
set_trace_range
(
max_chunk_region_list
,
active_nodes
)
self
.
trace_indice
.
trace_indice
()
def
_find_peak_node
(
self
,
mem_peak
:
List
)
->
int
:
max_value
=
max
(
mem_peak
)
max_idx
=
mem_peak
.
index
(
max_value
)
return
max_idx
...
...
@@ -73,7 +99,7 @@ class SearchChunk(object):
free_var_idx
.
append
(
idx
)
return
free_var_idx
def
_search_max_chunk_region
(
self
,
active_node
:
List
,
peak_node
:
Node
,
chunk_regions
:
List
)
->
Tuple
:
def
_search_max_chunk_region
(
self
,
active_node
:
List
,
peak_node
_idx
:
int
,
chunk_regions
:
List
=
None
)
->
Tuple
:
"""
Search max chunk region according to peak memory node
...
...
@@ -81,7 +107,7 @@ class SearchChunk(object):
Args:
active_node (List): active node status for every node
peak_node
(Node
): peak memory node
peak_node
_idx (int
): peak memory node
idx
chunk_regions (List): chunk region infos
Returns:
...
...
@@ -97,7 +123,7 @@ class SearchChunk(object):
# from peak_node to free_var
inside_flag
=
False
chunk_region_start
=
free_var_num
for
i
in
range
(
peak_node
,
-
1
,
-
1
):
for
i
in
range
(
peak_node
_idx
,
-
1
,
-
1
):
if
active_node_num
[
i
]
<=
threshold
:
inside_flag
=
True
if
inside_flag
and
active_node_num
[
i
]
>
threshold
:
...
...
@@ -107,21 +133,23 @@ class SearchChunk(object):
# from peak_node to len-2
inside_flag
=
False
chunk_region_end
=
len
(
active_node
)
-
1
for
i
in
range
(
peak_node
,
len
(
active_node
)):
for
i
in
range
(
peak_node
_idx
,
len
(
active_node
)):
if
active_node_num
[
i
]
<=
threshold
:
inside_flag
=
True
if
inside_flag
and
active_node_num
[
i
]
>
threshold
:
chunk_region_end
=
i
break
for
i
in
chunk_regions
:
region
=
i
[
"region"
]
if
chunk_region_start
>=
region
[
0
]
and
chunk_region_end
<=
region
[
1
]:
return
None
elif
(
region
[
0
]
<=
chunk_region_start
<=
region
[
1
]
and
chunk_region_end
>
region
[
1
]):
chunk_region_start
=
region
[
1
]
+
1
elif
(
region
[
0
]
<=
chunk_region_end
<=
region
[
1
]
and
chunk_region_start
<
region
[
0
]):
chunk_region_end
=
region
[
0
]
-
1
# avoid chunk regions overlap
if
chunk_regions
is
not
None
:
for
i
in
chunk_regions
:
region
=
i
[
"region"
]
if
chunk_region_start
>=
region
[
0
]
and
chunk_region_end
<=
region
[
1
]:
return
None
elif
(
region
[
0
]
<=
chunk_region_start
<=
region
[
1
]
and
chunk_region_end
>
region
[
1
]):
chunk_region_start
=
region
[
1
]
+
1
elif
(
region
[
0
]
<=
chunk_region_end
<=
region
[
1
]
and
chunk_region_start
<
region
[
0
]):
chunk_region_end
=
region
[
0
]
-
1
return
chunk_region_start
,
chunk_region_end
def
_find_chunk_info
(
self
,
input_trace
,
output_trace
,
start_idx
,
end_idx
)
->
List
:
...
...
@@ -154,6 +182,9 @@ class SearchChunk(object):
# dim size cannot be 1
if
(
get_node_shape
(
end_node
)[
end_dim
]
==
1
or
get_node_shape
(
start_node
)[
start_dim
]
==
1
):
continue
# must have users
if
len
(
end_node
.
users
)
==
0
:
continue
# check index source align
if
not
self
.
trace_flow
.
check_index_source
(
start_dim
,
start_node
,
start_idx
,
end_dim
,
end_node
):
continue
...
...
@@ -253,6 +284,9 @@ class SearchChunk(object):
Returns:
chunk_infos (Dict)
"""
if
self
.
print_progress
:
get_logger
().
info
(
"AutoChunk start searching chunk regions"
)
chunk_infos
=
[]
(
init_mem_peak
,
...
...
@@ -272,6 +306,11 @@ class SearchChunk(object):
_
,
active_node
,
)
=
self
.
estimate_memory
.
estimate_chunk_inference_mem
(
self
.
trace_indice
.
node_list
,
chunk_infos
)
if
self
.
print_progress
:
get_logger
().
info
(
"AutoChunk find chunk region %d = (%d, %d)"
%
(
len
(
chunk_infos
),
chunk_info
[
"region"
][
0
],
chunk_info
[
"region"
][
1
]))
if
self
.
_stop_search
(
init_mem_peak
,
mem_peak
):
break
if
self
.
print_mem
:
...
...
colossalai/autochunk/trace_flow.py
View file @
c04f1832
...
...
@@ -281,7 +281,10 @@ class TraceFlow(object):
if
chunk_dim
is
not
None
:
user_source
=
self
.
trace_indice
.
_find_source_trace_from_node
(
user
)[
chunk_dim
]
if
input_node_idx
in
user_source
:
input_dict
[
user_idx
]
=
user_source
[
input_node_idx
]
if
get_node_shape
(
input_node
)[
user_source
[
input_node_idx
][
0
]]
==
1
:
input_dict
[
user_idx
]
=
[
None
]
else
:
input_dict
[
user_idx
]
=
user_source
[
input_node_idx
]
else
:
return
None
,
None
if
len
(
input_dict
)
==
0
:
...
...
colossalai/autochunk/trace_indice.py
View file @
c04f1832
...
...
@@ -33,6 +33,8 @@ class TraceIndice(object):
self
.
indice_trace_list
=
self
.
_init_indice_trace_list
()
self
.
indice_view_list
=
{}
self
.
indice_count
=
-
1
self
.
trace_range
=
[]
self
.
active_node_list
=
[]
def
_init_indice_trace_list
(
self
):
indice_trace_list
=
[]
...
...
@@ -48,6 +50,10 @@ class TraceIndice(object):
indice_trace_list
.
append
(
cur_trace
)
return
indice_trace_list
def
set_trace_range
(
self
,
trace_range
:
List
,
active_node_list
:
List
)
->
None
:
self
.
trace_range
=
trace_range
self
.
active_node_list
=
active_node_list
def
_add_indice
(
self
):
"""
Update the count and return it. To record the idx number.
...
...
@@ -493,6 +499,9 @@ class TraceIndice(object):
new_dim_num
=
sum
([
1
if
str
(
i
)
==
"None"
else
0
for
i
in
node_args
])
for
_
in
range
(
new_dim_num
):
self
.
_del_dim
(
node_idx
,
0
)
delete_dim_num
=
sum
([
1
if
str
(
i
)
==
"0"
else
0
for
i
in
node_args
])
for
_
in
range
(
delete_dim_num
):
self
.
_add_dim
(
node_idx
,
0
)
self
.
_assign_indice_as_input
(
node
,
node_idx
)
for
_
,
node_arg
in
enumerate
(
node_args
):
...
...
@@ -513,6 +522,9 @@ class TraceIndice(object):
elif
"None"
==
node_arg_str
:
self
.
_add_dim
(
node_idx
,
new_idx_count
)
new_idx_count
+=
1
elif
"0"
==
node_arg_str
:
self
.
_del_dim
(
node_idx
,
new_idx_count
)
origin_idx_count
+=
1
else
:
raise
NotImplementedError
()
...
...
@@ -596,6 +608,37 @@ class TraceIndice(object):
}
self
.
indice_view_list
[
node
]
=
view_dict
def
_clear_trace
(
self
,
node_idx
:
int
)
->
None
:
"""
clear too far trace to speed up computation
"""
trace_range
=
None
for
i
in
range
(
len
(
self
.
trace_range
)):
if
self
.
trace_range
[
i
][
1
]
==
node_idx
:
trace_range
=
(
self
.
trace_range
[
i
][
0
],
self
.
trace_range
[
i
][
1
])
break
if
self
.
trace_range
[
i
][
1
]
>
node_idx
:
break
if
trace_range
is
None
:
return
active_nodes
=
self
.
active_node_list
[
trace_range
[
0
]:
trace_range
[
1
]
+
1
]
active_nodes
=
set
(
flat_list
(
active_nodes
))
active_nodes
=
[
find_idx_by_name
(
i
,
self
.
node_list
)
for
i
in
active_nodes
]
for
i
in
range
(
trace_range
[
0
],
trace_range
[
1
]
+
1
):
trace
=
self
.
indice_trace_list
[
i
]
# clear compute
for
dim_compute
in
trace
[
"compute"
]:
for
i
in
range
(
len
(
dim_compute
)
-
1
,
-
1
,
-
1
):
if
dim_compute
[
i
]
<
trace_range
[
0
]
and
dim_compute
[
i
]
not
in
active_nodes
:
dim_compute
.
pop
(
i
)
continue
# clear source
for
dim_source
in
trace
[
"source"
]:
for
k
in
list
(
dim_source
.
keys
()):
if
k
<
trace_range
[
0
]
and
k
not
in
active_nodes
:
dim_source
.
pop
(
k
)
def
trace_indice
(
self
):
for
idx
,
node
in
enumerate
(
self
.
node_list
):
if
node
.
op
==
"placeholder"
:
...
...
@@ -655,3 +698,6 @@ class TraceIndice(object):
continue
else
:
raise
NotImplementedError
(
node
.
op
,
"op not implemented yet!"
)
# limit trace range
self
.
_clear_trace
(
idx
)
colossalai/autochunk/utils.py
View file @
c04f1832
...
...
@@ -2,6 +2,14 @@ from typing import Any, Callable, Dict, Iterable, List, Tuple
from
torch.fx.node
import
Node
from
colossalai.logging
import
get_dist_logger
logger
=
get_dist_logger
()
def
get_logger
():
return
logger
def
flat_list
(
inputs
:
Any
)
->
List
:
"""
...
...
tests/test_autochunk/test_evoformer_stack_codegen.py
0 → 100644
View file @
c04f1832
from
functools
import
partial
import
pytest
import
torch
import
torch.fx
import
torch.multiprocessing
as
mp
try
:
from
fastfold.model.nn.evoformer
import
EvoformerStack
HAS_REPO
=
True
except
:
HAS_REPO
=
False
import
colossalai
from
colossalai.core
import
global_context
as
gpc
from
colossalai.fx._compatibility
import
is_compatible_with_meta
from
colossalai.fx.codegen.activation_checkpoint_codegen
import
CODEGEN_AVAILABLE
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.utils
import
free_port
if
CODEGEN_AVAILABLE
and
is_compatible_with_meta
():
from
colossalai.autochunk.autochunk_codegen
import
AutoChunkCodeGen
from
colossalai.fx.profiler
import
MetaTensor
from
colossalai.fx.tracer.experimental
import
ColoTracer
,
symbolic_trace
def
_test_fwd
(
model
:
torch
.
nn
.
Module
,
gm
:
ColoGraphModule
,
node
,
pair
,
node_mask
,
pair_mask
):
# for memory test
# model = model.cuda()
# torch.cuda.reset_peak_memory_stats()
# now_mem = torch.cuda.memory_allocated() / 1024**2
# with torch.no_grad():
# node1 = node.clone()
# pair1 = pair.clone()
# node_mask1 = node_mask.clone()
# pair_mask1 = pair_mask.clone()
# gm(node1, pair1, node_mask1, pair_mask1, None)
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
# print("autochunk max mem:%.2f"% (new_max_mem - now_mem))
# test forward
model
=
model
.
cuda
()
with
torch
.
no_grad
():
non_fx_out
=
model
(
node
,
pair
,
node_mask
,
pair_mask
,
None
)
fx_out
=
gm
(
node
,
pair
,
node_mask
,
pair_mask
,
None
)
assert
torch
.
allclose
(
non_fx_out
[
0
],
fx_out
[
0
],
atol
=
1e-4
),
"fx_out doesn't comply with original output, diff is %.2e"
%
torch
.
mean
(
torch
.
abs
(
non_fx_out
[
0
]
-
fx_out
[
0
]))
assert
torch
.
allclose
(
non_fx_out
[
1
],
fx_out
[
1
],
atol
=
1e-4
),
"fx_out doesn't comply with original output, diff is %.2e"
%
torch
.
mean
(
torch
.
abs
(
non_fx_out
[
1
]
-
fx_out
[
1
]))
def
_build_openfold
():
model
=
EvoformerStack
(
c_m
=
256
,
c_z
=
128
,
c_hidden_msa_att
=
32
,
c_hidden_opm
=
32
,
c_hidden_mul
=
128
,
c_hidden_pair_att
=
32
,
c_s
=
384
,
no_heads_msa
=
8
,
no_heads_pair
=
4
,
no_blocks
=
2
,
# 48
transition_n
=
4
,
msa_dropout
=
0.15
,
pair_dropout
=
0.25
,
blocks_per_ckpt
=
None
,
inf
=
1000000000.0
,
eps
=
1e-08
,
clear_cache_between_blocks
=
False
,
is_multimer
=
False
,
).
eval
().
cuda
()
return
model
def
_test_evoformer_stack_codegen
(
rank
,
msa_len
,
pair_len
,
max_memory
):
# launch colossalai
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
1
,
host
=
"localhost"
,
port
=
free_port
(),
backend
=
"nccl"
,
)
# build model and input
model
=
_build_openfold
()
node
=
torch
.
randn
(
1
,
msa_len
,
pair_len
,
256
).
cuda
()
node_mask
=
torch
.
randn
(
1
,
msa_len
,
pair_len
).
cuda
()
pair
=
torch
.
randn
(
1
,
pair_len
,
pair_len
,
128
).
cuda
()
pair_mask
=
torch
.
randn
(
1
,
pair_len
,
pair_len
).
cuda
()
# trace the meta graph and setup codegen
meta_graph
=
symbolic_trace
(
model
,
meta_args
=
{
"m"
:
node
.
to
(
torch
.
device
(
"meta"
)),
"z"
:
pair
.
to
(
torch
.
device
(
"meta"
)),
"msa_mask"
:
node_mask
.
to
(
torch
.
device
(
"meta"
)),
"pair_mask"
:
pair_mask
.
to
(
torch
.
device
(
"meta"
)),
},
concrete_args
=
{
"chunk_size"
:
None
,
"_mask_trans"
:
True
,
},
)
interp
=
MetaInfoProp
(
meta_graph
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
),
MetaTensor
(
node_mask
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair_mask
,
fake_device
=
"cuda:0"
),
None
)
codegen
=
AutoChunkCodeGen
(
meta_graph
,
max_memory
=
max_memory
,
print_mem
=
False
,
print_progress
=
False
)
# trace and recompile
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
graph
=
ColoTracer
().
trace
(
model
,
meta_args
=
{
"m"
:
node
.
to
(
torch
.
device
(
"meta"
)),
"z"
:
pair
.
to
(
torch
.
device
(
"meta"
)),
"msa_mask"
:
node_mask
.
to
(
torch
.
device
(
"meta"
)),
"pair_mask"
:
pair_mask
.
to
(
torch
.
device
(
"meta"
)),
},
concrete_args
=
{
"chunk_size"
:
None
,
"_mask_trans"
:
True
,
},
)
graph
.
set_codegen
(
codegen
)
gm
=
ColoGraphModule
(
model
,
graph
,
ckpt_codegen
=
False
)
gm
.
recompile
()
# assert we have inserted chunk
code
=
graph
.
python_code
(
"self"
).
src
# print(code)
assert
"chunk_result = None; chunk_size = None;"
in
code
_test_fwd
(
model
,
gm
,
node
,
pair
,
node_mask
,
pair_mask
)
gpc
.
destroy
()
@
pytest
.
mark
.
skipif
(
not
(
CODEGEN_AVAILABLE
and
is_compatible_with_meta
()
and
HAS_REPO
),
reason
=
"torch version is lower than 1.12.0"
,
)
@
pytest
.
mark
.
parametrize
(
"max_memory"
,
[
None
,
24
,
28
,
32
])
@
pytest
.
mark
.
parametrize
(
"msa_len"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"pair_len"
,
[
64
])
def
test_evoformer_stack_codegen
(
msa_len
,
pair_len
,
max_memory
):
run_func
=
partial
(
_test_evoformer_stack_codegen
,
msa_len
=
msa_len
,
pair_len
=
pair_len
,
max_memory
=
max_memory
,
)
mp
.
spawn
(
run_func
,
nprocs
=
1
)
if
__name__
==
"__main__"
:
_test_evoformer_stack_codegen
(
0
,
32
,
64
,
None
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment