Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
34e9363c
Commit
34e9363c
authored
Nov 19, 2021
by
Gustaf Ahdritz
Browse files
Make more components TorchScript-able, add tracing
parent
34e4e6ce
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
516 additions
and
200 deletions
+516
-200
openfold/model/evoformer.py
openfold/model/evoformer.py
+21
-11
openfold/model/msa.py
openfold/model/msa.py
+81
-28
openfold/model/outer_product_mean.py
openfold/model/outer_product_mean.py
+35
-18
openfold/model/pair_transition.py
openfold/model/pair_transition.py
+22
-9
openfold/model/structure_module.py
openfold/model/structure_module.py
+8
-9
openfold/model/template.py
openfold/model/template.py
+45
-24
openfold/model/torchscript.py
openfold/model/torchscript.py
+200
-15
openfold/model/triangular_attention.py
openfold/model/triangular_attention.py
+30
-14
openfold/model/triangular_multiplicative_update.py
openfold/model/triangular_multiplicative_update.py
+37
-42
openfold/utils/affine_utils.py
openfold/utils/affine_utils.py
+1
-1
openfold/utils/checkpointing.py
openfold/utils/checkpointing.py
+2
-1
openfold/utils/import_weights.py
openfold/utils/import_weights.py
+6
-1
openfold/utils/tensor_utils.py
openfold/utils/tensor_utils.py
+19
-19
run_pretrained_openfold.py
run_pretrained_openfold.py
+3
-3
train_openfold.py
train_openfold.py
+6
-5
No files found.
openfold/model/evoformer.py
View file @
34e9363c
...
@@ -71,11 +71,25 @@ class MSATransition(nn.Module):
...
@@ -71,11 +71,25 @@ class MSATransition(nn.Module):
m
=
self
.
linear_2
(
m
)
*
mask
m
=
self
.
linear_2
(
m
)
*
mask
return
m
return
m
@
torch
.
jit
.
ignore
def
_chunk
(
self
,
m
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
chunk_size
:
int
,
)
->
torch
.
Tensor
:
return
chunk_layer
(
self
.
_transition
,
{
"m"
:
m
,
"mask"
:
mask
},
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
)
def
forward
(
def
forward
(
self
,
self
,
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
=
None
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
int
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Args:
Args:
...
@@ -95,16 +109,10 @@ class MSATransition(nn.Module):
...
@@ -95,16 +109,10 @@ class MSATransition(nn.Module):
m
=
self
.
layer_norm
(
m
)
m
=
self
.
layer_norm
(
m
)
inp
=
{
"m"
:
m
,
"mask"
:
mask
}
if
chunk_size
is
not
None
:
if
chunk_size
is
not
None
:
m
=
chunk_layer
(
m
=
self
.
_chunk
(
m
,
mask
,
chunk_size
)
self
.
_transition
,
inp
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
)
else
:
else
:
m
=
self
.
_transition
(
**
inp
)
m
=
self
.
_transition
(
m
,
mask
)
return
m
return
m
...
@@ -201,9 +209,11 @@ class EvoformerBlock(nn.Module):
...
@@ -201,9 +209,11 @@ class EvoformerBlock(nn.Module):
z
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
int
,
chunk_size
:
Optional
[
int
]
=
None
,
_mask_trans
:
bool
=
True
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
#print(torch.cuda.memory_summary())
# DeepMind doesn't mask these transitions in the source, so _mask_trans
# DeepMind doesn't mask these transitions in the source, so _mask_trans
# should be disabled to better approximate the exact activations of
# should be disabled to better approximate the exact activations of
# the original.
# the original.
...
...
openfold/model/msa.py
View file @
34e9363c
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
import
math
import
math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
typing
import
Optional
from
typing
import
Optional
,
List
from
openfold.model.primitives
import
Linear
,
Attention
,
GlobalAttention
from
openfold.model.primitives
import
Linear
,
Attention
,
GlobalAttention
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
...
@@ -63,6 +63,8 @@ class MSAAttention(nn.Module):
...
@@ -63,6 +63,8 @@ class MSAAttention(nn.Module):
self
.
layer_norm_m
=
nn
.
LayerNorm
(
self
.
c_in
)
self
.
layer_norm_m
=
nn
.
LayerNorm
(
self
.
c_in
)
self
.
layer_norm_z
=
None
self
.
linear_z
=
None
if
self
.
pair_bias
:
if
self
.
pair_bias
:
self
.
layer_norm_z
=
nn
.
LayerNorm
(
self
.
c_z
)
self
.
layer_norm_z
=
nn
.
LayerNorm
(
self
.
c_z
)
self
.
linear_z
=
Linear
(
self
.
linear_z
=
Linear
(
...
@@ -73,7 +75,25 @@ class MSAAttention(nn.Module):
...
@@ -73,7 +75,25 @@ class MSAAttention(nn.Module):
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_hidden
,
self
.
no_heads
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_hidden
,
self
.
no_heads
)
)
def
forward
(
self
,
m
,
chunk_size
,
z
=
None
,
mask
=
None
):
@
torch
.
jit
.
ignore
def
_chunk
(
self
,
m
:
torch
.
Tensor
,
biases
:
List
[
torch
.
Tensor
],
chunk_size
:
int
,
)
->
torch
.
Tensor
:
return
chunk_layer
(
self
.
mha
,
{
"q_x"
:
m
,
"k_x"
:
m
,
"v_x"
:
m
,
"biases"
:
biases
},
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
)
def
forward
(
self
,
m
:
torch
.
Tensor
,
z
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
"""
"""
Args:
Args:
m:
m:
...
@@ -83,6 +103,11 @@ class MSAAttention(nn.Module):
...
@@ -83,6 +103,11 @@ class MSAAttention(nn.Module):
pair_bias is True
pair_bias is True
mask:
mask:
[*, N_seq, N_res] MSA mask
[*, N_seq, N_res] MSA mask
chunk_size:
Size of chunks into which the inputs are split along their
batch dimensions. A low value decreases memory overhead at the
cost of slower execution. Chunking is not performed by default.
"""
"""
# [*, N_seq, N_res, C_m]
# [*, N_seq, N_res, C_m]
m
=
self
.
layer_norm_m
(
m
)
m
=
self
.
layer_norm_m
(
m
)
...
@@ -106,7 +131,11 @@ class MSAAttention(nn.Module):
...
@@ -106,7 +131,11 @@ class MSAAttention(nn.Module):
biases
=
[
bias
]
biases
=
[
bias
]
if
self
.
pair_bias
:
if
(
self
.
pair_bias
and
z
is
not
None
and
# For the
self
.
layer_norm_z
is
not
None
and
# benefit of
self
.
linear_z
is
not
None
# TorchScript
):
# [*, N_res, N_res, C_z]
# [*, N_res, N_res, C_z]
z
=
self
.
layer_norm_z
(
z
)
z
=
self
.
layer_norm_z
(
z
)
...
@@ -118,16 +147,10 @@ class MSAAttention(nn.Module):
...
@@ -118,16 +147,10 @@ class MSAAttention(nn.Module):
biases
.
append
(
z
)
biases
.
append
(
z
)
mha_inputs
=
{
"q_x"
:
m
,
"k_x"
:
m
,
"v_x"
:
m
,
"biases"
:
biases
}
if
chunk_size
is
not
None
:
if
chunk_size
is
not
None
:
m
=
chunk_layer
(
m
=
self
.
_chunk
(
m
,
biases
,
chunk_size
)
self
.
mha
,
mha_inputs
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
)
else
:
else
:
m
=
self
.
mha
(
**
mha_input
s
)
m
=
self
.
mha
(
q_x
=
m
,
k_x
=
m
,
v_x
=
m
,
biases
=
biase
s
)
return
m
return
m
...
@@ -161,9 +184,12 @@ class MSARowAttentionWithPairBias(MSAAttention):
...
@@ -161,9 +184,12 @@ class MSARowAttentionWithPairBias(MSAAttention):
)
)
class
MSAColumnAttention
(
MSAAttention
):
class
MSAColumnAttention
(
nn
.
Module
):
"""
"""
Implements Algorithm 8.
Implements Algorithm 8.
By rights, this should also be a subclass of MSAAttention. Alas,
most inheritance isn't supported by TorchScript.
"""
"""
def
__init__
(
self
,
c_m
,
c_hidden
,
no_heads
,
inf
=
1e9
):
def
__init__
(
self
,
c_m
,
c_hidden
,
no_heads
,
inf
=
1e9
):
...
@@ -178,7 +204,14 @@ class MSAColumnAttention(MSAAttention):
...
@@ -178,7 +204,14 @@ class MSAColumnAttention(MSAAttention):
inf:
inf:
Large number used to construct attention masks
Large number used to construct attention masks
"""
"""
super
(
MSAColumnAttention
,
self
).
__init__
(
super
(
MSAColumnAttention
,
self
).
__init__
()
self
.
c_m
=
c_m
self
.
c_hidden
=
c_hidden
self
.
no_heads
=
no_heads
self
.
inf
=
inf
self
.
_msa_att
=
MSAAttention
(
c_in
=
c_m
,
c_in
=
c_m
,
c_hidden
=
c_hidden
,
c_hidden
=
c_hidden
,
no_heads
=
no_heads
,
no_heads
=
no_heads
,
...
@@ -187,31 +220,40 @@ class MSAColumnAttention(MSAAttention):
...
@@ -187,31 +220,40 @@ class MSAColumnAttention(MSAAttention):
inf
=
inf
,
inf
=
inf
,
)
)
def
forward
(
self
,
m
,
chunk_size
,
mask
=
None
):
def
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
"""
"""
Args:
Args:
m:
m:
[*, N_seq, N_res, C_m] MSA embedding
[*, N_seq, N_res, C_m] MSA embedding
mask:
mask:
[*, N_seq, N_res] MSA mask
[*, N_seq, N_res] MSA mask
chunk_size:
Size of chunks into which the inputs are split along their
batch dimensions. A low value decreases memory overhead at the
cost of slower execution. Chunking is not performed by default.
"""
"""
# [*, N_res, N_seq, C_in]
# [*, N_res, N_seq, C_in]
m
=
m
.
transpose
(
-
2
,
-
3
)
m
=
m
.
transpose
(
-
2
,
-
3
)
if
mask
is
not
None
:
if
mask
is
not
None
:
mask
=
mask
.
transpose
(
-
1
,
-
2
)
mask
=
mask
.
transpose
(
-
1
,
-
2
)
m
=
s
uper
().
forward
(
m
,
chunk_size
=
chunk_size
,
mask
=
mask
)
m
=
s
elf
.
_msa_att
(
m
,
mask
=
mask
,
chunk_size
=
chunk_size
)
# [*, N_seq, N_res, C_in]
# [*, N_seq, N_res, C_in]
m
=
m
.
transpose
(
-
2
,
-
3
)
m
=
m
.
transpose
(
-
2
,
-
3
)
if
mask
is
not
None
:
if
mask
is
not
None
:
mask
=
mask
.
transpose
(
-
1
,
-
2
)
mask
=
mask
.
transpose
(
-
1
,
-
2
)
return
m
return
m
class
MSAColumnGlobalAttention
(
nn
.
Module
):
class
MSAColumnGlobalAttention
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
c_in
,
c_hidden
,
no_heads
,
inf
=
1e9
,
eps
=
1e-10
self
,
c_in
,
c_hidden
,
no_heads
,
inf
=
1e9
,
eps
=
1e-10
,
):
):
super
(
MSAColumnGlobalAttention
,
self
).
__init__
()
super
(
MSAColumnGlobalAttention
,
self
).
__init__
()
...
@@ -231,8 +273,28 @@ class MSAColumnGlobalAttention(nn.Module):
...
@@ -231,8 +273,28 @@ class MSAColumnGlobalAttention(nn.Module):
eps
=
eps
,
eps
=
eps
,
)
)
@
torch
.
jit
.
ignore
def
_chunk
(
self
,
m
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
chunk_size
:
int
,
)
->
torch
.
Tensor
:
mha_input
=
{
"m"
:
m
,
"mask"
:
mask
,
}
return
chunk_layer
(
self
.
global_attention
,
mha_input
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
)
def
forward
(
def
forward
(
self
,
m
:
torch
.
Tensor
,
chunk_size
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
self
,
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
n_seq
,
n_res
,
c_in
=
m
.
shape
[
-
3
:]
n_seq
,
n_res
,
c_in
=
m
.
shape
[
-
3
:]
...
@@ -251,19 +313,10 @@ class MSAColumnGlobalAttention(nn.Module):
...
@@ -251,19 +313,10 @@ class MSAColumnGlobalAttention(nn.Module):
# [*, N_res, N_seq, C_in]
# [*, N_res, N_seq, C_in]
m
=
self
.
layer_norm_m
(
m
)
m
=
self
.
layer_norm_m
(
m
)
mha_input
=
{
"m"
:
m
,
"mask"
:
mask
,
}
if
chunk_size
is
not
None
:
if
chunk_size
is
not
None
:
m
=
chunk_layer
(
m
=
self
.
_chunk
(
m
,
mask
,
chunk_size
)
self
.
global_attention
,
mha_input
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
)
else
:
else
:
m
=
self
.
global_attention
(
m
=
m
ha_input
[
"m"
],
mask
=
mha_input
[
"
mask
"
]
)
m
=
self
.
global_attention
(
m
=
m
,
mask
=
mask
)
# [*, N_seq, N_res, C_in]
# [*, N_seq, N_res, C_in]
m
=
m
.
transpose
(
-
2
,
-
3
)
m
=
m
.
transpose
(
-
2
,
-
3
)
...
...
openfold/model/outer_product_mean.py
View file @
34e9363c
...
@@ -14,6 +14,8 @@
...
@@ -14,6 +14,8 @@
# limitations under the License.
# limitations under the License.
from
functools
import
partial
from
functools
import
partial
from
typing
import
Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -38,6 +40,7 @@ class OuterProductMean(nn.Module):
...
@@ -38,6 +40,7 @@ class OuterProductMean(nn.Module):
"""
"""
super
(
OuterProductMean
,
self
).
__init__
()
super
(
OuterProductMean
,
self
).
__init__
()
self
.
c_m
=
c_m
self
.
c_z
=
c_z
self
.
c_z
=
c_z
self
.
c_hidden
=
c_hidden
self
.
c_hidden
=
c_hidden
self
.
eps
=
eps
self
.
eps
=
eps
...
@@ -52,14 +55,43 @@ class OuterProductMean(nn.Module):
...
@@ -52,14 +55,43 @@ class OuterProductMean(nn.Module):
outer
=
torch
.
einsum
(
"...bac,...dae->...bdce"
,
a
,
b
)
outer
=
torch
.
einsum
(
"...bac,...dae->...bdce"
,
a
,
b
)
# [*, N_res, N_res, C * C]
# [*, N_res, N_res, C * C]
outer
=
outer
.
reshape
(
*
outer
.
shape
[:
-
2
]
,
-
1
)
outer
=
outer
.
reshape
(
outer
.
shape
[:
-
2
]
+
(
-
1
,)
)
# [*, N_res, N_res, C_z]
# [*, N_res, N_res, C_z]
outer
=
self
.
linear_out
(
outer
)
outer
=
self
.
linear_out
(
outer
)
return
outer
return
outer
def
forward
(
self
,
m
,
chunk_size
,
mask
=
None
):
@
torch
.
jit
.
ignore
def
_chunk
(
self
,
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
chunk_size
:
int
)
->
torch
.
Tensor
:
# Since the "batch dim" in this case is not a true batch dimension
# (in that the shape of the output depends on it), we need to
# iterate over it ourselves
a_reshape
=
a
.
reshape
((
-
1
,)
+
a
.
shape
[
-
3
:])
b_reshape
=
b
.
reshape
((
-
1
,)
+
b
.
shape
[
-
3
:])
out
=
[]
for
a_prime
,
b_prime
in
zip
(
a_reshape
,
b_reshape
):
outer
=
chunk_layer
(
partial
(
self
.
_opm
,
b
=
b_prime
),
{
"a"
:
a_prime
},
chunk_size
=
chunk_size
,
no_batch_dims
=
1
,
)
out
.
append
(
outer
)
outer
=
torch
.
stack
(
out
,
dim
=
0
)
outer
=
outer
.
reshape
(
a
.
shape
[:
-
3
]
+
outer
.
shape
[
1
:])
return
outer
def
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
"""
"""
Args:
Args:
m:
m:
...
@@ -84,22 +116,7 @@ class OuterProductMean(nn.Module):
...
@@ -84,22 +116,7 @@ class OuterProductMean(nn.Module):
b
=
b
.
transpose
(
-
2
,
-
3
)
b
=
b
.
transpose
(
-
2
,
-
3
)
if
chunk_size
is
not
None
:
if
chunk_size
is
not
None
:
# Since the "batch dim" in this case is not a true batch dimension
outer
=
self
.
_chunk
(
a
,
b
,
chunk_size
)
# (in that the shape of the output depends on it), we need to
# iterate over it ourselves
a_reshape
=
a
.
reshape
(
-
1
,
*
a
.
shape
[
-
3
:])
b_reshape
=
b
.
reshape
(
-
1
,
*
b
.
shape
[
-
3
:])
out
=
[]
for
a_prime
,
b_prime
in
zip
(
a_reshape
,
b_reshape
):
outer
=
chunk_layer
(
partial
(
self
.
_opm
,
b
=
b_prime
),
{
"a"
:
a_prime
},
chunk_size
=
chunk_size
,
no_batch_dims
=
1
,
)
out
.
append
(
outer
)
outer
=
torch
.
stack
(
out
,
dim
=
0
)
outer
=
outer
.
reshape
(
*
a
.
shape
[:
-
3
],
*
outer
.
shape
[
1
:])
else
:
else
:
outer
=
self
.
_opm
(
a
,
b
)
outer
=
self
.
_opm
(
a
,
b
)
...
...
openfold/model/pair_transition.py
View file @
34e9363c
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
typing
import
Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -54,7 +55,25 @@ class PairTransition(nn.Module):
...
@@ -54,7 +55,25 @@ class PairTransition(nn.Module):
return
z
return
z
def
forward
(
self
,
z
,
chunk_size
,
mask
=
None
):
@
torch
.
jit
.
ignore
def
_chunk
(
self
,
z
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
chunk_size
:
int
,
)
->
torch
.
Tensor
:
return
chunk_layer
(
self
.
_transition
,
{
"z"
:
z
,
"mask"
:
mask
},
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
z
.
shape
[:
-
2
]),
)
def
forward
(
self
,
z
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
"""
"""
Args:
Args:
z:
z:
...
@@ -72,15 +91,9 @@ class PairTransition(nn.Module):
...
@@ -72,15 +91,9 @@ class PairTransition(nn.Module):
# [*, N_res, N_res, C_z]
# [*, N_res, N_res, C_z]
z
=
self
.
layer_norm
(
z
)
z
=
self
.
layer_norm
(
z
)
inp
=
{
"z"
:
z
,
"mask"
:
mask
}
if
chunk_size
is
not
None
:
if
chunk_size
is
not
None
:
z
=
chunk_layer
(
z
=
self
.
_chunk
(
z
,
mask
,
chunk_size
)
self
.
_transition
,
inp
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
z
.
shape
[:
-
2
]),
)
else
:
else
:
z
=
self
.
_transition
(
**
inp
)
z
=
self
.
_transition
(
z
=
z
,
mask
=
mask
)
return
z
return
z
openfold/model/structure_module.py
View file @
34e9363c
...
@@ -155,17 +155,16 @@ class InvariantPointAttention(nn.Module):
...
@@ -155,17 +155,16 @@ class InvariantPointAttention(nn.Module):
"""
"""
Implements Algorithm 22.
Implements Algorithm 22.
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
c_s
,
c_s
:
int
,
c_z
,
c_z
:
int
,
c_hidden
,
c_hidden
:
int
,
no_heads
,
no_heads
:
int
,
no_qk_points
,
no_qk_points
:
int
,
no_v_points
,
no_v_points
:
int
,
inf
=
1e5
,
inf
:
float
=
1e5
,
eps
=
1e-8
,
eps
:
float
=
1e-8
,
):
):
"""
"""
Args:
Args:
...
...
openfold/model/template.py
View file @
34e9363c
...
@@ -12,9 +12,10 @@
...
@@ -12,9 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
functools
import
partial
from
functools
import
partial
import
math
import
math
from
typing
import
Optional
,
List
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -71,7 +72,32 @@ class TemplatePointwiseAttention(nn.Module):
...
@@ -71,7 +72,32 @@ class TemplatePointwiseAttention(nn.Module):
gating
=
False
,
gating
=
False
,
)
)
def
forward
(
self
,
t
,
z
,
chunk_size
,
template_mask
=
None
):
def
_chunk
(
self
,
z
:
torch
.
Tensor
,
t
:
torch
.
Tensor
,
biases
:
List
[
torch
.
Tensor
],
chunk_size
:
int
,
)
->
torch
.
Tensor
:
mha_inputs
=
{
"q_x"
:
z
,
"k_x"
:
t
,
"v_x"
:
t
,
"biases"
:
biases
,
}
return
chunk_layer
(
self
.
mha
,
mha_inputs
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
z
.
shape
[:
-
2
]),
)
def
forward
(
self
,
t
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
template_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
"""
"""
Args:
Args:
t:
t:
...
@@ -95,21 +121,11 @@ class TemplatePointwiseAttention(nn.Module):
...
@@ -95,21 +121,11 @@ class TemplatePointwiseAttention(nn.Module):
t
=
permute_final_dims
(
t
,
(
1
,
2
,
0
,
3
))
t
=
permute_final_dims
(
t
,
(
1
,
2
,
0
,
3
))
# [*, N_res, N_res, 1, C_z]
# [*, N_res, N_res, 1, C_z]
mha_inputs
=
{
biases
=
[
bias
]
"q_x"
:
z
,
"k_x"
:
t
,
"v_x"
:
t
,
"biases"
:
[
bias
],
}
if
chunk_size
is
not
None
:
if
chunk_size
is
not
None
:
z
=
chunk_layer
(
z
=
self
.
_chunk
(
z
,
t
,
biases
,
chunk_size
)
self
.
mha
,
mha_inputs
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
z
.
shape
[:
-
2
]),
)
else
:
else
:
z
=
self
.
mha
(
**
mha_input
s
)
z
=
self
.
mha
(
q_x
=
z
,
k_x
=
t
,
v_x
=
t
,
biases
=
biase
s
)
# [*, N_res, N_res, C_z]
# [*, N_res, N_res, C_z]
z
=
z
.
squeeze
(
-
2
)
z
=
z
.
squeeze
(
-
2
)
...
@@ -120,13 +136,13 @@ class TemplatePointwiseAttention(nn.Module):
...
@@ -120,13 +136,13 @@ class TemplatePointwiseAttention(nn.Module):
class
TemplatePairStackBlock
(
nn
.
Module
):
class
TemplatePairStackBlock
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
c_t
,
c_t
:
int
,
c_hidden_tri_att
,
c_hidden_tri_att
:
int
,
c_hidden_tri_mul
,
c_hidden_tri_mul
:
int
,
no_heads
,
no_heads
:
int
,
pair_transition_n
,
pair_transition_n
:
int
,
dropout_rate
,
dropout_rate
:
float
,
inf
,
inf
:
float
,
**
kwargs
,
**
kwargs
,
):
):
super
(
TemplatePairStackBlock
,
self
).
__init__
()
super
(
TemplatePairStackBlock
,
self
).
__init__
()
...
@@ -169,7 +185,12 @@ class TemplatePairStackBlock(nn.Module):
...
@@ -169,7 +185,12 @@ class TemplatePairStackBlock(nn.Module):
self
.
pair_transition_n
,
self
.
pair_transition_n
,
)
)
def
forward
(
self
,
z
,
mask
,
chunk_size
,
_mask_trans
=
True
):
def
forward
(
self
,
z
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
_mask_trans
:
bool
=
True
):
single_templates
=
[
single_templates
=
[
t
.
unsqueeze
(
-
4
)
for
t
in
torch
.
unbind
(
z
,
dim
=-
4
)
t
.
unsqueeze
(
-
4
)
for
t
in
torch
.
unbind
(
z
,
dim
=-
4
)
]
]
...
@@ -208,8 +229,8 @@ class TemplatePairStackBlock(nn.Module):
...
@@ -208,8 +229,8 @@ class TemplatePairStackBlock(nn.Module):
)
)
single
=
single
+
self
.
pair_transition
(
single
=
single
+
self
.
pair_transition
(
single
,
single
,
mask
=
single_mask
if
_mask_trans
else
None
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
mask
=
single_mask
if
_mask_trans
else
None
)
)
single_templates
[
i
]
=
single
single_templates
[
i
]
=
single
...
...
openfold/model/torchscript.py
View file @
34e9363c
from
typing
import
Optional
,
Sequence
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Optional
,
Sequence
,
Tuple
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
openfold.model.dropout
import
(
DropoutRowwise
,
DropoutColumnwise
,
)
from
openfold.model.evoformer
import
(
EvoformerBlock
,
EvoformerStack
,
)
from
openfold.model.outer_product_mean
import
OuterProductMean
from
openfold.model.msa
import
(
MSARowAttentionWithPairBias
,
MSAColumnAttention
,
MSAColumnGlobalAttention
,
)
from
openfold.model.pair_transition
import
PairTransition
from
openfold.model.primitives
import
Attention
,
GlobalAttention
from
openfold.model.primitives
import
Attention
,
GlobalAttention
from
openfold.model.structure_module
import
(
InvariantPointAttention
,
BackboneUpdate
,
)
from
openfold.model.template
import
TemplatePairStackBlock
from
openfold.model.triangular_attention
import
(
TriangleAttentionStartingNode
,
TriangleAttentionEndingNode
,
)
from
openfold.model.triangular_multiplicative_update
import
(
TriangleMultiplicationOutgoing
,
TriangleMultiplicationIncoming
,
)
def
script_preset_
(
model
:
torch
.
nn
.
Module
):
"""
TorchScript a handful of low-level but frequently used submodule types
that are known to be scriptable.
Args:
model:
A torch.nn.Module. It should contain at least some modules from
this repository, or this function won't do anything.
"""
script_submodules_
(
model
,
[
nn
.
Dropout
,
Attention
,
GlobalAttention
,
EvoformerBlock
,
#TemplatePairStackBlock,
],
attempt_trace
=
False
,
batch_dims
=
None
,
)
def
_get_module_device
(
module
:
torch
.
nn
.
Module
)
->
torch
.
device
:
"""
Fetches the device of a module, assuming that all of the module's
parameters reside on a single device
Args:
module: A torch.nn.Module
Returns:
The module's device
"""
return
next
(
module
.
parameters
()).
device
def
_trace_module
(
module
,
batch_dims
=
None
):
if
(
batch_dims
is
None
):
batch_dims
=
()
# Stand-in values
n_seq
=
10
n_res
=
10
device
=
_get_module_device
(
module
)
def
msa
(
channel_dim
):
return
torch
.
rand
(
(
*
batch_dims
,
n_seq
,
n_res
,
channel_dim
),
device
=
device
,
)
def
pair
(
channel_dim
):
return
torch
.
rand
(
(
*
batch_dims
,
n_res
,
n_res
,
channel_dim
),
device
=
device
,
)
if
(
isinstance
(
module
,
MSARowAttentionWithPairBias
)):
inputs
=
{
"forward"
:
(
msa
(
module
.
c_in
),
# m
pair
(
module
.
c_z
),
# z
torch
.
randint
(
0
,
2
,
(
*
batch_dims
,
n_seq
,
n_res
)
),
# mask
),
}
elif
(
isinstance
(
module
,
MSAColumnAttention
)):
inputs
=
{
"forward"
:
(
msa
(
module
.
c_in
),
# m
torch
.
randint
(
0
,
2
,
(
*
batch_dims
,
n_seq
,
n_res
)
),
# mask
),
}
elif
(
isinstance
(
module
,
OuterProductMean
)):
inputs
=
{
"forward"
:
(
msa
(
module
.
c_m
),
torch
.
randint
(
0
,
2
,
(
*
batch_dims
,
n_seq
,
n_res
)
)
)
}
module
=
OPM
(
module
)
else
:
raise
TypeError
(
f
"tracing is not supported for modules of type
{
type
(
module
)
}
"
)
return
torch
.
jit
.
trace_module
(
module
,
inputs
)
def
_script_submodules_helper_
(
model
,
types
,
attempt_trace
,
to_trace
,
):
for
name
,
child
in
model
.
named_children
():
if
(
types
is
None
or
any
(
isinstance
(
child
,
t
)
for
t
in
types
)):
try
:
scripted
=
torch
.
jit
.
script
(
child
)
setattr
(
model
,
name
,
scripted
)
continue
except
(
RuntimeError
,
torch
.
jit
.
frontend
.
NotSupportedError
)
as
e
:
if
(
attempt_trace
):
to_trace
.
add
(
type
(
child
))
else
:
raise
e
_script_submodules_helper_
(
child
,
types
,
attempt_trace
,
to_trace
)
def
_trace_submodules_
(
model
,
types
,
batch_dims
=
None
,
):
for
name
,
child
in
model
.
named_children
():
if
(
any
(
isinstance
(
child
,
t
)
for
t
in
types
)):
traced
=
_trace_module
(
child
,
batch_dims
=
batch_dims
)
setattr
(
model
,
name
,
traced
)
else
:
_trace_submodules_
(
child
,
types
,
batch_dims
=
batch_dims
)
def
script_primitives_
(
model
):
script_submodules_
(
model
,
[
Attention
,
GlobalAttention
])
def
script_submodules_
(
def
script_submodules_
(
model
:
nn
.
Module
,
model
:
nn
.
Module
,
types
:
Optional
[
Sequence
[
type
]]
=
None
,
types
:
Optional
[
Sequence
[
type
]]
=
None
,
attempt_trace
:
Optional
[
bool
]
=
True
,
batch_dims
:
Optional
[
Tuple
[
int
]]
=
None
,
):
):
"""
"""
Convert all submodules whose types match one of those in the input
Convert all submodules whose types match one of those in the input
...
@@ -21,11 +195,22 @@ def script_submodules_(
...
@@ -21,11 +195,22 @@ def script_submodules_(
When types is None, all submodules are scripted.
When types is None, all submodules are scripted.
Args:
Args:
model: A torch.nn.Module
model:
types: A list of types of submodules to script
A torch.nn.Module
types:
A list of types of submodules to script
attempt_trace:
Whether to attempt to trace specified modules if scripting
fails. Recall that tracing eliminates all conditional
logic---with great tracing comes the mild responsibility of
having to remember to ensure that the modules in question
perform the same computations no matter what.
"""
"""
for
name
,
child
in
model
.
named_children
():
to_trace
=
set
()
if
(
types
is
None
or
any
(
isinstance
(
child
,
t
)
for
t
in
types
)):
setattr
(
model
,
name
,
torch
.
jit
.
script
(
child
))
# Aggressively script as much as possible first...
else
:
_script_submodules_helper_
(
model
,
types
,
attempt_trace
,
to_trace
)
script_submodules_
(
child
,
types
)
# ... and then trace stragglers.
if
(
attempt_trace
and
len
(
to_trace
)
>
0
):
_trace_submodules_
(
model
,
to_trace
,
batch_dims
=
batch_dims
)
openfold/model/triangular_attention.py
View file @
34e9363c
...
@@ -15,6 +15,8 @@
...
@@ -15,6 +15,8 @@
from
functools
import
partialmethod
from
functools
import
partialmethod
import
math
import
math
from
typing
import
Optional
,
List
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -55,7 +57,30 @@ class TriangleAttention(nn.Module):
...
@@ -55,7 +57,30 @@ class TriangleAttention(nn.Module):
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_hidden
,
self
.
no_heads
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_hidden
,
self
.
no_heads
)
)
def
forward
(
self
,
x
,
chunk_size
,
mask
=
None
):
@
torch
.
jit
.
ignore
def
_chunk
(
self
,
x
:
torch
.
Tensor
,
biases
:
List
[
torch
.
Tensor
],
chunk_size
:
int
,
)
->
torch
.
Tensor
:
mha_inputs
=
{
"q_x"
:
x
,
"k_x"
:
x
,
"v_x"
:
x
,
"biases"
:
biases
,
}
return
chunk_layer
(
self
.
mha
,
mha_inputs
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
x
.
shape
[:
-
2
]),
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
"""
"""
Args:
Args:
x:
x:
...
@@ -86,21 +111,12 @@ class TriangleAttention(nn.Module):
...
@@ -86,21 +111,12 @@ class TriangleAttention(nn.Module):
# [*, 1, H, I, J]
# [*, 1, H, I, J]
triangle_bias
=
triangle_bias
.
unsqueeze
(
-
4
)
triangle_bias
=
triangle_bias
.
unsqueeze
(
-
4
)
mha_inputs
=
{
biases
=
[
mask_bias
,
triangle_bias
]
"q_x"
:
x
,
"k_x"
:
x
,
"v_x"
:
x
,
"biases"
:
[
mask_bias
,
triangle_bias
],
}
if
chunk_size
is
not
None
:
if
chunk_size
is
not
None
:
x
=
chunk_layer
(
x
=
self
.
_chunk
(
x
,
biases
,
chunk_size
)
self
.
mha
,
mha_inputs
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
x
.
shape
[:
-
2
]),
)
else
:
else
:
x
=
self
.
mha
(
**
mha_input
s
)
x
=
self
.
mha
(
q_x
=
x
,
k_x
=
x
,
v_x
=
x
,
biases
=
biase
s
)
if
not
self
.
starting
:
if
not
self
.
starting
:
x
=
x
.
transpose
(
-
2
,
-
3
)
x
=
x
.
transpose
(
-
2
,
-
3
)
...
...
openfold/model/triangular_multiplicative_update.py
View file @
34e9363c
...
@@ -14,6 +14,8 @@
...
@@ -14,6 +14,8 @@
# limitations under the License.
# limitations under the License.
from
functools
import
partialmethod
from
functools
import
partialmethod
from
typing
import
Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -25,7 +27,6 @@ class TriangleMultiplicativeUpdate(nn.Module):
...
@@ -25,7 +27,6 @@ class TriangleMultiplicativeUpdate(nn.Module):
"""
"""
Implements Algorithms 11 and 12.
Implements Algorithms 11 and 12.
"""
"""
def
__init__
(
self
,
c_z
,
c_hidden
,
_outgoing
=
True
):
def
__init__
(
self
,
c_z
,
c_hidden
,
_outgoing
=
True
):
"""
"""
Args:
Args:
...
@@ -51,39 +52,16 @@ class TriangleMultiplicativeUpdate(nn.Module):
...
@@ -51,39 +52,16 @@ class TriangleMultiplicativeUpdate(nn.Module):
self
.
sigmoid
=
nn
.
Sigmoid
()
self
.
sigmoid
=
nn
.
Sigmoid
()
cp
=
self
.
_outgoing_matmul
if
self
.
_outgoing
else
self
.
_incoming_matmul
def
_combine_projections
(
self
.
combine_projections
=
cp
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
def
_outgoing_matmul
(
)
->
torch
.
Tensor
:
self
,
raise
NotImplementedError
(
"This method needs to be overridden"
)
a
:
torch
.
Tensor
,
# [*, N_i, N_k, C]
b
:
torch
.
Tensor
,
# [*, N_j, N_k, C]
):
# [*, C, N_i, N_j]
p
=
torch
.
matmul
(
permute_final_dims
(
a
,
(
2
,
0
,
1
)),
permute_final_dims
(
b
,
(
2
,
1
,
0
)),
)
# [*, N_i, N_j, C]
return
permute_final_dims
(
p
,
(
1
,
2
,
0
))
def
_incoming_matmul
(
self
,
a
:
torch
.
Tensor
,
# [*, N_k, N_i, C]
b
:
torch
.
Tensor
,
# [*, N_k, N_j, C]
):
# [*, C, N_i, N_j]
p
=
torch
.
matmul
(
permute_final_dims
(
a
,
(
2
,
1
,
0
)),
permute_final_dims
(
b
,
(
2
,
0
,
1
)),
)
# [*, N_i, N_j, C]
return
permute_final_dims
(
p
,
(
1
,
2
,
0
))
def
forward
(
self
,
z
,
mask
=
None
):
def
forward
(
self
,
z
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""
"""
Args:
Args:
x:
x:
...
@@ -103,7 +81,7 @@ class TriangleMultiplicativeUpdate(nn.Module):
...
@@ -103,7 +81,7 @@ class TriangleMultiplicativeUpdate(nn.Module):
a
=
a
*
mask
a
=
a
*
mask
b
=
self
.
linear_b_p
(
z
)
*
self
.
sigmoid
(
self
.
linear_b_g
(
z
))
b
=
self
.
linear_b_p
(
z
)
*
self
.
sigmoid
(
self
.
linear_b_g
(
z
))
b
=
b
*
mask
b
=
b
*
mask
x
=
self
.
combine_projections
(
a
,
b
)
x
=
self
.
_
combine_projections
(
a
,
b
)
x
=
self
.
layer_norm_out
(
x
)
x
=
self
.
layer_norm_out
(
x
)
x
=
self
.
linear_z
(
x
)
x
=
self
.
linear_z
(
x
)
g
=
self
.
sigmoid
(
self
.
linear_g
(
z
))
g
=
self
.
sigmoid
(
self
.
linear_g
(
z
))
...
@@ -116,19 +94,36 @@ class TriangleMultiplicationOutgoing(TriangleMultiplicativeUpdate):
...
@@ -116,19 +94,36 @@ class TriangleMultiplicationOutgoing(TriangleMultiplicativeUpdate):
"""
"""
Implements Algorithm 11.
Implements Algorithm 11.
"""
"""
def
_combine_projections
(
__init__
=
partialmethod
(
self
,
TriangleMultiplicativeUpdate
.
__init__
,
a
:
torch
.
Tensor
,
# [*, N_i, N_k, C]
_outgoing
=
True
,
b
:
torch
.
Tensor
,
# [*, N_j, N_k, C]
):
# [*, C, N_i, N_j]
p
=
torch
.
matmul
(
permute_final_dims
(
a
,
(
2
,
0
,
1
)),
permute_final_dims
(
b
,
(
2
,
1
,
0
)),
)
)
# [*, N_i, N_j, C]
return
permute_final_dims
(
p
,
(
1
,
2
,
0
))
class
TriangleMultiplicationIncoming
(
TriangleMultiplicativeUpdate
):
class
TriangleMultiplicationIncoming
(
TriangleMultiplicativeUpdate
):
"""
"""
Implements Algorithm 12.
Implements Algorithm 12.
"""
"""
def
_combine_projections
(
__init__
=
partialmethod
(
self
,
TriangleMultiplicativeUpdate
.
__init__
,
a
:
torch
.
Tensor
,
# [*, N_k, N_i, C]
_outgoing
=
False
,
b
:
torch
.
Tensor
,
# [*, N_k, N_j, C]
):
# [*, C, N_i, N_j]
p
=
torch
.
matmul
(
permute_final_dims
(
a
,
(
2
,
1
,
0
)),
permute_final_dims
(
b
,
(
2
,
0
,
1
)),
)
)
# [*, N_i, N_j, C]
return
permute_final_dims
(
p
,
(
1
,
2
,
0
))
openfold/utils/affine_utils.py
View file @
34e9363c
...
@@ -631,7 +631,7 @@ _qtr_ind_dict = {key: ind for ind, key in enumerate(_qtr_keys)}
...
@@ -631,7 +631,7 @@ _qtr_ind_dict = {key: ind for ind, key in enumerate(_qtr_keys)}
def
_to_mat
(
pairs
):
def
_to_mat
(
pairs
):
mat
=
torch
.
zeros
((
4
,
4
))
mat
=
np
.
zeros
((
4
,
4
))
for
pair
in
pairs
:
for
pair
in
pairs
:
key
,
value
=
pair
key
,
value
=
pair
ind
=
_qtr_ind_dict
[
key
]
ind
=
_qtr_ind_dict
[
key
]
...
...
openfold/utils/checkpointing.py
View file @
34e9363c
...
@@ -17,10 +17,11 @@ import torch
...
@@ -17,10 +17,11 @@ import torch
import
torch.utils.checkpoint
import
torch.utils.checkpoint
from
typing
import
Any
,
Tuple
,
List
,
Callable
from
typing
import
Any
,
Tuple
,
List
,
Callable
BLOCK_ARG
=
Any
BLOCK_ARG
=
Any
BLOCK_ARGS
=
List
[
BLOCK_ARG
]
BLOCK_ARGS
=
List
[
BLOCK_ARG
]
@
torch
.
jit
.
ignore
def
checkpoint_blocks
(
def
checkpoint_blocks
(
blocks
:
List
[
Callable
],
blocks
:
List
[
Callable
],
args
:
BLOCK_ARGS
,
args
:
BLOCK_ARGS
,
...
...
openfold/utils/import_weights.py
View file @
34e9363c
...
@@ -217,6 +217,11 @@ def import_jax_weights_(model, npz_path, version="model_1"):
...
@@ -217,6 +217,11 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"attention"
:
AttentionGatedParams
(
matt
.
mha
),
"attention"
:
AttentionGatedParams
(
matt
.
mha
),
}
}
MSAColAttParams
=
lambda
matt
:
{
"query_norm"
:
LayerNormParams
(
matt
.
_msa_att
.
layer_norm_m
),
"attention"
:
AttentionGatedParams
(
matt
.
_msa_att
.
mha
),
}
MSAGlobalAttParams
=
lambda
matt
:
{
MSAGlobalAttParams
=
lambda
matt
:
{
"query_norm"
:
LayerNormParams
(
matt
.
layer_norm_m
),
"query_norm"
:
LayerNormParams
(
matt
.
layer_norm_m
),
"attention"
:
GlobalAttentionParams
(
matt
.
global_attention
),
"attention"
:
GlobalAttentionParams
(
matt
.
global_attention
),
...
@@ -270,7 +275,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
...
@@ -270,7 +275,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
msa_col_att_params
=
MSAGlobalAttParams
(
b
.
msa_att_col
)
msa_col_att_params
=
MSAGlobalAttParams
(
b
.
msa_att_col
)
else
:
else
:
col_att_name
=
"msa_column_attention"
col_att_name
=
"msa_column_attention"
msa_col_att_params
=
MSAAttParams
(
b
.
msa_att_col
)
msa_col_att_params
=
MSA
Col
AttParams
(
b
.
msa_att_col
)
d
=
{
d
=
{
"msa_row_attention_with_pair_bias"
:
MSAAttPairBiasParams
(
"msa_row_attention_with_pair_bias"
:
MSAAttPairBiasParams
(
...
...
openfold/utils/tensor_utils.py
View file @
34e9363c
...
@@ -107,6 +107,22 @@ def tree_map(fn, tree, leaf_type):
...
@@ -107,6 +107,22 @@ def tree_map(fn, tree, leaf_type):
tensor_tree_map
=
partial
(
tree_map
,
leaf_type
=
torch
.
Tensor
)
tensor_tree_map
=
partial
(
tree_map
,
leaf_type
=
torch
.
Tensor
)
def
_fetch_dims
(
tree
):
shapes
=
[]
tree_type
=
type
(
tree
)
if
tree_type
is
dict
:
for
v
in
tree
.
values
():
shapes
.
extend
(
_fetch_dims
(
v
))
elif
tree_type
is
list
or
tree_type
is
tuple
:
for
t
in
tree
:
shapes
.
extend
(
_fetch_dims
(
t
))
elif
tree_type
is
torch
.
Tensor
:
shapes
.
append
(
tree
.
shape
)
else
:
raise
ValueError
(
"Not supported"
)
return
shapes
def
chunk_layer
(
def
chunk_layer
(
layer
:
Callable
,
layer
:
Callable
,
...
@@ -141,33 +157,17 @@ def chunk_layer(
...
@@ -141,33 +157,17 @@ def chunk_layer(
if
not
(
len
(
inputs
)
>
0
):
if
not
(
len
(
inputs
)
>
0
):
raise
ValueError
(
"Must provide at least one input"
)
raise
ValueError
(
"Must provide at least one input"
)
def
fetch_dims
(
tree
):
initial_dims
=
[
shape
[:
no_batch_dims
]
for
shape
in
_fetch_dims
(
inputs
)]
shapes
=
[]
tree_type
=
type
(
tree
)
if
tree_type
is
dict
:
for
v
in
tree
.
values
():
shapes
.
extend
(
fetch_dims
(
v
))
elif
tree_type
is
list
or
tree_type
is
tuple
:
for
t
in
tree
:
shapes
.
extend
(
fetch_dims
(
t
))
elif
tree_type
is
torch
.
Tensor
:
shapes
.
append
(
tree
.
shape
)
else
:
raise
ValueError
(
"Not supported"
)
return
shapes
initial_dims
=
[
shape
[:
no_batch_dims
]
for
shape
in
fetch_dims
(
inputs
)]
orig_batch_dims
=
tuple
([
max
(
s
)
for
s
in
zip
(
*
initial_dims
)])
orig_batch_dims
=
tuple
([
max
(
s
)
for
s
in
zip
(
*
initial_dims
)])
def
prep_inputs
(
t
):
def
_
prep_inputs
(
t
):
# TODO: make this more memory efficient. This sucks
# TODO: make this more memory efficient. This sucks
if
not
sum
(
t
.
shape
[:
no_batch_dims
])
==
no_batch_dims
:
if
not
sum
(
t
.
shape
[:
no_batch_dims
])
==
no_batch_dims
:
t
=
t
.
expand
(
*
orig_batch_dims
,
*
t
.
shape
[
no_batch_dims
:])
t
=
t
.
expand
(
*
orig_batch_dims
,
*
t
.
shape
[
no_batch_dims
:])
t
=
t
.
reshape
(
-
1
,
*
t
.
shape
[
no_batch_dims
:])
t
=
t
.
reshape
(
-
1
,
*
t
.
shape
[
no_batch_dims
:])
return
t
return
t
flattened_inputs
=
tensor_tree_map
(
prep_inputs
,
inputs
)
flattened_inputs
=
tensor_tree_map
(
_
prep_inputs
,
inputs
)
flat_batch_dim
=
1
flat_batch_dim
=
1
for
d
in
orig_batch_dims
:
for
d
in
orig_batch_dims
:
...
...
run_pretrained_openfold.py
View file @
34e9363c
...
@@ -31,7 +31,7 @@ import torch
...
@@ -31,7 +31,7 @@ import torch
from
openfold.config
import
model_config
from
openfold.config
import
model_config
from
openfold.data
import
templates
,
feature_pipeline
,
data_pipeline
from
openfold.data
import
templates
,
feature_pipeline
,
data_pipeline
from
openfold.model.model
import
AlphaFold
from
openfold.model.model
import
AlphaFold
from
openfold.model.torchscript
import
script_pr
imitiv
es_
from
openfold.model.torchscript
import
script_pres
et
_
from
openfold.np
import
residue_constants
,
protein
from
openfold.np
import
residue_constants
,
protein
import
openfold.np.relax.relax
as
relax
import
openfold.np.relax.relax
as
relax
from
openfold.utils.import_weights
import
(
from
openfold.utils.import_weights
import
(
...
@@ -49,7 +49,7 @@ def main(args):
...
@@ -49,7 +49,7 @@ def main(args):
model
=
AlphaFold
(
config
)
model
=
AlphaFold
(
config
)
model
=
model
.
eval
()
model
=
model
.
eval
()
import_jax_weights_
(
model
,
args
.
param_path
)
import_jax_weights_
(
model
,
args
.
param_path
)
script_pr
imitiv
es_
(
model
)
script_pres
et
_
(
model
)
model
=
model
.
to
(
args
.
model_device
)
model
=
model
.
to
(
args
.
model_device
)
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
...
...
train_openfold.py
View file @
34e9363c
...
@@ -2,7 +2,7 @@ import argparse
...
@@ -2,7 +2,7 @@ import argparse
import
logging
import
logging
import
os
import
os
#
os.environ["CUDA_VISIBLE_DEVICES"] = "6"
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"6"
#os.environ["MASTER_ADDR"]="10.119.81.14"
#os.environ["MASTER_ADDR"]="10.119.81.14"
#os.environ["MASTER_PORT"]="42069"
#os.environ["MASTER_PORT"]="42069"
#os.environ["NODE_RANK"]="0"
#os.environ["NODE_RANK"]="0"
...
@@ -23,6 +23,7 @@ from openfold.data.data_modules import (
...
@@ -23,6 +23,7 @@ from openfold.data.data_modules import (
DummyDataLoader
,
DummyDataLoader
,
)
)
from
openfold.model.model
import
AlphaFold
from
openfold.model.model
import
AlphaFold
from
openfold.model.torchscript
import
script_preset_
from
openfold.utils.callbacks
import
(
from
openfold.utils.callbacks
import
(
EarlyStoppingVerbose
,
EarlyStoppingVerbose
,
)
)
...
@@ -64,10 +65,6 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -64,10 +65,6 @@ class OpenFoldWrapper(pl.LightningModule):
# Compute loss
# Compute loss
loss
=
self
.
loss
(
outputs
,
batch
)
loss
=
self
.
loss
(
outputs
,
batch
)
#if(torch.isnan(loss) or torch.isinf(loss)):
# logging.warning("loss is NaN. Skipping example...")
# loss = loss.new_tensor(0., requires_grad=True)
return
{
"loss"
:
loss
}
return
{
"loss"
:
loss
}
def
validation_step
(
self
,
batch
,
batch_idx
):
def
validation_step
(
self
,
batch
,
batch_idx
):
...
@@ -121,6 +118,10 @@ def main(args):
...
@@ -121,6 +118,10 @@ def main(args):
sd
=
{
k
[
len
(
"module."
):]:
v
for
k
,
v
in
sd
.
items
()}
sd
=
{
k
[
len
(
"module."
):]:
v
for
k
,
v
in
sd
.
items
()}
model_module
.
load_state_dict
(
sd
)
model_module
.
load_state_dict
(
sd
)
logging
.
info
(
"Successfully loaded model weights..."
)
logging
.
info
(
"Successfully loaded model weights..."
)
# TorchScript components of the model
script_preset_
(
model_module
)
#data_module = DummyDataLoader("batch.pickle")
#data_module = DummyDataLoader("batch.pickle")
data_module
=
OpenFoldDataModule
(
data_module
=
OpenFoldDataModule
(
config
=
config
.
data
,
config
=
config
.
data
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment