Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
b026de28
"git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "63d6540ca692d62fe5447bfe7c63233126170774"
Commit
b026de28
authored
Oct 17, 2021
by
Gustaf Ahdritz
Browse files
Make chunk size more flexible, reduce verbosity
parent
a59ae7c1
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
111 additions
and
106 deletions
+111
-106
openfold/config.py
openfold/config.py
+6
-10
openfold/data/templates.py
openfold/data/templates.py
+2
-2
openfold/data/tools/hhsearch.py
openfold/data/tools/hhsearch.py
+1
-1
openfold/model/evoformer.py
openfold/model/evoformer.py
+28
-24
openfold/model/model.py
openfold/model/model.py
+22
-13
openfold/model/msa.py
openfold/model/msa.py
+11
-16
openfold/model/outer_product_mean.py
openfold/model/outer_product_mean.py
+4
-5
openfold/model/pair_transition.py
openfold/model/pair_transition.py
+4
-5
openfold/model/primitives.py
openfold/model/primitives.py
+7
-4
openfold/model/template.py
openfold/model/template.py
+16
-19
openfold/model/triangular_attention.py
openfold/model/triangular_attention.py
+4
-5
train_openfold.py
train_openfold.py
+6
-2
No files found.
openfold/config.py
View file @
b026de28
...
...
@@ -45,13 +45,12 @@ def model_config(name, train=False, low_prec=False):
if
train
:
c
.
globals
.
blocks_per_ckpt
=
1
c
.
globals
.
chunk_size
=
None
if
low_prec
:
c
.
globals
.
eps
=
1e-4
# If we want exact numerical parity with the original, inf can't be
# a global constant
set_inf
(
c
,
1e
4
)
set_inf
(
c
,
1e
5
)
return
c
...
...
@@ -225,7 +224,8 @@ config = mlc.ConfigDict(
# Recurring FieldReferences that can be changed globally here
"globals"
:
{
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"chunk_size"
:
chunk_size
,
"train_chunk_size"
:
None
,
"eval_chunk_size"
:
chunk_size
,
"c_z"
:
c_z
,
"c_m"
:
c_m
,
"c_t"
:
c_t
,
...
...
@@ -277,8 +277,7 @@ config = mlc.ConfigDict(
"pair_transition_n"
:
2
,
"dropout_rate"
:
0.25
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"chunk_size"
:
chunk_size
,
"inf"
:
1e5
,
# 1e9,
"inf"
:
1e9
,
},
"template_pointwise_attention"
:
{
"c_t"
:
c_t
,
...
...
@@ -287,7 +286,6 @@ config = mlc.ConfigDict(
# It's actually 16.
"c_hidden"
:
16
,
"no_heads"
:
4
,
"chunk_size"
:
chunk_size
,
"inf"
:
1e5
,
# 1e9,
},
"inf"
:
1e5
,
# 1e9,
...
...
@@ -314,8 +312,7 @@ config = mlc.ConfigDict(
"msa_dropout"
:
0.15
,
"pair_dropout"
:
0.25
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"chunk_size"
:
chunk_size
,
"inf"
:
1e5
,
# 1e9,
"inf"
:
1e9
,
"eps"
:
eps
,
# 1e-10,
},
"enabled"
:
True
,
...
...
@@ -335,8 +332,7 @@ config = mlc.ConfigDict(
"msa_dropout"
:
0.15
,
"pair_dropout"
:
0.25
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"chunk_size"
:
chunk_size
,
"inf"
:
1e5
,
# 1e9,
"inf"
:
1e9
,
"eps"
:
eps
,
# 1e-10,
},
"structure_module"
:
{
...
...
openfold/data/templates.py
View file @
b026de28
...
...
@@ -165,7 +165,7 @@ def generate_release_dates_cache(mmcif_dir: str, out_path: str):
file_id
=
file_id
,
mmcif_string
=
mmcif_string
)
if
mmcif
.
mmcif_object
is
None
:
logging
.
warning
(
f
"Failed to parse
{
f
}
. Skipping..."
)
logging
.
info
(
f
"Failed to parse
{
f
}
. Skipping..."
)
continue
mmcif
=
mmcif
.
mmcif_object
...
...
@@ -822,7 +822,7 @@ def _process_single_hit(
if
strict_error_check
:
return
SingleHitResult
(
features
=
None
,
error
=
error
,
warning
=
None
)
else
:
logging
.
warning
(
error
)
logging
.
info
(
error
)
return
SingleHitResult
(
features
=
None
,
error
=
None
,
warning
=
None
)
try
:
...
...
openfold/data/tools/hhsearch.py
View file @
b026de28
...
...
@@ -20,7 +20,7 @@ import os
import
subprocess
from
typing
import
Sequence
from
openfold.data.
np
import
utils
from
openfold.data.
tools
import
utils
class
HHSearch
:
...
...
openfold/model/evoformer.py
View file @
b026de28
...
...
@@ -46,7 +46,7 @@ class MSATransition(nn.Module):
Implements Algorithm 9
"""
def
__init__
(
self
,
c_m
,
n
,
chunk_size
):
def
__init__
(
self
,
c_m
,
n
):
"""
Args:
c_m:
...
...
@@ -59,7 +59,6 @@ class MSATransition(nn.Module):
self
.
c_m
=
c_m
self
.
n
=
n
self
.
chunk_size
=
chunk_size
self
.
layer_norm
=
nn
.
LayerNorm
(
self
.
c_m
)
self
.
linear_1
=
Linear
(
self
.
c_m
,
self
.
n
*
self
.
c_m
,
init
=
"relu"
)
...
...
@@ -76,6 +75,7 @@ class MSATransition(nn.Module):
self
,
m
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
=
None
,
chunk_size
:
int
=
None
,
)
->
torch
.
Tensor
:
"""
Args:
...
...
@@ -96,11 +96,11 @@ class MSATransition(nn.Module):
m
=
self
.
layer_norm
(
m
)
inp
=
{
"m"
:
m
,
"mask"
:
mask
}
if
self
.
chunk_size
is
not
None
:
if
chunk_size
is
not
None
:
m
=
chunk_layer
(
self
.
_transition
,
inp
,
chunk_size
=
self
.
chunk_size
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
)
else
:
...
...
@@ -123,7 +123,6 @@ class EvoformerBlock(nn.Module):
transition_n
:
int
,
msa_dropout
:
float
,
pair_dropout
:
float
,
chunk_size
:
int
,
inf
:
float
,
eps
:
float
,
_is_extra_msa_stack
:
bool
=
False
,
...
...
@@ -135,7 +134,6 @@ class EvoformerBlock(nn.Module):
c_z
=
c_z
,
c_hidden
=
c_hidden_msa_att
,
no_heads
=
no_heads_msa
,
chunk_size
=
chunk_size
,
inf
=
inf
,
)
...
...
@@ -144,7 +142,6 @@ class EvoformerBlock(nn.Module):
c_in
=
c_m
,
c_hidden
=
c_hidden_msa_att
,
no_heads
=
no_heads_msa
,
chunk_size
=
chunk_size
,
inf
=
inf
,
eps
=
eps
,
)
...
...
@@ -153,21 +150,18 @@ class EvoformerBlock(nn.Module):
c_m
,
c_hidden_msa_att
,
no_heads_msa
,
chunk_size
=
chunk_size
,
inf
=
inf
,
)
self
.
msa_transition
=
MSATransition
(
c_m
=
c_m
,
n
=
transition_n
,
chunk_size
=
chunk_size
,
)
self
.
outer_product_mean
=
OuterProductMean
(
c_m
,
c_z
,
c_hidden_opm
,
chunk_size
=
chunk_size
,
)
self
.
tri_mul_out
=
TriangleMultiplicationOutgoing
(
...
...
@@ -183,21 +177,18 @@ class EvoformerBlock(nn.Module):
c_z
,
c_hidden_pair_att
,
no_heads_pair
,
chunk_size
=
chunk_size
,
inf
=
inf
,
)
self
.
tri_att_end
=
TriangleAttentionEndingNode
(
c_z
,
c_hidden_pair_att
,
no_heads_pair
,
chunk_size
=
chunk_size
,
inf
=
inf
,
)
self
.
pair_transition
=
PairTransition
(
c_z
,
transition_n
,
chunk_size
=
chunk_size
,
)
self
.
msa_dropout_layer
=
DropoutRowwise
(
msa_dropout
)
...
...
@@ -210,6 +201,7 @@ class EvoformerBlock(nn.Module):
z
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
int
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# DeepMind doesn't mask these transitions in the source, so _mask_trans
...
...
@@ -218,15 +210,27 @@ class EvoformerBlock(nn.Module):
msa_trans_mask
=
msa_mask
if
_mask_trans
else
None
pair_trans_mask
=
pair_mask
if
_mask_trans
else
None
m
=
m
+
self
.
msa_dropout_layer
(
self
.
msa_att_row
(
m
,
z
,
mask
=
msa_mask
))
m
=
m
+
self
.
msa_att_col
(
m
,
mask
=
msa_mask
)
m
=
m
+
self
.
msa_transition
(
m
,
mask
=
msa_trans_mask
)
z
=
z
+
self
.
outer_product_mean
(
m
,
mask
=
msa_mask
)
m
=
m
+
self
.
msa_dropout_layer
(
self
.
msa_att_row
(
m
,
z
=
z
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
)
)
m
=
m
+
self
.
msa_att_col
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
)
m
=
m
+
self
.
msa_transition
(
m
,
mask
=
msa_trans_mask
,
chunk_size
=
chunk_size
)
z
=
z
+
self
.
outer_product_mean
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
)
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_mul_out
(
z
,
mask
=
pair_mask
))
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_mul_in
(
z
,
mask
=
pair_mask
))
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_att_start
(
z
,
mask
=
pair_mask
))
z
=
z
+
self
.
ps_dropout_col_layer
(
self
.
tri_att_end
(
z
,
mask
=
pair_mask
))
z
=
z
+
self
.
pair_transition
(
z
,
mask
=
pair_trans_mask
)
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_att_start
(
z
,
mask
=
pair_mask
,
chunk_size
=
chunk_size
)
)
z
=
z
+
self
.
ps_dropout_col_layer
(
self
.
tri_att_end
(
z
,
mask
=
pair_mask
,
chunk_size
=
chunk_size
)
)
z
=
z
+
self
.
pair_transition
(
z
,
mask
=
pair_trans_mask
,
chunk_size
=
chunk_size
)
return
m
,
z
...
...
@@ -254,7 +258,6 @@ class EvoformerStack(nn.Module):
msa_dropout
:
float
,
pair_dropout
:
float
,
blocks_per_ckpt
:
int
,
chunk_size
:
int
,
inf
:
float
,
eps
:
float
,
_is_extra_msa_stack
:
bool
=
False
,
...
...
@@ -312,7 +315,6 @@ class EvoformerStack(nn.Module):
transition_n
=
transition_n
,
msa_dropout
=
msa_dropout
,
pair_dropout
=
pair_dropout
,
chunk_size
=
chunk_size
,
inf
=
inf
,
eps
=
eps
,
_is_extra_msa_stack
=
_is_extra_msa_stack
,
...
...
@@ -328,6 +330,7 @@ class EvoformerStack(nn.Module):
z
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
int
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""
...
...
@@ -354,6 +357,7 @@ class EvoformerStack(nn.Module):
b
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
_mask_trans
=
_mask_trans
,
)
for
b
in
self
.
blocks
...
...
@@ -392,7 +396,6 @@ class ExtraMSAStack(nn.Module):
msa_dropout
:
float
,
pair_dropout
:
float
,
blocks_per_ckpt
:
int
,
chunk_size
:
int
,
inf
:
float
,
eps
:
float
,
**
kwargs
,
...
...
@@ -415,7 +418,6 @@ class ExtraMSAStack(nn.Module):
msa_dropout
=
msa_dropout
,
pair_dropout
=
pair_dropout
,
blocks_per_ckpt
=
blocks_per_ckpt
,
chunk_size
=
chunk_size
,
inf
=
inf
,
eps
=
eps
,
_is_extra_msa_stack
=
True
,
...
...
@@ -425,6 +427,7 @@ class ExtraMSAStack(nn.Module):
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
chunk_size
:
int
,
msa_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
pair_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
_mask_trans
:
bool
=
True
,
...
...
@@ -447,6 +450,7 @@ class ExtraMSAStack(nn.Module):
z
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
_mask_trans
=
_mask_trans
,
)
return
z
openfold/model/model.py
View file @
b026de28
...
...
@@ -63,6 +63,8 @@ class AlphaFold(nn.Module):
"""
super
(
AlphaFold
,
self
).
__init__
()
self
.
globals
=
config
.
globals
config
=
config
.
model
template_config
=
config
.
template
extra_msa_config
=
config
.
extra_msa
...
...
@@ -104,7 +106,7 @@ class AlphaFold(nn.Module):
self
.
config
=
config
def
embed_templates
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
):
def
embed_templates
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
,
chunk_size
):
# Embed the templates one at a time (with a poor man's vmap)
template_embeds
=
[]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
...
...
@@ -135,14 +137,13 @@ class AlphaFold(nn.Module):
)
t
=
self
.
template_pair_embedder
(
t
)
t
=
self
.
template_pair_stack
(
t
,
pair_mask
.
unsqueeze
(
-
3
),
_mask_trans
=
self
.
config
.
_mask_trans
t
,
pair_mask
.
unsqueeze
(
-
3
),
chunk_size
=
chunk_size
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
single_template_embeds
.
update
(
{
"pair"
:
t
,
}
)
single_template_embeds
.
update
({
"pair"
:
t
})
template_embeds
.
append
(
single_template_embeds
)
...
...
@@ -153,7 +154,10 @@ class AlphaFold(nn.Module):
# [*, N, N, C_z]
t
=
self
.
template_pointwise_att
(
template_embeds
[
"pair"
],
z
,
template_mask
=
batch
[
"template_mask"
]
template_embeds
[
"pair"
],
z
,
template_mask
=
batch
[
"template_mask"
],
chunk_size
=
chunk_size
,
)
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
])
>
0
)
...
...
@@ -161,15 +165,17 @@ class AlphaFold(nn.Module):
if
self
.
config
.
template
.
embed_angles
:
ret
[
"template_angle_embedding"
]
=
template_embeds
[
"angle"
]
ret
.
update
(
{
"template_pair_embedding"
:
t
,
}
)
ret
.
update
({
"template_pair_embedding"
:
t
})
return
ret
def
iteration
(
self
,
feats
,
m_1_prev
,
z_prev
,
x_prev
):
# Establish constants
chunk_size
=
(
self
.
globals
.
train_chunk_size
if
self
.
training
else
self
.
globals
.
eval_chunk_size
)
# Primary output dictionary
outputs
=
{}
...
...
@@ -243,6 +249,7 @@ class AlphaFold(nn.Module):
z
,
pair_mask
,
no_batch_dims
,
chunk_size
,
)
# [*, N, N, C_z]
...
...
@@ -270,6 +277,7 @@ class AlphaFold(nn.Module):
a
,
z
,
msa_mask
=
feats
[
"extra_msa_mask"
],
chunk_size
=
chunk_size
,
pair_mask
=
pair_mask
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
...
...
@@ -283,6 +291,7 @@ class AlphaFold(nn.Module):
z
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
...
...
openfold/model/msa.py
View file @
b026de28
...
...
@@ -34,7 +34,6 @@ class MSAAttention(nn.Module):
no_heads
,
pair_bias
=
False
,
c_z
=
None
,
chunk_size
=
4
,
inf
=
1e9
,
):
"""
...
...
@@ -60,7 +59,6 @@ class MSAAttention(nn.Module):
self
.
no_heads
=
no_heads
self
.
pair_bias
=
pair_bias
self
.
c_z
=
c_z
self
.
chunk_size
=
chunk_size
self
.
inf
=
inf
self
.
layer_norm_m
=
nn
.
LayerNorm
(
self
.
c_in
)
...
...
@@ -75,7 +73,7 @@ class MSAAttention(nn.Module):
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_hidden
,
self
.
no_heads
)
def
forward
(
self
,
m
,
z
=
None
,
mask
=
None
):
def
forward
(
self
,
m
,
chunk_size
,
z
=
None
,
mask
=
None
):
"""
Args:
m:
...
...
@@ -117,11 +115,11 @@ class MSAAttention(nn.Module):
biases
.
append
(
z
)
mha_inputs
=
{
"q_x"
:
m
,
"k_x"
:
m
,
"v_x"
:
m
,
"biases"
:
biases
}
if
self
.
chunk_size
is
not
None
:
if
chunk_size
is
not
None
:
m
=
chunk_layer
(
self
.
mha
,
mha_inputs
,
chunk_size
=
self
.
chunk_size
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
)
else
:
...
...
@@ -135,7 +133,7 @@ class MSARowAttentionWithPairBias(MSAAttention):
Implements Algorithm 7.
"""
def
__init__
(
self
,
c_m
,
c_z
,
c_hidden
,
no_heads
,
chunk_size
,
inf
=
1e9
):
def
__init__
(
self
,
c_m
,
c_z
,
c_hidden
,
no_heads
,
inf
=
1e9
):
"""
Args:
c_m:
...
...
@@ -155,7 +153,6 @@ class MSARowAttentionWithPairBias(MSAAttention):
no_heads
,
pair_bias
=
True
,
c_z
=
c_z
,
chunk_size
=
chunk_size
,
inf
=
inf
,
)
...
...
@@ -165,7 +162,7 @@ class MSAColumnAttention(MSAAttention):
Implements Algorithm 8.
"""
def
__init__
(
self
,
c_m
,
c_hidden
,
no_heads
,
chunk_size
=
4
,
inf
=
1e9
):
def
__init__
(
self
,
c_m
,
c_hidden
,
no_heads
,
inf
=
1e9
):
"""
Args:
c_m:
...
...
@@ -183,11 +180,10 @@ class MSAColumnAttention(MSAAttention):
no_heads
=
no_heads
,
pair_bias
=
False
,
c_z
=
None
,
chunk_size
=
chunk_size
,
inf
=
inf
,
)
def
forward
(
self
,
m
,
mask
=
None
):
def
forward
(
self
,
m
,
chunk_size
,
mask
=
None
):
"""
Args:
m:
...
...
@@ -200,7 +196,7 @@ class MSAColumnAttention(MSAAttention):
if
mask
is
not
None
:
mask
=
mask
.
transpose
(
-
1
,
-
2
)
m
=
super
().
forward
(
m
,
mask
=
mask
)
m
=
super
().
forward
(
m
,
chunk_size
=
chunk_size
,
mask
=
mask
)
# [*, N_seq, N_res, C_in]
m
=
m
.
transpose
(
-
2
,
-
3
)
...
...
@@ -211,14 +207,13 @@ class MSAColumnAttention(MSAAttention):
class
MSAColumnGlobalAttention
(
nn
.
Module
):
def
__init__
(
self
,
c_in
,
c_hidden
,
no_heads
,
chunk_size
=
4
,
inf
=
1e9
,
eps
=
1e-10
self
,
c_in
,
c_hidden
,
no_heads
,
inf
=
1e9
,
eps
=
1e-10
):
super
(
MSAColumnGlobalAttention
,
self
).
__init__
()
self
.
c_in
=
c_in
self
.
c_hidden
=
c_hidden
self
.
no_heads
=
no_heads
self
.
chunk_size
=
chunk_size
self
.
inf
=
inf
self
.
eps
=
eps
...
...
@@ -233,7 +228,7 @@ class MSAColumnGlobalAttention(nn.Module):
)
def
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
self
,
m
:
torch
.
Tensor
,
chunk_size
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
n_seq
,
n_res
,
c_in
=
m
.
shape
[
-
3
:]
...
...
@@ -256,11 +251,11 @@ class MSAColumnGlobalAttention(nn.Module):
"m"
:
m
,
"mask"
:
mask
,
}
if
self
.
chunk_size
is
not
None
:
if
chunk_size
is
not
None
:
m
=
chunk_layer
(
self
.
global_attention
,
mha_input
,
chunk_size
=
self
.
chunk_size
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
)
else
:
...
...
openfold/model/outer_product_mean.py
View file @
b026de28
...
...
@@ -26,7 +26,7 @@ class OuterProductMean(nn.Module):
Implements Algorithm 10.
"""
def
__init__
(
self
,
c_m
,
c_z
,
c_hidden
,
chunk_size
=
4
,
eps
=
1e-3
):
def
__init__
(
self
,
c_m
,
c_z
,
c_hidden
,
eps
=
1e-3
):
"""
Args:
c_m:
...
...
@@ -40,7 +40,6 @@ class OuterProductMean(nn.Module):
self
.
c_z
=
c_z
self
.
c_hidden
=
c_hidden
self
.
chunk_size
=
chunk_size
self
.
eps
=
eps
self
.
layer_norm
=
nn
.
LayerNorm
(
c_m
)
...
...
@@ -60,7 +59,7 @@ class OuterProductMean(nn.Module):
return
outer
def
forward
(
self
,
m
,
mask
=
None
):
def
forward
(
self
,
m
,
chunk_size
,
mask
=
None
):
"""
Args:
m:
...
...
@@ -84,7 +83,7 @@ class OuterProductMean(nn.Module):
a
=
a
.
transpose
(
-
2
,
-
3
)
b
=
b
.
transpose
(
-
2
,
-
3
)
if
self
.
chunk_size
is
not
None
:
if
chunk_size
is
not
None
:
# Since the "batch dim" in this case is not a true batch dimension
# (in that the shape of the output depends on it), we need to
# iterate over it ourselves
...
...
@@ -95,7 +94,7 @@ class OuterProductMean(nn.Module):
outer
=
chunk_layer
(
partial
(
self
.
_opm
,
b
=
b_prime
),
{
"a"
:
a_prime
},
chunk_size
=
self
.
chunk_size
,
chunk_size
=
chunk_size
,
no_batch_dims
=
1
,
)
out
.
append
(
outer
)
...
...
openfold/model/pair_transition.py
View file @
b026de28
...
...
@@ -25,7 +25,7 @@ class PairTransition(nn.Module):
Implements Algorithm 15.
"""
def
__init__
(
self
,
c_z
,
n
,
chunk_size
=
4
):
def
__init__
(
self
,
c_z
,
n
):
"""
Args:
c_z:
...
...
@@ -38,7 +38,6 @@ class PairTransition(nn.Module):
self
.
c_z
=
c_z
self
.
n
=
n
self
.
chunk_size
=
chunk_size
self
.
layer_norm
=
nn
.
LayerNorm
(
self
.
c_z
)
self
.
linear_1
=
Linear
(
self
.
c_z
,
self
.
n
*
self
.
c_z
,
init
=
"relu"
)
...
...
@@ -55,7 +54,7 @@ class PairTransition(nn.Module):
return
z
def
forward
(
self
,
z
,
mask
=
None
):
def
forward
(
self
,
z
,
chunk_size
,
mask
=
None
):
"""
Args:
z:
...
...
@@ -74,11 +73,11 @@ class PairTransition(nn.Module):
z
=
self
.
layer_norm
(
z
)
inp
=
{
"z"
:
z
,
"mask"
:
mask
}
if
self
.
chunk_size
is
not
None
:
if
chunk_size
is
not
None
:
z
=
chunk_layer
(
self
.
_transition
,
inp
,
chunk_size
=
self
.
chunk_size
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
z
.
shape
[:
-
2
]),
)
else
:
...
...
openfold/model/primitives.py
View file @
b026de28
...
...
@@ -260,11 +260,14 @@ class Attention(nn.Module):
# [*, H, Q, K]
a
=
torch
.
matmul
(
permute_final_dims
(
q
,
(
0
,
2
,
1
,
3
)),
# [*, H, Q, C_hidden]
permute_final_dims
(
k
,
(
0
,
2
,
3
,
1
)),
# [*, H, C_hidden, K]
permute_final_dims
(
q
,
(
1
,
0
,
2
)),
# [*, H, Q, C_hidden]
permute_final_dims
(
k
,
(
1
,
2
,
0
)),
# [*, H, C_hidden, K]
)
del
q
,
k
norm
=
1
/
math
.
sqrt
(
self
.
c_hidden
)
# [1]
a
=
a
*
norm
a
*
=
norm
if
biases
is
not
None
:
for
b
in
biases
:
a
=
a
+
b
...
...
@@ -273,7 +276,7 @@ class Attention(nn.Module):
# [*, H, Q, C_hidden]
o
=
torch
.
matmul
(
a
,
permute_final_dims
(
v
,
(
0
,
2
,
1
,
3
)),
# [*, H, V, C_hidden]
permute_final_dims
(
v
,
(
1
,
0
,
2
)),
# [*, H, V, C_hidden]
)
# [*, Q, H, C_hidden]
...
...
openfold/model/template.py
View file @
b026de28
...
...
@@ -45,7 +45,7 @@ class TemplatePointwiseAttention(nn.Module):
Implements Algorithm 17.
"""
def
__init__
(
self
,
c_t
,
c_z
,
c_hidden
,
no_heads
,
chunk_size
,
inf
,
**
kwargs
):
def
__init__
(
self
,
c_t
,
c_z
,
c_hidden
,
no_heads
,
inf
,
**
kwargs
):
"""
Args:
c_t:
...
...
@@ -61,7 +61,6 @@ class TemplatePointwiseAttention(nn.Module):
self
.
c_z
=
c_z
self
.
c_hidden
=
c_hidden
self
.
no_heads
=
no_heads
self
.
chunk_size
=
chunk_size
self
.
inf
=
inf
self
.
mha
=
Attention
(
...
...
@@ -73,7 +72,7 @@ class TemplatePointwiseAttention(nn.Module):
gating
=
False
,
)
def
forward
(
self
,
t
,
z
,
template_mask
=
None
):
def
forward
(
self
,
t
,
z
,
chunk_size
,
template_mask
=
None
):
"""
Args:
t:
...
...
@@ -106,11 +105,11 @@ class TemplatePointwiseAttention(nn.Module):
"v_x"
:
t
,
"biases"
:
[
bias
],
}
if
self
.
chunk_size
is
not
None
:
if
chunk_size
is
not
None
:
z
=
chunk_layer
(
self
.
mha
,
mha_inputs
,
chunk_size
=
self
.
chunk_size
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
z
.
shape
[:
-
2
]),
)
else
:
...
...
@@ -131,7 +130,6 @@ class TemplatePairStackBlock(nn.Module):
no_heads
,
pair_transition_n
,
dropout_rate
,
chunk_size
,
inf
,
**
kwargs
,
):
...
...
@@ -143,7 +141,6 @@ class TemplatePairStackBlock(nn.Module):
self
.
no_heads
=
no_heads
self
.
pair_transition_n
=
pair_transition_n
self
.
dropout_rate
=
dropout_rate
self
.
chunk_size
=
chunk_size
self
.
inf
=
inf
self
.
dropout_row
=
DropoutRowwise
(
self
.
dropout_rate
)
...
...
@@ -153,14 +150,12 @@ class TemplatePairStackBlock(nn.Module):
self
.
c_t
,
self
.
c_hidden_tri_att
,
self
.
no_heads
,
chunk_size
=
chunk_size
,
inf
=
inf
,
)
self
.
tri_att_end
=
TriangleAttentionEndingNode
(
self
.
c_t
,
self
.
c_hidden_tri_att
,
self
.
no_heads
,
chunk_size
=
chunk_size
,
inf
=
inf
,
)
...
...
@@ -176,15 +171,20 @@ class TemplatePairStackBlock(nn.Module):
self
.
pair_transition
=
PairTransition
(
self
.
c_t
,
self
.
pair_transition_n
,
chunk_size
=
chunk_size
,
)
def
forward
(
self
,
z
,
mask
,
_mask_trans
=
True
):
z
=
z
+
self
.
dropout_row
(
self
.
tri_att_start
(
z
,
mask
=
mask
))
z
=
z
+
self
.
dropout_col
(
self
.
tri_att_end
(
z
,
mask
=
mask
))
def
forward
(
self
,
z
,
mask
,
chunk_size
,
_mask_trans
=
True
):
z
=
z
+
self
.
dropout_row
(
self
.
tri_att_start
(
z
,
chunk_size
=
chunk_size
,
mask
=
mask
)
)
z
=
z
+
self
.
dropout_col
(
self
.
tri_att_end
(
z
,
chunk_size
=
chunk_size
,
mask
=
mask
)
)
z
=
z
+
self
.
dropout_row
(
self
.
tri_mul_out
(
z
,
mask
=
mask
))
z
=
z
+
self
.
dropout_row
(
self
.
tri_mul_in
(
z
,
mask
=
mask
))
z
=
z
+
self
.
pair_transition
(
z
,
mask
=
mask
if
_mask_trans
else
None
)
z
=
z
+
self
.
pair_transition
(
z
,
chunk_size
=
chunk_size
,
mask
=
mask
if
_mask_trans
else
None
)
return
z
...
...
@@ -204,7 +204,6 @@ class TemplatePairStack(nn.Module):
pair_transition_n
,
dropout_rate
,
blocks_per_ckpt
,
chunk_size
,
inf
=
1e9
,
**
kwargs
,
):
...
...
@@ -225,9 +224,6 @@ class TemplatePairStack(nn.Module):
blocks_per_ckpt:
Number of blocks per activation checkpoint. None disables
activation checkpointing
chunk_size:
Size of subbatches. A higher value increases throughput at
the cost of memory
"""
super
(
TemplatePairStack
,
self
).
__init__
()
...
...
@@ -242,7 +238,6 @@ class TemplatePairStack(nn.Module):
no_heads
=
no_heads
,
pair_transition_n
=
pair_transition_n
,
dropout_rate
=
dropout_rate
,
chunk_size
=
chunk_size
,
inf
=
inf
,
)
self
.
blocks
.
append
(
block
)
...
...
@@ -253,6 +248,7 @@ class TemplatePairStack(nn.Module):
self
,
t
:
torch
.
tensor
,
mask
:
torch
.
tensor
,
chunk_size
:
int
,
_mask_trans
:
bool
=
True
,
):
"""
...
...
@@ -269,6 +265,7 @@ class TemplatePairStack(nn.Module):
partial
(
b
,
mask
=
mask
,
chunk_size
=
chunk_size
,
_mask_trans
=
_mask_trans
,
)
for
b
in
self
.
blocks
...
...
openfold/model/triangular_attention.py
View file @
b026de28
...
...
@@ -28,7 +28,7 @@ from openfold.utils.tensor_utils import (
class
TriangleAttention
(
nn
.
Module
):
def
__init__
(
self
,
c_in
,
c_hidden
,
no_heads
,
starting
,
chunk_size
=
4
,
inf
=
1e9
self
,
c_in
,
c_hidden
,
no_heads
,
starting
,
inf
=
1e9
):
"""
Args:
...
...
@@ -45,7 +45,6 @@ class TriangleAttention(nn.Module):
self
.
c_hidden
=
c_hidden
self
.
no_heads
=
no_heads
self
.
starting
=
starting
self
.
chunk_size
=
chunk_size
self
.
inf
=
inf
self
.
layer_norm
=
nn
.
LayerNorm
(
self
.
c_in
)
...
...
@@ -56,7 +55,7 @@ class TriangleAttention(nn.Module):
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_hidden
,
self
.
no_heads
)
def
forward
(
self
,
x
,
mask
=
None
):
def
forward
(
self
,
x
,
chunk_size
,
mask
=
None
):
"""
Args:
x:
...
...
@@ -93,11 +92,11 @@ class TriangleAttention(nn.Module):
"v_x"
:
x
,
"biases"
:
[
mask_bias
,
triangle_bias
],
}
if
self
.
chunk_size
is
not
None
:
if
chunk_size
is
not
None
:
x
=
chunk_layer
(
self
.
mha
,
mha_inputs
,
chunk_size
=
self
.
chunk_size
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
x
.
shape
[:
-
2
]),
)
else
:
...
...
train_openfold.py
View file @
b026de28
...
...
@@ -2,7 +2,7 @@ import argparse
import
logging
import
os
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"
4
"
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"
5
"
import
random
import
time
...
...
@@ -26,7 +26,7 @@ class OpenFoldWrapper(pl.LightningModule):
def
__init__
(
self
,
config
):
super
(
OpenFoldWrapper
,
self
).
__init__
()
self
.
config
=
config
self
.
model
=
AlphaFold
(
config
.
model
)
self
.
model
=
AlphaFold
(
config
)
self
.
loss
=
AlphaFoldLoss
(
config
.
loss
)
self
.
ema
=
ExponentialMovingAverage
(
self
.
model
,
decay
=
config
.
ema
.
decay
)
...
...
@@ -50,6 +50,9 @@ class OpenFoldWrapper(pl.LightningModule):
with
open
(
"prediction/preds_"
+
str
(
time
.
strftime
(
"%H:%M:%S"
))
+
".pickle"
,
"wb"
)
as
f
:
pickle
.
dump
(
out
,
f
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
#def validation_step(self, batch, batch_idx):
# outputs = self(batch)
def
configure_optimizers
(
self
,
learning_rate
:
float
=
1e-3
,
eps
:
float
=
1e-8
...
...
@@ -64,6 +67,7 @@ class OpenFoldWrapper(pl.LightningModule):
def
on_before_zero_grad
(
self
,
*
args
,
**
kwargs
):
self
.
ema
.
update
(
self
.
model
)
def
main
(
args
):
config
=
model_config
(
"model_1"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment