Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
4f53624d
Commit
4f53624d
authored
Jul 08, 2022
by
Gustaf Ahdritz
Browse files
Add tracing, TensorFloat32 utilization, and FlashAttention
parent
db43f4ec
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
496 additions
and
56 deletions
+496
-56
openfold/config.py
openfold/config.py
+10
-1
openfold/model/evoformer.py
openfold/model/evoformer.py
+12
-2
openfold/model/model.py
openfold/model/model.py
+2
-0
openfold/model/primitives.py
openfold/model/primitives.py
+86
-9
openfold/utils/trace_utils.py
openfold/utils/trace_utils.py
+260
-0
run_pretrained_openfold.py
run_pretrained_openfold.py
+126
-44
No files found.
openfold/config.py
View file @
4f53624d
...
@@ -23,7 +23,11 @@ def enforce_config_constraints(config):
...
@@ -23,7 +23,11 @@ def enforce_config_constraints(config):
(
(
"model.template.average_templates"
,
"model.template.average_templates"
,
"model.template.offload_templates"
"model.template.offload_templates"
)
),
(
"globals.use_lma"
,
"globals.use_flash"
,
),
]
]
for
s1
,
s2
in
mutually_exclusive_bools
:
for
s1
,
s2
in
mutually_exclusive_bools
:
...
@@ -315,7 +319,12 @@ config = mlc.ConfigDict(
...
@@ -315,7 +319,12 @@ config = mlc.ConfigDict(
"globals"
:
{
"globals"
:
{
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"chunk_size"
:
chunk_size
,
"chunk_size"
:
chunk_size
,
# Use Staats & Rabe's low-memory attention algorithm. Mutually
# exclusive with use_flash.
"use_lma"
:
False
,
"use_lma"
:
False
,
# Use FlashAttention in selected modules. Mutually exclusive with
# use_lma.
"use_flash"
:
True
,
"offload_inference"
:
False
,
"offload_inference"
:
False
,
"c_z"
:
c_z
,
"c_z"
:
c_z
,
"c_m"
:
c_m
,
"c_m"
:
c_m
,
...
...
openfold/model/evoformer.py
View file @
4f53624d
...
@@ -361,6 +361,7 @@ class EvoformerBlock(nn.Module):
...
@@ -361,6 +361,7 @@ class EvoformerBlock(nn.Module):
pair_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
use_lma
:
bool
=
False
,
use_flash
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_mask_trans
:
bool
=
True
,
_attn_chunk_size
:
Optional
[
int
]
=
None
,
_attn_chunk_size
:
Optional
[
int
]
=
None
,
...
@@ -390,12 +391,14 @@ class EvoformerBlock(nn.Module):
...
@@ -390,12 +391,14 @@ class EvoformerBlock(nn.Module):
),
),
inplace
=
inplace_safe
,
inplace
=
inplace_safe
,
)
)
m
=
add
(
m
,
m
=
add
(
m
,
self
.
msa_att_col
(
self
.
msa_att_col
(
m
,
m
,
mask
=
msa_mask
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
use_lma
=
use_lma
,
use_flash
=
use_flash
,
),
),
inplace
=
inplace_safe
,
inplace
=
inplace_safe
,
)
)
...
@@ -666,6 +669,7 @@ class EvoformerStack(nn.Module):
...
@@ -666,6 +669,7 @@ class EvoformerStack(nn.Module):
z
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
chunk_size
:
int
,
chunk_size
:
int
,
use_lma
:
bool
,
use_lma
:
bool
,
use_flash
:
bool
,
msa_mask
:
Optional
[
torch
.
Tensor
],
msa_mask
:
Optional
[
torch
.
Tensor
],
pair_mask
:
Optional
[
torch
.
Tensor
],
pair_mask
:
Optional
[
torch
.
Tensor
],
inplace_safe
:
bool
,
inplace_safe
:
bool
,
...
@@ -678,6 +682,7 @@ class EvoformerStack(nn.Module):
...
@@ -678,6 +682,7 @@ class EvoformerStack(nn.Module):
pair_mask
=
pair_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
use_lma
=
use_lma
,
use_flash
=
use_flash
,
inplace_safe
=
inplace_safe
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
_mask_trans
=
_mask_trans
,
)
)
...
@@ -756,6 +761,7 @@ class EvoformerStack(nn.Module):
...
@@ -756,6 +761,7 @@ class EvoformerStack(nn.Module):
pair_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
int
,
chunk_size
:
int
,
use_lma
:
bool
=
False
,
use_lma
:
bool
=
False
,
use_flash
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
...
@@ -773,6 +779,9 @@ class EvoformerStack(nn.Module):
...
@@ -773,6 +779,9 @@ class EvoformerStack(nn.Module):
Inference-time subbatch size. Acts as a minimum if
Inference-time subbatch size. Acts as a minimum if
self.tune_chunk_size is True
self.tune_chunk_size is True
use_lma: Whether to use low-memory attention during inference
use_lma: Whether to use low-memory attention during inference
use_flash:
Whether to use FlashAttention where possible. Mutually
exclusive with use_lma.
Returns:
Returns:
m:
m:
[*, N_seq, N_res, C_m] MSA embedding
[*, N_seq, N_res, C_m] MSA embedding
...
@@ -786,6 +795,7 @@ class EvoformerStack(nn.Module):
...
@@ -786,6 +795,7 @@ class EvoformerStack(nn.Module):
z
=
z
,
z
=
z
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
use_lma
=
use_lma
,
use_flash
=
use_flash
,
msa_mask
=
msa_mask
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
pair_mask
=
pair_mask
,
inplace_safe
=
inplace_safe
,
inplace_safe
=
inplace_safe
,
...
@@ -947,10 +957,10 @@ class ExtraMSAStack(nn.Module):
...
@@ -947,10 +957,10 @@ class ExtraMSAStack(nn.Module):
def
forward
(
self
,
def
forward
(
self
,
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
msa_mask
:
Optional
[
torch
.
Tensor
],
pair_mask
:
Optional
[
torch
.
Tensor
],
chunk_size
:
int
,
chunk_size
:
int
,
use_lma
:
bool
=
False
,
use_lma
:
bool
=
False
,
msa_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
pair_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
inplace_safe
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_mask_trans
:
bool
=
True
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
...
openfold/model/model.py
View file @
4f53624d
...
@@ -380,6 +380,7 @@ class AlphaFold(nn.Module):
...
@@ -380,6 +380,7 @@ class AlphaFold(nn.Module):
pair_mask
=
pair_mask
.
to
(
dtype
=
input_tensors
[
1
].
dtype
),
pair_mask
=
pair_mask
.
to
(
dtype
=
input_tensors
[
1
].
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
use_lma
=
self
.
globals
.
use_lma
,
use_flash
=
self
.
globals
.
use_flash
,
_mask_trans
=
self
.
config
.
_mask_trans
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
)
...
@@ -392,6 +393,7 @@ class AlphaFold(nn.Module):
...
@@ -392,6 +393,7 @@ class AlphaFold(nn.Module):
pair_mask
=
pair_mask
.
to
(
dtype
=
z
.
dtype
),
pair_mask
=
pair_mask
.
to
(
dtype
=
z
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
use_lma
=
self
.
globals
.
use_lma
,
use_flash
=
self
.
globals
.
use_flash
,
inplace_safe
=
inplace_safe
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
self
.
config
.
_mask_trans
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
)
...
...
openfold/model/primitives.py
View file @
4f53624d
...
@@ -18,6 +18,9 @@ from typing import Optional, Callable, List, Tuple, Sequence
...
@@ -18,6 +18,9 @@ from typing import Optional, Callable, List, Tuple, Sequence
import
numpy
as
np
import
numpy
as
np
import
deepspeed
import
deepspeed
from
flash_attn.bert_padding
import
unpad_input
,
pad_input
from
flash_attn.flash_attention
import
FlashAttention
from
flash_attn.flash_attn_interface
import
flash_attn_unpadded_kvpacked_func
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
scipy.stats
import
truncnorm
from
scipy.stats
import
truncnorm
...
@@ -407,8 +410,10 @@ class Attention(nn.Module):
...
@@ -407,8 +410,10 @@ class Attention(nn.Module):
biases
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
biases
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
use_memory_efficient_kernel
:
bool
=
False
,
use_memory_efficient_kernel
:
bool
=
False
,
use_lma
:
bool
=
False
,
use_lma
:
bool
=
False
,
q_chunk_size
:
int
=
DEFAULT_LMA_Q_CHUNK_SIZE
,
lma_q_chunk_size
:
int
=
DEFAULT_LMA_Q_CHUNK_SIZE
,
kv_chunk_size
:
int
=
DEFAULT_LMA_KV_CHUNK_SIZE
,
lma_kv_chunk_size
:
int
=
DEFAULT_LMA_KV_CHUNK_SIZE
,
use_flash
:
bool
=
False
,
flash_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Args:
Args:
...
@@ -427,25 +432,34 @@ class Attention(nn.Module):
...
@@ -427,25 +432,34 @@ class Attention(nn.Module):
Whether to use low-memory attention (Staats & Rabe 2021). If
Whether to use low-memory attention (Staats & Rabe 2021). If
none of the "use_<...>" flags are True, a stock PyTorch
none of the "use_<...>" flags are True, a stock PyTorch
implementation is used instead
implementation is used instead
q_chunk_size:
lma_
q_chunk_size:
Query chunk size (for LMA)
Query chunk size (for LMA)
kv_chunk_size:
lma_
kv_chunk_size:
Key/Value chunk size (for LMA)
Key/Value chunk size (for LMA)
Returns
Returns
[*, Q, C_q] attention update
[*, Q, C_q] attention update
"""
"""
if
(
biases
is
None
):
biases
=
[]
if
(
use_lma
and
(
q_chunk_size
is
None
or
kv_chunk_size
is
None
)):
if
(
use_lma
and
(
q_chunk_size
is
None
or
kv_chunk_size
is
None
)):
raise
ValueError
(
raise
ValueError
(
"If use_lma is specified, q_chunk_size and kv_chunk_size must "
"If use_lma is specified, q_chunk_size and kv_chunk_size must "
"be provided"
"be provided"
)
)
if
(
use_memory_efficient_kernel
and
use_lma
):
if
(
use_flash
and
biases
is
not
None
):
raise
ValueError
(
"use_flash is incompatible with the bias option. For masking, "
"use flash_mask instead"
)
attn_options
=
[
use_memory_efficient_kernel
,
use_lma
,
use_flash
]
if
(
sum
(
attn_options
)
>
1
):
raise
ValueError
(
raise
ValueError
(
"Choose
one of use_memory_efficient_kernel and use_lma
"
"Choose
at most one alternative attention algorithm
"
)
)
if
(
biases
is
None
):
biases
=
[]
# [*, H, Q/K, C_hidden]
# [*, H, Q/K, C_hidden]
q
,
k
,
v
=
self
.
_prep_qkv
(
q_x
,
kv_x
)
q
,
k
,
v
=
self
.
_prep_qkv
(
q_x
,
kv_x
)
...
@@ -463,8 +477,10 @@ class Attention(nn.Module):
...
@@ -463,8 +477,10 @@ class Attention(nn.Module):
b
.
expand
(
b
.
shape
[:
-
2
]
+
(
q_x
.
shape
[
-
2
],)
+
(
kv_x
.
shape
[
-
2
],))
b
.
expand
(
b
.
shape
[:
-
2
]
+
(
q_x
.
shape
[
-
2
],)
+
(
kv_x
.
shape
[
-
2
],))
for
b
in
biases
for
b
in
biases
]
]
o
=
_lma
(
q
,
k
,
v
,
biases
,
q_chunk_size
,
kv_chunk_size
)
o
=
_lma
(
q
,
k
,
v
,
biases
,
lma_
q_chunk_size
,
lma_
kv_chunk_size
)
o
=
o
.
transpose
(
-
2
,
-
3
)
o
=
o
.
transpose
(
-
2
,
-
3
)
elif
(
use_flash
):
o
=
_flash_attn
(
q
,
k
,
v
,
flash_mask
)
else
:
else
:
o
=
_attention
(
q
,
k
,
v
,
biases
)
o
=
_attention
(
q
,
k
,
v
,
biases
)
o
=
o
.
transpose
(
-
2
,
-
3
)
o
=
o
.
transpose
(
-
2
,
-
3
)
...
@@ -623,3 +639,64 @@ def _lma(
...
@@ -623,3 +639,64 @@ def _lma(
o
[...,
q_s
:
q_s
+
q_chunk_size
,
:]
=
q_chunk_out
o
[...,
q_s
:
q_s
+
q_chunk_size
,
:]
=
q_chunk_out
return
o
return
o
@
torch
.
jit
.
ignore
def
_flash_attn
(
q
,
k
,
v
,
kv_mask
):
batch_dims
=
q
.
shape
[:
-
3
]
no_heads
,
n
,
c
=
q
.
shape
[
-
3
:]
dtype
=
q
.
dtype
q
=
q
.
half
()
k
=
k
.
half
()
v
=
v
.
half
()
kv_mask
=
kv_mask
.
half
()
# [*, B, N, H, C]
q
=
q
.
transpose
(
-
2
,
-
3
)
k
=
k
.
transpose
(
-
2
,
-
3
)
v
=
v
.
transpose
(
-
2
,
-
3
)
# [B_flat, N, H, C]
q
=
q
.
reshape
(
-
1
,
*
q
.
shape
[
-
3
:])
k
=
k
.
reshape
(
-
1
,
*
k
.
shape
[
-
3
:])
v
=
v
.
reshape
(
-
1
,
*
v
.
shape
[
-
3
:])
# Flattened batch size
batch_size
=
q
.
shape
[
0
]
# [B_flat * N, H, C]
q
=
q
.
reshape
(
-
1
,
*
q
.
shape
[
-
2
:])
q_max_s
=
n
q_cu_seqlens
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
n
,
step
=
n
,
dtype
=
torch
.
int32
,
device
=
q
.
device
)
# [B_flat, N, 2, H, C]
kv
=
torch
.
stack
([
k
,
v
],
dim
=-
3
)
kv_shape
=
kv
.
shape
# [B_flat, N, 2 * H * C]
kv
=
kv
.
reshape
(
*
kv
.
shape
[:
-
3
],
-
1
)
kv_unpad
,
_
,
kv_cu_seqlens
,
kv_max_s
=
unpad_input
(
kv
,
kv_mask
)
kv_unpad
=
kv_unpad
.
reshape
(
-
1
,
*
kv_shape
[
-
3
:])
out
=
flash_attn_unpadded_kvpacked_func
(
q
,
kv_unpad
,
q_cu_seqlens
,
kv_cu_seqlens
,
q_max_s
,
kv_max_s
,
dropout_p
=
0.
,
softmax_scale
=
1.
,
# q has been scaled already
)
# [*, B, N, H, C]
out
=
out
.
reshape
(
*
batch_dims
,
n
,
no_heads
,
c
)
out
=
out
.
to
(
dtype
=
dtype
)
return
out
openfold/utils/trace_utils.py
0 → 100644
View file @
4f53624d
# Copyright 2022 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
contextlib
from
functools
import
partialmethod
import
numpy
as
np
import
torch
from
openfold.utils.tensor_utils
import
tensor_tree_map
def
pad_feature_dict_seq
(
feature_dict
,
seqlen
):
""" Pads the sequence length of a feature dict. Used for tracing. """
# The real sequence length can't be longer than the desired one
true_n
=
feature_dict
[
"aatype"
].
shape
[
-
2
]
assert
(
true_n
<=
seqlen
)
new_feature_dict
=
{}
feat_seq_dims
=
{
"aatype"
:
-
2
,
"between_segment_residues"
:
-
1
,
"residue_index"
:
-
1
,
"seq_length"
:
-
1
,
"deletion_matrix_int"
:
-
1
,
"msa"
:
-
1
,
"num_alignments"
:
-
1
,
"template_aatype"
:
-
2
,
"template_all_atom_mask"
:
-
2
,
"template_all_atom_positions"
:
-
3
,
}
for
k
,
v
in
feature_dict
.
items
():
if
(
k
not
in
feat_seq_dims
):
new_feature_dict
[
k
]
=
v
continue
seq_dim
=
feat_seq_dims
[
k
]
padded_shape
=
list
(
v
.
shape
)
padded_shape
[
seq_dim
]
=
seqlen
new_value
=
np
.
zeros
(
padded_shape
,
dtype
=
v
.
dtype
)
new_value
[
tuple
(
slice
(
0
,
s
)
for
s
in
v
.
shape
)]
=
v
new_feature_dict
[
k
]
=
new_value
new_feature_dict
[
"seq_length"
][
0
]
=
seqlen
return
new_feature_dict
def
trace_model_
(
model
,
sample_input
):
# Grab the inputs to the final recycling iteration
feats
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
sample_input
)
# Gather some metadata
n
=
feats
[
"aatype"
].
shape
[
-
1
]
msa_depth
=
feats
[
"true_msa"
].
shape
[
-
2
]
extra_msa_depth
=
feats
[
"extra_msa"
].
shape
[
-
2
]
no_templates
=
feats
[
"template_aatype"
].
shape
[
-
2
]
device
=
feats
[
"aatype"
].
device
seq_mask
=
feats
[
"seq_mask"
].
to
(
device
)
pair_mask
=
seq_mask
[...,
None
]
*
seq_mask
[...,
None
,
:]
msa_mask
=
feats
[
"msa_mask"
].
to
(
device
)
extra_msa_mask
=
feats
[
"extra_msa_mask"
].
to
(
device
)
template_pair_mask
=
torch
.
stack
([
pair_mask
]
*
no_templates
,
dim
=-
3
)
# Create some fake representations with the correct shapes
m
=
torch
.
rand
(
msa_depth
,
n
,
model
.
globals
.
c_m
).
to
(
device
)
z
=
torch
.
rand
(
n
,
n
,
model
.
globals
.
c_z
).
to
(
device
)
t
=
torch
.
rand
(
no_templates
,
n
,
n
,
model
.
globals
.
c_t
).
to
(
device
)
a
=
torch
.
rand
(
extra_msa_depth
,
n
,
model
.
globals
.
c_e
).
to
(
device
)
# We need to do a dry run through the model so the chunk size tuners'
# trial runs (which run during the first-ever model iteration) aren't
# baked into the trace. There's no need to run the entire thing,
# though; we just need to run one block from each transformer stack.
evoformer_blocks
=
model
.
evoformer
.
blocks
model
.
evoformer
.
blocks
=
evoformer_blocks
[:
1
]
extra_msa_blocks
=
model
.
extra_msa_stack
.
blocks
model
.
extra_msa_stack
.
blocks
=
extra_msa_blocks
[:
1
]
if
(
model
.
template_config
.
enabled
):
template_pair_stack_blocks
=
model
.
template_pair_stack
.
blocks
model
.
template_pair_stack
.
blocks
=
template_pair_stack_blocks
[:
1
]
single_recycling_iter_input
=
tensor_tree_map
(
lambda
t
:
t
[...,
:
1
],
sample_input
,
)
with
torch
.
no_grad
():
_
=
model
(
single_recycling_iter_input
)
model
.
evoformer
.
blocks
=
evoformer_blocks
model
.
extra_msa_stack
.
blocks
=
extra_msa_blocks
del
evoformer_blocks
,
extra_msa_blocks
if
(
model
.
template_config
.
enabled
):
model
.
template_pair_stack
.
blocks
=
template_pair_stack_blocks
del
template_pair_stack_blocks
def
get_tuned_chunk_size
(
module
):
tuner
=
module
.
chunk_size_tuner
chunk_size
=
tuner
.
cached_chunk_size
# After our trial run above, this should always be set
assert
(
chunk_size
is
not
None
)
return
chunk_size
# Fetch the resulting chunk sizes
evoformer_chunk_size
=
model
.
globals
.
chunk_size
if
(
model
.
evoformer
.
chunk_size_tuner
is
not
None
):
evoformer_chunk_size
=
get_tuned_chunk_size
(
model
.
evoformer
)
extra_msa_chunk_size
=
model
.
globals
.
chunk_size
if
(
model
.
extra_msa_stack
.
chunk_size_tuner
is
not
None
):
extra_msa_chunk_size
=
get_tuned_chunk_size
(
model
.
extra_msa_stack
)
if
(
model
.
template_config
.
enabled
):
template_pair_stack_chunk_size
=
model
.
globals
.
chunk_size
if
(
model
.
template_pair_stack
.
chunk_size_tuner
is
not
None
):
template_pair_stack_chunk_size
=
get_tuned_chunk_size
(
model
.
template_pair_stack
)
def
trace_block
(
block
,
block_inputs
):
# Yes, yes, I know
with
contextlib
.
redirect_stderr
(
None
):
traced_block
=
torch
.
jit
.
trace
(
block
,
block_inputs
)
traced_block
=
torch
.
jit
.
optimize_for_inference
(
traced_block
)
# All trace inputs need to be tensors. This wrapper takes care of that
def
traced_block_wrapper
(
*
args
,
**
kwargs
):
to_tensor
=
lambda
t
:
torch
.
tensor
(
t
)
if
type
(
t
)
!=
torch
.
Tensor
else
t
args
=
[
to_tensor
(
a
)
for
a
in
args
]
kwargs
=
{
k
:
to_tensor
(
v
)
for
k
,
v
in
kwargs
.
items
()}
return
traced_block
(
*
args
,
**
kwargs
)
return
traced_block_wrapper
def
verify_arg_order
(
fn
,
arg_list
):
""" Because it's difficult to specify keyword arguments of Module
functions during tracing, we need to pass them as a tuple. As a
sanity check, we manually verify their order here.
"""
fn_arg_names
=
fn
.
__code__
.
co_varnames
# Remove the "self" parameter
assert
(
fn_arg_names
[
0
]
==
"self"
)
fn_arg_names
=
fn_arg_names
[
1
:]
# Trim unspecified arguments
fn_arg_names
=
fn_arg_names
[:
len
(
arg_list
)]
name_tups
=
zip
(
fn_arg_names
,
[
n
for
n
,
_
in
arg_list
])
assert
(
all
([
n1
==
n2
for
n1
,
n2
in
name_tups
]))
evoformer_attn_chunk_size
=
max
(
model
.
globals
.
chunk_size
,
evoformer_chunk_size
//
4
)
evoformer_arg_tuples
=
[
(
"m"
,
m
),
(
"z"
,
z
),
(
"msa_mask"
,
msa_mask
),
(
"pair_mask"
,
pair_mask
),
(
"chunk_size"
,
torch
.
tensor
(
evoformer_chunk_size
)),
(
"use_lma"
,
torch
.
tensor
(
model
.
globals
.
use_lma
)),
(
"use_flash"
,
torch
.
tensor
(
model
.
globals
.
use_flash
)),
(
"inplace_safe"
,
torch
.
tensor
(
1
)),
(
"_mask_trans"
,
torch
.
tensor
(
model
.
config
.
_mask_trans
)),
(
"_attn_chunk_size"
,
torch
.
tensor
(
evoformer_attn_chunk_size
)),
]
verify_arg_order
(
model
.
evoformer
.
blocks
[
0
].
forward
,
evoformer_arg_tuples
)
evoformer_args
=
[
arg
for
_
,
arg
in
evoformer_arg_tuples
]
with
torch
.
no_grad
():
traced_evoformer_stack
=
[]
for
b
in
model
.
evoformer
.
blocks
:
traced_block
=
trace_block
(
b
,
evoformer_args
)
traced_evoformer_stack
.
append
(
traced_block
)
del
model
.
evoformer
.
blocks
model
.
evoformer
.
blocks
=
traced_evoformer_stack
# extra_msa_attn_chunk_size = max(
# model.globals.chunk_size, extra_msa_chunk_size // 4
# )
# extra_msa_arg_tuples = [
# ("m", a),
# ("z", z),
# ("msa_mask", extra_msa_mask),
# ("pair_mask", pair_mask),
# ("chunk_size", torch.tensor(extra_msa_chunk_size)),
# ("use_lma", torch.tensor(model.globals.use_lma)),
# ("inplace_safe", torch.tensor(1)),
# ("_mask_trans", torch.tensor(model.config._mask_trans)),
# ("_attn_chunk_size", torch.tensor(extra_msa_attn_chunk_size)),
# ]
# verify_arg_order(
# model.extra_msa_stack.blocks[0].forward, extra_msa_arg_tuples
# )
# extra_msa_args = [arg for _, arg in extra_msa_arg_tuples]
# with torch.no_grad():
# traced_extra_msa_stack = []
# for b in model.extra_msa_stack.blocks:
# traced_block = trace_block(b, extra_msa_args)
# traced_extra_msa_stack.append(traced_block)
#
# del model.extra_msa_stack.blocks
# model.extra_msa_stack.blocks = traced_extra_msa_stack
# if(model.template_config.enabled):
# template_pair_stack_attn_chunk_size = max(
# model.globals.chunk_size, template_pair_stack_chunk_size // 4
# )
# template_pair_stack_arg_tuples = [
# ("z", t),
# ("mask", template_pair_mask),
# ("chunk_size", torch.tensor(template_pair_stack_chunk_size)),
# ("use_lma", torch.tensor(model.globals.use_lma)),
# ("inplace_safe", torch.tensor(1)),
# ("_mask_trans", torch.tensor(model.config._mask_trans)),
# ("_attn_chunk_size", torch.tensor(
# template_pair_stack_attn_chunk_size
# )),
# ]
# verify_arg_order(
# model.template_pair_stack.blocks[0].forward,
# template_pair_stack_arg_tuples
# )
# template_pair_stack_args = [
# arg for _, arg in template_pair_stack_arg_tuples
# ]
#
# with torch.no_grad():
# traced_template_pair_stack = []
# for b in model.template_pair_stack.blocks:
# traced_block = trace_block(b, template_pair_stack_args)
# traced_template_pair_stack.append(traced_block)
#
# del model.template_pair_stack.blocks
# model.template_pair_stack.blocks = traced_template_pair_stack
# We need to do another dry run after tracing to allow the model to reach
# top speeds. Why, I don't know.
two_recycling_iter_input
=
tensor_tree_map
(
lambda
t
:
t
[...,
:
2
],
sample_input
,
)
with
torch
.
no_grad
():
_
=
model
(
two_recycling_iter_input
)
run_pretrained_openfold.py
View file @
4f53624d
...
@@ -12,14 +12,17 @@
...
@@ -12,14 +12,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
argparse
import
argparse
from
copy
import
deepcopy
from
datetime
import
date
from
datetime
import
date
import
gc
import
logging
import
logging
import
math
import
numpy
as
np
import
numpy
as
np
import
os
import
os
from
copy
import
deepcopy
logging
.
basicConfig
()
logger
=
logging
.
getLogger
(
__file__
)
logger
.
setLevel
(
level
=
logging
.
INFO
)
import
pickle
import
pickle
from
pytorch_lightning.utilities.deepspeed
import
(
from
pytorch_lightning.utilities.deepspeed
import
(
...
@@ -31,7 +34,19 @@ import time
...
@@ -31,7 +34,19 @@ import time
import
torch
import
torch
import
re
import
re
from
openfold.config
import
model_config
torch_versions
=
torch
.
__version__
.
split
(
"."
)
torch_major_version
=
int
(
torch_versions
[
0
])
torch_minor_version
=
int
(
torch_versions
[
1
])
if
(
torch_major_version
>
1
or
(
torch_major_version
==
1
and
torch_minor_version
>=
12
)
):
# Gives a large speedup on Ampere-class GPUs
torch
.
set_float32_matmul_precision
(
"high"
)
torch
.
set_grad_enabled
(
False
)
from
openfold.config
import
model_config
,
NUM_RES
from
openfold.data
import
templates
,
feature_pipeline
,
data_pipeline
from
openfold.data
import
templates
,
feature_pipeline
,
data_pipeline
from
openfold.model.model
import
AlphaFold
from
openfold.model.model
import
AlphaFold
from
openfold.model.torchscript
import
script_preset_
from
openfold.model.torchscript
import
script_preset_
...
@@ -43,13 +58,14 @@ from openfold.utils.import_weights import (
...
@@ -43,13 +58,14 @@ from openfold.utils.import_weights import (
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
tensor_tree_map
,
tensor_tree_map
,
)
)
from
openfold.utils.trace_utils
import
(
pad_feature_dict_seq
,
trace_model_
,
)
from
scripts.utils
import
add_data_args
from
scripts.utils
import
add_data_args
logging
.
basicConfig
()
TRACING_INTERVAL
=
50
logger
=
logging
.
getLogger
(
__file__
)
logger
.
setLevel
(
level
=
logging
.
INFO
)
def
precompute_alignments
(
tags
,
seqs
,
alignment_dir
,
args
):
def
precompute_alignments
(
tags
,
seqs
,
alignment_dir
,
args
):
...
@@ -59,9 +75,9 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
...
@@ -59,9 +75,9 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
fp
.
write
(
f
">
{
tag
}
\n
{
seq
}
"
)
fp
.
write
(
f
">
{
tag
}
\n
{
seq
}
"
)
local_alignment_dir
=
os
.
path
.
join
(
alignment_dir
,
tag
)
local_alignment_dir
=
os
.
path
.
join
(
alignment_dir
,
tag
)
if
(
args
.
use_precomputed_alignments
is
None
):
if
(
args
.
use_precomputed_alignments
is
None
and
not
os
.
path
.
isdir
(
local_alignment_dir
)
):
logger
.
info
(
f
"Generating alignments for
{
tag
}
..."
)
logger
.
info
(
f
"Generating alignments for
{
tag
}
..."
)
if
not
os
.
path
.
exists
(
local_alignment_dir
):
os
.
makedirs
(
local_alignment_dir
)
os
.
makedirs
(
local_alignment_dir
)
alignment_runner
=
data_pipeline
.
AlignmentRunner
(
alignment_runner
=
data_pipeline
.
AlignmentRunner
(
...
@@ -78,18 +94,21 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
...
@@ -78,18 +94,21 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
alignment_runner
.
run
(
alignment_runner
.
run
(
tmp_fasta_path
,
local_alignment_dir
tmp_fasta_path
,
local_alignment_dir
)
)
else
:
logger
.
info
(
f
"Using precomputed alignments for
{
tag
}
at
{
alignment_dir
}
..."
)
# Remove temporary FASTA file
# Remove temporary FASTA file
os
.
remove
(
tmp_fasta_path
)
os
.
remove
(
tmp_fasta_path
)
def
round_up_seqlen
(
seqlen
):
return
int
(
math
.
ceil
(
seqlen
/
TRACING_INTERVAL
))
*
TRACING_INTERVAL
def
run_model
(
model
,
batch
,
tag
,
args
):
def
run_model
(
model
,
batch
,
tag
,
args
):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
batch
=
{
k
:
torch
.
as_tensor
(
v
,
device
=
args
.
model_device
)
for
k
,
v
in
batch
.
items
()
}
# Disable templates if there aren't any in the batch
# Disable templates if there aren't any in the batch
model
.
config
.
template
.
enabled
=
model
.
config
.
template
.
enabled
and
any
([
model
.
config
.
template
.
enabled
=
model
.
config
.
template
.
enabled
and
any
([
"template_"
in
k
for
k
in
batch
"template_"
in
k
for
k
in
batch
...
@@ -208,6 +227,7 @@ def generate_feature_dict(
...
@@ -208,6 +227,7 @@ def generate_feature_dict(
return
feature_dict
return
feature_dict
def
get_model_basename
(
model_path
):
def
get_model_basename
(
model_path
):
return
os
.
path
.
splitext
(
return
os
.
path
.
splitext
(
os
.
path
.
basename
(
os
.
path
.
basename
(
...
@@ -215,6 +235,7 @@ def get_model_basename(model_path):
...
@@ -215,6 +235,7 @@ def get_model_basename(model_path):
)
)
)[
0
]
)[
0
]
def
make_output_directory
(
output_dir
,
model_name
,
multiple_model_mode
):
def
make_output_directory
(
output_dir
,
model_name
,
multiple_model_mode
):
if
multiple_model_mode
:
if
multiple_model_mode
:
prediction_dir
=
os
.
path
.
join
(
output_dir
,
"predictions"
,
model_name
)
prediction_dir
=
os
.
path
.
join
(
output_dir
,
"predictions"
,
model_name
)
...
@@ -223,6 +244,7 @@ def make_output_directory(output_dir, model_name, multiple_model_mode):
...
@@ -223,6 +244,7 @@ def make_output_directory(output_dir, model_name, multiple_model_mode):
os
.
makedirs
(
prediction_dir
,
exist_ok
=
True
)
os
.
makedirs
(
prediction_dir
,
exist_ok
=
True
)
return
prediction_dir
return
prediction_dir
def
count_models_to_evaluate
(
openfold_checkpoint_path
,
jax_param_path
):
def
count_models_to_evaluate
(
openfold_checkpoint_path
,
jax_param_path
):
model_count
=
0
model_count
=
0
if
openfold_checkpoint_path
:
if
openfold_checkpoint_path
:
...
@@ -231,6 +253,7 @@ def count_models_to_evaluate(openfold_checkpoint_path, jax_param_path):
...
@@ -231,6 +253,7 @@ def count_models_to_evaluate(openfold_checkpoint_path, jax_param_path):
model_count
+=
len
(
jax_param_path
.
split
(
","
))
model_count
+=
len
(
jax_param_path
.
split
(
","
))
return
model_count
return
model_count
def
load_models_from_command_line
(
args
,
config
):
def
load_models_from_command_line
(
args
,
config
):
# Create the output directory
# Create the output directory
...
@@ -295,14 +318,23 @@ def load_models_from_command_line(args, config):
...
@@ -295,14 +318,23 @@ def load_models_from_command_line(args, config):
"be specified."
"be specified."
)
)
def
list_files_with_extensions
(
dir
,
extensions
):
def
list_files_with_extensions
(
dir
,
extensions
):
return
[
f
for
f
in
os
.
listdir
(
dir
)
if
f
.
endswith
(
extensions
)]
return
[
f
for
f
in
os
.
listdir
(
dir
)
if
f
.
endswith
(
extensions
)]
def
main
(
args
):
def
main
(
args
):
# Create the output directory
# Create the output directory
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
config
=
model_config
(
args
.
config_preset
)
config
=
model_config
(
args
.
config_preset
)
if
(
args
.
trace_model
):
if
(
not
config
.
data
.
predict
.
fixed_size
):
raise
ValueError
(
"Tracing requires that fixed_size mode be enabled in the config"
)
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
mmcif_dir
=
args
.
template_mmcif_dir
,
mmcif_dir
=
args
.
template_mmcif_dir
,
max_template_date
=
args
.
max_template_date
,
max_template_date
=
args
.
max_template_date
,
...
@@ -319,7 +351,11 @@ def main(args):
...
@@ -319,7 +351,11 @@ def main(args):
output_dir_base
=
args
.
output_dir
output_dir_base
=
args
.
output_dir
random_seed
=
args
.
data_random_seed
random_seed
=
args
.
data_random_seed
if
random_seed
is
None
:
if
random_seed
is
None
:
random_seed
=
random
.
randrange
(
sys
.
maxsize
)
random_seed
=
random
.
randrange
(
2
**
32
)
np
.
random
.
seed
(
random_seed
)
torch
.
manual_seed
(
random_seed
+
1
)
feature_processor
=
feature_pipeline
.
FeaturePipeline
(
config
.
data
)
feature_processor
=
feature_pipeline
.
FeaturePipeline
(
config
.
data
)
if
not
os
.
path
.
exists
(
output_dir_base
):
if
not
os
.
path
.
exists
(
output_dir_base
):
os
.
makedirs
(
output_dir_base
)
os
.
makedirs
(
output_dir_base
)
...
@@ -327,8 +363,9 @@ def main(args):
...
@@ -327,8 +363,9 @@ def main(args):
alignment_dir
=
os
.
path
.
join
(
output_dir_base
,
"alignments"
)
alignment_dir
=
os
.
path
.
join
(
output_dir_base
,
"alignments"
)
else
:
else
:
alignment_dir
=
args
.
use_precomputed_alignments
alignment_dir
=
args
.
use_precomputed_alignments
logger
.
info
(
f
"Using precomputed alignments at
{
alignment_dir
}
..."
)
tag_list
=
[]
seq_list
=
[]
for
fasta_file
in
list_files_with_extensions
(
args
.
fasta_dir
,
(
".fasta"
,
".fa"
)):
for
fasta_file
in
list_files_with_extensions
(
args
.
fasta_dir
,
(
".fasta"
,
".fa"
)):
# Gather input sequences
# Gather input sequences
with
open
(
os
.
path
.
join
(
args
.
fasta_dir
,
fasta_file
),
"r"
)
as
fp
:
with
open
(
os
.
path
.
join
(
args
.
fasta_dir
,
fasta_file
),
"r"
)
as
fp
:
...
@@ -338,12 +375,24 @@ def main(args):
...
@@ -338,12 +375,24 @@ def main(args):
# assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
# assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
tag
=
'-'
.
join
(
tags
)
tag
=
'-'
.
join
(
tags
)
tag_list
.
append
(
tag
)
seq_list
.
append
(
seqs
)
seq_sort_fn
=
lambda
target
:
sum
([
len
(
s
)
for
s
in
target
[
1
]])
sorted_targets
=
sorted
(
zip
(
tag_list
,
seq_list
),
key
=
seq_sort_fn
)
feature_dicts
=
{}
for
model
,
output_directory
in
load_models_from_command_line
(
args
,
config
):
cur_tracing_interval
=
0
for
tag
,
seqs
in
sorted_targets
:
output_name
=
f
'
{
tag
}
_
{
args
.
config_preset
}
'
output_name
=
f
'
{
tag
}
_
{
args
.
config_preset
}
'
if
args
.
output_postfix
is
not
None
:
if
args
.
output_postfix
is
not
None
:
output_name
=
f
'
{
output_name
}
_
{
args
.
output_postfix
}
'
output_name
=
f
'
{
output_name
}
_
{
args
.
output_postfix
}
'
# Does nothing if the alignments have already been computed
precompute_alignments
(
tags
,
seqs
,
alignment_dir
,
args
)
precompute_alignments
(
tags
,
seqs
,
alignment_dir
,
args
)
feature_dict
=
feature_dicts
.
get
(
tag
,
None
)
if
(
feature_dict
is
None
):
feature_dict
=
generate_feature_dict
(
feature_dict
=
generate_feature_dict
(
tags
,
tags
,
seqs
,
seqs
,
...
@@ -352,30 +401,57 @@ def main(args):
...
@@ -352,30 +401,57 @@ def main(args):
args
,
args
,
)
)
if
(
args
.
trace_model
):
n
=
feature_dict
[
"aatype"
].
shape
[
-
2
]
rounded_seqlen
=
round_up_seqlen
(
n
)
feature_dict
=
pad_feature_dict_seq
(
feature_dict
,
rounded_seqlen
,
)
feature_dicts
[
tag
]
=
feature_dict
processed_feature_dict
=
feature_processor
.
process_features
(
processed_feature_dict
=
feature_processor
.
process_features
(
feature_dict
,
mode
=
'predict'
,
feature_dict
,
mode
=
'predict'
,
)
)
for
model
,
output_directory
in
load_models_from_command_line
(
args
,
config
):
processed_feature_dict
=
{
working_batch
=
deepcopy
(
processed_feature_dict
)
k
:
torch
.
as_tensor
(
v
,
device
=
args
.
model_device
)
out
=
run_model
(
model
,
working_batch
,
tag
,
args
)
for
k
,
v
in
processed_feature_dict
.
items
()
}
if
(
args
.
trace_model
):
if
(
rounded_seqlen
>
cur_tracing_interval
):
logger
.
info
(
f
"Tracing model at
{
rounded_seqlen
}
residues..."
)
t
=
time
.
perf_counter
()
trace_model_
(
model
,
processed_feature_dict
)
logger
.
info
(
f
"Tracing time:
{
time
.
perf_counter
()
-
t
}
"
)
cur_tracing_interval
=
rounded_seqlen
out
=
run_model
(
model
,
processed_feature_dict
,
tag
,
args
)
# Toss out the recycling dimensions --- we don't need them anymore
# Toss out the recycling dimensions --- we don't need them anymore
working_batch
=
tensor_tree_map
(
lambda
x
:
np
.
array
(
x
[...,
-
1
].
cpu
()),
working_batch
)
processed_feature_dict
=
tensor_tree_map
(
lambda
x
:
np
.
array
(
x
[...,
-
1
].
cpu
()),
processed_feature_dict
)
out
=
tensor_tree_map
(
lambda
x
:
np
.
array
(
x
.
cpu
()),
out
)
out
=
tensor_tree_map
(
lambda
x
:
np
.
array
(
x
.
cpu
()),
out
)
unrelaxed_protein
=
prep_output
(
unrelaxed_protein
=
prep_output
(
out
,
working_batch
,
feature_dict
,
feature_processor
,
args
out
,
processed_feature_dict
,
feature_dict
,
feature_processor
,
args
)
)
unrelaxed_output_path
=
os
.
path
.
join
(
unrelaxed_output_path
=
os
.
path
.
join
(
output_directory
,
f
'
{
output_name
}
_unrelaxed.pdb'
output_directory
,
f
'
{
output_name
}
_unrelaxed.pdb'
)
)
# Output already exists
if
os
.
path
.
exists
(
unrelaxed_output_path
):
continue
with
open
(
unrelaxed_output_path
,
'w'
)
as
fp
:
with
open
(
unrelaxed_output_path
,
'w'
)
as
fp
:
fp
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
fp
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
...
@@ -481,6 +557,12 @@ if __name__ == "__main__":
...
@@ -481,6 +557,12 @@ if __name__ == "__main__":
"--multimer_ri_gap"
,
type
=
int
,
default
=
200
,
"--multimer_ri_gap"
,
type
=
int
,
default
=
200
,
help
=
"""Residue index offset between multiple sequences, if provided"""
help
=
"""Residue index offset between multiple sequences, if provided"""
)
)
parser
.
add_argument
(
"--trace_model"
,
action
=
"store_true"
,
default
=
False
,
help
=
"""Whether to convert parts of each model to TorchScript.
Significantly improves runtime at the cost of lengthy
'compilation.' Useful for large batch jobs."""
)
add_data_args
(
parser
)
add_data_args
(
parser
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment