Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
27ab5240
Commit
27ab5240
authored
Jan 06, 2023
by
oahzxl
Browse files
refactor structure
parent
71e72c48
Changes
19
Show whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
29 additions
and
34 deletions
+29
-34
autochunk/chunk_codegen.py
autochunk/chunk_codegen.py
+18
-23
autochunk/evoformer/evoformer.py
autochunk/evoformer/evoformer.py
+0
-0
autochunk/evoformer/initializer.py
autochunk/evoformer/initializer.py
+0
-0
autochunk/evoformer/kernel.py
autochunk/evoformer/kernel.py
+0
-0
autochunk/evoformer/msa.py
autochunk/evoformer/msa.py
+0
-0
autochunk/evoformer/ops.py
autochunk/evoformer/ops.py
+0
-0
autochunk/evoformer/triangle.py
autochunk/evoformer/triangle.py
+0
-0
autochunk/openfold/checkpointing.py
autochunk/openfold/checkpointing.py
+0
-0
autochunk/openfold/dropout.py
autochunk/openfold/dropout.py
+0
-0
autochunk/openfold/evoformer.py
autochunk/openfold/evoformer.py
+0
-0
autochunk/openfold/msa.py
autochunk/openfold/msa.py
+0
-0
autochunk/openfold/outer_product_mean.py
autochunk/openfold/outer_product_mean.py
+0
-0
autochunk/openfold/pair_transition.py
autochunk/openfold/pair_transition.py
+0
-0
autochunk/openfold/primitives.py
autochunk/openfold/primitives.py
+0
-0
autochunk/openfold/tensor_utils.py
autochunk/openfold/tensor_utils.py
+0
-0
autochunk/openfold/triangular_attention.py
autochunk/openfold/triangular_attention.py
+0
-0
autochunk/openfold/triangular_multiplicative_update.py
autochunk/openfold/triangular_multiplicative_update.py
+0
-0
autochunk_benchmark.py
autochunk_benchmark.py
+9
-9
autochunk_test.py
autochunk_test.py
+2
-2
No files found.
chunk_codegen.py
→
autochunk/
chunk_codegen.py
View file @
27ab5240
...
@@ -1967,13 +1967,11 @@ def _replace_reshape_size(context, node_name, reshape_size_dict):
...
@@ -1967,13 +1967,11 @@ def _replace_reshape_size(context, node_name, reshape_size_dict):
def
emit_code_with_chunk
(
def
emit_code_with_chunk
(
body
,
body
,
ckpt_func
,
nodes
,
nodes
,
emit_node_func
,
emit_node_func
,
delete_unused_value_func
,
delete_unused_value_func
,
meta_nodes
,
chunk_region_search
,
meta_graph
,
chunk_infos
max_memory
=
None
,
):
):
"""Emit code with nested activation checkpoint
"""Emit code with nested activation checkpoint
When we detect some of the node.activation_checkpoint is a List, we will use
When we detect some of the node.activation_checkpoint is a List, we will use
...
@@ -1988,23 +1986,19 @@ def emit_code_with_chunk(
...
@@ -1988,23 +1986,19 @@ def emit_code_with_chunk(
"""
"""
node_list
=
list
(
nodes
)
node_list
=
list
(
nodes
)
# find the chunk regions
chunk_regions
=
[
i
[
"region"
]
for
i
in
chunk_infos
]
chunk_region_search
=
ChunkRegionSearch
(
meta_graph
,
max_memory
)
chunk_search
=
chunk_region_search
.
search_region
()
chunk_regions
=
[
i
[
"region"
]
for
i
in
chunk_search
]
chunk_starts
=
[
i
[
0
]
for
i
in
chunk_regions
]
chunk_starts
=
[
i
[
0
]
for
i
in
chunk_regions
]
chunk_ends
=
[
i
[
1
]
for
i
in
chunk_regions
]
chunk_ends
=
[
i
[
1
]
for
i
in
chunk_regions
]
chunk_inputs
=
[
i
[
"inputs"
]
for
i
in
chunk_
search
]
chunk_inputs
=
[
i
[
"inputs"
]
for
i
in
chunk_
infos
]
chunk_inputs_non_chunk
=
[
i
[
"inputs_non_chunk"
]
for
i
in
chunk_
search
]
chunk_inputs_non_chunk
=
[
i
[
"inputs_non_chunk"
]
for
i
in
chunk_
infos
]
chunk_inputs_dim
=
[
i
[
"inputs_dim"
]
for
i
in
chunk_
search
]
chunk_inputs_dim
=
[
i
[
"inputs_dim"
]
for
i
in
chunk_
infos
]
chunk_inputs_names
=
[
j
.
name
for
i
in
chunk_inputs
for
j
in
i
]
+
[
chunk_inputs_names
=
[
j
.
name
for
i
in
chunk_inputs
for
j
in
i
]
+
[
j
.
name
for
i
in
chunk_inputs_non_chunk
for
j
in
i
j
.
name
for
i
in
chunk_inputs_non_chunk
for
j
in
i
]
]
chunk_outputs
=
[
i
[
"outputs"
][
0
]
for
i
in
chunk_
search
]
chunk_outputs
=
[
i
[
"outputs"
][
0
]
for
i
in
chunk_
infos
]
chunk_outputs_dim
=
[
i
[
"outputs_dim"
]
for
i
in
chunk_
search
]
chunk_outputs_dim
=
[
i
[
"outputs_dim"
]
for
i
in
chunk_
infos
]
node_list
=
chunk_region_search
.
index_tracer
.
reorder_node_list
(
node_list
)
node_list
=
chunk_region_search
.
index_tracer
.
reorder_node_list
(
node_list
)
node_idx
=
0
node_idx
=
0
...
@@ -2022,7 +2016,7 @@ def emit_code_with_chunk(
...
@@ -2022,7 +2016,7 @@ def emit_code_with_chunk(
chunk_inputs
[
region_idx
],
chunk_inputs
[
region_idx
],
chunk_outputs
[
region_idx
],
chunk_outputs
[
region_idx
],
chunk_outputs_dim
[
region_idx
],
chunk_outputs_dim
[
region_idx
],
chunk_
search
[
region_idx
][
"chunk_size"
],
chunk_
infos
[
region_idx
][
"chunk_size"
],
)
)
)
)
...
@@ -2041,14 +2035,14 @@ def emit_code_with_chunk(
...
@@ -2041,14 +2035,14 @@ def emit_code_with_chunk(
# ones like
# ones like
if
"ones_like"
in
node
.
name
:
if
"ones_like"
in
node
.
name
:
meta_node
=
chunk_region_search
.
index_tracer
.
node_list
[
node_idx
]
meta_node
=
chunk_region_search
.
index_tracer
.
node_list
[
node_idx
]
chunk_dim
=
chunk_
search
[
region_idx
][
"node_chunk_dim"
][
meta_node
][
chunk_dim
=
chunk_
infos
[
region_idx
][
"node_chunk_dim"
][
meta_node
][
"chunk_dim"
"chunk_dim"
]
]
if
_get_node_shape
(
meta_node
)[
chunk_dim
]
!=
1
:
if
_get_node_shape
(
meta_node
)[
chunk_dim
]
!=
1
:
source_node
=
meta_node
.
args
[
0
].
args
[
0
]
source_node
=
meta_node
.
args
[
0
].
args
[
0
]
if
(
if
(
source_node
not
in
chunk_
search
[
region_idx
][
"node_chunk_dim"
]
source_node
not
in
chunk_
infos
[
region_idx
][
"node_chunk_dim"
]
or
chunk_
search
[
region_idx
][
"node_chunk_dim"
][
source_node
][
or
chunk_
infos
[
region_idx
][
"node_chunk_dim"
][
source_node
][
"chunk_dim"
"chunk_dim"
]
]
is
None
is
None
...
@@ -2060,7 +2054,7 @@ def emit_code_with_chunk(
...
@@ -2060,7 +2054,7 @@ def emit_code_with_chunk(
body
[
-
1
],
node
.
args
[
0
].
name
,
node
.
args
[
0
].
name
+
chunk_slice
body
[
-
1
],
node
.
args
[
0
].
name
,
node
.
args
[
0
].
name
+
chunk_slice
)
)
body
[
-
1
]
=
_replace_reshape_size
(
body
[
-
1
]
=
_replace_reshape_size
(
body
[
-
1
],
node
.
name
,
chunk_
search
[
region_idx
][
"reshape_size"
]
body
[
-
1
],
node
.
name
,
chunk_
infos
[
region_idx
][
"reshape_size"
]
)
)
body
[
-
1
]
=
" "
+
body
[
-
1
]
body
[
-
1
]
=
" "
+
body
[
-
1
]
delete_unused_value_func
(
node
,
body
,
chunk_inputs_names
)
delete_unused_value_func
(
node
,
body
,
chunk_inputs_names
)
...
@@ -2092,6 +2086,9 @@ if CODEGEN_AVAILABLE:
...
@@ -2092,6 +2086,9 @@ if CODEGEN_AVAILABLE:
self
.
meta_graph
=
meta_graph
self
.
meta_graph
=
meta_graph
self
.
max_memory
=
max_memory
self
.
max_memory
=
max_memory
self
.
meta_node
=
list
(
meta_graph
.
graph
.
nodes
)
self
.
meta_node
=
list
(
meta_graph
.
graph
.
nodes
)
# find the chunk regions
self
.
chunk_region_search
=
ChunkRegionSearch
(
meta_graph
,
max_memory
)
self
.
chunk_infos
=
self
.
chunk_region_search
.
search_region
()
def
_gen_python_code
(
def
_gen_python_code
(
self
,
nodes
,
root_module
:
str
,
namespace
:
_Namespace
self
,
nodes
,
root_module
:
str
,
namespace
:
_Namespace
...
@@ -2323,13 +2320,11 @@ if CODEGEN_AVAILABLE:
...
@@ -2323,13 +2320,11 @@ if CODEGEN_AVAILABLE:
# will use nested type of activation checkpoint codegen
# will use nested type of activation checkpoint codegen
emit_code_with_chunk
(
emit_code_with_chunk
(
body
,
body
,
ckpt_func
,
nodes
,
nodes
,
emit_node
,
emit_node
,
delete_unused_values
,
delete_unused_values
,
self
.
meta_node
,
self
.
chunk_region_search
,
self
.
meta_graph
,
self
.
chunk_infos
self
.
max_memory
,
)
)
if
len
(
body
)
==
0
:
if
len
(
body
)
==
0
:
...
...
evoformer/evoformer.py
→
autochunk/
evoformer/evoformer.py
View file @
27ab5240
File moved
evoformer/initializer.py
→
autochunk/
evoformer/initializer.py
View file @
27ab5240
File moved
evoformer/kernel.py
→
autochunk/
evoformer/kernel.py
View file @
27ab5240
File moved
evoformer/msa.py
→
autochunk/
evoformer/msa.py
View file @
27ab5240
File moved
evoformer/ops.py
→
autochunk/
evoformer/ops.py
View file @
27ab5240
File moved
evoformer/triangle.py
→
autochunk/
evoformer/triangle.py
View file @
27ab5240
File moved
openfold/checkpointing.py
→
autochunk/
openfold/checkpointing.py
View file @
27ab5240
File moved
openfold/dropout.py
→
autochunk/
openfold/dropout.py
View file @
27ab5240
File moved
openfold/evoformer.py
→
autochunk/
openfold/evoformer.py
View file @
27ab5240
File moved
openfold/msa.py
→
autochunk/
openfold/msa.py
View file @
27ab5240
File moved
openfold/outer_product_mean.py
→
autochunk/
openfold/outer_product_mean.py
View file @
27ab5240
File moved
openfold/pair_transition.py
→
autochunk/
openfold/pair_transition.py
View file @
27ab5240
File moved
openfold/primitives.py
→
autochunk/
openfold/primitives.py
View file @
27ab5240
File moved
openfold/tensor_utils.py
→
autochunk/
openfold/tensor_utils.py
View file @
27ab5240
File moved
openfold/triangular_attention.py
→
autochunk/
openfold/triangular_attention.py
View file @
27ab5240
File moved
openfold/triangular_multiplicative_update.py
→
autochunk/
openfold/triangular_multiplicative_update.py
View file @
27ab5240
File moved
autochunk_benchmark.py
View file @
27ab5240
...
@@ -3,13 +3,13 @@ import time
...
@@ -3,13 +3,13 @@ import time
import
torch
import
torch
import
torch.fx
import
torch.fx
from
chunk_codegen
import
ChunkCodeGen
from
autochunk.
chunk_codegen
import
ChunkCodeGen
from
colossalai.fx
import
ColoTracer
from
colossalai.fx
import
ColoTracer
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.fx.profiler
import
MetaTensor
from
colossalai.fx.profiler
import
MetaTensor
from
evoformer.evoformer
import
evoformer_base
from
autochunk.
evoformer.evoformer
import
evoformer_base
from
openfold.evoformer
import
EvoformerBlock
from
autochunk.
openfold.evoformer
import
EvoformerBlock
def
_benchmark_evoformer
(
model
:
torch
.
nn
.
Module
,
node
,
pair
,
title
,
chunk_size
=
None
):
def
_benchmark_evoformer
(
model
:
torch
.
nn
.
Module
,
node
,
pair
,
title
,
chunk_size
=
None
):
...
@@ -94,23 +94,23 @@ def _build_openfold():
...
@@ -94,23 +94,23 @@ def _build_openfold():
def
benchmark_evoformer
():
def
benchmark_evoformer
():
# init data and model
# init data and model
msa_len
=
256
msa_len
=
256
pair_len
=
2048
pair_len
=
1024
node
=
torch
.
randn
(
1
,
msa_len
,
pair_len
,
256
).
cuda
()
node
=
torch
.
randn
(
1
,
msa_len
,
pair_len
,
256
).
cuda
()
pair
=
torch
.
randn
(
1
,
pair_len
,
pair_len
,
128
).
cuda
()
pair
=
torch
.
randn
(
1
,
pair_len
,
pair_len
,
128
).
cuda
()
model
=
evoformer_base
().
cuda
()
model
=
evoformer_base
().
cuda
()
# build autochunk model
# build autochunk model
max_memory
=
10000
# MB fit memory mode
#
max_memory = 10000 # MB fit memory mode
#
max_memory = None # min memory mode
max_memory
=
None
# min memory mode
autochunk
=
_build_autochunk
(
evoformer_base
().
cuda
(),
max_memory
,
node
,
pair
)
autochunk
=
_build_autochunk
(
evoformer_base
().
cuda
(),
max_memory
,
node
,
pair
)
# build openfold
# build openfold
chunk_size
=
64
chunk_size
=
64
openfold
=
_build_openfold
()
#
openfold = _build_openfold()
# benchmark
# benchmark
_benchmark_evoformer
(
model
,
node
,
pair
,
"base"
)
#
_benchmark_evoformer(model, node, pair, "base")
_benchmark_evoformer
(
openfold
,
node
,
pair
,
"openfold"
,
chunk_size
=
chunk_size
)
#
_benchmark_evoformer(openfold, node, pair, "openfold", chunk_size=chunk_size)
_benchmark_evoformer
(
autochunk
,
node
,
pair
,
"autochunk"
)
_benchmark_evoformer
(
autochunk
,
node
,
pair
,
"autochunk"
)
...
...
chunk_
codegen_run
.py
→
auto
chunk_
test
.py
View file @
27ab5240
...
@@ -12,8 +12,8 @@ from colossalai.core import global_context as gpc
...
@@ -12,8 +12,8 @@ from colossalai.core import global_context as gpc
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
,
TensorMetadata
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
,
TensorMetadata
from
colossalai.fx.profiler
import
MetaTensor
from
colossalai.fx.profiler
import
MetaTensor
from
evoformer.evoformer
import
evoformer_base
from
autochunk.
evoformer.evoformer
import
evoformer_base
from
chunk_codegen
import
ChunkCodeGen
from
autochunk.
chunk_codegen
import
ChunkCodeGen
with_codegen
=
True
with_codegen
=
True
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment