Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
93f62dd1
Unverified
Commit
93f62dd1
authored
Jan 10, 2023
by
Jiarui Fang
Committed by
GitHub
Jan 10, 2023
Browse files
[autochunk] add autochunk feature
parents
dddacd2d
61fdd346
Changes
27
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4400 additions
and
0 deletions
+4400
-0
colossalai/autochunk/autochunk_codegen.py
colossalai/autochunk/autochunk_codegen.py
+593
-0
colossalai/autochunk/estimate_memory.py
colossalai/autochunk/estimate_memory.py
+328
-0
colossalai/autochunk/reorder_graph.py
colossalai/autochunk/reorder_graph.py
+117
-0
colossalai/autochunk/search_chunk.py
colossalai/autochunk/search_chunk.py
+319
-0
colossalai/autochunk/select_chunk.py
colossalai/autochunk/select_chunk.py
+224
-0
colossalai/autochunk/trace_flow.py
colossalai/autochunk/trace_flow.py
+420
-0
colossalai/autochunk/trace_indice.py
colossalai/autochunk/trace_indice.py
+559
-0
colossalai/autochunk/utils.py
colossalai/autochunk/utils.py
+95
-0
tests/test_autochunk/benchmark_autochunk.py
tests/test_autochunk/benchmark_autochunk.py
+122
-0
tests/test_autochunk/evoformer/evoformer.py
tests/test_autochunk/evoformer/evoformer.py
+59
-0
tests/test_autochunk/evoformer/initializer.py
tests/test_autochunk/evoformer/initializer.py
+29
-0
tests/test_autochunk/evoformer/kernel.py
tests/test_autochunk/evoformer/kernel.py
+19
-0
tests/test_autochunk/evoformer/msa.py
tests/test_autochunk/evoformer/msa.py
+95
-0
tests/test_autochunk/evoformer/ops.py
tests/test_autochunk/evoformer/ops.py
+176
-0
tests/test_autochunk/evoformer/triangle.py
tests/test_autochunk/evoformer/triangle.py
+192
-0
tests/test_autochunk/openfold/checkpointing.py
tests/test_autochunk/openfold/checkpointing.py
+84
-0
tests/test_autochunk/openfold/dropout.py
tests/test_autochunk/openfold/dropout.py
+78
-0
tests/test_autochunk/openfold/evoformer.py
tests/test_autochunk/openfold/evoformer.py
+431
-0
tests/test_autochunk/openfold/msa.py
tests/test_autochunk/openfold/msa.py
+331
-0
tests/test_autochunk/openfold/outer_product_mean.py
tests/test_autochunk/openfold/outer_product_mean.py
+129
-0
No files found.
colossalai/autochunk/autochunk_codegen.py
0 → 100644
View file @
93f62dd1
This diff is collapsed.
Click to expand it.
colossalai/autochunk/estimate_memory.py
0 → 100644
View file @
93f62dd1
import
copy
from
typing
import
Any
,
Callable
,
Dict
,
Iterable
,
List
,
Tuple
import
torch
from
torch.fx.node
import
Node
,
map_arg
from
colossalai.fx.profiler
import
activation_size
,
parameter_size
from
.utils
import
(
delete_free_var_from_last_use
,
find_idx_by_name
,
get_node_shape
,
is_non_compute_node_except_placeholder
,
)
class
EstimateMemory
(
object
):
"""
Estimate memory with chunk
"""
def
__init__
(
self
)
->
None
:
pass
def
_get_meta_node_size
(
self
,
x
):
x
=
x
.
meta
[
"tensor_meta"
]
x
=
x
.
numel
*
torch
.
tensor
([],
dtype
=
x
.
dtype
).
element_size
()
return
x
def
_get_output_node
(
self
,
n
):
out_size
=
activation_size
(
n
.
meta
[
"fwd_out"
])
out_node
=
[
n
.
name
]
if
out_size
>
0
else
[]
return
out_size
,
out_node
def
_get_output_node_size
(
self
,
n
):
return
self
.
_get_output_node
(
n
)[
0
]
def
_add_active_node
(
self
,
n
,
active_list
):
new_active
=
self
.
_get_output_node
(
n
)[
1
]
if
n
.
op
==
"placeholder"
:
new_active
.
append
(
n
.
name
)
for
i
in
new_active
:
if
i
not
in
active_list
:
active_list
.
append
(
i
)
def
_get_delete_node
(
self
,
user
,
user_to_last_uses
,
to_keep
=
None
):
delete_size
=
0
delete_node
=
[]
if
user
.
op
not
in
(
"output"
,):
nodes_to_delete
=
user_to_last_uses
.
get
(
user
,
[])
if
to_keep
is
not
None
:
keep_list
=
[]
for
n
in
nodes_to_delete
:
if
n
.
name
in
to_keep
:
keep_list
.
append
(
n
)
for
n
in
keep_list
:
if
n
in
nodes_to_delete
:
nodes_to_delete
.
remove
(
n
)
if
len
(
nodes_to_delete
):
out_node
=
[
self
.
_get_output_node
(
i
)
for
i
in
nodes_to_delete
]
delete_size
=
sum
([
i
[
0
]
for
i
in
out_node
])
for
i
in
range
(
len
(
out_node
)):
if
out_node
[
i
][
0
]
>
0
:
delete_node
.
append
(
out_node
[
i
][
1
][
0
])
elif
nodes_to_delete
[
i
].
op
==
"placeholder"
:
delete_node
.
append
(
nodes_to_delete
[
i
].
name
)
# elif any(j in nodes_to_delete[i].name for j in ['transpose', 'permute', 'view']):
# delete_node.append(nodes_to_delete[i].name)
return
delete_size
,
delete_node
def
_get_delete_node_size
(
self
,
user
,
user_to_last_uses
,
to_keep
):
return
self
.
_get_delete_node
(
user
,
user_to_last_uses
,
to_keep
)[
0
]
def
_remove_deactive_node
(
self
,
user
,
user_to_last_uses
,
active_list
):
delete_node
=
self
.
_get_delete_node
(
user
,
user_to_last_uses
)[
1
]
for
i
in
delete_node
:
if
i
in
active_list
:
active_list
.
remove
(
i
)
def
_get_chunk_inputs_size
(
self
,
chunk_inputs
,
chunk_inputs_non_chunk
,
node_list
,
chunk_end_idx
):
nodes_to_delete
=
[]
for
chunk_input
in
chunk_inputs
+
chunk_inputs_non_chunk
:
chunk_input_users
=
chunk_input
.
users
.
keys
()
chunk_input_users_idx
=
[
find_idx_by_name
(
i
.
name
,
node_list
)
for
i
in
chunk_input_users
]
if
all
(
i
<=
chunk_end_idx
for
i
in
chunk_input_users_idx
):
if
chunk_input
not
in
nodes_to_delete
:
nodes_to_delete
.
append
(
chunk_input
)
out_node
=
[
self
.
_get_output_node
(
i
)
for
i
in
nodes_to_delete
]
delete_size
=
sum
([
i
[
0
]
for
i
in
out_node
])
return
delete_size
def
_get_last_usr
(
self
,
nodes
):
node_to_last_use
:
Dict
[
Node
,
Node
]
=
{}
user_to_last_uses
:
Dict
[
Node
,
List
[
Node
]]
=
{}
def
register_last_uses
(
n
:
Node
,
user
:
Node
):
if
n
not
in
node_to_last_use
:
node_to_last_use
[
n
]
=
user
user_to_last_uses
.
setdefault
(
user
,
[]).
append
(
n
)
for
node
in
reversed
(
nodes
):
map_arg
(
node
.
args
,
lambda
n
:
register_last_uses
(
n
,
node
))
map_arg
(
node
.
kwargs
,
lambda
n
:
register_last_uses
(
n
,
node
))
return
user_to_last_uses
def
_get_contiguous_memory
(
self
,
node
,
not_contiguous_list
,
delete
=
False
):
mem
=
0
not_contiguous_ops
=
[
"permute"
]
inherit_contiguous_ops
=
[
"transpose"
,
"view"
]
if
node
.
op
==
"call_function"
and
any
(
n
in
node
.
name
for
n
in
[
"matmul"
,
"reshape"
]
):
for
n
in
node
.
args
:
if
n
in
not_contiguous_list
:
# matmul won't change origin tensor, but create a tmp copy
mem
+=
self
.
_get_output_node_size
(
n
)
elif
node
.
op
==
"call_module"
:
for
n
in
node
.
args
:
if
n
in
not_contiguous_list
:
# module will just make origin tensor to contiguous
if
delete
:
not_contiguous_list
.
remove
(
n
)
elif
node
.
op
==
"call_method"
and
any
(
i
in
node
.
name
for
i
in
not_contiguous_ops
):
if
node
not
in
not_contiguous_list
:
not_contiguous_list
.
append
(
node
)
return
mem
def
_get_chunk_ratio
(
self
,
node
,
chunk_node_dim
,
chunk_size
):
if
node
not
in
chunk_node_dim
:
return
1.0
node_shape
=
get_node_shape
(
node
)
chunk_dim
=
chunk_node_dim
[
node
][
"chunk_dim"
]
if
chunk_dim
is
None
:
return
1.0
else
:
return
float
(
chunk_size
)
/
node_shape
[
chunk_dim
]
def
_get_chunk_delete_node_size
(
self
,
user
,
user_to_last_uses
,
chunk_ratio
,
chunk_inputs_names
):
# if any(j in user.name for j in ['transpose', 'permute', 'view']):
# return 0
if
user
.
op
in
(
"placeholder"
,
"output"
):
return
0
nodes_to_delete
=
user_to_last_uses
.
get
(
user
,
[])
delete_size
=
0
for
n
in
nodes_to_delete
:
if
n
.
name
in
chunk_inputs_names
:
continue
delete_size
+=
self
.
_get_output_node_size
(
n
)
*
chunk_ratio
return
delete_size
def
_print_mem_log
(
self
,
log
,
nodes
,
title
=
None
):
if
title
:
print
(
title
)
for
idx
,
(
l
,
n
)
in
enumerate
(
zip
(
log
,
nodes
)):
print
(
"%s:%.2f
\t
"
%
(
n
.
name
,
l
),
end
=
""
)
if
(
idx
+
1
)
%
3
==
0
:
print
(
""
)
print
(
"
\n
"
)
def
_print_compute_op_mem_log
(
self
,
log
,
nodes
,
title
=
None
):
if
title
:
print
(
title
)
for
idx
,
(
l
,
n
)
in
enumerate
(
zip
(
log
,
nodes
)):
if
n
.
op
in
[
"placeholder"
,
"get_attr"
,
"output"
]:
continue
if
any
(
i
in
n
.
name
for
i
in
[
"getitem"
,
"getattr"
]):
continue
print
(
"%s:%.2f
\t
"
%
(
n
.
name
,
l
),
end
=
""
)
if
(
idx
+
1
)
%
3
==
0
:
print
(
""
)
print
(
"
\n
"
)
def
estimate_chunk_inference_mem
(
self
,
node_list
:
List
,
chunk_infos
=
None
,
print_mem
=
False
,
):
"""
Estimate inference memory with chunk
Args:
node_list (List): _description_
chunk_infos (Dict): Chunk information. Defaults to None.
print_mem (bool): Wether to print peak memory of every node. Defaults to False.
Returns:
act_memory_peak_log (List): peak memory of every node
act_memory_after_node_log (List): memory after excuting every node
active_node_list_log (List): active nodes of every node. active nodes refer to
nodes generated but not deleted.
"""
act_memory
=
0.0
act_memory_peak_log
=
[]
act_memory_after_node_log
=
[]
active_node_list
=
[]
active_node_list_log
=
[]
not_contiguous_list
=
[]
user_to_last_uses
=
self
.
_get_last_usr
(
node_list
)
user_to_last_uses_no_free_var
=
self
.
_get_last_usr
(
node_list
)
delete_free_var_from_last_use
(
user_to_last_uses_no_free_var
)
use_chunk
=
True
if
chunk_infos
is
not
None
else
False
chunk_within
=
False
chunk_region_idx
=
None
chunk_ratio
=
1
# use it to estimate chunk mem
chunk_inputs_names
=
[]
if
use_chunk
:
chunk_regions
=
[
i
[
"region"
]
for
i
in
chunk_infos
]
chunk_starts
=
[
i
[
0
]
for
i
in
chunk_regions
]
chunk_ends
=
[
i
[
1
]
for
i
in
chunk_regions
]
chunk_inputs
=
[
i
[
"inputs"
]
for
i
in
chunk_infos
]
chunk_inputs_non_chunk
=
[
i
[
"inputs_non_chunk"
]
for
i
in
chunk_infos
]
chunk_inputs_names
=
[
j
.
name
for
i
in
chunk_inputs
for
j
in
i
]
+
[
j
.
name
for
i
in
chunk_inputs_non_chunk
for
j
in
i
]
chunk_outputs
=
[
i
[
"outputs"
][
0
]
for
i
in
chunk_infos
]
chunk_node_dim
=
[
i
[
"node_chunk_dim"
]
for
i
in
chunk_infos
]
chunk_sizes
=
[
i
[
"chunk_size"
]
if
"chunk_size"
in
i
else
1
for
i
in
chunk_infos
]
for
idx
,
node
in
enumerate
(
node_list
):
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
if
use_chunk
and
idx
in
chunk_starts
:
chunk_within
=
True
chunk_region_idx
=
chunk_starts
.
index
(
idx
)
act_memory
+=
self
.
_get_output_node_size
(
chunk_outputs
[
chunk_region_idx
]
)
/
(
1024
**
2
)
# determine chunk ratio for current node
if
chunk_within
:
chunk_ratio
=
self
.
_get_chunk_ratio
(
node
,
chunk_node_dim
[
chunk_region_idx
],
chunk_sizes
[
chunk_region_idx
],
)
# if node is placeholder, just add the size of the node
if
node
.
op
==
"placeholder"
:
act_memory
+=
self
.
_get_meta_node_size
(
node
)
*
chunk_ratio
/
(
1024
**
2
)
act_memory_peak_log
.
append
(
act_memory
)
# skip output
elif
node
.
op
==
"output"
:
continue
# no change for non compute node
elif
is_non_compute_node_except_placeholder
(
node
):
act_memory_peak_log
.
append
(
act_memory
)
# node is a compute op
# calculate tmp, output node and delete node memory
else
:
# forward memory
# TODO: contiguous_memory still not accurate for matmul, view, reshape and transpose
act_memory
+=
(
self
.
_get_contiguous_memory
(
node
,
not_contiguous_list
)
*
chunk_ratio
/
(
1024
**
2
)
)
act_memory
+=
(
self
.
_get_output_node_size
(
node
)
*
chunk_ratio
/
(
1024
**
2
)
)
# record max act memory
act_memory_peak_log
.
append
(
act_memory
)
# delete useless memory
act_memory
-=
(
self
.
_get_contiguous_memory
(
node
,
not_contiguous_list
,
delete
=
True
)
*
chunk_ratio
/
(
1024
**
2
)
)
# delete unused vars not in chunk_input_list
# we can't delete input nodes until chunk ends
if
chunk_within
:
act_memory
-=
self
.
_get_chunk_delete_node_size
(
node
,
user_to_last_uses_no_free_var
,
chunk_ratio
,
chunk_inputs_names
,
)
/
(
1024
**
2
)
else
:
act_memory
-=
self
.
_get_delete_node_size
(
node
,
user_to_last_uses_no_free_var
,
chunk_inputs_names
)
/
(
1024
**
2
)
# log active node, only effective without chunk
self
.
_add_active_node
(
node
,
active_node_list
)
self
.
_remove_deactive_node
(
node
,
user_to_last_uses
,
active_node_list
)
# if node in chunk end nodes, restore chunk settings
if
use_chunk
and
idx
in
chunk_ends
:
act_memory
-=
(
self
.
_get_output_node_size
(
node
)
*
chunk_ratio
/
(
1024
**
2
)
)
act_memory
-=
self
.
_get_chunk_inputs_size
(
chunk_inputs
[
chunk_region_idx
],
chunk_inputs_non_chunk
[
chunk_region_idx
],
node_list
,
chunk_regions
[
chunk_region_idx
][
1
],
)
/
(
1024
**
2
)
chunk_within
=
False
chunk_ratio
=
1
chunk_region_idx
=
None
act_memory_after_node_log
.
append
(
act_memory
)
active_node_list_log
.
append
(
copy
.
deepcopy
(
active_node_list
))
if
print_mem
:
print
(
"with chunk"
if
use_chunk
else
"without chunk"
)
# self._print_mem_log(act_memory_peak_log, node_list, "peak")
# self._print_mem_log(act_memory_after_node_log, node_list, "after")
self
.
_print_compute_op_mem_log
(
act_memory_peak_log
,
node_list
,
"peak"
)
# self._print_compute_op_mem_log(
# act_memory_after_node_log, node_list, "after"
# )
# param_memory = parameter_size(gm)
# all_memory = act_memory + param_memory
return
act_memory_peak_log
,
act_memory_after_node_log
,
active_node_list_log
colossalai/autochunk/reorder_graph.py
0 → 100644
View file @
93f62dd1
from
.trace_indice
import
TraceIndice
from
.utils
import
find_idx_by_name
class
ReorderGraph
(
object
):
"""
Reorder node list and indice trace list
"""
def
__init__
(
self
,
trace_indice
:
TraceIndice
)
->
None
:
self
.
trace_indice
=
trace_indice
self
.
all_reorder_map
=
{
i
:
i
for
i
in
range
(
len
(
self
.
trace_indice
.
indice_trace_list
))
}
def
_get_reorder_map
(
self
,
chunk_info
):
reorder_map
=
{
i
:
i
for
i
in
range
(
len
(
self
.
trace_indice
.
node_list
))}
chunk_region_start
=
chunk_info
[
"region"
][
0
]
chunk_region_end
=
chunk_info
[
"region"
][
1
]
chunk_prepose_nodes
=
chunk_info
[
"args"
][
"prepose_nodes"
]
chunk_prepose_nodes_idx
=
[
find_idx_by_name
(
i
.
name
,
self
.
trace_indice
.
node_list
)
for
i
in
chunk_prepose_nodes
]
# put prepose nodes ahead
for
idx
,
n
in
enumerate
(
chunk_prepose_nodes
):
n_idx
=
chunk_prepose_nodes_idx
[
idx
]
reorder_map
[
n_idx
]
=
chunk_region_start
+
idx
# put other nodes after prepose nodes
for
n
in
self
.
trace_indice
.
node_list
[
chunk_region_start
:
chunk_region_end
+
1
]:
if
n
in
chunk_prepose_nodes
:
continue
n_idx
=
find_idx_by_name
(
n
.
name
,
self
.
trace_indice
.
node_list
)
pos
=
sum
([
n_idx
<
i
for
i
in
chunk_prepose_nodes_idx
])
reorder_map
[
n_idx
]
=
n_idx
+
pos
return
reorder_map
def
_reorder_chunk_info
(
self
,
chunk_info
,
reorder_map
):
# update chunk info
chunk_info
[
"region"
]
=
(
chunk_info
[
"region"
][
0
]
+
len
(
chunk_info
[
"args"
][
"prepose_nodes"
]),
chunk_info
[
"region"
][
1
],
)
new_inputs_dim
=
[]
for
idx
,
input_dim
in
enumerate
(
chunk_info
[
"inputs_dim"
]):
new_input_dim
=
{}
for
k
,
v
in
input_dim
.
items
():
new_input_dim
[
reorder_map
[
k
]]
=
v
new_inputs_dim
.
append
(
new_input_dim
)
chunk_info
[
"inputs_dim"
]
=
new_inputs_dim
return
chunk_info
def
_update_all_reorder_map
(
self
,
reorder_map
):
for
origin_idx
,
map_idx
in
self
.
all_reorder_map
.
items
():
self
.
all_reorder_map
[
origin_idx
]
=
reorder_map
[
map_idx
]
def
_reorder_self_node_list
(
self
,
reorder_map
):
new_node_list
=
[
None
for
_
in
range
(
len
(
self
.
trace_indice
.
node_list
))]
for
old_idx
,
new_idx
in
reorder_map
.
items
():
new_node_list
[
new_idx
]
=
self
.
trace_indice
.
node_list
[
old_idx
]
self
.
trace_indice
.
node_list
=
new_node_list
def
_reorder_idx_trace
(
self
,
reorder_map
):
# reorder list
new_idx_trace_list
=
[
None
for
_
in
range
(
len
(
self
.
trace_indice
.
indice_trace_list
))
]
for
old_idx
,
new_idx
in
reorder_map
.
items
():
new_idx_trace_list
[
new_idx
]
=
self
.
trace_indice
.
indice_trace_list
[
old_idx
]
self
.
trace_indice
.
indice_trace_list
=
new_idx_trace_list
# update compute
for
idx_trace
in
self
.
trace_indice
.
indice_trace_list
:
compute
=
idx_trace
[
"compute"
]
for
dim_compute
in
compute
:
for
idx
,
i
in
enumerate
(
dim_compute
):
dim_compute
[
idx
]
=
reorder_map
[
i
]
# update source
for
idx_trace
in
self
.
trace_indice
.
indice_trace_list
:
source
=
idx_trace
[
"source"
]
for
dim_idx
,
dim_source
in
enumerate
(
source
):
new_dim_source
=
{}
for
k
,
v
in
dim_source
.
items
():
new_dim_source
[
reorder_map
[
k
]]
=
v
source
[
dim_idx
]
=
new_dim_source
def
reorder_all
(
self
,
chunk_info
):
if
chunk_info
is
None
:
return
chunk_info
if
len
(
chunk_info
[
"args"
][
"prepose_nodes"
])
==
0
:
return
chunk_info
reorder_map
=
self
.
_get_reorder_map
(
chunk_info
)
self
.
_update_all_reorder_map
(
reorder_map
)
self
.
_reorder_idx_trace
(
reorder_map
)
self
.
_reorder_self_node_list
(
reorder_map
)
chunk_info
=
self
.
_reorder_chunk_info
(
chunk_info
,
reorder_map
)
return
chunk_info
def
reorder_node_list
(
self
,
node_list
):
new_node_list
=
[
None
for
_
in
range
(
len
(
node_list
))]
for
old_idx
,
new_idx
in
self
.
all_reorder_map
.
items
():
new_node_list
[
new_idx
]
=
node_list
[
old_idx
]
return
new_node_list
def
tmp_reorder
(
self
,
node_list
,
chunk_info
):
if
len
(
chunk_info
[
"args"
][
"prepose_nodes"
])
==
0
:
return
node_list
,
chunk_info
reorder_map
=
self
.
_get_reorder_map
(
chunk_info
)
# new tmp node list
new_node_list
=
[
None
for
_
in
range
(
len
(
node_list
))]
for
old_idx
,
new_idx
in
reorder_map
.
items
():
new_node_list
[
new_idx
]
=
node_list
[
old_idx
]
chunk_info
=
self
.
_reorder_chunk_info
(
chunk_info
,
reorder_map
)
return
new_node_list
,
chunk_info
colossalai/autochunk/search_chunk.py
0 → 100644
View file @
93f62dd1
import
copy
from
typing
import
Dict
,
List
,
Tuple
from
torch.fx.node
import
Node
from
.estimate_memory
import
EstimateMemory
from
.reorder_graph
import
ReorderGraph
from
.select_chunk
import
SelectChunk
from
.trace_flow
import
TraceFlow
from
.trace_indice
import
TraceIndice
from
.utils
import
(
get_node_shape
,
is_non_compute_node
,
is_non_compute_node_except_placeholder
,
)
class
SearchChunk
(
object
):
"""
This is the core class for AutoChunk.
It defines the framework of the strategy of AutoChunk.
Chunks will be selected one by one utill search stops.
The chunk search is as follows:
1. find the peak memory node
2. find the max chunk region according to the peak memory node
3. find all possible chunk regions in the max chunk region
4. find the best chunk region for current status
5. goto 1
Attributes:
gm: graph model
print_mem (bool): print estimated memory
trace_index: trace the flow of every dim of every node to find all free dims
trace_flow: determine the region chunk strategy
reorder_graph: reorder nodes to improve chunk efficiency
estimate_memory: estimate memory with chunk
select_chunk: select the best chunk region
Args:
gm: graph model
max_memory (int): max memory in MB
print_mem (bool): print estimated memory
"""
def
__init__
(
self
,
gm
,
max_memory
=
None
,
print_mem
=
False
)
->
None
:
self
.
gm
=
gm
self
.
print_mem
=
print_mem
self
.
trace_indice
=
TraceIndice
(
list
(
gm
.
graph
.
nodes
))
self
.
trace_indice
.
trace_indice
()
self
.
trace_flow
=
TraceFlow
(
self
.
trace_indice
)
self
.
reorder_graph
=
ReorderGraph
(
self
.
trace_indice
)
self
.
estimate_memory
=
EstimateMemory
()
self
.
select_chunk
=
SelectChunk
(
self
.
trace_indice
,
self
.
estimate_memory
,
self
.
reorder_graph
,
max_memory
=
max_memory
,
)
def
_find_peak_node
(
self
,
mem_peak
):
max_value
=
max
(
mem_peak
)
max_idx
=
mem_peak
.
index
(
max_value
)
return
max_idx
def
_get_free_var_idx
(
self
)
->
List
:
"""
Get free var index
Returns:
free_var_idx (List): all indexs of free vars
"""
free_var_idx
=
[]
for
idx
,
n
in
enumerate
(
self
.
trace_indice
.
node_list
):
if
n
.
op
==
"placeholder"
:
free_var_idx
.
append
(
idx
)
return
free_var_idx
def
_search_max_chunk_region
(
self
,
active_node
:
List
,
peak_node
:
Node
,
chunk_regions
:
List
)
->
Tuple
:
"""
Search max chunk region according to peak memory node
Chunk region starts extending from the peak node, stops where free var num is min
Args:
active_node (List): active node status for every node
peak_node (Node): peak memory node
chunk_regions (List): chunk region infos
Returns:
chunk_region_start (int)
chunk_region_end (int)
"""
free_vars
=
self
.
_get_free_var_idx
()
free_var_num
=
len
(
free_vars
)
active_node_num
=
[
len
(
i
)
for
i
in
active_node
]
min_active_node_num
=
min
(
active_node_num
[
free_var_num
:])
threshold
=
max
(
free_var_num
,
min_active_node_num
)
# from peak_node to free_var
inside_flag
=
False
chunk_region_start
=
free_var_num
for
i
in
range
(
peak_node
,
-
1
,
-
1
):
if
active_node_num
[
i
]
<=
threshold
:
inside_flag
=
True
if
inside_flag
and
active_node_num
[
i
]
>
threshold
:
chunk_region_start
=
i
+
1
break
# from peak_node to len-2
inside_flag
=
False
chunk_region_end
=
len
(
active_node
)
-
1
for
i
in
range
(
peak_node
,
len
(
active_node
)):
if
active_node_num
[
i
]
<=
threshold
:
inside_flag
=
True
if
inside_flag
and
active_node_num
[
i
]
>
threshold
:
chunk_region_end
=
i
break
for
i
in
chunk_regions
:
region
=
i
[
"region"
]
if
chunk_region_start
>=
region
[
0
]
and
chunk_region_end
<=
region
[
1
]:
return
None
elif
(
region
[
0
]
<=
chunk_region_start
<=
region
[
1
]
and
chunk_region_end
>
region
[
1
]
):
chunk_region_start
=
region
[
1
]
+
1
elif
(
region
[
0
]
<=
chunk_region_end
<=
region
[
1
]
and
chunk_region_start
<
region
[
0
]
):
chunk_region_end
=
region
[
0
]
-
1
return
chunk_region_start
,
chunk_region_end
def
_find_chunk_info
(
self
,
input_trace
,
output_trace
,
start_idx
,
end_idx
)
->
List
:
"""
Find chunk info for a region.
We are given the region start and region end, and need to find out all chunk info for it.
We first loop every dim of start node and end node, to see if we can find dim pair,
which is linked in a flow and not computed.
If found, we then search flow in the whole region to find out all chunk infos.
Args:
input_trace (List): node's input trace in region
output_trace (List): node's output trace in region
start_idx (int): region start node index
end_idx (int): region end node index
Returns:
chunk_infos: possible regions found
"""
start_traces
=
input_trace
[
start_idx
]
end_trace
=
output_trace
[
end_idx
]
end_node
=
self
.
trace_indice
.
node_list
[
end_idx
]
chunk_infos
=
[]
for
end_dim
,
_
in
enumerate
(
end_trace
[
"indice"
]):
if
len
(
start_traces
)
>
1
:
continue
for
start_node
,
start_trace
in
start_traces
.
items
():
for
start_dim
,
_
in
enumerate
(
start_trace
[
"indice"
]):
# dim size cannot be 1
if
(
get_node_shape
(
end_node
)[
end_dim
]
==
1
or
get_node_shape
(
start_node
)[
start_dim
]
==
1
):
continue
# check index source align
if
not
self
.
trace_flow
.
check_index_source
(
start_dim
,
start_node
,
start_idx
,
end_dim
,
end_node
):
continue
# check index copmute
if
not
self
.
trace_flow
.
check_index_compute
(
start_idx
,
end_dim
,
end_node
,
end_idx
):
continue
# flow search
chunk_info
=
self
.
trace_flow
.
flow_search
(
start_idx
,
start_dim
,
end_idx
,
end_dim
)
if
chunk_info
is
None
:
continue
# check index copmute
if
not
self
.
trace_flow
.
check_index_duplicate
(
chunk_info
):
continue
chunk_infos
.
append
(
chunk_info
)
return
chunk_infos
def
_search_possible_chunk_regions
(
self
,
max_chunk_region
:
Tuple
,
peak_node
:
Node
)
->
List
:
"""
Search every possible region within the max chunk region.
Args:
max_chunk_region (Tuple)
peak_node (Node): peak memory node
Returns:
possible_chunk_region (List)
"""
possible_chunk_region
=
[]
output_trace
=
copy
.
deepcopy
(
self
.
trace_indice
.
indice_trace_list
)
input_trace
=
[]
# trace of a node's input nodes
for
_
,
n
in
enumerate
(
self
.
trace_indice
.
node_list
):
cur_trace
=
{}
for
arg
in
n
.
args
:
if
type
(
arg
)
==
type
(
n
)
and
not
is_non_compute_node_except_placeholder
(
arg
):
cur_trace
[
arg
]
=
self
.
trace_indice
.
_find_trace_from_node
(
arg
)
input_trace
.
append
(
cur_trace
)
for
start_idx
in
range
(
max_chunk_region
[
0
],
peak_node
+
1
):
for
end_idx
in
range
(
peak_node
,
max_chunk_region
[
1
]
+
1
):
# skip non compute nodes
if
is_non_compute_node
(
self
.
trace_indice
.
node_list
[
start_idx
]
)
or
is_non_compute_node
(
self
.
trace_indice
.
node_list
[
end_idx
]):
continue
# select free dim
chunk_info
=
self
.
_find_chunk_info
(
input_trace
,
output_trace
,
start_idx
,
end_idx
)
if
len
(
chunk_info
)
>
0
:
possible_chunk_region
.
extend
(
chunk_info
)
return
possible_chunk_region
def
_step_search
(
self
,
mem_peak
:
List
[
float
],
active_node
:
List
[
List
[
Node
]],
chunk_infos
:
List
[
Dict
],
)
->
Dict
:
"""
Find one chunk region
The chunk search is as follows:
1. find the peak memory node
2. find the max chunk region according to the peak memory node
3. find all possible chunk regions in the max chunk region
4. find the best chunk region for current status
Args:
mem_peak (List): peak memory for every node
active_node (List[List[Node]]): active node for every node
chunk_infos (List[Dict]): all chunk info
Returns:
best_chunk_region (Dict)
"""
peak_node
=
self
.
_find_peak_node
(
mem_peak
)
max_chunk_region
=
self
.
_search_max_chunk_region
(
active_node
,
peak_node
,
chunk_infos
)
if
max_chunk_region
==
None
:
return
None
possible_chunk_regions
=
self
.
_search_possible_chunk_regions
(
max_chunk_region
,
peak_node
)
best_chunk_region
=
self
.
select_chunk
.
_select_best_chunk_region
(
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
)
best_chunk_region
=
self
.
reorder_graph
.
reorder_all
(
best_chunk_region
)
return
best_chunk_region
def
_stop_search
(
self
,
init_mem_peak
,
mem_peak
):
sorted_init_mem_peak
=
sorted
(
init_mem_peak
)
if
max
(
mem_peak
)
<
sorted_init_mem_peak
[
int
(
len
(
sorted_init_mem_peak
)
*
0.5
)]:
return
True
return
False
def
search_region
(
self
)
->
Dict
:
"""
Search all chunk regions:
1. Estimate current memory
2. Find best chunk for current memory
3. goto 1
Returns:
chunk_infos (Dict)
"""
chunk_infos
=
[]
(
init_mem_peak
,
_
,
active_node
,
)
=
self
.
estimate_memory
.
estimate_chunk_inference_mem
(
self
.
trace_indice
.
node_list
)
mem_peak
=
init_mem_peak
while
True
:
chunk_info
=
self
.
_step_search
(
mem_peak
,
active_node
,
chunk_infos
)
if
chunk_info
is
None
:
break
chunk_infos
.
append
(
chunk_info
)
(
mem_peak
,
_
,
active_node
,
)
=
self
.
estimate_memory
.
estimate_chunk_inference_mem
(
self
.
trace_indice
.
node_list
,
chunk_infos
)
if
self
.
_stop_search
(
init_mem_peak
,
mem_peak
):
break
if
self
.
print_mem
:
self
.
print_mem
=
False
self
.
estimate_memory
.
estimate_chunk_inference_mem
(
self
.
trace_indice
.
node_list
,
chunk_infos
,
print_mem
=
True
)
return
chunk_infos
colossalai/autochunk/select_chunk.py
0 → 100644
View file @
93f62dd1
from
.estimate_memory
import
EstimateMemory
from
.reorder_graph
import
ReorderGraph
from
.trace_indice
import
TraceIndice
from
.utils
import
is_non_compute_node
class
SelectChunk
(
object
):
def
__init__
(
self
,
trace_indice
:
TraceIndice
,
estimate_memory
:
EstimateMemory
,
reorder_graph
:
ReorderGraph
,
max_memory
=
None
,
):
self
.
trace_indice
=
trace_indice
self
.
estimate_memory
=
estimate_memory
self
.
reorder_graph
=
reorder_graph
if
max_memory
is
not
None
:
self
.
stratge
=
"fit_memory"
self
.
max_memory
=
max_memory
# MB
else
:
self
.
stratge
=
"min_memory"
def
_select_best_chunk_region
(
self
,
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
):
if
self
.
stratge
==
"min_memory"
:
best_region
=
self
.
_select_min_memory_chunk_region
(
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
,
)
elif
self
.
stratge
==
"fit_memory"
:
best_region
=
self
.
_select_fit_memory_chunk_region
(
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
,
)
else
:
raise
RuntimeError
()
return
best_region
def
_select_fit_memory_chunk_region
(
self
,
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
):
# stop chunk if max memory satisfy memory limit
if
max
(
mem_peak
)
<
self
.
max_memory
:
return
None
# remove illegal regions
illegal_regions
=
[]
for
i
in
possible_chunk_regions
:
if
not
self
.
_is_legal_region
(
i
,
chunk_infos
):
illegal_regions
.
append
(
i
)
for
i
in
illegal_regions
:
if
i
in
possible_chunk_regions
:
possible_chunk_regions
.
remove
(
i
)
if
len
(
possible_chunk_regions
)
==
0
:
return
None
# get mem for chunk region
regions_dict
=
[]
for
region
in
possible_chunk_regions
:
cur_region
=
region
.
copy
()
cur_node_list
,
cur_region
=
self
.
reorder_graph
.
tmp_reorder
(
self
.
trace_indice
.
node_list
,
cur_region
)
cur_chunk_infos
=
chunk_infos
+
[
cur_region
]
cur_mem_peak
=
self
.
estimate_memory
.
estimate_chunk_inference_mem
(
cur_node_list
,
cur_chunk_infos
)[
0
]
cur_chunk_region_peak
=
cur_mem_peak
[
max_chunk_region
[
0
]
:
max_chunk_region
[
1
]
+
1
]
cur_chunk_region_max_peak
=
max
(
cur_chunk_region_peak
)
if
cur_chunk_region_max_peak
<
self
.
max_memory
:
regions_dict
.
append
(
{
"chunk_info"
:
region
,
"chunk_max_mem"
:
cur_chunk_region_max_peak
,
"chunk_len"
:
self
.
_get_compute_node_num
(
region
[
"region"
][
0
],
region
[
"region"
][
1
]
),
"reorder_chunk_info"
:
cur_region
,
"reorder_node_list"
:
cur_node_list
,
}
)
# no region found
if
len
(
regions_dict
)
==
0
:
raise
RuntimeError
(
"Search failed. Try a larger memory threshold."
)
# select the min chunk len
chunk_len
=
[
i
[
"chunk_len"
]
for
i
in
regions_dict
]
best_region_idx
=
chunk_len
.
index
(
min
(
chunk_len
))
best_region
=
regions_dict
[
best_region_idx
]
# get max chunk size
best_region
=
self
.
_get_fit_chunk_size
(
best_region
,
chunk_infos
)
return
best_region
def
_get_fit_chunk_size
(
self
,
chunk_region_dict
,
chunk_infos
):
chunk_size
=
1
reorder_chunk_info
=
chunk_region_dict
[
"reorder_chunk_info"
]
reorder_chunk_info
[
"chunk_size"
]
=
chunk_size
cur_chunk_max_mem
=
0
# search a region
while
cur_chunk_max_mem
<
self
.
max_memory
:
chunk_size
*=
2
reorder_chunk_info
[
"chunk_size"
]
=
chunk_size
cur_chunk_infos
=
chunk_infos
+
[
reorder_chunk_info
]
cur_mem_peak
=
self
.
estimate_memory
.
estimate_chunk_inference_mem
(
chunk_region_dict
[
"reorder_node_list"
],
cur_chunk_infos
)[
0
]
cur_chunk_max_mem
=
max
(
cur_mem_peak
[
reorder_chunk_info
[
"region"
][
0
]
:
reorder_chunk_info
[
"region"
][
1
]
+
1
]
)
# search exact size
chunk_info
=
chunk_region_dict
[
"chunk_info"
]
chunk_info
[
"chunk_size"
]
=
self
.
_chunk_size_binary_search
(
chunk_size
//
2
,
chunk_size
,
chunk_region_dict
,
chunk_infos
)
return
chunk_info
def
_chunk_size_binary_search
(
self
,
left
,
right
,
chunk_region_dict
,
chunk_infos
):
if
left
>=
16
:
gap
=
4
else
:
gap
=
1
chunk_info
=
chunk_region_dict
[
"reorder_chunk_info"
]
while
right
>=
left
+
gap
:
mid
=
int
((
left
+
right
)
/
2
+
0.5
)
chunk_info
[
"chunk_size"
]
=
mid
cur_chunk_infos
=
chunk_infos
+
[
chunk_info
]
cur_mem_peak
=
self
.
estimate_memory
.
estimate_chunk_inference_mem
(
chunk_region_dict
[
"reorder_node_list"
],
cur_chunk_infos
)[
0
]
cur_chunk_max_mem
=
max
(
cur_mem_peak
[
chunk_info
[
"region"
][
0
]
:
chunk_info
[
"region"
][
1
]
+
1
]
)
if
cur_chunk_max_mem
>=
self
.
max_memory
:
right
=
mid
-
gap
else
:
left
=
mid
+
gap
return
left
def
_get_compute_node_num
(
self
,
start
,
end
):
count
=
0
for
i
in
self
.
trace_indice
.
node_list
[
start
:
end
+
1
]:
if
not
is_non_compute_node
(
i
):
count
+=
1
return
count
def
_select_min_memory_chunk_region
(
self
,
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
):
# remove illegal regions
illegal_regions
=
[]
for
i
in
possible_chunk_regions
:
if
not
self
.
_is_legal_region
(
i
,
chunk_infos
):
illegal_regions
.
append
(
i
)
for
i
in
illegal_regions
:
if
i
in
possible_chunk_regions
:
possible_chunk_regions
.
remove
(
i
)
if
len
(
possible_chunk_regions
)
==
0
:
return
None
# get mem for chunk region
regions_dict
=
[]
for
region
in
possible_chunk_regions
:
cur_region
=
region
.
copy
()
cur_node_list
,
cur_region
=
self
.
reorder_graph
.
tmp_reorder
(
self
.
trace_indice
.
node_list
,
cur_region
)
cur_chunk_infos
=
chunk_infos
+
[
cur_region
]
cur_mem_peak
=
self
.
estimate_memory
.
estimate_chunk_inference_mem
(
cur_node_list
,
cur_chunk_infos
)[
0
]
cur_chunk_region_peak
=
cur_mem_peak
[
max_chunk_region
[
0
]
:
max_chunk_region
[
1
]
+
1
]
cur_chunk_region_max_peak
=
max
(
cur_chunk_region_peak
)
regions_dict
.
append
(
{
"chunk_info"
:
region
,
"chunk_max_mem"
:
cur_chunk_region_max_peak
,
"chunk_len"
:
self
.
_get_compute_node_num
(
region
[
"region"
][
0
],
region
[
"region"
][
1
]
),
"reorder_chunk_info"
:
cur_region
,
"reorder_node_list"
:
cur_node_list
,
}
)
# select the min mem
chunk_max_mem
=
[
i
[
"chunk_max_mem"
]
for
i
in
regions_dict
]
best_region_idx
=
chunk_max_mem
.
index
(
min
(
chunk_max_mem
))
best_region
=
regions_dict
[
best_region_idx
][
"chunk_info"
]
if
best_region
is
not
None
:
best_region
[
"chunk_size"
]
=
1
return
best_region
def
_is_legal_region
(
self
,
cur_chunk_info
,
chunk_infos
):
(
chunk_region_start
,
chunk_region_end
)
=
cur_chunk_info
[
"region"
]
if
cur_chunk_info
in
chunk_infos
:
return
False
if
chunk_region_end
<
chunk_region_start
:
return
False
for
i
in
chunk_infos
:
region
=
i
[
"region"
]
if
not
(
(
chunk_region_start
>
region
[
1
]
and
chunk_region_end
>
region
[
1
])
or
(
chunk_region_start
<
region
[
0
]
and
chunk_region_end
<
region
[
0
])
):
return
False
return
True
colossalai/autochunk/trace_flow.py
0 → 100644
View file @
93f62dd1
from
.trace_indice
import
TraceIndice
from
.utils
import
(
find_chunk_all_input_nodes
,
find_chunk_compute_input_and_output_nodes
,
find_idx_by_name
,
get_node_shape
,
is_non_compute_node
,
is_non_compute_node_except_placeholder
,
)
class
TraceFlow
(
object
):
def
__init__
(
self
,
trace_indice
:
TraceIndice
)
->
None
:
self
.
trace_indice
=
trace_indice
def
check_index_source
(
self
,
start_dim
,
start_node
,
start_idx
,
end_dim
,
end_node
):
"""
Check 2 given index: one index should be source of the other
Args:
start_idx(int): start node chunk dim
start_node(node): start node
end_idx(int): end node chunk dim
end_node(node): end node
Returns:
bool: True if check pass
"""
start_node_idx
=
find_idx_by_name
(
start_node
.
name
,
self
.
trace_indice
.
node_list
)
end_node_trace
=
self
.
trace_indice
.
_find_trace_from_node
(
end_node
)
end_node_trace_source
=
end_node_trace
[
"source"
][
end_dim
]
sorted_source
=
sorted
(
end_node_trace_source
.
items
(),
key
=
lambda
d
:
d
[
0
],
reverse
=
True
)
for
node_idx
,
node_dim
in
sorted_source
:
if
node_idx
==
start_node_idx
and
start_dim
in
node_dim
:
return
True
# it means we meet a node outside the loop, and the node is not input node
if
node_idx
<
start_idx
:
return
False
return
False
def
check_index_compute
(
self
,
start_idx
,
end_dim
,
end_node
,
end_idx
):
"""
Check 2 given index: check they haven't been computed in the source trace.
Args:
start_idx(int): start node chunk dim
start_node(node): start node
end_idx(int): end node chunk dim
end_node(node): end node
Returns:
bool: True if check pass
"""
end_node_trace
=
self
.
trace_indice
.
_find_trace_from_node
(
end_node
)
end_node_compute
=
end_node_trace
[
"compute"
][
end_dim
]
if
any
(
start_idx
<=
i
<=
end_idx
for
i
in
end_node_compute
):
return
False
return
True
def
get_node_chunk_dim
(
self
,
node_from
,
node_from_dim
,
node_to
):
node_from_source
=
self
.
trace_indice
.
_find_source_trace_from_node
(
node_from
)
dim_source
=
node_from_source
[
node_from_dim
]
node_to_idx
=
find_idx_by_name
(
node_to
.
name
,
self
.
trace_indice
.
node_list
)
for
k
,
v
in
dim_source
.
items
():
if
k
==
node_to_idx
:
return
v
return
None
def
_find_inherit_dim
(
self
,
input_node
,
input_dim
,
node
):
input_node_idx
=
find_idx_by_name
(
input_node
.
name
,
self
.
trace_indice
.
node_list
)
node_trace_source
=
self
.
trace_indice
.
_find_source_trace_from_node
(
node
)
for
node_dim
in
range
(
len
(
get_node_shape
(
node
))):
if
(
input_node_idx
in
node_trace_source
[
node_dim
]
and
input_dim
[
0
]
in
node_trace_source
[
node_dim
][
input_node_idx
]
):
return
node_dim
return
None
def
check_index_duplicate
(
self
,
chunk_infos
,
return_dim
=
False
):
input_dim_after_node
=
{}
for
input_node_idx
,
input_node
in
enumerate
(
chunk_infos
[
"inputs"
]):
for
k
,
v
in
chunk_infos
[
"inputs_dim"
][
input_node_idx
].
items
():
inherit_dim
=
self
.
_find_inherit_dim
(
input_node
,
v
,
self
.
trace_indice
.
node_list
[
k
]
)
if
inherit_dim
:
input_dim_after_node
[
k
]
=
inherit_dim
for
node
in
self
.
trace_indice
.
node_list
[
chunk_infos
[
"region"
][
0
]
:
chunk_infos
[
"region"
][
1
]
+
1
]:
if
is_non_compute_node_except_placeholder
(
node
):
continue
count
=
0
duplicate_dims
=
[]
node_trace_source
=
self
.
trace_indice
.
_find_source_trace_from_node
(
node
)
for
node_dim
in
range
(
len
(
get_node_shape
(
node
))):
duplicate_dim
=
[]
duplicate_flag
=
False
dim_source
=
node_trace_source
[
node_dim
]
for
k
,
v
in
dim_source
.
items
():
if
chunk_infos
[
"region"
][
0
]
<=
k
<=
chunk_infos
[
"region"
][
1
]:
if
k
in
input_dim_after_node
and
input_dim_after_node
[
k
]
in
v
:
duplicate_flag
=
True
duplicate_dim
.
append
((
k
,
v
))
duplicate_dims
.
append
(
duplicate_dim
)
if
duplicate_flag
:
count
+=
1
if
count
>
1
:
if
return_dim
:
return
False
,
duplicate_dims
else
:
return
False
if
return_dim
:
return
True
,
None
else
:
return
True
def
_assgin_single_node_flow
(
self
,
arg_node
,
start_idx
,
end_idx
,
cur_node_dim
,
cur_node_compute
,
cur_node_source
,
cur_node_fix_dim
,
all_node_info
,
next_node_list
,
):
arg_idx
=
find_idx_by_name
(
arg_node
.
name
,
self
.
trace_indice
.
node_list
)
# arg in chunk range or be inputs
if
not
(
start_idx
<=
arg_idx
<
end_idx
):
return
True
# find arg dim
if
cur_node_dim
is
not
None
:
# dim is computed
if
arg_idx
in
cur_node_compute
[
cur_node_dim
]:
return
False
if
arg_idx
not
in
cur_node_source
[
cur_node_dim
]:
arg_dim
=
None
else
:
arg_dim
=
cur_node_source
[
cur_node_dim
][
arg_idx
][
0
]
else
:
arg_dim
=
None
# get fix dim
arg_fix_dim
=
[]
if
cur_node_dim
is
not
None
:
for
i
in
cur_node_fix_dim
:
fix_dim_source
=
cur_node_source
[
i
]
if
arg_idx
in
fix_dim_source
:
arg_fix_dim
.
append
(
fix_dim_source
[
arg_idx
][
0
])
# if already in node_info, arg dim must be same
if
arg_node
in
all_node_info
:
if
all_node_info
[
arg_node
][
"chunk_dim"
]
!=
arg_dim
:
return
False
all_node_info
[
arg_node
][
"fix_dim"
]
=
list
(
set
(
all_node_info
[
arg_node
][
"fix_dim"
]
+
arg_fix_dim
)
)
# else add it to list
else
:
all_node_info
[
arg_node
]
=
{
"chunk_dim"
:
arg_dim
,
"fix_dim"
:
arg_fix_dim
}
next_node_list
.
append
(
arg_node
)
return
True
def
_get_all_node_info
(
self
,
end_dim
,
start_idx
,
end_idx
):
cur_node_list
=
[
self
.
trace_indice
.
node_list
[
end_idx
]
]
# start from the last node
all_node_info
=
{
cur_node_list
[
0
]:
{
"chunk_dim"
:
end_dim
,
"fix_dim"
:
[]}}
while
len
(
cur_node_list
)
>
0
:
next_node_list
=
[]
for
cur_node
in
cur_node_list
:
# get cur node info
cur_node_chunk_dim
=
all_node_info
[
cur_node
][
"chunk_dim"
]
cur_node_fix_dim
=
all_node_info
[
cur_node
][
"fix_dim"
]
if
cur_node_chunk_dim
:
cur_node_compute
=
self
.
trace_indice
.
_find_compute_trace_from_node
(
cur_node
)
cur_node_source
=
self
.
trace_indice
.
_find_source_trace_from_node
(
cur_node
)
else
:
cur_node_compute
=
cur_node_source
=
None
# get all valid args
arg_list
=
[]
for
arg
in
cur_node
.
args
:
if
type
(
arg
)
!=
type
(
cur_node
):
continue
if
is_non_compute_node
(
arg
):
continue
arg_list
.
append
(
arg
)
flow_flag
=
self
.
_assgin_single_node_flow
(
arg
,
start_idx
,
end_idx
,
cur_node_chunk_dim
,
cur_node_compute
,
cur_node_source
,
cur_node_fix_dim
,
all_node_info
,
next_node_list
,
)
if
flow_flag
==
False
:
return
None
if
len
(
arg_list
)
==
2
:
if
any
(
i
in
cur_node
.
name
for
i
in
[
"add"
,
"mul"
]):
for
arg
in
arg_list
:
if
not
(
start_idx
<=
find_idx_by_name
(
arg
.
name
,
self
.
trace_indice
.
node_list
)
<
end_idx
):
continue
arg_chunk_dim
=
all_node_info
[
arg
][
"chunk_dim"
]
arg_fix_dim
=
all_node_info
[
arg
][
"fix_dim"
]
arg_shape
=
get_node_shape
(
arg
)
# add all dim as fix dim except chunk dim
for
i
,
shape
in
enumerate
(
arg_shape
):
if
shape
!=
1
and
i
!=
cur_node_chunk_dim
:
if
i
==
arg_chunk_dim
:
return
None
if
i
not
in
arg_fix_dim
:
arg_fix_dim
.
append
(
i
)
elif
"einsum"
in
cur_node
.
name
:
pass
elif
"matmul"
in
cur_node
.
name
:
pass
else
:
raise
NotImplementedError
()
cur_node_list
=
next_node_list
return
all_node_info
def
_get_input_nodes_dim
(
self
,
inputs
,
start_idx
,
end_idx
,
all_node_info
):
inputs_dim
=
[]
remove_inputs
=
[]
for
input_node
in
inputs
:
input_dict
=
{}
input_node_idx
=
find_idx_by_name
(
input_node
.
name
,
self
.
trace_indice
.
node_list
)
for
user
in
input_node
.
users
.
keys
():
if
is_non_compute_node
(
user
):
continue
user_idx
=
find_idx_by_name
(
user
.
name
,
self
.
trace_indice
.
node_list
)
if
start_idx
<=
user_idx
<=
end_idx
:
chunk_dim
=
all_node_info
[
user
][
"chunk_dim"
]
if
chunk_dim
is
not
None
:
user_source
=
self
.
trace_indice
.
_find_source_trace_from_node
(
user
)[
chunk_dim
]
if
input_node_idx
in
user_source
:
input_dict
[
user_idx
]
=
user_source
[
input_node_idx
]
else
:
return
None
,
None
if
len
(
input_dict
)
==
0
:
remove_inputs
.
append
(
input_node
)
else
:
inputs_dim
.
append
(
input_dict
)
for
i
in
remove_inputs
:
if
i
in
inputs
:
inputs
.
remove
(
i
)
return
inputs
,
inputs_dim
def
_get_prepose_nodes
(
self
,
all_node_info
,
start_idx
,
end_idx
):
# get all possible prepose nodes
maybe_prepose_nodes
=
[]
for
node
,
node_info
in
all_node_info
.
items
():
if
node_info
[
"chunk_dim"
]
is
None
:
maybe_prepose_nodes
.
append
(
node
)
maybe_prepose_nodes
.
sort
(
key
=
lambda
x
:
find_idx_by_name
(
x
.
name
,
self
.
trace_indice
.
node_list
),
reverse
=
True
,
)
# from last node to first node
prepose_nodes
=
[]
# set every node as root, search its args, if all legal, turn root and args as prepose nodes
while
len
(
maybe_prepose_nodes
)
>
0
:
tmp_cur_prepose_nodes
=
[
maybe_prepose_nodes
[
0
]]
tmp_cur_related_prepose_nodes
=
[]
prepose_flag
=
True
# loop cur node's all arg until out of chunk
while
len
(
tmp_cur_prepose_nodes
)
>
0
:
if
prepose_flag
==
False
:
break
tmp_next_prepose_nodes
=
[]
tmp_cur_related_prepose_nodes
.
extend
(
tmp_cur_prepose_nodes
)
for
cur_prepose_node
in
tmp_cur_prepose_nodes
:
if
prepose_flag
==
False
:
break
for
cur_prepose_node_arg
in
cur_prepose_node
.
args
:
if
type
(
cur_prepose_node_arg
)
!=
type
(
cur_prepose_node
):
continue
# out of loop
if
not
(
start_idx
<=
find_idx_by_name
(
cur_prepose_node_arg
.
name
,
self
.
trace_indice
.
node_list
)
<
end_idx
):
continue
# compute op in loop
elif
cur_prepose_node_arg
in
all_node_info
:
if
all_node_info
[
cur_prepose_node_arg
][
"chunk_dim"
]
is
None
:
tmp_next_prepose_nodes
.
append
(
cur_prepose_node_arg
)
else
:
prepose_flag
=
False
break
# non compute op
else
:
tmp_next_prepose_nodes
.
append
(
cur_prepose_node_arg
)
tmp_cur_prepose_nodes
=
tmp_next_prepose_nodes
if
prepose_flag
==
False
:
maybe_prepose_nodes
.
remove
(
maybe_prepose_nodes
[
0
])
continue
else
:
for
n
in
tmp_cur_related_prepose_nodes
:
if
n
not
in
prepose_nodes
:
prepose_nodes
.
append
(
n
)
if
n
in
maybe_prepose_nodes
:
maybe_prepose_nodes
.
remove
(
n
)
# sort by index
prepose_nodes
.
sort
(
key
=
lambda
x
:
find_idx_by_name
(
x
.
name
,
self
.
trace_indice
.
node_list
)
)
return
prepose_nodes
def
_get_non_chunk_inputs
(
self
,
chunk_info
,
start_idx
,
end_idx
):
# we need to log input nodes to avoid deleteing them in the loop
chunk_node_list
=
self
.
trace_indice
.
node_list
[
start_idx
:
end_idx
+
1
]
# also need to get some prepose node's arg out of non_chunk_inputs
for
n
in
chunk_info
[
"args"
][
"prepose_nodes"
]:
chunk_node_list
.
remove
(
n
)
non_chunk_inputs
=
find_chunk_all_input_nodes
(
chunk_node_list
)
for
i
in
non_chunk_inputs
:
if
i
not
in
chunk_info
[
"inputs"
]:
chunk_info
[
"inputs_non_chunk"
].
append
(
i
)
return
chunk_info
def
flow_search
(
self
,
start_idx
,
start_dim
,
end_idx
,
end_dim
):
inputs
,
outputs
=
find_chunk_compute_input_and_output_nodes
(
self
.
trace_indice
.
node_list
[
start_idx
:
end_idx
+
1
]
)
# only single ouput
if
len
(
outputs
)
>
1
:
return
None
# get every node's chunk dim and fix dim
all_node_info
=
self
.
_get_all_node_info
(
end_dim
,
start_idx
,
end_idx
)
if
all_node_info
is
None
:
return
None
# get input nodes' chunk dim
inputs
,
inputs_dim
=
self
.
_get_input_nodes_dim
(
inputs
,
start_idx
,
end_idx
,
all_node_info
)
if
inputs
is
None
:
return
None
chunk_info
=
{
"region"
:
(
start_idx
,
end_idx
),
"inputs"
:
inputs
,
"inputs_non_chunk"
:
[],
"inputs_dim"
:
inputs_dim
,
"outputs"
:
outputs
,
"outputs_dim"
:
end_dim
,
"node_chunk_dim"
:
all_node_info
,
"args"
:
{},
}
# move useless nodes ahead of loop
chunk_info
[
"args"
][
"prepose_nodes"
]
=
self
.
_get_prepose_nodes
(
all_node_info
,
start_idx
,
end_idx
)
# find non chunk inputs
chunk_info
=
self
.
_get_non_chunk_inputs
(
chunk_info
,
start_idx
,
end_idx
)
# reassgin reshape size, some size may have changed due to chunk
chunk_info
=
self
.
_reassgin_reshape_size
(
chunk_info
)
return
chunk_info
def
_reassgin_reshape_size
(
self
,
chunk_info
):
chunk_region
=
chunk_info
[
"region"
]
reshape_size
=
{}
chunk_shape
=
get_node_shape
(
chunk_info
[
"outputs"
][
0
])[
chunk_info
[
"outputs_dim"
]
]
for
node
in
self
.
trace_indice
.
node_list
[
chunk_region
[
0
]
:
chunk_region
[
1
]
+
1
]:
if
any
(
i
in
node
.
name
for
i
in
[
"reshape"
,
"view"
]):
reshape_args
=
node
.
args
[
1
:]
reshape_log
=
self
.
trace_indice
.
indice_view_list
[
node
]
chunk_dim
=
chunk_info
[
"node_chunk_dim"
][
node
][
"chunk_dim"
]
reshape_size
[
node
.
name
]
=
{}
for
reshape_arg_dim
,
reshape_arg
in
enumerate
(
reshape_args
):
if
reshape_arg_dim
in
reshape_log
[
"dim_to"
]:
continue
if
reshape_arg_dim
==
chunk_dim
:
reshape_size
[
node
.
name
][
reshape_arg
.
name
]
=
(
"min(chunk_size, %d - chunk_idx)"
%
chunk_shape
)
chunk_info
[
"reshape_size"
]
=
reshape_size
return
chunk_info
colossalai/autochunk/trace_indice.py
0 → 100644
View file @
93f62dd1
This diff is collapsed.
Click to expand it.
colossalai/autochunk/utils.py
0 → 100644
View file @
93f62dd1
from
typing
import
Any
,
Callable
,
Dict
,
Iterable
,
List
,
Tuple
from
torch.fx.node
import
Node
def
is_non_compute_node
(
node
):
if
any
(
i
in
node
.
op
for
i
in
[
"placeholder"
,
"get_attr"
,
"output"
])
or
any
(
i
in
node
.
name
for
i
in
[
"getitem"
,
"getattr"
]
):
return
True
return
False
def
get_node_shape
(
node
):
if
hasattr
(
node
.
meta
[
"tensor_meta"
],
"shape"
):
return
node
.
meta
[
"tensor_meta"
].
shape
return
None
def
is_non_compute_node_except_placeholder
(
node
):
if
any
(
i
in
node
.
op
for
i
in
[
"get_attr"
,
"output"
])
or
any
(
i
in
node
.
name
for
i
in
[
"getitem"
,
"getattr"
]
):
return
True
return
False
def
is_non_compute_node_except_placeholder_output
(
node
):
if
any
(
i
in
node
.
op
for
i
in
[
"get_attr"
])
or
any
(
i
in
node
.
name
for
i
in
[
"getitem"
,
"getattr"
]
):
return
True
return
False
def
find_idx_by_name
(
name
,
nodes_list
):
for
idx
,
node
in
enumerate
(
nodes_list
):
if
node
.
name
==
name
:
return
idx
raise
RuntimeError
(
"name %s not found in node list"
%
name
)
def
delete_free_var_from_last_use
(
user_to_last_uses
):
for
key
,
value
in
user_to_last_uses
.
items
():
for
n
in
value
:
if
n
.
op
==
"placeholder"
:
user_to_last_uses
[
key
].
remove
(
n
)
def
find_chunk_all_input_nodes
(
nodes
:
List
[
Node
]):
"""
Find non-compute input and output node names.
input nodes are nodes used in the list
output nodes are nodes will use nodes in the list
"""
input_nodes
=
[]
for
node
in
nodes
:
for
input_node
in
node
.
_input_nodes
.
keys
():
if
input_node
not
in
nodes
and
input_node
not
in
input_nodes
:
input_nodes
.
append
(
input_node
)
return
input_nodes
def
find_chunk_compute_input_and_output_nodes
(
nodes
:
List
[
Node
]):
"""
Find non-compute input and output node names.
input nodes are nodes used in the list
output nodes are nodes will use nodes in the list
"""
input_nodes
=
[]
output_nodes
=
[]
# if a node has an input node which is not in the node list
# we treat that input node as the input of the checkpoint function
for
node
in
nodes
:
for
input_node
in
node
.
_input_nodes
.
keys
():
if
(
input_node
not
in
nodes
and
input_node
not
in
input_nodes
and
not
is_non_compute_node_except_placeholder
(
input_node
)
):
input_nodes
.
append
(
input_node
)
# if a node has a user node which is not in the node list
# we treat that user node as the node receiving the current node output
for
node
in
nodes
:
for
output_node
in
node
.
users
.
keys
():
if
(
output_node
not
in
nodes
and
node
not
in
output_nodes
and
not
is_non_compute_node_except_placeholder_output
(
output_node
)
):
output_nodes
.
append
(
node
)
return
input_nodes
,
output_nodes
tests/test_autochunk/benchmark_autochunk.py
0 → 100644
View file @
93f62dd1
import
time
import
torch
import
torch.fx
from
colossalai.autochunk.autochunk_codegen
import
AutoChunkCodeGen
from
colossalai.fx
import
ColoTracer
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.fx.profiler
import
MetaTensor
from
tests.test_autochunk.evoformer.evoformer
import
evoformer_base
from
tests.test_autochunk.openfold.evoformer
import
EvoformerBlock
def
_benchmark_evoformer
(
model
:
torch
.
nn
.
Module
,
node
,
pair
,
title
,
chunk_size
=
None
):
torch
.
cuda
.
reset_peak_memory_stats
()
now_mem
=
torch
.
cuda
.
memory_allocated
()
/
1024
**
2
loop
=
3
with
torch
.
no_grad
():
for
_
in
range
(
loop
//
2
+
1
):
if
chunk_size
:
model
(
node
,
pair
,
chunk_size
)
else
:
model
(
node
,
pair
)
torch
.
cuda
.
synchronize
()
time1
=
time
.
time
()
for
_
in
range
(
loop
):
if
chunk_size
:
model
(
node
,
pair
,
chunk_size
)
else
:
model
(
node
,
pair
)
torch
.
cuda
.
synchronize
()
time2
=
time
.
time
()
new_max_mem
=
torch
.
cuda
.
max_memory_allocated
()
/
1024
**
2
print
(
"%s: time %.4fs, mem %dMB"
%
(
title
,
(
time2
-
time1
)
/
loop
,
new_max_mem
-
now_mem
)
)
def
_build_autochunk
(
model
,
max_memory
,
node
,
pair
):
# trace the module and replace codegen
graph
=
ColoTracer
().
trace
(
model
,
meta_args
=
{
"node"
:
node
.
to
(
torch
.
device
(
"meta"
)),
"pair"
:
pair
.
to
(
torch
.
device
(
"meta"
)),
},
)
gm_prop
=
torch
.
fx
.
symbolic_trace
(
model
)
# must use symbolic_trace
interp
=
MetaInfoProp
(
gm_prop
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
)
)
# now run it twice to get meta info in graph module, not necessary
gm
=
torch
.
fx
.
GraphModule
(
model
,
graph
)
interp
=
MetaInfoProp
(
gm
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
)
)
# set code_gen
codegen
=
AutoChunkCodeGen
(
gm_prop
,
max_memory
,
print_mem
=
False
)
graph
.
set_codegen
(
codegen
)
gm
=
ColoGraphModule
(
model
,
graph
)
gm
.
recompile
()
# print
# code = graph.python_code("self").src
# print(code)
return
gm
def
_build_openfold
():
model
=
EvoformerBlock
(
c_m
=
256
,
c_z
=
128
,
c_hidden_msa_att
=
32
,
c_hidden_opm
=
32
,
c_hidden_mul
=
128
,
c_hidden_pair_att
=
32
,
no_heads_msa
=
8
,
no_heads_pair
=
4
,
transition_n
=
4
,
msa_dropout
=
0.15
,
pair_dropout
=
0.15
,
inf
=
1e4
,
eps
=
1e-4
,
is_multimer
=
False
,
).
cuda
()
return
model
def
benchmark_evoformer
():
# init data and model
msa_len
=
256
pair_len
=
512
node
=
torch
.
randn
(
1
,
msa_len
,
pair_len
,
256
).
cuda
()
pair
=
torch
.
randn
(
1
,
pair_len
,
pair_len
,
128
).
cuda
()
model
=
evoformer_base
().
cuda
()
# build autochunk model
# max_memory = 1000 # MB, fit memory mode
max_memory
=
None
# min memory mode
autochunk
=
_build_autochunk
(
evoformer_base
().
cuda
(),
max_memory
,
node
,
pair
)
# build openfold
chunk_size
=
64
openfold
=
_build_openfold
()
# benchmark
_benchmark_evoformer
(
model
,
node
,
pair
,
"base"
)
_benchmark_evoformer
(
openfold
,
node
,
pair
,
"openfold"
,
chunk_size
=
chunk_size
)
_benchmark_evoformer
(
autochunk
,
node
,
pair
,
"autochunk"
)
if
__name__
==
"__main__"
:
benchmark_evoformer
()
tests/test_autochunk/evoformer/evoformer.py
0 → 100644
View file @
93f62dd1
import
torch
import
torch.nn
as
nn
from
.msa
import
MSAStack
from
.ops
import
OutProductMean
from
.triangle
import
PairStack
def
print_memory
(
init_mem
,
text
=
None
):
now_mem
=
torch
.
cuda
.
memory_allocated
()
/
1024
**
2
-
init_mem
max_mem
=
torch
.
cuda
.
max_memory_allocated
()
/
1024
**
2
-
init_mem
print
(
"%s now:%.2f max:%.2f"
%
(
""
if
text
is
None
else
text
,
now_mem
,
max_mem
))
torch
.
cuda
.
reset_peak_memory_stats
()
class
EvoformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
d_node
,
d_pair
):
super
(
EvoformerBlock
,
self
).
__init__
()
self
.
msa_stack
=
MSAStack
(
d_node
,
d_pair
,
p_drop
=
0.15
)
self
.
communication
=
OutProductMean
(
n_feat
=
d_node
,
n_feat_out
=
d_pair
,
n_feat_proj
=
32
)
self
.
pair_stack
=
PairStack
(
d_pair
=
d_pair
)
def
forward
(
self
,
node
,
pair
):
node
=
self
.
msa_stack
(
node
,
pair
)
pair
=
pair
+
self
.
communication
(
node
)
pair
=
self
.
pair_stack
(
pair
)
return
node
,
pair
class
Evoformer
(
nn
.
Module
):
def
__init__
(
self
,
d_node
,
d_pair
):
super
(
Evoformer
,
self
).
__init__
()
self
.
blocks
=
nn
.
ModuleList
()
for
_
in
range
(
1
):
self
.
blocks
.
append
(
EvoformerBlock
(
d_node
,
d_pair
))
def
forward
(
self
,
node
,
pair
):
for
b
in
self
.
blocks
:
node
,
pair
=
b
(
node
,
pair
)
return
node
,
pair
def
evoformer_tiny
():
return
Evoformer
(
d_node
=
64
,
d_pair
=
32
)
def
evoformer_base
():
return
Evoformer
(
d_node
=
256
,
d_pair
=
128
)
def
evoformer_large
():
return
Evoformer
(
d_node
=
512
,
d_pair
=
256
)
__all__
=
[
'Evoformer'
,
'evoformer_base'
,
'evoformer_large'
]
tests/test_autochunk/evoformer/initializer.py
0 → 100755
View file @
93f62dd1
import
math
import
numpy
as
np
import
torch.nn
as
nn
def
glorot_uniform_af
(
x
,
gain
=
1.0
):
"""
initialize tensors the same as xavier_initializer in PyTorch, but the dimensions are different:
In PyTorch:
[feature_out, feature_in, n_head ...]
In Jax:
[... n_head, feature_in, feature_out]
However, there is a feature in original Alphafold2 code that they use the Jax version initializer to initialize tensors like:
[feature_in, n_head, feature_out]
In this function, we keep this feature to initialize [feature_in, n_head, ..., feature_out] tensors
"""
fan_in
,
fan_out
=
x
.
shape
[
-
2
:]
if
len
(
x
.
shape
)
>
2
:
receptive_field_size
=
np
.
prod
(
x
.
shape
[:
-
2
])
fan_in
*=
receptive_field_size
fan_out
*=
receptive_field_size
std
=
gain
*
math
.
sqrt
(
2.0
/
float
(
fan_in
+
fan_out
))
dev
=
math
.
sqrt
(
3.0
)
*
std
# Calculate uniform bounds from standard deviation
nn
.
init
.
uniform_
(
x
,
-
dev
,
dev
)
return
x
tests/test_autochunk/evoformer/kernel.py
0 → 100644
View file @
93f62dd1
import
torch
import
torch.nn.functional
as
F
def
bias_sigmod_ele
(
y
,
bias
,
z
):
return
torch
.
sigmoid
(
y
+
bias
)
*
z
def
bias_dropout_add
(
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
dropmask
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
prob
:
float
)
->
torch
.
Tensor
:
out
=
(
x
+
bias
)
*
F
.
dropout
(
dropmask
,
p
=
prob
,
training
=
False
)
out
=
residual
+
out
return
out
def
bias_ele_dropout_residual
(
ab
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
dropout_mask
:
torch
.
Tensor
,
Z_raw
:
torch
.
Tensor
,
prob
:
float
)
->
torch
.
Tensor
:
return
Z_raw
+
F
.
dropout
(
dropout_mask
,
p
=
prob
,
training
=
True
)
*
(
g
*
(
ab
+
b
))
\ No newline at end of file
tests/test_autochunk/evoformer/msa.py
0 → 100644
View file @
93f62dd1
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
torch.nn
import
LayerNorm
from
.kernel
import
bias_dropout_add
from
.ops
import
SelfAttention
,
Transition
class
MSARowAttentionWithPairBias
(
nn
.
Module
):
def
__init__
(
self
,
d_node
,
d_pair
,
c
=
32
,
n_head
=
8
,
p_drop
=
0.15
):
super
(
MSARowAttentionWithPairBias
,
self
).
__init__
()
self
.
d_node
=
d_node
self
.
d_pair
=
d_pair
self
.
c
=
c
self
.
n_head
=
n_head
self
.
p_drop
=
p_drop
self
.
layernormM
=
LayerNorm
(
d_node
)
self
.
layernormZ
=
LayerNorm
(
d_pair
)
_init_weights
=
torch
.
nn
.
init
.
normal_
(
torch
.
zeros
([
n_head
,
d_pair
]),
std
=
1.0
/
math
.
sqrt
(
d_pair
))
self
.
linear_b_weights
=
nn
.
parameter
.
Parameter
(
data
=
_init_weights
,
requires_grad
=
True
)
self
.
attention
=
SelfAttention
(
qkv_dim
=
d_node
,
c
=
c
,
n_head
=
n_head
,
out_dim
=
d_node
,
gating
=
True
,
last_bias_fuse
=
True
)
self
.
out_bias
=
nn
.
parameter
.
Parameter
(
data
=
torch
.
zeros
((
d_node
,)),
requires_grad
=
True
)
def
forward
(
self
,
M_raw
,
Z
):
## Input projections
M
=
self
.
layernormM
(
M_raw
)
Z
=
self
.
layernormZ
(
Z
)
b
=
F
.
linear
(
Z
,
self
.
linear_b_weights
)
b
=
b
.
permute
(
0
,
3
,
1
,
2
)
# b = rearrange(b, 'b q k h -> b h q k')
M
=
self
.
attention
(
M
,
b
)
dropout_mask
=
torch
.
ones_like
(
M
[:,
0
:
1
,
:,
:]).
to
(
M
.
device
).
to
(
M
.
dtype
)
return
bias_dropout_add
(
M
,
self
.
out_bias
,
dropout_mask
,
M_raw
,
prob
=
self
.
p_drop
)
class
MSAColumnAttention
(
nn
.
Module
):
def
__init__
(
self
,
d_node
,
c
=
32
,
n_head
=
8
):
super
(
MSAColumnAttention
,
self
).
__init__
()
self
.
d_node
=
d_node
self
.
c
=
c
self
.
n_head
=
n_head
self
.
layernormM
=
LayerNorm
(
d_node
)
self
.
attention
=
SelfAttention
(
qkv_dim
=
d_node
,
c
=
c
,
n_head
=
n_head
,
out_dim
=
d_node
,
gating
=
True
)
def
forward
(
self
,
M_raw
):
M
=
M_raw
.
transpose
(
-
2
,
-
3
)
M
=
self
.
layernormM
(
M
)
M
=
self
.
attention
(
M
)
M
=
M
.
transpose
(
-
2
,
-
3
)
return
M_raw
+
M
class
MSAStack
(
nn
.
Module
):
def
__init__
(
self
,
d_node
,
d_pair
,
p_drop
=
0.15
):
super
(
MSAStack
,
self
).
__init__
()
self
.
MSARowAttentionWithPairBias
=
MSARowAttentionWithPairBias
(
d_node
=
d_node
,
d_pair
=
d_pair
,
p_drop
=
p_drop
)
self
.
MSAColumnAttention
=
MSAColumnAttention
(
d_node
=
d_node
)
self
.
MSATransition
=
Transition
(
d
=
d_node
)
def
forward
(
self
,
node
,
pair
):
node
=
self
.
MSARowAttentionWithPairBias
(
node
,
pair
)
node
=
self
.
MSAColumnAttention
(
node
)
node
=
self
.
MSATransition
(
node
)
return
node
tests/test_autochunk/evoformer/ops.py
0 → 100755
View file @
93f62dd1
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
torch.nn
import
LayerNorm
from
.initializer
import
glorot_uniform_af
from
.kernel
import
bias_sigmod_ele
class
DropoutRowwise
(
nn
.
Module
):
def
__init__
(
self
,
p
):
super
(
DropoutRowwise
,
self
).
__init__
()
self
.
p
=
p
self
.
dropout
=
nn
.
Dropout
(
p
=
p
)
def
forward
(
self
,
x
):
dropout_mask
=
torch
.
ones_like
(
x
[:,
0
:
1
,
:,
:])
dropout_mask
=
self
.
dropout
(
dropout_mask
)
return
dropout_mask
*
x
class
DropoutColumnwise
(
nn
.
Module
):
def
__init__
(
self
,
p
):
super
(
DropoutColumnwise
,
self
).
__init__
()
self
.
p
=
p
self
.
dropout
=
nn
.
Dropout
(
p
=
p
)
def
forward
(
self
,
x
):
dropout_mask
=
torch
.
ones_like
(
x
[:,
:,
0
:
1
,
:])
dropout_mask
=
self
.
dropout
(
dropout_mask
)
return
dropout_mask
*
x
class
Transition
(
nn
.
Module
):
def
__init__
(
self
,
d
,
n
=
4
):
super
(
Transition
,
self
).
__init__
()
self
.
norm
=
LayerNorm
(
d
)
self
.
linear1
=
Linear
(
d
,
n
*
d
,
initializer
=
'relu'
)
self
.
linear2
=
Linear
(
n
*
d
,
d
,
initializer
=
'zeros'
)
def
forward
(
self
,
src
):
x
=
self
.
norm
(
src
)
x
=
self
.
linear2
(
F
.
relu
(
self
.
linear1
(
x
)))
return
src
+
x
class
OutProductMean
(
nn
.
Module
):
def
__init__
(
self
,
n_feat
=
64
,
n_feat_out
=
128
,
n_feat_proj
=
32
):
super
(
OutProductMean
,
self
).
__init__
()
self
.
layernormM
=
LayerNorm
(
n_feat
)
self
.
linear_a
=
Linear
(
n_feat
,
n_feat_proj
)
self
.
linear_b
=
Linear
(
n_feat
,
n_feat_proj
)
self
.
o_linear
=
Linear
(
n_feat_proj
*
n_feat_proj
,
n_feat_out
,
initializer
=
'zero'
,
use_bias
=
True
)
def
forward
(
self
,
M
):
M
=
self
.
layernormM
(
M
)
left_act
=
self
.
linear_a
(
M
)
right_act
=
self
.
linear_b
(
M
)
o
=
torch
.
einsum
(
'bsid,bsje->bijde'
,
left_act
,
right_act
).
contiguous
()
# O = rearrange(O, 'b i j d e -> b i j (d e)')
o
=
o
.
reshape
(
o
.
shape
[
0
],
o
.
shape
[
1
],
o
.
shape
[
2
],
-
1
)
Z
=
self
.
o_linear
(
o
)
return
Z
class
Linear
(
nn
.
Linear
):
"""
A Linear layer with built-in nonstandard initializations. Called just
like torch.nn.Linear.
Implements the initializers in 1.11.4, plus some additional ones found
in the code.
"""
def
__init__
(
self
,
feature_in
:
int
,
feature_out
:
int
,
initializer
:
str
=
'linear'
,
use_bias
:
bool
=
True
,
bias_init
:
float
=
0.
,
):
super
(
Linear
,
self
).
__init__
(
feature_in
,
feature_out
,
bias
=
use_bias
)
self
.
use_bias
=
use_bias
if
initializer
==
'linear'
:
glorot_uniform_af
(
self
.
weight
,
gain
=
1.0
)
elif
initializer
==
'relu'
:
glorot_uniform_af
(
self
.
weight
,
gain
=
2.0
)
elif
initializer
==
'zeros'
:
nn
.
init
.
zeros_
(
self
.
weight
)
if
self
.
use_bias
:
with
torch
.
no_grad
():
self
.
bias
.
fill_
(
bias_init
)
class
SelfAttention
(
nn
.
Module
):
"""
Multi-Head SelfAttention dealing with [batch_size1, batch_size2, len, dim] tensors
"""
def
__init__
(
self
,
qkv_dim
,
c
,
n_head
,
out_dim
,
gating
=
True
,
last_bias_fuse
=
False
):
super
(
SelfAttention
,
self
).
__init__
()
self
.
qkv_dim
=
qkv_dim
self
.
c
=
c
self
.
n_head
=
n_head
self
.
out_dim
=
out_dim
self
.
gating
=
gating
self
.
last_bias_fuse
=
last_bias_fuse
self
.
scaling
=
self
.
c
**
(
-
0.5
)
# self.to_qkv = Linear(qkv_dim, 3 * n_head * c, initializer='linear')
self
.
to_q
=
Linear
(
qkv_dim
,
n_head
*
c
,
initializer
=
'linear'
,
use_bias
=
False
)
self
.
to_k
=
Linear
(
qkv_dim
,
n_head
*
c
,
initializer
=
'linear'
,
use_bias
=
False
)
self
.
to_v
=
Linear
(
qkv_dim
,
n_head
*
c
,
initializer
=
'linear'
,
use_bias
=
False
)
if
gating
:
self
.
gating_bias
=
nn
.
parameter
.
Parameter
(
data
=
torch
.
ones
((
n_head
*
c
,)))
self
.
gating_linear
=
Linear
(
qkv_dim
,
n_head
*
c
,
initializer
=
'zero'
,
use_bias
=
False
)
self
.
o_linear
=
Linear
(
n_head
*
c
,
out_dim
,
initializer
=
'zero'
,
use_bias
=
(
not
last_bias_fuse
))
def
forward
(
self
,
in_data
,
nonbatched_bias
=
None
):
"""
:param in_data: [batch_size1, batch_size2, len_qkv, qkv_dim]
:param bias: None or [batch_size1, batch_size2, n_head, len_q, len_kv]
:param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv]
"""
# qkv = self.to_qkv(in_data).chunk(3, dim=-1)
# q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), qkv)
q
=
self
.
to_q
(
in_data
)
k
=
self
.
to_k
(
in_data
)
v
=
self
.
to_v
(
in_data
)
# q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head),
# [q, k, v])
q
,
k
,
v
=
map
(
lambda
t
:
t
.
view
(
t
.
shape
[
0
],
t
.
shape
[
1
],
t
.
shape
[
2
],
self
.
n_head
,
-
1
).
permute
(
0
,
1
,
3
,
2
,
4
),
[
q
,
k
,
v
])
q
=
q
*
self
.
scaling
logits
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
))
if
nonbatched_bias
is
not
None
:
logits
+=
nonbatched_bias
.
unsqueeze
(
1
)
weights
=
torch
.
softmax
(
logits
,
dim
=-
1
)
# weights = softmax(logits)
weighted_avg
=
torch
.
matmul
(
weights
,
v
)
# weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)')
weighted_avg
=
weighted_avg
.
permute
(
0
,
1
,
3
,
2
,
4
)
weighted_avg
=
weighted_avg
.
reshape
(
weighted_avg
.
shape
[
0
],
weighted_avg
.
shape
[
1
],
weighted_avg
.
shape
[
2
],
-
1
)
if
self
.
gating
:
gate_values
=
self
.
gating_linear
(
in_data
)
weighted_avg
=
bias_sigmod_ele
(
gate_values
,
self
.
gating_bias
,
weighted_avg
)
output
=
self
.
o_linear
(
weighted_avg
)
return
output
tests/test_autochunk/evoformer/triangle.py
0 → 100644
View file @
93f62dd1
import
math
import
torch
import
torch.nn
as
nn
from
torch.nn
import
LayerNorm
from
.kernel
import
bias_dropout_add
,
bias_ele_dropout_residual
from
.ops
import
Linear
,
SelfAttention
,
Transition
def
permute_final_dims
(
tensor
,
inds
):
zero_index
=
-
1
*
len
(
inds
)
first_inds
=
list
(
range
(
len
(
tensor
.
shape
[:
zero_index
])))
return
tensor
.
permute
(
first_inds
+
[
zero_index
+
i
for
i
in
inds
])
class
TriangleMultiplicationOutgoing
(
nn
.
Module
):
def
__init__
(
self
,
d_pair
,
p_drop
,
c
=
128
):
super
(
TriangleMultiplicationOutgoing
,
self
).
__init__
()
self
.
d_pair
=
d_pair
self
.
c
=
c
self
.
layernorm1
=
LayerNorm
(
d_pair
)
self
.
left_projection
=
Linear
(
d_pair
,
c
)
self
.
right_projection
=
Linear
(
d_pair
,
c
)
self
.
left_gate
=
Linear
(
d_pair
,
c
,
initializer
=
'zeros'
,
bias_init
=
1.
)
self
.
right_gate
=
Linear
(
d_pair
,
c
,
initializer
=
'zeros'
,
bias_init
=
1.
)
self
.
output_gate
=
Linear
(
d_pair
,
d_pair
,
initializer
=
'zeros'
,
bias_init
=
1.
)
self
.
layernorm2
=
LayerNorm
(
c
)
self
.
output_projection
=
Linear
(
d_pair
,
d_pair
,
initializer
=
'zeros'
,
use_bias
=
False
)
self
.
output_bias
=
nn
.
parameter
.
Parameter
(
data
=
torch
.
zeros
((
d_pair
,)),
requires_grad
=
True
)
self
.
p_drop
=
p_drop
def
forward
(
self
,
Z_raw
):
Z
=
self
.
layernorm1
(
Z_raw
)
left_proj_act
=
self
.
left_projection
(
Z
)
right_proj_act
=
self
.
right_projection
(
Z
)
left_proj_act
=
left_proj_act
*
torch
.
sigmoid
(
self
.
left_gate
(
Z
))
right_proj_act
=
right_proj_act
*
torch
.
sigmoid
(
self
.
right_gate
(
Z
))
g
=
torch
.
sigmoid
(
self
.
output_gate
(
Z
))
# p = torch.matmul(
# permute_final_dims(left_proj_act, (2, 0, 1)),
# permute_final_dims(right_proj_act, (2, 1, 0)),
# )
# ab = permute_final_dims(p, (1, 2, 0))
ab
=
torch
.
einsum
(
'bikd,bjkd->bijd'
,
left_proj_act
,
right_proj_act
)
ab
=
self
.
output_projection
(
self
.
layernorm2
(
ab
))
dropout_mask
=
torch
.
ones_like
(
Z
[:,
0
:
1
,
:,
:]).
to
(
Z
.
device
).
to
(
Z
.
dtype
)
return
bias_ele_dropout_residual
(
ab
,
self
.
output_bias
,
g
,
dropout_mask
,
Z_raw
,
prob
=
self
.
p_drop
)
class
TriangleMultiplicationIncoming
(
nn
.
Module
):
def
__init__
(
self
,
d_pair
,
p_drop
,
c
=
128
):
super
(
TriangleMultiplicationIncoming
,
self
).
__init__
()
self
.
d_pair
=
d_pair
self
.
c
=
c
self
.
layernorm1
=
LayerNorm
(
d_pair
)
self
.
left_projection
=
Linear
(
d_pair
,
c
)
self
.
right_projection
=
Linear
(
d_pair
,
c
)
self
.
left_gate
=
Linear
(
d_pair
,
c
,
initializer
=
'zeros'
,
bias_init
=
1.
)
self
.
right_gate
=
Linear
(
d_pair
,
c
,
initializer
=
'zeros'
,
bias_init
=
1.
)
self
.
output_gate
=
Linear
(
d_pair
,
d_pair
,
initializer
=
'zeros'
,
bias_init
=
1.
)
self
.
layernorm2
=
LayerNorm
(
c
)
self
.
output_projection
=
Linear
(
d_pair
,
d_pair
,
initializer
=
'zeros'
,
use_bias
=
False
)
self
.
output_bias
=
nn
.
parameter
.
Parameter
(
data
=
torch
.
zeros
((
d_pair
,)),
requires_grad
=
True
)
self
.
p_drop
=
p_drop
def
forward
(
self
,
Z_raw
):
Z
=
self
.
layernorm1
(
Z_raw
)
left_proj_act
=
self
.
left_projection
(
Z
)
right_proj_act
=
self
.
right_projection
(
Z
)
left_proj_act
=
left_proj_act
*
torch
.
sigmoid
(
self
.
left_gate
(
Z
))
right_proj_act
=
right_proj_act
*
torch
.
sigmoid
(
self
.
right_gate
(
Z
))
g
=
torch
.
sigmoid
(
self
.
output_gate
(
Z
))
# p = torch.matmul(
# permute_final_dims(left_proj_act, (2, 1, 0)),
# permute_final_dims(right_proj_act, (2, 0, 1)),
# )
# ab = permute_final_dims(p, (1, 2, 0))
ab
=
torch
.
einsum
(
'bkid,bkjd->bijd'
,
left_proj_act
,
right_proj_act
)
ab
=
self
.
output_projection
(
self
.
layernorm2
(
ab
))
dropout_mask
=
torch
.
ones_like
(
Z
[:,
0
:
1
,
:,
:]).
to
(
Z
.
device
).
to
(
Z
.
dtype
)
return
bias_ele_dropout_residual
(
ab
,
self
.
output_bias
,
g
,
dropout_mask
,
Z_raw
,
prob
=
self
.
p_drop
)
class
TriangleAttentionStartingNode
(
nn
.
Module
):
def
__init__
(
self
,
d_pair
,
p_drop
,
c
=
32
,
n_head
=
4
):
super
(
TriangleAttentionStartingNode
,
self
).
__init__
()
self
.
d_pair
=
d_pair
self
.
c
=
c
self
.
n_head
=
n_head
self
.
p_drop
=
p_drop
self
.
layernorm1
=
LayerNorm
(
d_pair
)
_init_weights
=
torch
.
nn
.
init
.
normal_
(
torch
.
zeros
([
d_pair
,
n_head
]),
std
=
1.0
/
math
.
sqrt
(
d_pair
))
self
.
linear_b_weights
=
nn
.
parameter
.
Parameter
(
data
=
_init_weights
)
self
.
attention
=
SelfAttention
(
qkv_dim
=
d_pair
,
c
=
c
,
n_head
=
n_head
,
out_dim
=
d_pair
,
gating
=
True
,
last_bias_fuse
=
True
)
self
.
out_bias
=
nn
.
parameter
.
Parameter
(
data
=
torch
.
zeros
((
d_pair
,)),
requires_grad
=
True
)
def
forward
(
self
,
Z_raw
):
Z
=
self
.
layernorm1
(
Z_raw
)
b
=
torch
.
einsum
(
'bqkc,ch->bhqk'
,
Z
,
self
.
linear_b_weights
)
Z
=
self
.
attention
(
Z
,
b
)
dropout_mask
=
torch
.
ones_like
(
Z
[:,
0
:
1
,
:,
:]).
to
(
Z
.
device
).
to
(
Z
.
dtype
)
return
bias_dropout_add
(
Z
,
self
.
out_bias
,
dropout_mask
,
Z_raw
,
prob
=
self
.
p_drop
)
class
TriangleAttentionEndingNode
(
nn
.
Module
):
def
__init__
(
self
,
d_pair
,
p_drop
,
c
=
32
,
n_head
=
4
):
super
(
TriangleAttentionEndingNode
,
self
).
__init__
()
self
.
d_pair
=
d_pair
self
.
c
=
c
self
.
n_head
=
n_head
self
.
p_drop
=
p_drop
self
.
layernorm1
=
LayerNorm
(
d_pair
)
_init_weights
=
torch
.
nn
.
init
.
normal_
(
torch
.
zeros
([
d_pair
,
n_head
]),
std
=
1.0
/
math
.
sqrt
(
d_pair
))
self
.
linear_b_weights
=
nn
.
parameter
.
Parameter
(
data
=
_init_weights
)
self
.
attention
=
SelfAttention
(
qkv_dim
=
d_pair
,
c
=
c
,
n_head
=
n_head
,
out_dim
=
d_pair
,
gating
=
True
,
last_bias_fuse
=
True
)
self
.
out_bias
=
nn
.
parameter
.
Parameter
(
data
=
torch
.
zeros
((
d_pair
,)),
requires_grad
=
True
)
def
forward
(
self
,
Z_raw
):
Z
=
Z_raw
.
transpose
(
-
2
,
-
3
)
Z
=
self
.
layernorm1
(
Z
)
b
=
torch
.
einsum
(
'bqkc,ch->bhqk'
,
Z
,
self
.
linear_b_weights
)
Z
=
self
.
attention
(
Z
,
b
)
Z
=
Z
.
transpose
(
-
2
,
-
3
)
dropout_mask
=
torch
.
ones_like
(
Z
[:,
:,
0
:
1
,
:]).
to
(
Z
.
device
).
to
(
Z
.
dtype
)
return
bias_dropout_add
(
Z
,
self
.
out_bias
,
dropout_mask
,
Z_raw
,
prob
=
self
.
p_drop
)
class
PairStack
(
nn
.
Module
):
def
__init__
(
self
,
d_pair
,
p_drop
=
0.25
):
super
(
PairStack
,
self
).
__init__
()
self
.
TriangleMultiplicationOutgoing
=
TriangleMultiplicationOutgoing
(
d_pair
,
p_drop
=
p_drop
)
self
.
TriangleMultiplicationIncoming
=
TriangleMultiplicationIncoming
(
d_pair
,
p_drop
=
p_drop
)
self
.
TriangleAttentionStartingNode
=
TriangleAttentionStartingNode
(
d_pair
,
p_drop
=
p_drop
)
self
.
TriangleAttentionEndingNode
=
TriangleAttentionEndingNode
(
d_pair
,
p_drop
=
p_drop
)
self
.
PairTransition
=
Transition
(
d
=
d_pair
)
def
forward
(
self
,
pair
):
pair
=
self
.
TriangleMultiplicationOutgoing
(
pair
)
pair
=
self
.
TriangleMultiplicationIncoming
(
pair
)
pair
=
self
.
TriangleAttentionStartingNode
(
pair
)
pair
=
self
.
TriangleAttentionEndingNode
(
pair
)
pair
=
self
.
PairTransition
(
pair
)
return
pair
tests/test_autochunk/openfold/checkpointing.py
0 → 100644
View file @
93f62dd1
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch.utils.checkpoint
from
typing
import
Any
,
Tuple
,
List
,
Callable
,
Optional
BLOCK_ARG
=
Any
BLOCK_ARGS
=
List
[
BLOCK_ARG
]
def
get_checkpoint_fn
():
checkpoint
=
torch
.
utils
.
checkpoint
.
checkpoint
return
checkpoint
@
torch
.
jit
.
ignore
def
checkpoint_blocks
(
blocks
:
List
[
Callable
],
args
:
BLOCK_ARGS
,
blocks_per_ckpt
:
Optional
[
int
],
)
->
BLOCK_ARGS
:
"""
Chunk a list of blocks and run each chunk with activation
checkpointing. We define a "block" as a callable whose only inputs are
the outputs of the previous block.
Implements Subsection 1.11.8
Args:
blocks:
List of blocks
args:
Tuple of arguments for the first block.
blocks_per_ckpt:
Size of each chunk. A higher value corresponds to fewer
checkpoints, and trades memory for speed. If None, no checkpointing
is performed.
Returns:
The output of the final block
"""
def
wrap
(
a
):
return
(
a
,)
if
type
(
a
)
is
not
tuple
else
a
def
exec
(
b
,
a
):
for
block
in
b
:
a
=
wrap
(
block
(
*
a
))
return
a
def
chunker
(
s
,
e
):
def
exec_sliced
(
*
a
):
return
exec
(
blocks
[
s
:
e
],
a
)
return
exec_sliced
# Avoids mishaps when the blocks take just one argument
args
=
wrap
(
args
)
if
blocks_per_ckpt
is
None
:
return
exec
(
blocks
,
args
)
elif
blocks_per_ckpt
<
1
or
blocks_per_ckpt
>
len
(
blocks
):
raise
ValueError
(
"blocks_per_ckpt must be between 1 and len(blocks)"
)
checkpoint
=
get_checkpoint_fn
()
for
s
in
range
(
0
,
len
(
blocks
),
blocks_per_ckpt
):
e
=
s
+
blocks_per_ckpt
args
=
checkpoint
(
chunker
(
s
,
e
),
*
args
)
args
=
wrap
(
args
)
return
args
tests/test_autochunk/openfold/dropout.py
0 → 100644
View file @
93f62dd1
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch.nn
as
nn
from
functools
import
partialmethod
from
typing
import
Union
,
List
class
Dropout
(
nn
.
Module
):
"""
Implementation of dropout with the ability to share the dropout mask
along a particular dimension.
If not in training mode, this module computes the identity function.
"""
def
__init__
(
self
,
r
:
float
,
batch_dim
:
Union
[
int
,
List
[
int
]]):
"""
Args:
r:
Dropout rate
batch_dim:
Dimension(s) along which the dropout mask is shared
"""
super
(
Dropout
,
self
).
__init__
()
self
.
r
=
r
if
type
(
batch_dim
)
==
int
:
batch_dim
=
[
batch_dim
]
self
.
batch_dim
=
batch_dim
self
.
dropout
=
nn
.
Dropout
(
self
.
r
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Args:
x:
Tensor to which dropout is applied. Can have any shape
compatible with self.batch_dim
"""
shape
=
list
(
x
.
shape
)
if
self
.
batch_dim
is
not
None
:
for
bd
in
self
.
batch_dim
:
shape
[
bd
]
=
1
mask
=
x
.
new_ones
(
shape
)
mask
=
self
.
dropout
(
mask
)
x
*=
mask
return
x
class
DropoutRowwise
(
Dropout
):
"""
Convenience class for rowwise dropout as described in subsection
1.11.6.
"""
__init__
=
partialmethod
(
Dropout
.
__init__
,
batch_dim
=-
3
)
class
DropoutColumnwise
(
Dropout
):
"""
Convenience class for columnwise dropout as described in subsection
1.11.6.
"""
__init__
=
partialmethod
(
Dropout
.
__init__
,
batch_dim
=-
2
)
tests/test_autochunk/openfold/evoformer.py
0 → 100644
View file @
93f62dd1
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
torch
import
torch.nn
as
nn
from
typing
import
Tuple
,
Optional
from
functools
import
partial
from
.primitives
import
Linear
,
LayerNorm
from
.dropout
import
DropoutRowwise
,
DropoutColumnwise
from
.msa
import
(
MSARowAttentionWithPairBias
,
MSAColumnAttention
,
MSAColumnGlobalAttention
,
)
from
.outer_product_mean
import
OuterProductMean
from
.pair_transition
import
PairTransition
from
.triangular_attention
import
(
TriangleAttentionStartingNode
,
TriangleAttentionEndingNode
,
)
from
.triangular_multiplicative_update
import
(
TriangleMultiplicationOutgoing
,
TriangleMultiplicationIncoming
,
)
from
.checkpointing
import
checkpoint_blocks
,
get_checkpoint_fn
from
.tensor_utils
import
chunk_layer
class
MSATransition
(
nn
.
Module
):
"""
Feed-forward network applied to MSA activations after attention.
Implements Algorithm 9
"""
def
__init__
(
self
,
c_m
,
n
):
"""
Args:
c_m:
MSA channel dimension
n:
Factor multiplied to c_m to obtain the hidden channel
dimension
"""
super
(
MSATransition
,
self
).
__init__
()
self
.
c_m
=
c_m
self
.
n
=
n
self
.
layer_norm
=
LayerNorm
(
self
.
c_m
)
self
.
linear_1
=
Linear
(
self
.
c_m
,
self
.
n
*
self
.
c_m
,
init
=
"relu"
)
self
.
relu
=
nn
.
ReLU
()
self
.
linear_2
=
Linear
(
self
.
n
*
self
.
c_m
,
self
.
c_m
,
init
=
"final"
)
def
_transition
(
self
,
m
,
mask
):
m
=
self
.
linear_1
(
m
)
m
=
self
.
relu
(
m
)
m
=
self
.
linear_2
(
m
)
*
mask
return
m
@
torch
.
jit
.
ignore
def
_chunk
(
self
,
m
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
chunk_size
:
int
,
)
->
torch
.
Tensor
:
return
chunk_layer
(
self
.
_transition
,
{
"m"
:
m
,
"mask"
:
mask
},
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
)
def
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA activation
mask:
[*, N_seq, N_res, C_m] MSA mask
Returns:
m:
[*, N_seq, N_res, C_m] MSA activation update
"""
# DISCREPANCY: DeepMind forgets to apply the MSA mask here.
if
mask
is
None
:
mask
=
m
.
new_ones
(
m
.
shape
[:
-
1
])
# [*, N_seq, N_res, 1]
mask
=
mask
.
unsqueeze
(
-
1
)
m
=
self
.
layer_norm
(
m
)
if
chunk_size
is
not
None
:
m
=
self
.
_chunk
(
m
,
mask
,
chunk_size
)
else
:
m
=
self
.
_transition
(
m
,
mask
)
return
m
class
EvoformerBlockCore
(
nn
.
Module
):
def
__init__
(
self
,
c_m
:
int
,
c_z
:
int
,
c_hidden_opm
:
int
,
c_hidden_mul
:
int
,
c_hidden_pair_att
:
int
,
no_heads_msa
:
int
,
no_heads_pair
:
int
,
transition_n
:
int
,
pair_dropout
:
float
,
inf
:
float
,
eps
:
float
,
_is_extra_msa_stack
:
bool
=
False
,
is_multimer
:
bool
=
False
,
):
super
(
EvoformerBlockCore
,
self
).
__init__
()
self
.
is_multimer
=
is_multimer
self
.
msa_transition
=
MSATransition
(
c_m
=
c_m
,
n
=
transition_n
,
)
self
.
outer_product_mean
=
OuterProductMean
(
c_m
,
c_z
,
c_hidden_opm
,
)
self
.
tri_mul_out
=
TriangleMultiplicationOutgoing
(
c_z
,
c_hidden_mul
,
)
self
.
tri_mul_in
=
TriangleMultiplicationIncoming
(
c_z
,
c_hidden_mul
,
)
self
.
tri_att_start
=
TriangleAttentionStartingNode
(
c_z
,
c_hidden_pair_att
,
no_heads_pair
,
inf
=
inf
,
)
self
.
tri_att_end
=
TriangleAttentionEndingNode
(
c_z
,
c_hidden_pair_att
,
no_heads_pair
,
inf
=
inf
,
)
self
.
pair_transition
=
PairTransition
(
c_z
,
transition_n
,
)
self
.
ps_dropout_row_layer
=
DropoutRowwise
(
pair_dropout
)
self
.
ps_dropout_col_layer
=
DropoutColumnwise
(
pair_dropout
)
def
forward
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# DeepMind doesn't mask these transitions in the source, so _mask_trans
# should be disabled to better approximate the exact activations of
# the original.
m
=
m
+
self
.
msa_transition
(
m
,
chunk_size
=
chunk_size
)
z
=
z
+
self
.
outer_product_mean
(
m
,
chunk_size
=
chunk_size
)
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_mul_out
(
z
))
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_mul_in
(
z
))
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_att_start
(
z
,
chunk_size
=
chunk_size
)
)
z
=
z
+
self
.
ps_dropout_col_layer
(
self
.
tri_att_end
(
z
,
chunk_size
=
chunk_size
)
)
z
=
z
+
self
.
pair_transition
(
z
,
chunk_size
=
chunk_size
)
return
m
,
z
class
EvoformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
c_m
:
int
,
c_z
:
int
,
c_hidden_msa_att
:
int
,
c_hidden_opm
:
int
,
c_hidden_mul
:
int
,
c_hidden_pair_att
:
int
,
no_heads_msa
:
int
,
no_heads_pair
:
int
,
transition_n
:
int
,
msa_dropout
:
float
,
pair_dropout
:
float
,
inf
:
float
,
eps
:
float
,
is_multimer
:
bool
,
):
super
(
EvoformerBlock
,
self
).
__init__
()
self
.
msa_att_row
=
MSARowAttentionWithPairBias
(
c_m
=
c_m
,
c_z
=
c_z
,
c_hidden
=
c_hidden_msa_att
,
no_heads
=
no_heads_msa
,
inf
=
inf
,
)
self
.
msa_att_col
=
MSAColumnAttention
(
c_m
,
c_hidden_msa_att
,
no_heads_msa
,
inf
=
inf
,
)
self
.
msa_dropout_layer
=
DropoutRowwise
(
msa_dropout
)
self
.
core
=
EvoformerBlockCore
(
c_m
=
c_m
,
c_z
=
c_z
,
c_hidden_opm
=
c_hidden_opm
,
c_hidden_mul
=
c_hidden_mul
,
c_hidden_pair_att
=
c_hidden_pair_att
,
no_heads_msa
=
no_heads_msa
,
no_heads_pair
=
no_heads_pair
,
transition_n
=
transition_n
,
pair_dropout
=
pair_dropout
,
inf
=
inf
,
eps
=
eps
,
)
self
.
outer_product_mean
=
OuterProductMean
(
c_m
,
c_z
,
c_hidden_opm
,
)
self
.
is_multimer
=
is_multimer
def
forward
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
m
=
m
+
self
.
msa_dropout_layer
(
self
.
msa_att_row
(
m
,
z
=
z
,
chunk_size
=
chunk_size
)
)
m
=
m
+
self
.
msa_att_col
(
m
,
chunk_size
=
chunk_size
)
m
,
z
=
self
.
core
(
m
,
z
,
chunk_size
=
chunk_size
,
)
return
m
,
z
class
EvoformerStack
(
nn
.
Module
):
"""
Main Evoformer trunk.
Implements Algorithm 6.
"""
def
__init__
(
self
,
c_m
:
int
,
c_z
:
int
,
c_hidden_msa_att
:
int
,
c_hidden_opm
:
int
,
c_hidden_mul
:
int
,
c_hidden_pair_att
:
int
,
c_s
:
int
,
no_heads_msa
:
int
,
no_heads_pair
:
int
,
no_blocks
:
int
,
transition_n
:
int
,
msa_dropout
:
float
,
pair_dropout
:
float
,
blocks_per_ckpt
:
int
,
inf
:
float
,
eps
:
float
,
clear_cache_between_blocks
:
bool
=
False
,
is_multimer
:
bool
=
False
,
**
kwargs
,
):
"""
Args:
c_m:
MSA channel dimension
c_z:
Pair channel dimension
c_hidden_msa_att:
Hidden dimension in MSA attention
c_hidden_opm:
Hidden dimension in outer product mean module
c_hidden_mul:
Hidden dimension in multiplicative updates
c_hidden_pair_att:
Hidden dimension in triangular attention
c_s:
Channel dimension of the output "single" embedding
no_heads_msa:
Number of heads used for MSA attention
no_heads_pair:
Number of heads used for pair attention
no_blocks:
Number of Evoformer blocks in the stack
transition_n:
Factor by which to multiply c_m to obtain the MSATransition
hidden dimension
msa_dropout:
Dropout rate for MSA activations
pair_dropout:
Dropout used for pair activations
blocks_per_ckpt:
Number of Evoformer blocks in each activation checkpoint
clear_cache_between_blocks:
Whether to clear CUDA's GPU memory cache between blocks of the
stack. Slows down each block but can reduce fragmentation
"""
super
(
EvoformerStack
,
self
).
__init__
()
self
.
blocks_per_ckpt
=
blocks_per_ckpt
self
.
clear_cache_between_blocks
=
clear_cache_between_blocks
self
.
blocks
=
nn
.
ModuleList
()
for
_
in
range
(
no_blocks
):
block
=
EvoformerBlock
(
c_m
=
c_m
,
c_z
=
c_z
,
c_hidden_msa_att
=
c_hidden_msa_att
,
c_hidden_opm
=
c_hidden_opm
,
c_hidden_mul
=
c_hidden_mul
,
c_hidden_pair_att
=
c_hidden_pair_att
,
no_heads_msa
=
no_heads_msa
,
no_heads_pair
=
no_heads_pair
,
transition_n
=
transition_n
,
msa_dropout
=
msa_dropout
,
pair_dropout
=
pair_dropout
,
inf
=
inf
,
eps
=
eps
,
is_multimer
=
is_multimer
,
)
self
.
blocks
.
append
(
block
)
self
.
linear
=
Linear
(
c_m
,
c_s
)
def
forward
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
int
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
msa_mask:
[*, N_seq, N_res] MSA mask
pair_mask:
[*, N_res, N_res] pair mask
Returns:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
s:
[*, N_res, C_s] single embedding (or None if extra MSA stack)
"""
blocks
=
[
partial
(
b
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
_mask_trans
=
_mask_trans
,
)
for
b
in
self
.
blocks
]
if
(
self
.
clear_cache_between_blocks
):
def
block_with_cache_clear
(
block
,
*
args
):
torch
.
cuda
.
empty_cache
()
return
block
(
*
args
)
blocks
=
[
partial
(
block_with_cache_clear
,
b
)
for
b
in
blocks
]
m
,
z
=
checkpoint_blocks
(
blocks
,
args
=
(
m
,
z
),
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
self
.
training
else
None
,
)
s
=
self
.
linear
(
m
[...,
0
,
:,
:])
return
m
,
z
,
s
tests/test_autochunk/openfold/msa.py
0 → 100644
View file @
93f62dd1
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
torch
import
torch.nn
as
nn
from
typing
import
Optional
,
List
,
Tuple
from
.primitives
import
(
Linear
,
LayerNorm
,
Attention
,
GlobalAttention
,
_attention_chunked_trainable
,
)
from
.checkpointing
import
get_checkpoint_fn
from
.tensor_utils
import
(
chunk_layer
,
permute_final_dims
,
flatten_final_dims
,
)
class
MSAAttention
(
nn
.
Module
):
def
__init__
(
self
,
c_in
,
c_hidden
,
no_heads
,
pair_bias
=
False
,
c_z
=
None
,
inf
=
1e9
,
):
"""
Args:
c_in:
Input channel dimension
c_hidden:
Per-head hidden channel dimension
no_heads:
Number of attention heads
pair_bias:
Whether to use pair embedding bias
c_z:
Pair embedding channel dimension. Ignored unless pair_bias
is true
inf:
A large number to be used in computing the attention mask
"""
super
(
MSAAttention
,
self
).
__init__
()
self
.
c_in
=
c_in
self
.
c_hidden
=
c_hidden
self
.
no_heads
=
no_heads
self
.
pair_bias
=
pair_bias
self
.
c_z
=
c_z
self
.
inf
=
inf
self
.
layer_norm_m
=
LayerNorm
(
self
.
c_in
)
self
.
layer_norm_z
=
None
self
.
linear_z
=
None
if
self
.
pair_bias
:
self
.
layer_norm_z
=
LayerNorm
(
self
.
c_z
)
self
.
linear_z
=
Linear
(
self
.
c_z
,
self
.
no_heads
,
bias
=
False
,
init
=
"normal"
)
self
.
mha
=
Attention
(
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_hidden
,
self
.
no_heads
)
@
torch
.
jit
.
ignore
def
_chunk
(
self
,
m
:
torch
.
Tensor
,
biases
:
List
[
torch
.
Tensor
],
chunk_size
:
int
,
)
->
torch
.
Tensor
:
return
chunk_layer
(
self
.
mha
,
{
"q_x"
:
m
,
"kv_x"
:
m
,
"biases"
:
biases
},
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
)
def
_prep_inputs
(
self
,
m
:
torch
.
Tensor
,
z
:
Optional
[
torch
.
Tensor
],
mask
:
Optional
[
torch
.
Tensor
]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
# [*, N_seq, N_res, C_m]
m
=
self
.
layer_norm_m
(
m
)
n_seq
,
n_res
=
m
.
shape
[
-
3
:
-
1
]
if
mask
is
None
:
# [*, N_seq, N_res]
mask
=
m
.
new_ones
(
m
.
shape
[:
-
3
]
+
(
n_seq
,
n_res
),
)
# [*, N_seq, 1, 1, N_res]
mask_bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
None
,
:]
# This step simply returns a larger view of the bias, and does not
# consume additional memory.
# [*, N_seq, no_heads, N_res, N_res]
#bias = bias.expand(
# ((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1)
#)
if
(
self
.
pair_bias
and
z
is
not
None
and
# For the
self
.
layer_norm_z
is
not
None
and
# benefit of
self
.
linear_z
is
not
None
# TorchScript
):
# [*, N_res, N_res, C_z]
z
=
self
.
layer_norm_z
(
z
)
# [*, N_res, N_res, no_heads]
z
=
self
.
linear_z
(
z
)
# [*, 1, no_heads, N_res, N_res]
z
=
permute_final_dims
(
z
,
(
2
,
0
,
1
)).
unsqueeze
(
-
4
)
return
m
,
mask_bias
,
z
def
forward
(
self
,
m
:
torch
.
Tensor
,
z
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
_chunk_logits
:
Optional
[
int
]
=
None
,
_checkpoint_chunks
:
Optional
[
bool
]
=
None
,
)
->
torch
.
Tensor
:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding. Required only if
pair_bias is True
mask:
[*, N_seq, N_res] MSA mask
chunk_size:
Size of chunks into which the inputs are split along their
batch dimensions. A low value decreases memory overhead at the
cost of slower execution. Chunking is not performed by default.
"""
m
,
mask_bias
,
z
=
self
.
_prep_inputs
(
m
,
z
,
mask
)
biases
=
[
mask_bias
]
if
(
z
is
not
None
):
biases
.
append
(
z
)
if
chunk_size
is
not
None
:
m
=
self
.
_chunk
(
m
,
biases
,
chunk_size
)
else
:
m
=
self
.
mha
(
q_x
=
m
,
kv_x
=
m
,
biases
=
biases
)
return
m
class
MSARowAttentionWithPairBias
(
MSAAttention
):
"""
Implements Algorithm 7.
"""
def
__init__
(
self
,
c_m
,
c_z
,
c_hidden
,
no_heads
,
inf
=
1e9
):
"""
Args:
c_m:
Input channel dimension
c_z:
Pair embedding channel dimension
c_hidden:
Per-head hidden channel dimension
no_heads:
Number of attention heads
inf:
Large number used to construct attention masks
"""
super
(
MSARowAttentionWithPairBias
,
self
).
__init__
(
c_m
,
c_hidden
,
no_heads
,
pair_bias
=
True
,
c_z
=
c_z
,
inf
=
inf
,
)
class
MSAColumnAttention
(
nn
.
Module
):
"""
Implements Algorithm 8.
By rights, this should also be a subclass of MSAAttention. Alas,
most inheritance isn't supported by TorchScript.
"""
def
__init__
(
self
,
c_m
,
c_hidden
,
no_heads
,
inf
=
1e9
):
"""
Args:
c_m:
MSA channel dimension
c_hidden:
Per-head hidden channel dimension
no_heads:
Number of attention heads
inf:
Large number used to construct attention masks
"""
super
(
MSAColumnAttention
,
self
).
__init__
()
self
.
c_m
=
c_m
self
.
c_hidden
=
c_hidden
self
.
no_heads
=
no_heads
self
.
inf
=
inf
self
.
_msa_att
=
MSAAttention
(
c_in
=
c_m
,
c_hidden
=
c_hidden
,
no_heads
=
no_heads
,
pair_bias
=
False
,
c_z
=
None
,
inf
=
inf
,
)
def
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
mask:
[*, N_seq, N_res] MSA mask
chunk_size:
Size of chunks into which the inputs are split along their
batch dimensions. A low value decreases memory overhead at the
cost of slower execution. Chunking is not performed by default.
"""
# [*, N_res, N_seq, C_in]
m
=
m
.
transpose
(
-
2
,
-
3
)
m
=
self
.
_msa_att
(
m
,
chunk_size
=
chunk_size
)
# [*, N_seq, N_res, C_in]
m
=
m
.
transpose
(
-
2
,
-
3
)
return
m
class
MSAColumnGlobalAttention
(
nn
.
Module
):
def
__init__
(
self
,
c_in
,
c_hidden
,
no_heads
,
inf
=
1e9
,
eps
=
1e-10
,
):
super
(
MSAColumnGlobalAttention
,
self
).
__init__
()
self
.
c_in
=
c_in
self
.
c_hidden
=
c_hidden
self
.
no_heads
=
no_heads
self
.
inf
=
inf
self
.
eps
=
eps
self
.
layer_norm_m
=
nn
.
LayerNorm
(
c_in
)
self
.
global_attention
=
GlobalAttention
(
c_in
=
c_in
,
c_hidden
=
c_hidden
,
no_heads
=
no_heads
,
inf
=
inf
,
eps
=
eps
,
)
@
torch
.
jit
.
ignore
def
_chunk
(
self
,
m
:
torch
.
Tensor
,
chunk_size
:
int
,
)
->
torch
.
Tensor
:
mha_input
=
{
"m"
:
m
,
}
return
chunk_layer
(
self
.
global_attention
,
mha_input
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
)
def
forward
(
self
,
m
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
n_seq
,
n_res
,
c_in
=
m
.
shape
[
-
3
:]
# [*, N_res, N_seq, C_in]
m
=
m
.
transpose
(
-
2
,
-
3
)
# [*, N_res, N_seq, C_in]
m
=
self
.
layer_norm_m
(
m
)
if
chunk_size
is
not
None
:
m
=
self
.
_chunk
(
m
,
chunk_size
)
else
:
m
=
self
.
global_attention
(
m
=
m
)
# [*, N_seq, N_res, C_in]
m
=
m
.
transpose
(
-
2
,
-
3
)
return
m
tests/test_autochunk/openfold/outer_product_mean.py
0 → 100644
View file @
93f62dd1
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
functools
import
partial
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
.primitives
import
Linear
from
.tensor_utils
import
chunk_layer
class
OuterProductMean
(
nn
.
Module
):
"""
Implements Algorithm 10.
"""
def
__init__
(
self
,
c_m
,
c_z
,
c_hidden
,
eps
=
1e-3
):
"""
Args:
c_m:
MSA embedding channel dimension
c_z:
Pair embedding channel dimension
c_hidden:
Hidden channel dimension
"""
super
(
OuterProductMean
,
self
).
__init__
()
self
.
c_m
=
c_m
self
.
c_z
=
c_z
self
.
c_hidden
=
c_hidden
self
.
eps
=
eps
self
.
layer_norm
=
nn
.
LayerNorm
(
c_m
)
self
.
linear_1
=
Linear
(
c_m
,
c_hidden
)
self
.
linear_2
=
Linear
(
c_m
,
c_hidden
)
self
.
linear_out
=
Linear
(
c_hidden
**
2
,
c_z
,
init
=
"final"
)
def
_opm
(
self
,
a
,
b
):
# [*, N_res, N_res, C, C]
outer
=
torch
.
einsum
(
"...bac,...dae->...bdce"
,
a
,
b
)
# [*, N_res, N_res, C * C]
outer
=
outer
.
reshape
(
outer
.
shape
[:
-
2
]
+
(
-
1
,))
# [*, N_res, N_res, C_z]
outer
=
self
.
linear_out
(
outer
)
return
outer
@
torch
.
jit
.
ignore
def
_chunk
(
self
,
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
chunk_size
:
int
)
->
torch
.
Tensor
:
# Since the "batch dim" in this case is not a true batch dimension
# (in that the shape of the output depends on it), we need to
# iterate over it ourselves
a_reshape
=
a
.
reshape
((
-
1
,)
+
a
.
shape
[
-
3
:])
b_reshape
=
b
.
reshape
((
-
1
,)
+
b
.
shape
[
-
3
:])
out
=
[]
for
a_prime
,
b_prime
in
zip
(
a_reshape
,
b_reshape
):
outer
=
chunk_layer
(
partial
(
self
.
_opm
,
b
=
b_prime
),
{
"a"
:
a_prime
},
chunk_size
=
chunk_size
,
no_batch_dims
=
1
,
)
out
.
append
(
outer
)
outer
=
torch
.
stack
(
out
,
dim
=
0
)
outer
=
outer
.
reshape
(
a
.
shape
[:
-
3
]
+
outer
.
shape
[
1
:])
return
outer
def
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
mask:
[*, N_seq, N_res] MSA mask
Returns:
[*, N_res, N_res, C_z] pair embedding update
"""
if
mask
is
None
:
mask
=
m
.
new_ones
(
m
.
shape
[:
-
1
])
# [*, N_seq, N_res, C_m]
m
=
self
.
layer_norm
(
m
)
# [*, N_seq, N_res, C]
mask
=
mask
.
unsqueeze
(
-
1
)
a
=
self
.
linear_1
(
m
)
*
mask
b
=
self
.
linear_2
(
m
)
*
mask
a
=
a
.
transpose
(
-
2
,
-
3
)
b
=
b
.
transpose
(
-
2
,
-
3
)
if
chunk_size
is
not
None
:
outer
=
self
.
_chunk
(
a
,
b
,
chunk_size
)
else
:
outer
=
self
.
_opm
(
a
,
b
)
# [*, N_res, N_res, 1]
norm
=
torch
.
einsum
(
"...abc,...adc->...bdc"
,
mask
,
mask
)
# [*, N_res, N_res, C_z]
outer
=
outer
/
(
self
.
eps
+
norm
)
return
outer
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment