Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
b026de28
Commit
b026de28
authored
Oct 17, 2021
by
Gustaf Ahdritz
Browse files
Make chunk size more flexible, reduce verbosity
parent
a59ae7c1
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
111 additions
and
106 deletions
+111
-106
openfold/config.py
openfold/config.py
+6
-10
openfold/data/templates.py
openfold/data/templates.py
+2
-2
openfold/data/tools/hhsearch.py
openfold/data/tools/hhsearch.py
+1
-1
openfold/model/evoformer.py
openfold/model/evoformer.py
+28
-24
openfold/model/model.py
openfold/model/model.py
+22
-13
openfold/model/msa.py
openfold/model/msa.py
+11
-16
openfold/model/outer_product_mean.py
openfold/model/outer_product_mean.py
+4
-5
openfold/model/pair_transition.py
openfold/model/pair_transition.py
+4
-5
openfold/model/primitives.py
openfold/model/primitives.py
+7
-4
openfold/model/template.py
openfold/model/template.py
+16
-19
openfold/model/triangular_attention.py
openfold/model/triangular_attention.py
+4
-5
train_openfold.py
train_openfold.py
+6
-2
No files found.
openfold/config.py
View file @
b026de28
...
@@ -45,13 +45,12 @@ def model_config(name, train=False, low_prec=False):
...
@@ -45,13 +45,12 @@ def model_config(name, train=False, low_prec=False):
if
train
:
if
train
:
c
.
globals
.
blocks_per_ckpt
=
1
c
.
globals
.
blocks_per_ckpt
=
1
c
.
globals
.
chunk_size
=
None
if
low_prec
:
if
low_prec
:
c
.
globals
.
eps
=
1e-4
c
.
globals
.
eps
=
1e-4
# If we want exact numerical parity with the original, inf can't be
# If we want exact numerical parity with the original, inf can't be
# a global constant
# a global constant
set_inf
(
c
,
1e
4
)
set_inf
(
c
,
1e
5
)
return
c
return
c
...
@@ -225,7 +224,8 @@ config = mlc.ConfigDict(
...
@@ -225,7 +224,8 @@ config = mlc.ConfigDict(
# Recurring FieldReferences that can be changed globally here
# Recurring FieldReferences that can be changed globally here
"globals"
:
{
"globals"
:
{
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"chunk_size"
:
chunk_size
,
"train_chunk_size"
:
None
,
"eval_chunk_size"
:
chunk_size
,
"c_z"
:
c_z
,
"c_z"
:
c_z
,
"c_m"
:
c_m
,
"c_m"
:
c_m
,
"c_t"
:
c_t
,
"c_t"
:
c_t
,
...
@@ -277,8 +277,7 @@ config = mlc.ConfigDict(
...
@@ -277,8 +277,7 @@ config = mlc.ConfigDict(
"pair_transition_n"
:
2
,
"pair_transition_n"
:
2
,
"dropout_rate"
:
0.25
,
"dropout_rate"
:
0.25
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"chunk_size"
:
chunk_size
,
"inf"
:
1e9
,
"inf"
:
1e5
,
# 1e9,
},
},
"template_pointwise_attention"
:
{
"template_pointwise_attention"
:
{
"c_t"
:
c_t
,
"c_t"
:
c_t
,
...
@@ -287,7 +286,6 @@ config = mlc.ConfigDict(
...
@@ -287,7 +286,6 @@ config = mlc.ConfigDict(
# It's actually 16.
# It's actually 16.
"c_hidden"
:
16
,
"c_hidden"
:
16
,
"no_heads"
:
4
,
"no_heads"
:
4
,
"chunk_size"
:
chunk_size
,
"inf"
:
1e5
,
# 1e9,
"inf"
:
1e5
,
# 1e9,
},
},
"inf"
:
1e5
,
# 1e9,
"inf"
:
1e5
,
# 1e9,
...
@@ -314,8 +312,7 @@ config = mlc.ConfigDict(
...
@@ -314,8 +312,7 @@ config = mlc.ConfigDict(
"msa_dropout"
:
0.15
,
"msa_dropout"
:
0.15
,
"pair_dropout"
:
0.25
,
"pair_dropout"
:
0.25
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"chunk_size"
:
chunk_size
,
"inf"
:
1e9
,
"inf"
:
1e5
,
# 1e9,
"eps"
:
eps
,
# 1e-10,
"eps"
:
eps
,
# 1e-10,
},
},
"enabled"
:
True
,
"enabled"
:
True
,
...
@@ -335,8 +332,7 @@ config = mlc.ConfigDict(
...
@@ -335,8 +332,7 @@ config = mlc.ConfigDict(
"msa_dropout"
:
0.15
,
"msa_dropout"
:
0.15
,
"pair_dropout"
:
0.25
,
"pair_dropout"
:
0.25
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"chunk_size"
:
chunk_size
,
"inf"
:
1e9
,
"inf"
:
1e5
,
# 1e9,
"eps"
:
eps
,
# 1e-10,
"eps"
:
eps
,
# 1e-10,
},
},
"structure_module"
:
{
"structure_module"
:
{
...
...
openfold/data/templates.py
View file @
b026de28
...
@@ -165,7 +165,7 @@ def generate_release_dates_cache(mmcif_dir: str, out_path: str):
...
@@ -165,7 +165,7 @@ def generate_release_dates_cache(mmcif_dir: str, out_path: str):
file_id
=
file_id
,
mmcif_string
=
mmcif_string
file_id
=
file_id
,
mmcif_string
=
mmcif_string
)
)
if
mmcif
.
mmcif_object
is
None
:
if
mmcif
.
mmcif_object
is
None
:
logging
.
warning
(
f
"Failed to parse
{
f
}
. Skipping..."
)
logging
.
info
(
f
"Failed to parse
{
f
}
. Skipping..."
)
continue
continue
mmcif
=
mmcif
.
mmcif_object
mmcif
=
mmcif
.
mmcif_object
...
@@ -822,7 +822,7 @@ def _process_single_hit(
...
@@ -822,7 +822,7 @@ def _process_single_hit(
if
strict_error_check
:
if
strict_error_check
:
return
SingleHitResult
(
features
=
None
,
error
=
error
,
warning
=
None
)
return
SingleHitResult
(
features
=
None
,
error
=
error
,
warning
=
None
)
else
:
else
:
logging
.
warning
(
error
)
logging
.
info
(
error
)
return
SingleHitResult
(
features
=
None
,
error
=
None
,
warning
=
None
)
return
SingleHitResult
(
features
=
None
,
error
=
None
,
warning
=
None
)
try
:
try
:
...
...
openfold/data/tools/hhsearch.py
View file @
b026de28
...
@@ -20,7 +20,7 @@ import os
...
@@ -20,7 +20,7 @@ import os
import
subprocess
import
subprocess
from
typing
import
Sequence
from
typing
import
Sequence
from
openfold.data.
np
import
utils
from
openfold.data.
tools
import
utils
class
HHSearch
:
class
HHSearch
:
...
...
openfold/model/evoformer.py
View file @
b026de28
...
@@ -46,7 +46,7 @@ class MSATransition(nn.Module):
...
@@ -46,7 +46,7 @@ class MSATransition(nn.Module):
Implements Algorithm 9
Implements Algorithm 9
"""
"""
def
__init__
(
self
,
c_m
,
n
,
chunk_size
):
def
__init__
(
self
,
c_m
,
n
):
"""
"""
Args:
Args:
c_m:
c_m:
...
@@ -59,7 +59,6 @@ class MSATransition(nn.Module):
...
@@ -59,7 +59,6 @@ class MSATransition(nn.Module):
self
.
c_m
=
c_m
self
.
c_m
=
c_m
self
.
n
=
n
self
.
n
=
n
self
.
chunk_size
=
chunk_size
self
.
layer_norm
=
nn
.
LayerNorm
(
self
.
c_m
)
self
.
layer_norm
=
nn
.
LayerNorm
(
self
.
c_m
)
self
.
linear_1
=
Linear
(
self
.
c_m
,
self
.
n
*
self
.
c_m
,
init
=
"relu"
)
self
.
linear_1
=
Linear
(
self
.
c_m
,
self
.
n
*
self
.
c_m
,
init
=
"relu"
)
...
@@ -76,6 +75,7 @@ class MSATransition(nn.Module):
...
@@ -76,6 +75,7 @@ class MSATransition(nn.Module):
self
,
self
,
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
=
None
,
mask
:
torch
.
Tensor
=
None
,
chunk_size
:
int
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Args:
Args:
...
@@ -96,11 +96,11 @@ class MSATransition(nn.Module):
...
@@ -96,11 +96,11 @@ class MSATransition(nn.Module):
m
=
self
.
layer_norm
(
m
)
m
=
self
.
layer_norm
(
m
)
inp
=
{
"m"
:
m
,
"mask"
:
mask
}
inp
=
{
"m"
:
m
,
"mask"
:
mask
}
if
self
.
chunk_size
is
not
None
:
if
chunk_size
is
not
None
:
m
=
chunk_layer
(
m
=
chunk_layer
(
self
.
_transition
,
self
.
_transition
,
inp
,
inp
,
chunk_size
=
self
.
chunk_size
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
)
)
else
:
else
:
...
@@ -123,7 +123,6 @@ class EvoformerBlock(nn.Module):
...
@@ -123,7 +123,6 @@ class EvoformerBlock(nn.Module):
transition_n
:
int
,
transition_n
:
int
,
msa_dropout
:
float
,
msa_dropout
:
float
,
pair_dropout
:
float
,
pair_dropout
:
float
,
chunk_size
:
int
,
inf
:
float
,
inf
:
float
,
eps
:
float
,
eps
:
float
,
_is_extra_msa_stack
:
bool
=
False
,
_is_extra_msa_stack
:
bool
=
False
,
...
@@ -135,7 +134,6 @@ class EvoformerBlock(nn.Module):
...
@@ -135,7 +134,6 @@ class EvoformerBlock(nn.Module):
c_z
=
c_z
,
c_z
=
c_z
,
c_hidden
=
c_hidden_msa_att
,
c_hidden
=
c_hidden_msa_att
,
no_heads
=
no_heads_msa
,
no_heads
=
no_heads_msa
,
chunk_size
=
chunk_size
,
inf
=
inf
,
inf
=
inf
,
)
)
...
@@ -144,7 +142,6 @@ class EvoformerBlock(nn.Module):
...
@@ -144,7 +142,6 @@ class EvoformerBlock(nn.Module):
c_in
=
c_m
,
c_in
=
c_m
,
c_hidden
=
c_hidden_msa_att
,
c_hidden
=
c_hidden_msa_att
,
no_heads
=
no_heads_msa
,
no_heads
=
no_heads_msa
,
chunk_size
=
chunk_size
,
inf
=
inf
,
inf
=
inf
,
eps
=
eps
,
eps
=
eps
,
)
)
...
@@ -153,21 +150,18 @@ class EvoformerBlock(nn.Module):
...
@@ -153,21 +150,18 @@ class EvoformerBlock(nn.Module):
c_m
,
c_m
,
c_hidden_msa_att
,
c_hidden_msa_att
,
no_heads_msa
,
no_heads_msa
,
chunk_size
=
chunk_size
,
inf
=
inf
,
inf
=
inf
,
)
)
self
.
msa_transition
=
MSATransition
(
self
.
msa_transition
=
MSATransition
(
c_m
=
c_m
,
c_m
=
c_m
,
n
=
transition_n
,
n
=
transition_n
,
chunk_size
=
chunk_size
,
)
)
self
.
outer_product_mean
=
OuterProductMean
(
self
.
outer_product_mean
=
OuterProductMean
(
c_m
,
c_m
,
c_z
,
c_z
,
c_hidden_opm
,
c_hidden_opm
,
chunk_size
=
chunk_size
,
)
)
self
.
tri_mul_out
=
TriangleMultiplicationOutgoing
(
self
.
tri_mul_out
=
TriangleMultiplicationOutgoing
(
...
@@ -183,21 +177,18 @@ class EvoformerBlock(nn.Module):
...
@@ -183,21 +177,18 @@ class EvoformerBlock(nn.Module):
c_z
,
c_z
,
c_hidden_pair_att
,
c_hidden_pair_att
,
no_heads_pair
,
no_heads_pair
,
chunk_size
=
chunk_size
,
inf
=
inf
,
inf
=
inf
,
)
)
self
.
tri_att_end
=
TriangleAttentionEndingNode
(
self
.
tri_att_end
=
TriangleAttentionEndingNode
(
c_z
,
c_z
,
c_hidden_pair_att
,
c_hidden_pair_att
,
no_heads_pair
,
no_heads_pair
,
chunk_size
=
chunk_size
,
inf
=
inf
,
inf
=
inf
,
)
)
self
.
pair_transition
=
PairTransition
(
self
.
pair_transition
=
PairTransition
(
c_z
,
c_z
,
transition_n
,
transition_n
,
chunk_size
=
chunk_size
,
)
)
self
.
msa_dropout_layer
=
DropoutRowwise
(
msa_dropout
)
self
.
msa_dropout_layer
=
DropoutRowwise
(
msa_dropout
)
...
@@ -210,6 +201,7 @@ class EvoformerBlock(nn.Module):
...
@@ -210,6 +201,7 @@ class EvoformerBlock(nn.Module):
z
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
int
,
_mask_trans
:
bool
=
True
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# DeepMind doesn't mask these transitions in the source, so _mask_trans
# DeepMind doesn't mask these transitions in the source, so _mask_trans
...
@@ -218,15 +210,27 @@ class EvoformerBlock(nn.Module):
...
@@ -218,15 +210,27 @@ class EvoformerBlock(nn.Module):
msa_trans_mask
=
msa_mask
if
_mask_trans
else
None
msa_trans_mask
=
msa_mask
if
_mask_trans
else
None
pair_trans_mask
=
pair_mask
if
_mask_trans
else
None
pair_trans_mask
=
pair_mask
if
_mask_trans
else
None
m
=
m
+
self
.
msa_dropout_layer
(
self
.
msa_att_row
(
m
,
z
,
mask
=
msa_mask
))
m
=
m
+
self
.
msa_dropout_layer
(
m
=
m
+
self
.
msa_att_col
(
m
,
mask
=
msa_mask
)
self
.
msa_att_row
(
m
,
z
=
z
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
)
m
=
m
+
self
.
msa_transition
(
m
,
mask
=
msa_trans_mask
)
)
z
=
z
+
self
.
outer_product_mean
(
m
,
mask
=
msa_mask
)
m
=
m
+
self
.
msa_att_col
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
)
m
=
m
+
self
.
msa_transition
(
m
,
mask
=
msa_trans_mask
,
chunk_size
=
chunk_size
)
z
=
z
+
self
.
outer_product_mean
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
)
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_mul_out
(
z
,
mask
=
pair_mask
))
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_mul_out
(
z
,
mask
=
pair_mask
))
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_mul_in
(
z
,
mask
=
pair_mask
))
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_mul_in
(
z
,
mask
=
pair_mask
))
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_att_start
(
z
,
mask
=
pair_mask
))
z
=
z
+
self
.
ps_dropout_row_layer
(
z
=
z
+
self
.
ps_dropout_col_layer
(
self
.
tri_att_end
(
z
,
mask
=
pair_mask
))
self
.
tri_att_start
(
z
,
mask
=
pair_mask
,
chunk_size
=
chunk_size
)
z
=
z
+
self
.
pair_transition
(
z
,
mask
=
pair_trans_mask
)
)
z
=
z
+
self
.
ps_dropout_col_layer
(
self
.
tri_att_end
(
z
,
mask
=
pair_mask
,
chunk_size
=
chunk_size
)
)
z
=
z
+
self
.
pair_transition
(
z
,
mask
=
pair_trans_mask
,
chunk_size
=
chunk_size
)
return
m
,
z
return
m
,
z
...
@@ -254,7 +258,6 @@ class EvoformerStack(nn.Module):
...
@@ -254,7 +258,6 @@ class EvoformerStack(nn.Module):
msa_dropout
:
float
,
msa_dropout
:
float
,
pair_dropout
:
float
,
pair_dropout
:
float
,
blocks_per_ckpt
:
int
,
blocks_per_ckpt
:
int
,
chunk_size
:
int
,
inf
:
float
,
inf
:
float
,
eps
:
float
,
eps
:
float
,
_is_extra_msa_stack
:
bool
=
False
,
_is_extra_msa_stack
:
bool
=
False
,
...
@@ -312,7 +315,6 @@ class EvoformerStack(nn.Module):
...
@@ -312,7 +315,6 @@ class EvoformerStack(nn.Module):
transition_n
=
transition_n
,
transition_n
=
transition_n
,
msa_dropout
=
msa_dropout
,
msa_dropout
=
msa_dropout
,
pair_dropout
=
pair_dropout
,
pair_dropout
=
pair_dropout
,
chunk_size
=
chunk_size
,
inf
=
inf
,
inf
=
inf
,
eps
=
eps
,
eps
=
eps
,
_is_extra_msa_stack
=
_is_extra_msa_stack
,
_is_extra_msa_stack
=
_is_extra_msa_stack
,
...
@@ -328,6 +330,7 @@ class EvoformerStack(nn.Module):
...
@@ -328,6 +330,7 @@ class EvoformerStack(nn.Module):
z
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
int
,
_mask_trans
:
bool
=
True
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""
"""
...
@@ -354,6 +357,7 @@ class EvoformerStack(nn.Module):
...
@@ -354,6 +357,7 @@ class EvoformerStack(nn.Module):
b
,
b
,
msa_mask
=
msa_mask
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
_mask_trans
=
_mask_trans
,
_mask_trans
=
_mask_trans
,
)
)
for
b
in
self
.
blocks
for
b
in
self
.
blocks
...
@@ -392,7 +396,6 @@ class ExtraMSAStack(nn.Module):
...
@@ -392,7 +396,6 @@ class ExtraMSAStack(nn.Module):
msa_dropout
:
float
,
msa_dropout
:
float
,
pair_dropout
:
float
,
pair_dropout
:
float
,
blocks_per_ckpt
:
int
,
blocks_per_ckpt
:
int
,
chunk_size
:
int
,
inf
:
float
,
inf
:
float
,
eps
:
float
,
eps
:
float
,
**
kwargs
,
**
kwargs
,
...
@@ -415,7 +418,6 @@ class ExtraMSAStack(nn.Module):
...
@@ -415,7 +418,6 @@ class ExtraMSAStack(nn.Module):
msa_dropout
=
msa_dropout
,
msa_dropout
=
msa_dropout
,
pair_dropout
=
pair_dropout
,
pair_dropout
=
pair_dropout
,
blocks_per_ckpt
=
blocks_per_ckpt
,
blocks_per_ckpt
=
blocks_per_ckpt
,
chunk_size
=
chunk_size
,
inf
=
inf
,
inf
=
inf
,
eps
=
eps
,
eps
=
eps
,
_is_extra_msa_stack
=
True
,
_is_extra_msa_stack
=
True
,
...
@@ -425,6 +427,7 @@ class ExtraMSAStack(nn.Module):
...
@@ -425,6 +427,7 @@ class ExtraMSAStack(nn.Module):
self
,
self
,
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
chunk_size
:
int
,
msa_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
msa_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
pair_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
pair_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
_mask_trans
:
bool
=
True
,
_mask_trans
:
bool
=
True
,
...
@@ -447,6 +450,7 @@ class ExtraMSAStack(nn.Module):
...
@@ -447,6 +450,7 @@ class ExtraMSAStack(nn.Module):
z
,
z
,
msa_mask
=
msa_mask
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
_mask_trans
=
_mask_trans
,
_mask_trans
=
_mask_trans
,
)
)
return
z
return
z
openfold/model/model.py
View file @
b026de28
...
@@ -63,6 +63,8 @@ class AlphaFold(nn.Module):
...
@@ -63,6 +63,8 @@ class AlphaFold(nn.Module):
"""
"""
super
(
AlphaFold
,
self
).
__init__
()
super
(
AlphaFold
,
self
).
__init__
()
self
.
globals
=
config
.
globals
config
=
config
.
model
template_config
=
config
.
template
template_config
=
config
.
template
extra_msa_config
=
config
.
extra_msa
extra_msa_config
=
config
.
extra_msa
...
@@ -104,7 +106,7 @@ class AlphaFold(nn.Module):
...
@@ -104,7 +106,7 @@ class AlphaFold(nn.Module):
self
.
config
=
config
self
.
config
=
config
def
embed_templates
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
):
def
embed_templates
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
,
chunk_size
):
# Embed the templates one at a time (with a poor man's vmap)
# Embed the templates one at a time (with a poor man's vmap)
template_embeds
=
[]
template_embeds
=
[]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
...
@@ -135,14 +137,13 @@ class AlphaFold(nn.Module):
...
@@ -135,14 +137,13 @@ class AlphaFold(nn.Module):
)
)
t
=
self
.
template_pair_embedder
(
t
)
t
=
self
.
template_pair_embedder
(
t
)
t
=
self
.
template_pair_stack
(
t
=
self
.
template_pair_stack
(
t
,
pair_mask
.
unsqueeze
(
-
3
),
_mask_trans
=
self
.
config
.
_mask_trans
t
,
pair_mask
.
unsqueeze
(
-
3
),
chunk_size
=
chunk_size
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
)
single_template_embeds
.
update
(
single_template_embeds
.
update
({
"pair"
:
t
})
{
"pair"
:
t
,
}
)
template_embeds
.
append
(
single_template_embeds
)
template_embeds
.
append
(
single_template_embeds
)
...
@@ -153,7 +154,10 @@ class AlphaFold(nn.Module):
...
@@ -153,7 +154,10 @@ class AlphaFold(nn.Module):
# [*, N, N, C_z]
# [*, N, N, C_z]
t
=
self
.
template_pointwise_att
(
t
=
self
.
template_pointwise_att
(
template_embeds
[
"pair"
],
z
,
template_mask
=
batch
[
"template_mask"
]
template_embeds
[
"pair"
],
z
,
template_mask
=
batch
[
"template_mask"
],
chunk_size
=
chunk_size
,
)
)
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
])
>
0
)
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
])
>
0
)
...
@@ -161,15 +165,17 @@ class AlphaFold(nn.Module):
...
@@ -161,15 +165,17 @@ class AlphaFold(nn.Module):
if
self
.
config
.
template
.
embed_angles
:
if
self
.
config
.
template
.
embed_angles
:
ret
[
"template_angle_embedding"
]
=
template_embeds
[
"angle"
]
ret
[
"template_angle_embedding"
]
=
template_embeds
[
"angle"
]
ret
.
update
(
ret
.
update
({
"template_pair_embedding"
:
t
})
{
"template_pair_embedding"
:
t
,
}
)
return
ret
return
ret
def
iteration
(
self
,
feats
,
m_1_prev
,
z_prev
,
x_prev
):
def
iteration
(
self
,
feats
,
m_1_prev
,
z_prev
,
x_prev
):
# Establish constants
chunk_size
=
(
self
.
globals
.
train_chunk_size
if
self
.
training
else
self
.
globals
.
eval_chunk_size
)
# Primary output dictionary
# Primary output dictionary
outputs
=
{}
outputs
=
{}
...
@@ -243,6 +249,7 @@ class AlphaFold(nn.Module):
...
@@ -243,6 +249,7 @@ class AlphaFold(nn.Module):
z
,
z
,
pair_mask
,
pair_mask
,
no_batch_dims
,
no_batch_dims
,
chunk_size
,
)
)
# [*, N, N, C_z]
# [*, N, N, C_z]
...
@@ -270,6 +277,7 @@ class AlphaFold(nn.Module):
...
@@ -270,6 +277,7 @@ class AlphaFold(nn.Module):
a
,
a
,
z
,
z
,
msa_mask
=
feats
[
"extra_msa_mask"
],
msa_mask
=
feats
[
"extra_msa_mask"
],
chunk_size
=
chunk_size
,
pair_mask
=
pair_mask
,
pair_mask
=
pair_mask
,
_mask_trans
=
self
.
config
.
_mask_trans
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
)
...
@@ -283,6 +291,7 @@ class AlphaFold(nn.Module):
...
@@ -283,6 +291,7 @@ class AlphaFold(nn.Module):
z
,
z
,
msa_mask
=
msa_mask
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
_mask_trans
=
self
.
config
.
_mask_trans
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
)
...
...
openfold/model/msa.py
View file @
b026de28
...
@@ -34,7 +34,6 @@ class MSAAttention(nn.Module):
...
@@ -34,7 +34,6 @@ class MSAAttention(nn.Module):
no_heads
,
no_heads
,
pair_bias
=
False
,
pair_bias
=
False
,
c_z
=
None
,
c_z
=
None
,
chunk_size
=
4
,
inf
=
1e9
,
inf
=
1e9
,
):
):
"""
"""
...
@@ -60,7 +59,6 @@ class MSAAttention(nn.Module):
...
@@ -60,7 +59,6 @@ class MSAAttention(nn.Module):
self
.
no_heads
=
no_heads
self
.
no_heads
=
no_heads
self
.
pair_bias
=
pair_bias
self
.
pair_bias
=
pair_bias
self
.
c_z
=
c_z
self
.
c_z
=
c_z
self
.
chunk_size
=
chunk_size
self
.
inf
=
inf
self
.
inf
=
inf
self
.
layer_norm_m
=
nn
.
LayerNorm
(
self
.
c_in
)
self
.
layer_norm_m
=
nn
.
LayerNorm
(
self
.
c_in
)
...
@@ -75,7 +73,7 @@ class MSAAttention(nn.Module):
...
@@ -75,7 +73,7 @@ class MSAAttention(nn.Module):
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_hidden
,
self
.
no_heads
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_hidden
,
self
.
no_heads
)
)
def
forward
(
self
,
m
,
z
=
None
,
mask
=
None
):
def
forward
(
self
,
m
,
chunk_size
,
z
=
None
,
mask
=
None
):
"""
"""
Args:
Args:
m:
m:
...
@@ -117,11 +115,11 @@ class MSAAttention(nn.Module):
...
@@ -117,11 +115,11 @@ class MSAAttention(nn.Module):
biases
.
append
(
z
)
biases
.
append
(
z
)
mha_inputs
=
{
"q_x"
:
m
,
"k_x"
:
m
,
"v_x"
:
m
,
"biases"
:
biases
}
mha_inputs
=
{
"q_x"
:
m
,
"k_x"
:
m
,
"v_x"
:
m
,
"biases"
:
biases
}
if
self
.
chunk_size
is
not
None
:
if
chunk_size
is
not
None
:
m
=
chunk_layer
(
m
=
chunk_layer
(
self
.
mha
,
self
.
mha
,
mha_inputs
,
mha_inputs
,
chunk_size
=
self
.
chunk_size
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
)
)
else
:
else
:
...
@@ -135,7 +133,7 @@ class MSARowAttentionWithPairBias(MSAAttention):
...
@@ -135,7 +133,7 @@ class MSARowAttentionWithPairBias(MSAAttention):
Implements Algorithm 7.
Implements Algorithm 7.
"""
"""
def
__init__
(
self
,
c_m
,
c_z
,
c_hidden
,
no_heads
,
chunk_size
,
inf
=
1e9
):
def
__init__
(
self
,
c_m
,
c_z
,
c_hidden
,
no_heads
,
inf
=
1e9
):
"""
"""
Args:
Args:
c_m:
c_m:
...
@@ -155,7 +153,6 @@ class MSARowAttentionWithPairBias(MSAAttention):
...
@@ -155,7 +153,6 @@ class MSARowAttentionWithPairBias(MSAAttention):
no_heads
,
no_heads
,
pair_bias
=
True
,
pair_bias
=
True
,
c_z
=
c_z
,
c_z
=
c_z
,
chunk_size
=
chunk_size
,
inf
=
inf
,
inf
=
inf
,
)
)
...
@@ -165,7 +162,7 @@ class MSAColumnAttention(MSAAttention):
...
@@ -165,7 +162,7 @@ class MSAColumnAttention(MSAAttention):
Implements Algorithm 8.
Implements Algorithm 8.
"""
"""
def
__init__
(
self
,
c_m
,
c_hidden
,
no_heads
,
chunk_size
=
4
,
inf
=
1e9
):
def
__init__
(
self
,
c_m
,
c_hidden
,
no_heads
,
inf
=
1e9
):
"""
"""
Args:
Args:
c_m:
c_m:
...
@@ -183,11 +180,10 @@ class MSAColumnAttention(MSAAttention):
...
@@ -183,11 +180,10 @@ class MSAColumnAttention(MSAAttention):
no_heads
=
no_heads
,
no_heads
=
no_heads
,
pair_bias
=
False
,
pair_bias
=
False
,
c_z
=
None
,
c_z
=
None
,
chunk_size
=
chunk_size
,
inf
=
inf
,
inf
=
inf
,
)
)
def
forward
(
self
,
m
,
mask
=
None
):
def
forward
(
self
,
m
,
chunk_size
,
mask
=
None
):
"""
"""
Args:
Args:
m:
m:
...
@@ -200,7 +196,7 @@ class MSAColumnAttention(MSAAttention):
...
@@ -200,7 +196,7 @@ class MSAColumnAttention(MSAAttention):
if
mask
is
not
None
:
if
mask
is
not
None
:
mask
=
mask
.
transpose
(
-
1
,
-
2
)
mask
=
mask
.
transpose
(
-
1
,
-
2
)
m
=
super
().
forward
(
m
,
mask
=
mask
)
m
=
super
().
forward
(
m
,
chunk_size
=
chunk_size
,
mask
=
mask
)
# [*, N_seq, N_res, C_in]
# [*, N_seq, N_res, C_in]
m
=
m
.
transpose
(
-
2
,
-
3
)
m
=
m
.
transpose
(
-
2
,
-
3
)
...
@@ -211,14 +207,13 @@ class MSAColumnAttention(MSAAttention):
...
@@ -211,14 +207,13 @@ class MSAColumnAttention(MSAAttention):
class
MSAColumnGlobalAttention
(
nn
.
Module
):
class
MSAColumnGlobalAttention
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
c_in
,
c_hidden
,
no_heads
,
chunk_size
=
4
,
inf
=
1e9
,
eps
=
1e-10
self
,
c_in
,
c_hidden
,
no_heads
,
inf
=
1e9
,
eps
=
1e-10
):
):
super
(
MSAColumnGlobalAttention
,
self
).
__init__
()
super
(
MSAColumnGlobalAttention
,
self
).
__init__
()
self
.
c_in
=
c_in
self
.
c_in
=
c_in
self
.
c_hidden
=
c_hidden
self
.
c_hidden
=
c_hidden
self
.
no_heads
=
no_heads
self
.
no_heads
=
no_heads
self
.
chunk_size
=
chunk_size
self
.
inf
=
inf
self
.
inf
=
inf
self
.
eps
=
eps
self
.
eps
=
eps
...
@@ -233,7 +228,7 @@ class MSAColumnGlobalAttention(nn.Module):
...
@@ -233,7 +228,7 @@ class MSAColumnGlobalAttention(nn.Module):
)
)
def
forward
(
def
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
self
,
m
:
torch
.
Tensor
,
chunk_size
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
n_seq
,
n_res
,
c_in
=
m
.
shape
[
-
3
:]
n_seq
,
n_res
,
c_in
=
m
.
shape
[
-
3
:]
...
@@ -256,11 +251,11 @@ class MSAColumnGlobalAttention(nn.Module):
...
@@ -256,11 +251,11 @@ class MSAColumnGlobalAttention(nn.Module):
"m"
:
m
,
"m"
:
m
,
"mask"
:
mask
,
"mask"
:
mask
,
}
}
if
self
.
chunk_size
is
not
None
:
if
chunk_size
is
not
None
:
m
=
chunk_layer
(
m
=
chunk_layer
(
self
.
global_attention
,
self
.
global_attention
,
mha_input
,
mha_input
,
chunk_size
=
self
.
chunk_size
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
)
)
else
:
else
:
...
...
openfold/model/outer_product_mean.py
View file @
b026de28
...
@@ -26,7 +26,7 @@ class OuterProductMean(nn.Module):
...
@@ -26,7 +26,7 @@ class OuterProductMean(nn.Module):
Implements Algorithm 10.
Implements Algorithm 10.
"""
"""
def
__init__
(
self
,
c_m
,
c_z
,
c_hidden
,
chunk_size
=
4
,
eps
=
1e-3
):
def
__init__
(
self
,
c_m
,
c_z
,
c_hidden
,
eps
=
1e-3
):
"""
"""
Args:
Args:
c_m:
c_m:
...
@@ -40,7 +40,6 @@ class OuterProductMean(nn.Module):
...
@@ -40,7 +40,6 @@ class OuterProductMean(nn.Module):
self
.
c_z
=
c_z
self
.
c_z
=
c_z
self
.
c_hidden
=
c_hidden
self
.
c_hidden
=
c_hidden
self
.
chunk_size
=
chunk_size
self
.
eps
=
eps
self
.
eps
=
eps
self
.
layer_norm
=
nn
.
LayerNorm
(
c_m
)
self
.
layer_norm
=
nn
.
LayerNorm
(
c_m
)
...
@@ -60,7 +59,7 @@ class OuterProductMean(nn.Module):
...
@@ -60,7 +59,7 @@ class OuterProductMean(nn.Module):
return
outer
return
outer
def
forward
(
self
,
m
,
mask
=
None
):
def
forward
(
self
,
m
,
chunk_size
,
mask
=
None
):
"""
"""
Args:
Args:
m:
m:
...
@@ -84,7 +83,7 @@ class OuterProductMean(nn.Module):
...
@@ -84,7 +83,7 @@ class OuterProductMean(nn.Module):
a
=
a
.
transpose
(
-
2
,
-
3
)
a
=
a
.
transpose
(
-
2
,
-
3
)
b
=
b
.
transpose
(
-
2
,
-
3
)
b
=
b
.
transpose
(
-
2
,
-
3
)
if
self
.
chunk_size
is
not
None
:
if
chunk_size
is
not
None
:
# Since the "batch dim" in this case is not a true batch dimension
# Since the "batch dim" in this case is not a true batch dimension
# (in that the shape of the output depends on it), we need to
# (in that the shape of the output depends on it), we need to
# iterate over it ourselves
# iterate over it ourselves
...
@@ -95,7 +94,7 @@ class OuterProductMean(nn.Module):
...
@@ -95,7 +94,7 @@ class OuterProductMean(nn.Module):
outer
=
chunk_layer
(
outer
=
chunk_layer
(
partial
(
self
.
_opm
,
b
=
b_prime
),
partial
(
self
.
_opm
,
b
=
b_prime
),
{
"a"
:
a_prime
},
{
"a"
:
a_prime
},
chunk_size
=
self
.
chunk_size
,
chunk_size
=
chunk_size
,
no_batch_dims
=
1
,
no_batch_dims
=
1
,
)
)
out
.
append
(
outer
)
out
.
append
(
outer
)
...
...
openfold/model/pair_transition.py
View file @
b026de28
...
@@ -25,7 +25,7 @@ class PairTransition(nn.Module):
...
@@ -25,7 +25,7 @@ class PairTransition(nn.Module):
Implements Algorithm 15.
Implements Algorithm 15.
"""
"""
def
__init__
(
self
,
c_z
,
n
,
chunk_size
=
4
):
def
__init__
(
self
,
c_z
,
n
):
"""
"""
Args:
Args:
c_z:
c_z:
...
@@ -38,7 +38,6 @@ class PairTransition(nn.Module):
...
@@ -38,7 +38,6 @@ class PairTransition(nn.Module):
self
.
c_z
=
c_z
self
.
c_z
=
c_z
self
.
n
=
n
self
.
n
=
n
self
.
chunk_size
=
chunk_size
self
.
layer_norm
=
nn
.
LayerNorm
(
self
.
c_z
)
self
.
layer_norm
=
nn
.
LayerNorm
(
self
.
c_z
)
self
.
linear_1
=
Linear
(
self
.
c_z
,
self
.
n
*
self
.
c_z
,
init
=
"relu"
)
self
.
linear_1
=
Linear
(
self
.
c_z
,
self
.
n
*
self
.
c_z
,
init
=
"relu"
)
...
@@ -55,7 +54,7 @@ class PairTransition(nn.Module):
...
@@ -55,7 +54,7 @@ class PairTransition(nn.Module):
return
z
return
z
def
forward
(
self
,
z
,
mask
=
None
):
def
forward
(
self
,
z
,
chunk_size
,
mask
=
None
):
"""
"""
Args:
Args:
z:
z:
...
@@ -74,11 +73,11 @@ class PairTransition(nn.Module):
...
@@ -74,11 +73,11 @@ class PairTransition(nn.Module):
z
=
self
.
layer_norm
(
z
)
z
=
self
.
layer_norm
(
z
)
inp
=
{
"z"
:
z
,
"mask"
:
mask
}
inp
=
{
"z"
:
z
,
"mask"
:
mask
}
if
self
.
chunk_size
is
not
None
:
if
chunk_size
is
not
None
:
z
=
chunk_layer
(
z
=
chunk_layer
(
self
.
_transition
,
self
.
_transition
,
inp
,
inp
,
chunk_size
=
self
.
chunk_size
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
z
.
shape
[:
-
2
]),
no_batch_dims
=
len
(
z
.
shape
[:
-
2
]),
)
)
else
:
else
:
...
...
openfold/model/primitives.py
View file @
b026de28
...
@@ -260,11 +260,14 @@ class Attention(nn.Module):
...
@@ -260,11 +260,14 @@ class Attention(nn.Module):
# [*, H, Q, K]
# [*, H, Q, K]
a
=
torch
.
matmul
(
a
=
torch
.
matmul
(
permute_final_dims
(
q
,
(
0
,
2
,
1
,
3
)),
# [*, H, Q, C_hidden]
permute_final_dims
(
q
,
(
1
,
0
,
2
)),
# [*, H, Q, C_hidden]
permute_final_dims
(
k
,
(
0
,
2
,
3
,
1
)),
# [*, H, C_hidden, K]
permute_final_dims
(
k
,
(
1
,
2
,
0
)),
# [*, H, C_hidden, K]
)
)
del
q
,
k
norm
=
1
/
math
.
sqrt
(
self
.
c_hidden
)
# [1]
norm
=
1
/
math
.
sqrt
(
self
.
c_hidden
)
# [1]
a
=
a
*
norm
a
*
=
norm
if
biases
is
not
None
:
if
biases
is
not
None
:
for
b
in
biases
:
for
b
in
biases
:
a
=
a
+
b
a
=
a
+
b
...
@@ -273,7 +276,7 @@ class Attention(nn.Module):
...
@@ -273,7 +276,7 @@ class Attention(nn.Module):
# [*, H, Q, C_hidden]
# [*, H, Q, C_hidden]
o
=
torch
.
matmul
(
o
=
torch
.
matmul
(
a
,
a
,
permute_final_dims
(
v
,
(
0
,
2
,
1
,
3
)),
# [*, H, V, C_hidden]
permute_final_dims
(
v
,
(
1
,
0
,
2
)),
# [*, H, V, C_hidden]
)
)
# [*, Q, H, C_hidden]
# [*, Q, H, C_hidden]
...
...
openfold/model/template.py
View file @
b026de28
...
@@ -45,7 +45,7 @@ class TemplatePointwiseAttention(nn.Module):
...
@@ -45,7 +45,7 @@ class TemplatePointwiseAttention(nn.Module):
Implements Algorithm 17.
Implements Algorithm 17.
"""
"""
def
__init__
(
self
,
c_t
,
c_z
,
c_hidden
,
no_heads
,
chunk_size
,
inf
,
**
kwargs
):
def
__init__
(
self
,
c_t
,
c_z
,
c_hidden
,
no_heads
,
inf
,
**
kwargs
):
"""
"""
Args:
Args:
c_t:
c_t:
...
@@ -61,7 +61,6 @@ class TemplatePointwiseAttention(nn.Module):
...
@@ -61,7 +61,6 @@ class TemplatePointwiseAttention(nn.Module):
self
.
c_z
=
c_z
self
.
c_z
=
c_z
self
.
c_hidden
=
c_hidden
self
.
c_hidden
=
c_hidden
self
.
no_heads
=
no_heads
self
.
no_heads
=
no_heads
self
.
chunk_size
=
chunk_size
self
.
inf
=
inf
self
.
inf
=
inf
self
.
mha
=
Attention
(
self
.
mha
=
Attention
(
...
@@ -73,7 +72,7 @@ class TemplatePointwiseAttention(nn.Module):
...
@@ -73,7 +72,7 @@ class TemplatePointwiseAttention(nn.Module):
gating
=
False
,
gating
=
False
,
)
)
def
forward
(
self
,
t
,
z
,
template_mask
=
None
):
def
forward
(
self
,
t
,
z
,
chunk_size
,
template_mask
=
None
):
"""
"""
Args:
Args:
t:
t:
...
@@ -106,11 +105,11 @@ class TemplatePointwiseAttention(nn.Module):
...
@@ -106,11 +105,11 @@ class TemplatePointwiseAttention(nn.Module):
"v_x"
:
t
,
"v_x"
:
t
,
"biases"
:
[
bias
],
"biases"
:
[
bias
],
}
}
if
self
.
chunk_size
is
not
None
:
if
chunk_size
is
not
None
:
z
=
chunk_layer
(
z
=
chunk_layer
(
self
.
mha
,
self
.
mha
,
mha_inputs
,
mha_inputs
,
chunk_size
=
self
.
chunk_size
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
z
.
shape
[:
-
2
]),
no_batch_dims
=
len
(
z
.
shape
[:
-
2
]),
)
)
else
:
else
:
...
@@ -131,7 +130,6 @@ class TemplatePairStackBlock(nn.Module):
...
@@ -131,7 +130,6 @@ class TemplatePairStackBlock(nn.Module):
no_heads
,
no_heads
,
pair_transition_n
,
pair_transition_n
,
dropout_rate
,
dropout_rate
,
chunk_size
,
inf
,
inf
,
**
kwargs
,
**
kwargs
,
):
):
...
@@ -143,7 +141,6 @@ class TemplatePairStackBlock(nn.Module):
...
@@ -143,7 +141,6 @@ class TemplatePairStackBlock(nn.Module):
self
.
no_heads
=
no_heads
self
.
no_heads
=
no_heads
self
.
pair_transition_n
=
pair_transition_n
self
.
pair_transition_n
=
pair_transition_n
self
.
dropout_rate
=
dropout_rate
self
.
dropout_rate
=
dropout_rate
self
.
chunk_size
=
chunk_size
self
.
inf
=
inf
self
.
inf
=
inf
self
.
dropout_row
=
DropoutRowwise
(
self
.
dropout_rate
)
self
.
dropout_row
=
DropoutRowwise
(
self
.
dropout_rate
)
...
@@ -153,14 +150,12 @@ class TemplatePairStackBlock(nn.Module):
...
@@ -153,14 +150,12 @@ class TemplatePairStackBlock(nn.Module):
self
.
c_t
,
self
.
c_t
,
self
.
c_hidden_tri_att
,
self
.
c_hidden_tri_att
,
self
.
no_heads
,
self
.
no_heads
,
chunk_size
=
chunk_size
,
inf
=
inf
,
inf
=
inf
,
)
)
self
.
tri_att_end
=
TriangleAttentionEndingNode
(
self
.
tri_att_end
=
TriangleAttentionEndingNode
(
self
.
c_t
,
self
.
c_t
,
self
.
c_hidden_tri_att
,
self
.
c_hidden_tri_att
,
self
.
no_heads
,
self
.
no_heads
,
chunk_size
=
chunk_size
,
inf
=
inf
,
inf
=
inf
,
)
)
...
@@ -176,15 +171,20 @@ class TemplatePairStackBlock(nn.Module):
...
@@ -176,15 +171,20 @@ class TemplatePairStackBlock(nn.Module):
self
.
pair_transition
=
PairTransition
(
self
.
pair_transition
=
PairTransition
(
self
.
c_t
,
self
.
c_t
,
self
.
pair_transition_n
,
self
.
pair_transition_n
,
chunk_size
=
chunk_size
,
)
)
def
forward
(
self
,
z
,
mask
,
_mask_trans
=
True
):
def
forward
(
self
,
z
,
mask
,
chunk_size
,
_mask_trans
=
True
):
z
=
z
+
self
.
dropout_row
(
self
.
tri_att_start
(
z
,
mask
=
mask
))
z
=
z
+
self
.
dropout_row
(
z
=
z
+
self
.
dropout_col
(
self
.
tri_att_end
(
z
,
mask
=
mask
))
self
.
tri_att_start
(
z
,
chunk_size
=
chunk_size
,
mask
=
mask
)
)
z
=
z
+
self
.
dropout_col
(
self
.
tri_att_end
(
z
,
chunk_size
=
chunk_size
,
mask
=
mask
)
)
z
=
z
+
self
.
dropout_row
(
self
.
tri_mul_out
(
z
,
mask
=
mask
))
z
=
z
+
self
.
dropout_row
(
self
.
tri_mul_out
(
z
,
mask
=
mask
))
z
=
z
+
self
.
dropout_row
(
self
.
tri_mul_in
(
z
,
mask
=
mask
))
z
=
z
+
self
.
dropout_row
(
self
.
tri_mul_in
(
z
,
mask
=
mask
))
z
=
z
+
self
.
pair_transition
(
z
,
mask
=
mask
if
_mask_trans
else
None
)
z
=
z
+
self
.
pair_transition
(
z
,
chunk_size
=
chunk_size
,
mask
=
mask
if
_mask_trans
else
None
)
return
z
return
z
...
@@ -204,7 +204,6 @@ class TemplatePairStack(nn.Module):
...
@@ -204,7 +204,6 @@ class TemplatePairStack(nn.Module):
pair_transition_n
,
pair_transition_n
,
dropout_rate
,
dropout_rate
,
blocks_per_ckpt
,
blocks_per_ckpt
,
chunk_size
,
inf
=
1e9
,
inf
=
1e9
,
**
kwargs
,
**
kwargs
,
):
):
...
@@ -225,9 +224,6 @@ class TemplatePairStack(nn.Module):
...
@@ -225,9 +224,6 @@ class TemplatePairStack(nn.Module):
blocks_per_ckpt:
blocks_per_ckpt:
Number of blocks per activation checkpoint. None disables
Number of blocks per activation checkpoint. None disables
activation checkpointing
activation checkpointing
chunk_size:
Size of subbatches. A higher value increases throughput at
the cost of memory
"""
"""
super
(
TemplatePairStack
,
self
).
__init__
()
super
(
TemplatePairStack
,
self
).
__init__
()
...
@@ -242,7 +238,6 @@ class TemplatePairStack(nn.Module):
...
@@ -242,7 +238,6 @@ class TemplatePairStack(nn.Module):
no_heads
=
no_heads
,
no_heads
=
no_heads
,
pair_transition_n
=
pair_transition_n
,
pair_transition_n
=
pair_transition_n
,
dropout_rate
=
dropout_rate
,
dropout_rate
=
dropout_rate
,
chunk_size
=
chunk_size
,
inf
=
inf
,
inf
=
inf
,
)
)
self
.
blocks
.
append
(
block
)
self
.
blocks
.
append
(
block
)
...
@@ -253,6 +248,7 @@ class TemplatePairStack(nn.Module):
...
@@ -253,6 +248,7 @@ class TemplatePairStack(nn.Module):
self
,
self
,
t
:
torch
.
tensor
,
t
:
torch
.
tensor
,
mask
:
torch
.
tensor
,
mask
:
torch
.
tensor
,
chunk_size
:
int
,
_mask_trans
:
bool
=
True
,
_mask_trans
:
bool
=
True
,
):
):
"""
"""
...
@@ -269,6 +265,7 @@ class TemplatePairStack(nn.Module):
...
@@ -269,6 +265,7 @@ class TemplatePairStack(nn.Module):
partial
(
partial
(
b
,
b
,
mask
=
mask
,
mask
=
mask
,
chunk_size
=
chunk_size
,
_mask_trans
=
_mask_trans
,
_mask_trans
=
_mask_trans
,
)
)
for
b
in
self
.
blocks
for
b
in
self
.
blocks
...
...
openfold/model/triangular_attention.py
View file @
b026de28
...
@@ -28,7 +28,7 @@ from openfold.utils.tensor_utils import (
...
@@ -28,7 +28,7 @@ from openfold.utils.tensor_utils import (
class
TriangleAttention
(
nn
.
Module
):
class
TriangleAttention
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
c_in
,
c_hidden
,
no_heads
,
starting
,
chunk_size
=
4
,
inf
=
1e9
self
,
c_in
,
c_hidden
,
no_heads
,
starting
,
inf
=
1e9
):
):
"""
"""
Args:
Args:
...
@@ -45,7 +45,6 @@ class TriangleAttention(nn.Module):
...
@@ -45,7 +45,6 @@ class TriangleAttention(nn.Module):
self
.
c_hidden
=
c_hidden
self
.
c_hidden
=
c_hidden
self
.
no_heads
=
no_heads
self
.
no_heads
=
no_heads
self
.
starting
=
starting
self
.
starting
=
starting
self
.
chunk_size
=
chunk_size
self
.
inf
=
inf
self
.
inf
=
inf
self
.
layer_norm
=
nn
.
LayerNorm
(
self
.
c_in
)
self
.
layer_norm
=
nn
.
LayerNorm
(
self
.
c_in
)
...
@@ -56,7 +55,7 @@ class TriangleAttention(nn.Module):
...
@@ -56,7 +55,7 @@ class TriangleAttention(nn.Module):
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_hidden
,
self
.
no_heads
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_hidden
,
self
.
no_heads
)
)
def
forward
(
self
,
x
,
mask
=
None
):
def
forward
(
self
,
x
,
chunk_size
,
mask
=
None
):
"""
"""
Args:
Args:
x:
x:
...
@@ -93,11 +92,11 @@ class TriangleAttention(nn.Module):
...
@@ -93,11 +92,11 @@ class TriangleAttention(nn.Module):
"v_x"
:
x
,
"v_x"
:
x
,
"biases"
:
[
mask_bias
,
triangle_bias
],
"biases"
:
[
mask_bias
,
triangle_bias
],
}
}
if
self
.
chunk_size
is
not
None
:
if
chunk_size
is
not
None
:
x
=
chunk_layer
(
x
=
chunk_layer
(
self
.
mha
,
self
.
mha
,
mha_inputs
,
mha_inputs
,
chunk_size
=
self
.
chunk_size
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
x
.
shape
[:
-
2
]),
no_batch_dims
=
len
(
x
.
shape
[:
-
2
]),
)
)
else
:
else
:
...
...
train_openfold.py
View file @
b026de28
...
@@ -2,7 +2,7 @@ import argparse
...
@@ -2,7 +2,7 @@ import argparse
import
logging
import
logging
import
os
import
os
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"
4
"
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"
5
"
import
random
import
random
import
time
import
time
...
@@ -26,7 +26,7 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -26,7 +26,7 @@ class OpenFoldWrapper(pl.LightningModule):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
OpenFoldWrapper
,
self
).
__init__
()
super
(
OpenFoldWrapper
,
self
).
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
model
=
AlphaFold
(
config
.
model
)
self
.
model
=
AlphaFold
(
config
)
self
.
loss
=
AlphaFoldLoss
(
config
.
loss
)
self
.
loss
=
AlphaFoldLoss
(
config
.
loss
)
self
.
ema
=
ExponentialMovingAverage
(
self
.
model
,
decay
=
config
.
ema
.
decay
)
self
.
ema
=
ExponentialMovingAverage
(
self
.
model
,
decay
=
config
.
ema
.
decay
)
...
@@ -50,6 +50,9 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -50,6 +50,9 @@ class OpenFoldWrapper(pl.LightningModule):
with
open
(
"prediction/preds_"
+
str
(
time
.
strftime
(
"%H:%M:%S"
))
+
".pickle"
,
"wb"
)
as
f
:
with
open
(
"prediction/preds_"
+
str
(
time
.
strftime
(
"%H:%M:%S"
))
+
".pickle"
,
"wb"
)
as
f
:
pickle
.
dump
(
out
,
f
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
pickle
.
dump
(
out
,
f
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
#def validation_step(self, batch, batch_idx):
# outputs = self(batch)
def
configure_optimizers
(
self
,
def
configure_optimizers
(
self
,
learning_rate
:
float
=
1e-3
,
learning_rate
:
float
=
1e-3
,
eps
:
float
=
1e-8
eps
:
float
=
1e-8
...
@@ -64,6 +67,7 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -64,6 +67,7 @@ class OpenFoldWrapper(pl.LightningModule):
def
on_before_zero_grad
(
self
,
*
args
,
**
kwargs
):
def
on_before_zero_grad
(
self
,
*
args
,
**
kwargs
):
self
.
ema
.
update
(
self
.
model
)
self
.
ema
.
update
(
self
.
model
)
def
main
(
args
):
def
main
(
args
):
config
=
model_config
(
config
=
model_config
(
"model_1"
,
"model_1"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment