Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
68828c49
Commit
68828c49
authored
Apr 17, 2023
by
Christina Floristean
Browse files
Multimer v3 updates
parent
736f27fd
Changes
17
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
568 additions
and
270 deletions
+568
-270
openfold/config.py
openfold/config.py
+19
-0
openfold/model/evoformer.py
openfold/model/evoformer.py
+212
-144
openfold/model/structure_module.py
openfold/model/structure_module.py
+32
-53
openfold/model/template.py
openfold/model/template.py
+23
-8
openfold/model/triangular_multiplicative_update.py
openfold/model/triangular_multiplicative_update.py
+191
-7
openfold/utils/feats.py
openfold/utils/feats.py
+6
-15
openfold/utils/geometry/rigid_matrix_vector.py
openfold/utils/geometry/rigid_matrix_vector.py
+5
-2
openfold/utils/import_weights.py
openfold/utils/import_weights.py
+41
-24
scripts/download_alphafold_dbs.sh
scripts/download_alphafold_dbs.sh
+6
-6
scripts/download_alphafold_params.sh
scripts/download_alphafold_params.sh
+1
-1
scripts/download_mgnify.sh
scripts/download_mgnify.sh
+2
-2
scripts/download_pdb_seqres.sh
scripts/download_pdb_seqres.sh
+4
-0
scripts/download_uniref30.sh
scripts/download_uniref30.sh
+4
-2
tests/config.py
tests/config.py
+1
-1
tests/test_evoformer.py
tests/test_evoformer.py
+5
-0
tests/test_template.py
tests/test_template.py
+3
-0
tests/test_triangular_multiplicative_update.py
tests/test_triangular_multiplicative_update.py
+13
-5
No files found.
openfold/config.py
View file @
68828c49
import
re
import
copy
import
copy
import
importlib
import
importlib
import
ml_collections
as
mlc
import
ml_collections
as
mlc
...
@@ -155,6 +156,18 @@ def model_config(
...
@@ -155,6 +156,18 @@ def model_config(
elif
"multimer"
in
name
:
elif
"multimer"
in
name
:
c
.
globals
.
is_multimer
=
True
c
.
globals
.
is_multimer
=
True
c
.
loss
.
masked_msa
.
num_classes
=
22
c
.
loss
.
masked_msa
.
num_classes
=
22
if
re
.
fullmatch
(
"^model_[1-5]_multimer(_v2)?$"
,
name
):
c
.
model
.
evoformer
.
num_msa
=
252
c
.
model
.
evoformer
.
num_extra_msa
=
1152
c
.
model
.
evoformer
.
fuse_projection_weights
=
False
c
.
model
.
extra_msa
.
extra_msa_stack
.
fuse_projection_weights
=
False
c
.
model
.
template
.
template_pair_stack
.
fuse_projection_weights
=
False
elif
name
==
'model_4_multimer_v3'
:
c
.
model
.
evoformer
.
num_extra_msa
=
1152
elif
name
==
'model_5_multimer_v3'
:
c
.
model
.
evoformer
.
num_extra_msa
=
1152
for
k
,
v
in
multimer_model_config_update
.
items
():
for
k
,
v
in
multimer_model_config_update
.
items
():
c
.
model
[
k
]
=
v
c
.
model
[
k
]
=
v
...
@@ -438,6 +451,7 @@ config = mlc.ConfigDict(
...
@@ -438,6 +451,7 @@ config = mlc.ConfigDict(
"pair_transition_n"
:
2
,
"pair_transition_n"
:
2
,
"dropout_rate"
:
0.25
,
"dropout_rate"
:
0.25
,
"tri_mul_first"
:
False
,
"tri_mul_first"
:
False
,
"fuse_projection_weights"
:
False
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"tune_chunk_size"
:
tune_chunk_size
,
"tune_chunk_size"
:
tune_chunk_size
,
"inf"
:
1e9
,
"inf"
:
1e9
,
...
@@ -487,6 +501,7 @@ config = mlc.ConfigDict(
...
@@ -487,6 +501,7 @@ config = mlc.ConfigDict(
"msa_dropout"
:
0.15
,
"msa_dropout"
:
0.15
,
"pair_dropout"
:
0.25
,
"pair_dropout"
:
0.25
,
"opm_first"
:
False
,
"opm_first"
:
False
,
"fuse_projection_weights"
:
False
,
"clear_cache_between_blocks"
:
False
,
"clear_cache_between_blocks"
:
False
,
"tune_chunk_size"
:
tune_chunk_size
,
"tune_chunk_size"
:
tune_chunk_size
,
"inf"
:
1e9
,
"inf"
:
1e9
,
...
@@ -510,6 +525,7 @@ config = mlc.ConfigDict(
...
@@ -510,6 +525,7 @@ config = mlc.ConfigDict(
"msa_dropout"
:
0.15
,
"msa_dropout"
:
0.15
,
"pair_dropout"
:
0.25
,
"pair_dropout"
:
0.25
,
"opm_first"
:
False
,
"opm_first"
:
False
,
"fuse_projection_weights"
:
False
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"clear_cache_between_blocks"
:
False
,
"clear_cache_between_blocks"
:
False
,
"tune_chunk_size"
:
tune_chunk_size
,
"tune_chunk_size"
:
tune_chunk_size
,
...
@@ -671,6 +687,7 @@ multimer_model_config_update = {
...
@@ -671,6 +687,7 @@ multimer_model_config_update = {
"pair_transition_n"
:
2
,
"pair_transition_n"
:
2
,
"dropout_rate"
:
0.25
,
"dropout_rate"
:
0.25
,
"tri_mul_first"
:
True
,
"tri_mul_first"
:
True
,
"fuse_projection_weights"
:
True
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"inf"
:
1e9
,
"inf"
:
1e9
,
},
},
...
@@ -701,6 +718,7 @@ multimer_model_config_update = {
...
@@ -701,6 +718,7 @@ multimer_model_config_update = {
"msa_dropout"
:
0.15
,
"msa_dropout"
:
0.15
,
"pair_dropout"
:
0.25
,
"pair_dropout"
:
0.25
,
"opm_first"
:
True
,
"opm_first"
:
True
,
"fuse_projection_weights"
:
True
,
"clear_cache_between_blocks"
:
True
,
"clear_cache_between_blocks"
:
True
,
"inf"
:
1e9
,
"inf"
:
1e9
,
"eps"
:
eps
,
# 1e-10,
"eps"
:
eps
,
# 1e-10,
...
@@ -723,6 +741,7 @@ multimer_model_config_update = {
...
@@ -723,6 +741,7 @@ multimer_model_config_update = {
"msa_dropout"
:
0.15
,
"msa_dropout"
:
0.15
,
"pair_dropout"
:
0.25
,
"pair_dropout"
:
0.25
,
"opm_first"
:
True
,
"opm_first"
:
True
,
"fuse_projection_weights"
:
True
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"clear_cache_between_blocks"
:
False
,
"clear_cache_between_blocks"
:
False
,
"inf"
:
1e9
,
"inf"
:
1e9
,
...
...
openfold/model/evoformer.py
View file @
68828c49
This diff is collapsed.
Click to expand it.
openfold/model/structure_module.py
View file @
68828c49
...
@@ -178,7 +178,7 @@ class PointProjection(nn.Module):
...
@@ -178,7 +178,7 @@ class PointProjection(nn.Module):
def
forward
(
self
,
def
forward
(
self
,
activations
:
torch
.
Tensor
,
activations
:
torch
.
Tensor
,
rigids
:
Union
[
Rigid
,
Rigid3Array
],
rigids
:
Union
[
Rigid
,
Rigid3Array
],
)
->
Union
[
Vec3Array
,
Tuple
[
Vec3Array
,
Vec3Array
],
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
# TODO: Needs to run in high precision during training
# TODO: Needs to run in high precision during training
points_local
=
self
.
linear
(
activations
)
points_local
=
self
.
linear
(
activations
)
...
@@ -398,20 +398,14 @@ class InvariantPointAttention(nn.Module):
...
@@ -398,20 +398,14 @@ class InvariantPointAttention(nn.Module):
a
+=
(
math
.
sqrt
(
1.0
/
3
)
*
permute_final_dims
(
b
,
(
2
,
0
,
1
)))
a
+=
(
math
.
sqrt
(
1.0
/
3
)
*
permute_final_dims
(
b
,
(
2
,
0
,
1
)))
# [*, N_res, N_res, H, P_q, 3]
# [*, N_res, N_res, H, P_q, 3]
if
self
.
is_multimer
:
pt_att
=
q_pts
.
unsqueeze
(
-
4
)
-
k_pts
.
unsqueeze
(
-
5
)
pt_att
=
q_pts
.
unsqueeze
(
-
3
)
-
k_pts
.
unsqueeze
(
-
4
)
# [*, N_res, N_res, H, P_q]
if
(
inplace_safe
):
pt_att
=
sum
([
c
**
2
for
c
in
pt_att
])
pt_att
*
=
pt_att
else
:
else
:
pt_att
=
q_pts
.
unsqueeze
(
-
4
)
-
k_pts
.
unsqueeze
(
-
5
)
pt_att
=
pt_att
**
2
if
(
inplace_safe
):
pt_att
=
sum
(
torch
.
unbind
(
pt_att
,
dim
=-
1
))
pt_att
*=
pt_att
else
:
pt_att
=
pt_att
**
2
pt_att
=
sum
(
torch
.
unbind
(
pt_att
,
dim
=-
1
))
head_weights
=
self
.
softplus
(
self
.
head_weights
).
view
(
head_weights
=
self
.
softplus
(
self
.
head_weights
).
view
(
*
((
1
,)
*
len
(
pt_att
.
shape
[:
-
2
])
+
(
-
1
,
1
))
*
((
1
,)
*
len
(
pt_att
.
shape
[:
-
2
])
+
(
-
1
,
1
))
...
@@ -427,6 +421,7 @@ class InvariantPointAttention(nn.Module):
...
@@ -427,6 +421,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, N_res, H]
# [*, N_res, N_res, H]
pt_att
=
torch
.
sum
(
pt_att
,
dim
=-
1
)
*
(
-
0.5
)
pt_att
=
torch
.
sum
(
pt_att
,
dim
=-
1
)
*
(
-
0.5
)
# [*, N_res, N_res]
# [*, N_res, N_res]
square_mask
=
mask
.
unsqueeze
(
-
1
)
*
mask
.
unsqueeze
(
-
2
)
square_mask
=
mask
.
unsqueeze
(
-
1
)
*
mask
.
unsqueeze
(
-
2
)
square_mask
=
self
.
inf
*
(
square_mask
-
1
)
square_mask
=
self
.
inf
*
(
square_mask
-
1
)
...
@@ -460,51 +455,35 @@ class InvariantPointAttention(nn.Module):
...
@@ -460,51 +455,35 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * C_hidden]
# [*, N_res, H * C_hidden]
o
=
flatten_final_dims
(
o
,
2
)
o
=
flatten_final_dims
(
o
,
2
)
if
self
.
is_multimer
:
# [*, H, 3, N_res, P_v]
# As DeepMind explains, this manual matmul ensures that the operation
if
(
inplace_safe
):
# happens in float32.
v_pts
=
permute_final_dims
(
v_pts
,
(
1
,
3
,
0
,
2
))
# [*, N_res, H, P_v]
o_pt
=
[
o_pt
=
v_pts
[...,
None
,
:,
:,
:]
*
permute_final_dims
(
a
,
(
1
,
2
,
0
)).
unsqueeze
(
-
1
)
torch
.
matmul
(
a
,
v
.
to
(
a
.
dtype
))
o_pt
=
o_pt
.
sum
(
dim
=-
3
)
for
v
in
torch
.
unbind
(
v_pts
,
dim
=-
3
)
]
# [*, N_res, H, P_v]
o_pt
=
torch
.
stack
(
o_pt
,
dim
=-
3
)
o_pt
=
r
[...,
None
,
None
].
apply_inverse_to_point
(
o_pt
)
# [*, N_res, H * P_v, 3]
o_pt
=
o_pt
.
reshape
(
o_pt
.
shape
[:
-
2
]
+
(
-
1
,))
# [*, N_res, H * P_v]
o_pt_norm
=
o_pt
.
norm
(
self
.
eps
)
else
:
else
:
# [*, H, 3, N_res, P_v]
o_pt
=
torch
.
sum
(
if
(
inplace_safe
):
(
v_pts
=
permute_final_dims
(
v_pts
,
(
1
,
3
,
0
,
2
))
a
[...,
None
,
:,
:,
None
]
o_pt
=
[
*
permute_final_dims
(
v_pts
,
(
1
,
3
,
0
,
2
))[...,
None
,
:,
:]
torch
.
matmul
(
a
,
v
.
to
(
a
.
dtype
))
),
for
v
in
torch
.
unbind
(
v_pts
,
dim
=-
3
)
dim
=-
2
,
]
)
o_pt
=
torch
.
stack
(
o_pt
,
dim
=-
3
)
else
:
o_pt
=
torch
.
sum
(
(
a
[...,
None
,
:,
:,
None
]
*
permute_final_dims
(
v_pts
,
(
1
,
3
,
0
,
2
))[...,
None
,
:,
:]
),
dim
=-
2
,
)
# [*, N_res, H, P_v, 3]
# [*, N_res, H, P_v, 3]
o_pt
=
permute_final_dims
(
o_pt
,
(
2
,
0
,
3
,
1
))
o_pt
=
permute_final_dims
(
o_pt
,
(
2
,
0
,
3
,
1
))
o_pt
=
r
[...,
None
,
None
].
invert_apply
(
o_pt
)
o_pt
=
r
[...,
None
,
None
].
invert_apply
(
o_pt
)
# [*, N_res, H * P_v]
# [*, N_res, H * P_v]
o_pt_norm
=
flatten_final_dims
(
o_pt_norm
=
flatten_final_dims
(
torch
.
sqrt
(
torch
.
sum
(
o_pt
**
2
,
dim
=-
1
)
+
self
.
eps
),
2
torch
.
sqrt
(
torch
.
sum
(
o_pt
**
2
,
dim
=-
1
)
+
self
.
eps
),
2
)
)
# [*, N_res, H * P_v, 3]
# [*, N_res, H * P_v, 3]
o_pt
=
o_pt
.
reshape
(
*
o_pt
.
shape
[:
-
3
],
-
1
,
3
)
o_pt
=
o_pt
.
reshape
(
*
o_pt
.
shape
[:
-
3
],
-
1
,
3
)
o_pt
=
torch
.
unbind
(
o_pt
,
dim
=-
1
)
o_pt
=
torch
.
unbind
(
o_pt
,
dim
=-
1
)
if
(
_offload_inference
):
if
(
_offload_inference
):
z
[
0
]
=
z
[
0
].
to
(
o_pt
.
device
)
z
[
0
]
=
z
[
0
].
to
(
o_pt
.
device
)
...
...
openfold/model/template.py
View file @
68828c49
...
@@ -33,6 +33,8 @@ from openfold.model.triangular_attention import (
...
@@ -33,6 +33,8 @@ from openfold.model.triangular_attention import (
from
openfold.model.triangular_multiplicative_update
import
(
from
openfold.model.triangular_multiplicative_update
import
(
TriangleMultiplicationOutgoing
,
TriangleMultiplicationOutgoing
,
TriangleMultiplicationIncoming
,
TriangleMultiplicationIncoming
,
FusedTriangleMultiplicationOutgoing
,
FusedTriangleMultiplicationIncoming
)
)
from
openfold.utils.checkpointing
import
checkpoint_blocks
from
openfold.utils.checkpointing
import
checkpoint_blocks
from
openfold.utils.chunk_utils
import
(
from
openfold.utils.chunk_utils
import
(
...
@@ -155,6 +157,7 @@ class TemplatePairStackBlock(nn.Module):
...
@@ -155,6 +157,7 @@ class TemplatePairStackBlock(nn.Module):
pair_transition_n
:
int
,
pair_transition_n
:
int
,
dropout_rate
:
float
,
dropout_rate
:
float
,
tri_mul_first
:
bool
,
tri_mul_first
:
bool
,
fuse_projection_weights
:
bool
,
inf
:
float
,
inf
:
float
,
**
kwargs
,
**
kwargs
,
):
):
...
@@ -185,14 +188,24 @@ class TemplatePairStackBlock(nn.Module):
...
@@ -185,14 +188,24 @@ class TemplatePairStackBlock(nn.Module):
inf
=
inf
,
inf
=
inf
,
)
)
self
.
tri_mul_out
=
TriangleMultiplicationOutgoing
(
if
fuse_projection_weights
:
self
.
c_t
,
self
.
tri_mul_out
=
FusedTriangleMultiplicationOutgoing
(
self
.
c_hidden_tri_mul
,
self
.
c_t
,
)
self
.
c_hidden_tri_mul
,
self
.
tri_mul_in
=
TriangleMultiplicationIncoming
(
)
self
.
c_t
,
self
.
tri_mul_in
=
FusedTriangleMultiplicationIncoming
(
self
.
c_hidden_tri_mul
,
self
.
c_t
,
)
self
.
c_hidden_tri_mul
,
)
else
:
self
.
tri_mul_out
=
TriangleMultiplicationOutgoing
(
self
.
c_t
,
self
.
c_hidden_tri_mul
,
)
self
.
tri_mul_in
=
TriangleMultiplicationIncoming
(
self
.
c_t
,
self
.
c_hidden_tri_mul
,
)
self
.
pair_transition
=
PairTransition
(
self
.
pair_transition
=
PairTransition
(
self
.
c_t
,
self
.
c_t
,
...
@@ -329,6 +342,7 @@ class TemplatePairStack(nn.Module):
...
@@ -329,6 +342,7 @@ class TemplatePairStack(nn.Module):
pair_transition_n
,
pair_transition_n
,
dropout_rate
,
dropout_rate
,
tri_mul_first
,
tri_mul_first
,
fuse_projection_weights
,
blocks_per_ckpt
,
blocks_per_ckpt
,
tune_chunk_size
:
bool
=
False
,
tune_chunk_size
:
bool
=
False
,
inf
=
1e9
,
inf
=
1e9
,
...
@@ -366,6 +380,7 @@ class TemplatePairStack(nn.Module):
...
@@ -366,6 +380,7 @@ class TemplatePairStack(nn.Module):
pair_transition_n
=
pair_transition_n
,
pair_transition_n
=
pair_transition_n
,
dropout_rate
=
dropout_rate
,
dropout_rate
=
dropout_rate
,
tri_mul_first
=
tri_mul_first
,
tri_mul_first
=
tri_mul_first
,
fuse_projection_weights
=
fuse_projection_weights
,
inf
=
inf
,
inf
=
inf
,
)
)
self
.
blocks
.
append
(
block
)
self
.
blocks
.
append
(
block
)
...
...
openfold/model/triangular_multiplicative_update.py
View file @
68828c49
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
from
functools
import
partialmethod
from
functools
import
partialmethod
from
typing
import
Optional
from
typing
import
Optional
from
abc
import
ABC
,
abstractmethod
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -25,11 +26,12 @@ from openfold.utils.precision_utils import is_fp16_enabled
...
@@ -25,11 +26,12 @@ from openfold.utils.precision_utils import is_fp16_enabled
from
openfold.utils.tensor_utils
import
add
,
permute_final_dims
from
openfold.utils.tensor_utils
import
add
,
permute_final_dims
class
TriangleMultiplicativeUpdate
(
nn
.
Module
):
class
Base
TriangleMultiplicativeUpdate
(
nn
.
Module
,
ABC
):
"""
"""
Implements Algorithms 11 and 12.
Implements Algorithms 11 and 12.
"""
"""
def
__init__
(
self
,
c_z
,
c_hidden
,
_outgoing
=
True
):
@
abstractmethod
def
__init__
(
self
,
c_z
,
c_hidden
,
_outgoing
):
"""
"""
Args:
Args:
c_z:
c_z:
...
@@ -37,15 +39,11 @@ class TriangleMultiplicativeUpdate(nn.Module):
...
@@ -37,15 +39,11 @@ class TriangleMultiplicativeUpdate(nn.Module):
c:
c:
Hidden channel dimension
Hidden channel dimension
"""
"""
super
(
TriangleMultiplicativeUpdate
,
self
).
__init__
()
super
(
Base
TriangleMultiplicativeUpdate
,
self
).
__init__
()
self
.
c_z
=
c_z
self
.
c_z
=
c_z
self
.
c_hidden
=
c_hidden
self
.
c_hidden
=
c_hidden
self
.
_outgoing
=
_outgoing
self
.
_outgoing
=
_outgoing
self
.
linear_a_p
=
Linear
(
self
.
c_z
,
self
.
c_hidden
)
self
.
linear_a_g
=
Linear
(
self
.
c_z
,
self
.
c_hidden
,
init
=
"gating"
)
self
.
linear_b_p
=
Linear
(
self
.
c_z
,
self
.
c_hidden
)
self
.
linear_b_g
=
Linear
(
self
.
c_z
,
self
.
c_hidden
,
init
=
"gating"
)
self
.
linear_g
=
Linear
(
self
.
c_z
,
self
.
c_z
,
init
=
"gating"
)
self
.
linear_g
=
Linear
(
self
.
c_z
,
self
.
c_z
,
init
=
"gating"
)
self
.
linear_z
=
Linear
(
self
.
c_hidden
,
self
.
c_z
,
init
=
"final"
)
self
.
linear_z
=
Linear
(
self
.
c_hidden
,
self
.
c_z
,
init
=
"final"
)
...
@@ -84,6 +82,46 @@ class TriangleMultiplicativeUpdate(nn.Module):
...
@@ -84,6 +82,46 @@ class TriangleMultiplicativeUpdate(nn.Module):
return
permute_final_dims
(
p
,
(
1
,
2
,
0
))
return
permute_final_dims
(
p
,
(
1
,
2
,
0
))
@
abstractmethod
def
forward
(
self
,
z
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
inplace_safe
:
bool
=
False
,
_add_with_inplace
:
bool
=
False
)
->
torch
.
Tensor
:
"""
Args:
x:
[*, N_res, N_res, C_z] input tensor
mask:
[*, N_res, N_res] input mask
Returns:
[*, N_res, N_res, C_z] output tensor
"""
pass
class
TriangleMultiplicativeUpdate
(
BaseTriangleMultiplicativeUpdate
):
"""
Implements Algorithms 11 and 12.
"""
def
__init__
(
self
,
c_z
,
c_hidden
,
_outgoing
=
True
):
"""
Args:
c_z:
Input channel dimension
c:
Hidden channel dimension
"""
super
(
TriangleMultiplicativeUpdate
,
self
).
__init__
(
c_z
=
c_z
,
c_hidden
=
c_hidden
,
_outgoing
=
_outgoing
)
self
.
linear_a_p
=
Linear
(
self
.
c_z
,
self
.
c_hidden
)
self
.
linear_a_g
=
Linear
(
self
.
c_z
,
self
.
c_hidden
,
init
=
"gating"
)
self
.
linear_b_p
=
Linear
(
self
.
c_z
,
self
.
c_hidden
)
self
.
linear_b_g
=
Linear
(
self
.
c_z
,
self
.
c_hidden
,
init
=
"gating"
)
def
_inference_forward
(
self
,
def
_inference_forward
(
self
,
z
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -425,3 +463,149 @@ class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate):
...
@@ -425,3 +463,149 @@ class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate):
Implements Algorithm 12.
Implements Algorithm 12.
"""
"""
__init__
=
partialmethod
(
TriangleMultiplicativeUpdate
.
__init__
,
_outgoing
=
False
)
__init__
=
partialmethod
(
TriangleMultiplicativeUpdate
.
__init__
,
_outgoing
=
False
)
class
FusedTriangleMultiplicativeUpdate
(
BaseTriangleMultiplicativeUpdate
):
"""
Implements Algorithms 11 and 12.
"""
def
__init__
(
self
,
c_z
,
c_hidden
,
_outgoing
=
True
):
"""
Args:
c_z:
Input channel dimension
c:
Hidden channel dimension
"""
super
(
FusedTriangleMultiplicativeUpdate
,
self
).
__init__
(
c_z
=
c_z
,
c_hidden
=
c_hidden
,
_outgoing
=
_outgoing
)
self
.
linear_ab_p
=
Linear
(
self
.
c_z
,
self
.
c_hidden
*
2
)
self
.
linear_ab_g
=
Linear
(
self
.
c_z
,
self
.
c_hidden
*
2
,
init
=
"gating"
)
def
_inference_forward
(
self
,
z
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
_inplace_chunk_size
:
Optional
[
int
]
=
None
,
with_add
:
bool
=
True
,
):
"""
Args:
z:
A [*, N, N, C_z] pair representation
mask:
A [*, N, N] pair mask
with_add:
If True, z is overwritten with (z + update). Otherwise, it is
overwritten with (update).
Returns:
A reference to the overwritten z
"""
if
mask
is
None
:
mask
=
z
.
new_ones
(
z
.
shape
[:
-
1
])
mask
=
mask
.
unsqueeze
(
-
1
)
def
compute_projection_helper
(
pair
,
mask
):
pair
=
self
.
layer_norm_in
(
pair
)
p
=
self
.
linear_ab_g
(
pair
)
p
.
sigmoid_
()
p
*=
self
.
linear_ab_p
(
pair
)
p
*=
mask
return
p
def
compute_projection
(
pair
,
mask
):
p
=
compute_projection_helper
(
pair
,
mask
)
a
=
p
[...,
:
self
.
c_hidden
]
b
=
p
[...,
self
.
c_hidden
:]
return
a
,
b
a
,
b
=
compute_projection
(
z
,
mask
)
x
=
self
.
_combine_projections
(
a
,
b
,
_inplace_chunk_size
=
_inplace_chunk_size
)
x
=
self
.
layer_norm_out
(
x
)
x
=
self
.
linear_z
(
x
)
g
=
self
.
linear_g
(
z
)
g
.
sigmoid_
()
x
*=
g
if
(
with_add
):
z
+=
x
else
:
z
=
x
return
z
def
forward
(
self
,
z
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
inplace_safe
:
bool
=
False
,
_add_with_inplace
:
bool
=
False
,
_inplace_chunk_size
:
Optional
[
int
]
=
256
)
->
torch
.
Tensor
:
"""
Args:
x:
[*, N_res, N_res, C_z] input tensor
mask:
[*, N_res, N_res] input mask
Returns:
[*, N_res, N_res, C_z] output tensor
"""
if
(
inplace_safe
):
x
=
self
.
_inference_forward
(
z
,
mask
,
_inplace_chunk_size
=
_inplace_chunk_size
,
with_add
=
_add_with_inplace
,
)
return
x
if
mask
is
None
:
mask
=
z
.
new_ones
(
z
.
shape
[:
-
1
])
mask
=
mask
.
unsqueeze
(
-
1
)
z
=
self
.
layer_norm_in
(
z
)
ab
=
mask
ab
=
ab
*
self
.
sigmoid
(
self
.
linear_ab_g
(
z
))
ab
=
ab
*
self
.
linear_ab_p
(
z
)
a
=
ab
[...,
:
self
.
c_hidden
]
b
=
ab
[...,
self
.
c_hidden
:]
# Prevents overflow of torch.matmul in combine projections in
# reduced-precision modes
a
=
a
/
a
.
std
()
b
=
b
/
b
.
std
()
if
(
is_fp16_enabled
()):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
x
=
self
.
_combine_projections
(
a
.
float
(),
b
.
float
())
else
:
x
=
self
.
_combine_projections
(
a
,
b
)
del
a
,
b
x
=
self
.
layer_norm_out
(
x
)
x
=
self
.
linear_z
(
x
)
g
=
self
.
sigmoid
(
self
.
linear_g
(
z
))
x
=
x
*
g
return
x
class
FusedTriangleMultiplicationOutgoing
(
FusedTriangleMultiplicativeUpdate
):
"""
Implements Algorithm 11.
"""
__init__
=
partialmethod
(
FusedTriangleMultiplicativeUpdate
.
__init__
,
_outgoing
=
True
)
class
FusedTriangleMultiplicationIncoming
(
FusedTriangleMultiplicativeUpdate
):
"""
Implements Algorithm 12.
"""
__init__
=
partialmethod
(
FusedTriangleMultiplicativeUpdate
.
__init__
,
_outgoing
=
False
)
openfold/utils/feats.py
View file @
68828c49
...
@@ -189,7 +189,7 @@ def torsion_angles_to_frames(
...
@@ -189,7 +189,7 @@ def torsion_angles_to_frames(
rrgdf
:
torch
.
Tensor
,
rrgdf
:
torch
.
Tensor
,
):
):
rigid_type
=
Rigid
if
isinstance
(
r
,
Rigid
)
else
rigid_matrix_vector
.
Rigid3Array
rigid_type
=
type
(
r
)
# [*, N, 8, 4, 4]
# [*, N, 8, 4, 4]
default_4x4
=
rrgdf
[
aatype
,
...]
default_4x4
=
rrgdf
[
aatype
,
...]
...
@@ -217,18 +217,14 @@ def torsion_angles_to_frames(
...
@@ -217,18 +217,14 @@ def torsion_angles_to_frames(
# This follows the original code rather than the supplement, which uses
# This follows the original code rather than the supplement, which uses
# different indices.
# different indices.
all_rots
=
alpha
.
new_zeros
(
default_r
.
shape
+
(
3
,
3
))
all_rots
=
alpha
.
new_zeros
(
default_r
.
shape
+
(
4
,
4
))
all_rots
[...,
0
,
0
]
=
1
all_rots
[...,
0
,
0
]
=
1
all_rots
[...,
1
,
1
]
=
alpha
[...,
1
]
all_rots
[...,
1
,
1
]
=
alpha
[...,
1
]
all_rots
[...,
1
,
2
]
=
-
alpha
[...,
0
]
all_rots
[...,
1
,
2
]
=
-
alpha
[...,
0
]
all_rots
[...,
2
,
1
:]
=
alpha
all_rots
[...,
2
,
1
:
3
]
=
alpha
if
isinstance
(
r
,
Rigid
):
all_rots
=
rigid_type
.
from_tensor_4x4
(
all_rots
)
all_rots
=
Rigid
(
Rotation
(
rot_mats
=
all_rots
),
None
)
all_frames
=
default_r
.
compose
(
all_rots
)
all_frames
=
default_r
.
compose
(
all_rots
)
else
:
all_rots
=
rotation_matrix
.
Rot3Array
.
from_array
(
all_rots
)
all_frames
=
default_r
.
compose_rotation
(
all_rots
)
chi2_frame_to_frame
=
all_frames
[...,
5
]
chi2_frame_to_frame
=
all_frames
[...,
5
]
chi3_frame_to_frame
=
all_frames
[...,
6
]
chi3_frame_to_frame
=
all_frames
[...,
6
]
...
@@ -283,16 +279,11 @@ def frames_and_literature_positions_to_atom14_pos(
...
@@ -283,16 +279,11 @@ def frames_and_literature_positions_to_atom14_pos(
)
)
# [*, N, 14]
# [*, N, 14]
atom_mask
=
atom_mask
[
aatype
,
...]
atom_mask
=
atom_mask
[
aatype
,
...].
unsqueeze
(
-
1
)
if
isinstance
(
r
,
Rigid
):
atom_mask
=
atom_mask
.
unsqueeze
(
-
1
)
# [*, N, 14, 3]
# [*, N, 14, 3]
lit_positions
=
lit_positions
[
aatype
,
...]
lit_positions
=
lit_positions
[
aatype
,
...]
pred_positions
=
t_atoms_to_global
.
apply
(
lit_positions
)
pred_positions
=
t_atoms_to_global
.
apply
(
lit_positions
)
pred_positions
=
pred_positions
*
atom_mask
pred_positions
=
pred_positions
*
atom_mask
if
isinstance
(
pred_positions
,
vector
.
Vec3Array
):
return
pred_positions
.
to_tensor
()
return
pred_positions
return
pred_positions
openfold/utils/geometry/rigid_matrix_vector.py
View file @
68828c49
...
@@ -67,14 +67,17 @@ class Rigid3Array:
...
@@ -67,14 +67,17 @@ class Rigid3Array:
"""Apply Rigid3Array transform to point."""
"""Apply Rigid3Array transform to point."""
return
self
.
rotation
.
apply_to_point
(
point
)
+
self
.
translation
return
self
.
rotation
.
apply_to_point
(
point
)
+
self
.
translation
def
apply
(
self
,
point
:
torch
.
Tensor
)
->
vector
.
Vec3Array
:
def
apply
(
self
,
point
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
apply_to_point
(
vector
.
Vec3Array
.
from_array
(
point
))
return
self
.
apply_to_point
(
vector
.
Vec3Array
.
from_array
(
point
))
.
to_tensor
()
def
apply_inverse_to_point
(
self
,
point
:
vector
.
Vec3Array
)
->
vector
.
Vec3Array
:
def
apply_inverse_to_point
(
self
,
point
:
vector
.
Vec3Array
)
->
vector
.
Vec3Array
:
"""Apply inverse Rigid3Array transform to point."""
"""Apply inverse Rigid3Array transform to point."""
new_point
=
point
-
self
.
translation
new_point
=
point
-
self
.
translation
return
self
.
rotation
.
apply_inverse_to_point
(
new_point
)
return
self
.
rotation
.
apply_inverse_to_point
(
new_point
)
def
invert_apply
(
self
,
point
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
apply_inverse_to_point
(
vector
.
Vec3Array
.
from_array
(
point
)).
to_tensor
()
def
compose_rotation
(
self
,
other_rotation
):
def
compose_rotation
(
self
,
other_rotation
):
rot
=
self
.
rotation
@
other_rotation
rot
=
self
.
rotation
@
other_rotation
return
Rigid3Array
(
rot
,
self
.
translation
.
clone
())
return
Rigid3Array
(
rot
,
self
.
translation
.
clone
())
...
...
openfold/utils/import_weights.py
View file @
68828c49
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
re
from
enum
import
Enum
from
enum
import
Enum
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
partial
from
functools
import
partial
...
@@ -191,31 +192,47 @@ def generate_translation_dict(model, version, is_multimer=False):
...
@@ -191,31 +192,47 @@ def generate_translation_dict(model, version, is_multimer=False):
"attention"
:
AttentionGatedParams
(
tri_att
.
mha
),
"attention"
:
AttentionGatedParams
(
tri_att
.
mha
),
}
}
TriMulOutParams
=
lambda
tri_mul
:
{
def
TriMulOutParams
(
tri_mul
,
outgoing
=
True
):
"layer_norm_input"
:
LayerNormParams
(
tri_mul
.
layer_norm_in
),
if
re
.
fullmatch
(
"^model_[1-5]_multimer_v3$"
,
version
):
"left_projection"
:
LinearParams
(
tri_mul
.
linear_a_p
),
d
=
{
"right_projection"
:
LinearParams
(
tri_mul
.
linear_b_p
),
"left_norm_input"
:
LayerNormParams
(
tri_mul
.
layer_norm_in
),
"left_gate"
:
LinearParams
(
tri_mul
.
linear_a_g
),
"projection"
:
LinearParams
(
tri_mul
.
linear_ab_p
),
"right_gate"
:
LinearParams
(
tri_mul
.
linear_b_g
),
"gate"
:
LinearParams
(
tri_mul
.
linear_ab_g
),
"center_layer_norm"
:
LayerNormParams
(
tri_mul
.
layer_norm_out
),
"center_norm"
:
LayerNormParams
(
tri_mul
.
layer_norm_out
),
"output_projection"
:
LinearParams
(
tri_mul
.
linear_z
),
}
"gating_linear"
:
LinearParams
(
tri_mul
.
linear_g
),
else
:
}
# see commit b88f8da on the Alphafold repo
# Alphafold swaps the pseudocode's a and b between the incoming/outcoming
# iterations of triangle multiplication, which is confusing and not
# reproduced in our implementation.
if
outgoing
:
left_projection
=
LinearParams
(
tri_mul
.
linear_a_p
)
right_projection
=
LinearParams
(
tri_mul
.
linear_b_p
)
left_gate
=
LinearParams
(
tri_mul
.
linear_a_g
)
right_gate
=
LinearParams
(
tri_mul
.
linear_b_g
)
else
:
left_projection
=
LinearParams
(
tri_mul
.
linear_b_p
)
right_projection
=
LinearParams
(
tri_mul
.
linear_a_p
)
left_gate
=
LinearParams
(
tri_mul
.
linear_b_g
)
right_gate
=
LinearParams
(
tri_mul
.
linear_a_g
)
d
=
{
"layer_norm_input"
:
LayerNormParams
(
tri_mul
.
layer_norm_in
),
"left_projection"
:
left_projection
,
"right_projection"
:
right_projection
,
"left_gate"
:
left_gate
,
"right_gate"
:
right_gate
,
"center_layer_norm"
:
LayerNormParams
(
tri_mul
.
layer_norm_out
),
}
# see commit b88f8da on the Alphafold repo
d
.
update
({
# Alphafold swaps the pseudocode's a and b between the incoming/outcoming
"output_projection"
:
LinearParams
(
tri_mul
.
linear_z
),
# iterations of triangle multiplication, which is confusing and not
"gating_linear"
:
LinearParams
(
tri_mul
.
linear_g
),
# reproduced in our implementation.
})
TriMulInParams
=
lambda
tri_mul
:
{
"layer_norm_input"
:
LayerNormParams
(
tri_mul
.
layer_norm_in
),
return
d
"left_projection"
:
LinearParams
(
tri_mul
.
linear_b_p
),
"right_projection"
:
LinearParams
(
tri_mul
.
linear_a_p
),
TriMulInParams
=
partial
(
TriMulOutParams
,
outgoing
=
False
)
"left_gate"
:
LinearParams
(
tri_mul
.
linear_b_g
),
"right_gate"
:
LinearParams
(
tri_mul
.
linear_a_g
),
"center_layer_norm"
:
LayerNormParams
(
tri_mul
.
layer_norm_out
),
"output_projection"
:
LinearParams
(
tri_mul
.
linear_z
),
"gating_linear"
:
LinearParams
(
tri_mul
.
linear_g
),
}
PairTransitionParams
=
lambda
pt
:
{
PairTransitionParams
=
lambda
pt
:
{
"input_layer_norm"
:
LayerNormParams
(
pt
.
layer_norm
),
"input_layer_norm"
:
LayerNormParams
(
pt
.
layer_norm
),
...
...
scripts/download_alphafold_dbs.sh
View file @
68828c49
...
@@ -56,16 +56,16 @@ bash "${SCRIPT_DIR}/download_pdb70.sh" "${DOWNLOAD_DIR}"
...
@@ -56,16 +56,16 @@ bash "${SCRIPT_DIR}/download_pdb70.sh" "${DOWNLOAD_DIR}"
echo
"Downloading PDB mmCIF files..."
echo
"Downloading PDB mmCIF files..."
bash
"
${
SCRIPT_DIR
}
/download_pdb_mmcif.sh"
"
${
DOWNLOAD_DIR
}
"
bash
"
${
SCRIPT_DIR
}
/download_pdb_mmcif.sh"
"
${
DOWNLOAD_DIR
}
"
echo
"Downloading Uni
clust
30..."
echo
"Downloading Uni
ref
30..."
bash
"
${
SCRIPT_DIR
}
/download_uni
clust
30.sh"
"
${
DOWNLOAD_DIR
}
"
bash
"
${
SCRIPT_DIR
}
/download_uni
ref
30.sh"
"
${
DOWNLOAD_DIR
}
"
echo
"Downloading Uniref90..."
echo
"Downloading Uniref90..."
bash
"
${
SCRIPT_DIR
}
/download_uniref90.sh"
"
${
DOWNLOAD_DIR
}
"
bash
"
${
SCRIPT_DIR
}
/download_uniref90.sh"
"
${
DOWNLOAD_DIR
}
"
echo
"Downloading PDB SeqRes..."
bash
"
${
SCRIPT_DIR
}
/download_uniclust30.sh"
"
${
DOWNLOAD_DIR
}
"
echo
"Downloading UniProt..."
echo
"Downloading UniProt..."
bash
"
${
SCRIPT_DIR
}
/download_uniref90.sh"
"
${
DOWNLOAD_DIR
}
"
bash
"
${
SCRIPT_DIR
}
/download_uniprot.sh"
"
${
DOWNLOAD_DIR
}
"
echo
"Downloading PDB SeqRes..."
bash
"
${
SCRIPT_DIR
}
/download_pdb_seqres.sh"
"
${
DOWNLOAD_DIR
}
"
echo
"All data downloaded."
echo
"All data downloaded."
scripts/download_alphafold_params.sh
View file @
68828c49
...
@@ -31,7 +31,7 @@ fi
...
@@ -31,7 +31,7 @@ fi
DOWNLOAD_DIR
=
"
$1
"
DOWNLOAD_DIR
=
"
$1
"
ROOT_DIR
=
"
${
DOWNLOAD_DIR
}
/params"
ROOT_DIR
=
"
${
DOWNLOAD_DIR
}
/params"
SOURCE_URL
=
"https://storage.googleapis.com/alphafold/alphafold_params_2022-
03
-0
2
.tar"
SOURCE_URL
=
"https://storage.googleapis.com/alphafold/alphafold_params_2022-
12
-0
6
.tar"
BASENAME
=
$(
basename
"
${
SOURCE_URL
}
"
)
BASENAME
=
$(
basename
"
${
SOURCE_URL
}
"
)
mkdir
--parents
"
${
ROOT_DIR
}
"
mkdir
--parents
"
${
ROOT_DIR
}
"
...
...
scripts/download_mgnify.sh
View file @
68828c49
...
@@ -32,8 +32,8 @@ fi
...
@@ -32,8 +32,8 @@ fi
DOWNLOAD_DIR
=
"
$1
"
DOWNLOAD_DIR
=
"
$1
"
ROOT_DIR
=
"
${
DOWNLOAD_DIR
}
/mgnify"
ROOT_DIR
=
"
${
DOWNLOAD_DIR
}
/mgnify"
# Mirror of:
# Mirror of:
# ftp://ftp.ebi.ac.uk/pub/databases/metagenomics/peptide_database/20
18_12
/mgy_clusters.fa.gz
# ftp://ftp.ebi.ac.uk/pub/databases/metagenomics/peptide_database/20
22_05
/mgy_clusters.fa.gz
SOURCE_URL
=
"https://storage.googleapis.com/alphafold-databases/
casp14_versions
/mgy_clusters_20
18_12
.fa.gz"
SOURCE_URL
=
"https://storage.googleapis.com/alphafold-databases/
v2.3
/mgy_clusters_20
22_05
.fa.gz"
BASENAME
=
$(
basename
"
${
SOURCE_URL
}
"
)
BASENAME
=
$(
basename
"
${
SOURCE_URL
}
"
)
mkdir
--parents
"
${
ROOT_DIR
}
"
mkdir
--parents
"
${
ROOT_DIR
}
"
...
...
scripts/download_pdb_seqres.sh
View file @
68828c49
...
@@ -36,3 +36,7 @@ BASENAME=$(basename "${SOURCE_URL}")
...
@@ -36,3 +36,7 @@ BASENAME=$(basename "${SOURCE_URL}")
mkdir
--parents
"
${
ROOT_DIR
}
"
mkdir
--parents
"
${
ROOT_DIR
}
"
aria2c
"
${
SOURCE_URL
}
"
--dir
=
"
${
ROOT_DIR
}
"
aria2c
"
${
SOURCE_URL
}
"
--dir
=
"
${
ROOT_DIR
}
"
# Keep only protein sequences.
grep
--after-context
=
1
--no-group-separator
'>.* mol:protein'
"
${
ROOT_DIR
}
/pdb_seqres.txt"
>
"
${
ROOT_DIR
}
/pdb_seqres_filtered.txt"
mv
"
${
ROOT_DIR
}
/pdb_seqres_filtered.txt"
"
${
ROOT_DIR
}
/pdb_seqres.txt"
scripts/download_uniref30.sh
View file @
68828c49
...
@@ -30,8 +30,10 @@ if ! command -v aria2c &> /dev/null ; then
...
@@ -30,8 +30,10 @@ if ! command -v aria2c &> /dev/null ; then
fi
fi
DOWNLOAD_DIR
=
"
$1
"
DOWNLOAD_DIR
=
"
$1
"
ROOT_DIR
=
"
${
DOWNLOAD_DIR
}
"
ROOT_DIR
=
"
${
DOWNLOAD_DIR
}
/uniref30"
SOURCE_URL
=
"http://wwwuser.gwdg.de/~compbiol/colabfold/uniref30_2103.tar.gz"
# Mirror of:
# https://wwwuser.gwdg.de/~compbiol/uniclust/2021_03/UniRef30_2021_03.tar.gz
SOURCE_URL
=
"https://storage.googleapis.com/alphafold-databases/v2.3/UniRef30_2021_03.tar.gz"
BASENAME
=
$(
basename
"
${
SOURCE_URL
}
"
)
BASENAME
=
$(
basename
"
${
SOURCE_URL
}
"
)
mkdir
--parents
"
${
ROOT_DIR
}
"
mkdir
--parents
"
${
ROOT_DIR
}
"
...
...
tests/config.py
View file @
68828c49
...
@@ -2,7 +2,7 @@ import ml_collections as mlc
...
@@ -2,7 +2,7 @@ import ml_collections as mlc
consts
=
mlc
.
ConfigDict
(
consts
=
mlc
.
ConfigDict
(
{
{
"model"
:
"model_1_multimer_v
2
"
,
# monomer:model_1_ptm, multimer: model_1_multimer_v
2
"model"
:
"model_1_multimer_v
3
"
,
# monomer:model_1_ptm, multimer: model_1_multimer_v
3
"is_multimer"
:
True
,
# monomer: False, multimer: True
"is_multimer"
:
True
,
# monomer: False, multimer: True
"chunk_size"
:
4
,
"chunk_size"
:
4
,
"batch_size"
:
2
,
"batch_size"
:
2
,
...
...
tests/test_evoformer.py
View file @
68828c49
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
re
import
torch
import
torch
import
numpy
as
np
import
numpy
as
np
import
unittest
import
unittest
...
@@ -49,6 +50,7 @@ class TestEvoformerStack(unittest.TestCase):
...
@@ -49,6 +50,7 @@ class TestEvoformerStack(unittest.TestCase):
msa_dropout
=
0.15
msa_dropout
=
0.15
pair_stack_dropout
=
0.25
pair_stack_dropout
=
0.25
opm_first
=
consts
.
is_multimer
opm_first
=
consts
.
is_multimer
fuse_projection_weights
=
True
if
re
.
fullmatch
(
"^model_[1-5]_multimer_v3$"
,
consts
.
model
)
else
False
inf
=
1e9
inf
=
1e9
eps
=
1e-10
eps
=
1e-10
...
@@ -67,6 +69,7 @@ class TestEvoformerStack(unittest.TestCase):
...
@@ -67,6 +69,7 @@ class TestEvoformerStack(unittest.TestCase):
msa_dropout
,
msa_dropout
,
pair_stack_dropout
,
pair_stack_dropout
,
opm_first
,
opm_first
,
fuse_projection_weights
,
blocks_per_ckpt
=
None
,
blocks_per_ckpt
=
None
,
inf
=
inf
,
inf
=
inf
,
eps
=
eps
,
eps
=
eps
,
...
@@ -177,6 +180,7 @@ class TestExtraMSAStack(unittest.TestCase):
...
@@ -177,6 +180,7 @@ class TestExtraMSAStack(unittest.TestCase):
msa_dropout
=
0.15
msa_dropout
=
0.15
pair_stack_dropout
=
0.25
pair_stack_dropout
=
0.25
opm_first
=
consts
.
is_multimer
opm_first
=
consts
.
is_multimer
fuse_projection_weights
=
True
if
re
.
fullmatch
(
"^model_[1-5]_multimer_v3$"
,
consts
.
model
)
else
False
inf
=
1e9
inf
=
1e9
eps
=
1e-10
eps
=
1e-10
...
@@ -194,6 +198,7 @@ class TestExtraMSAStack(unittest.TestCase):
...
@@ -194,6 +198,7 @@ class TestExtraMSAStack(unittest.TestCase):
msa_dropout
,
msa_dropout
,
pair_stack_dropout
,
pair_stack_dropout
,
opm_first
,
opm_first
,
fuse_projection_weights
,
ckpt
=
False
,
ckpt
=
False
,
inf
=
inf
,
inf
=
inf
,
eps
=
eps
,
eps
=
eps
,
...
...
tests/test_template.py
View file @
68828c49
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
re
import
torch
import
torch
import
numpy
as
np
import
numpy
as
np
import
unittest
import
unittest
...
@@ -78,6 +79,7 @@ class TestTemplatePairStack(unittest.TestCase):
...
@@ -78,6 +79,7 @@ class TestTemplatePairStack(unittest.TestCase):
n_templ
=
consts
.
n_templ
n_templ
=
consts
.
n_templ
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
tri_mul_first
=
consts
.
is_multimer
tri_mul_first
=
consts
.
is_multimer
fuse_projection_weights
=
True
if
re
.
fullmatch
(
"^model_[1-5]_multimer_v3$"
,
consts
.
model
)
else
False
blocks_per_ckpt
=
None
blocks_per_ckpt
=
None
chunk_size
=
4
chunk_size
=
4
inf
=
1e7
inf
=
1e7
...
@@ -92,6 +94,7 @@ class TestTemplatePairStack(unittest.TestCase):
...
@@ -92,6 +94,7 @@ class TestTemplatePairStack(unittest.TestCase):
pair_transition_n
=
pt_inner_dim
,
pair_transition_n
=
pt_inner_dim
,
dropout_rate
=
dropout
,
dropout_rate
=
dropout
,
tri_mul_first
=
tri_mul_first
,
tri_mul_first
=
tri_mul_first
,
fuse_projection_weights
=
fuse_projection_weights
,
blocks_per_ckpt
=
None
,
blocks_per_ckpt
=
None
,
inf
=
inf
,
inf
=
inf
,
eps
=
eps
,
eps
=
eps
,
...
...
tests/test_triangular_multiplicative_update.py
View file @
68828c49
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
import
torch
import
torch
import
re
import
numpy
as
np
import
numpy
as
np
import
unittest
import
unittest
from
openfold.model.triangular_multiplicative_update
import
*
from
openfold.model.triangular_multiplicative_update
import
*
...
@@ -31,10 +32,16 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
...
@@ -31,10 +32,16 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
c_z
=
consts
.
c_z
c_z
=
consts
.
c_z
c
=
11
c
=
11
tm
=
TriangleMultiplicationOutgoing
(
if
re
.
fullmatch
(
"^model_[1-5]_multimer_v3$"
,
consts
.
model
):
c_z
,
tm
=
FusedTriangleMultiplicationOutgoing
(
c
,
c_z
,
)
c
,
)
else
:
tm
=
TriangleMultiplicationOutgoing
(
c_z
,
c
,
)
n_res
=
consts
.
c_z
n_res
=
consts
.
c_z
batch_size
=
consts
.
batch_size
batch_size
=
consts
.
batch_size
...
@@ -62,7 +69,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
...
@@ -62,7 +69,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
config
.
model
.
global_config
,
config
.
model
.
global_config
,
name
=
name
,
name
=
name
,
)
)
act
=
tri_mul
(
act
=
pair_act
,
mask
=
pair_mask
)
act
=
tri_mul
(
pair_act
,
pair_mask
)
return
act
return
act
f
=
hk
.
transform
(
run_tri_mul
)
f
=
hk
.
transform
(
run_tri_mul
)
...
@@ -89,6 +96,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
...
@@ -89,6 +96,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
if
incoming
if
incoming
else
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
tri_mul_out
else
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
tri_mul_out
)
)
out_repro
=
module
(
out_repro
=
module
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
pair_mask
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
pair_mask
,
dtype
=
torch
.
float32
).
cuda
(),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment