Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
34e9363c
"...dynamo-run/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "a03fd3071eb9d42487da00af671115748168f4cf"
Commit
34e9363c
authored
Nov 19, 2021
by
Gustaf Ahdritz
Browse files
Make more components TorchScript-able, add tracing
parent
34e4e6ce
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
516 additions
and
200 deletions
+516
-200
openfold/model/evoformer.py
openfold/model/evoformer.py
+21
-11
openfold/model/msa.py
openfold/model/msa.py
+81
-28
openfold/model/outer_product_mean.py
openfold/model/outer_product_mean.py
+35
-18
openfold/model/pair_transition.py
openfold/model/pair_transition.py
+22
-9
openfold/model/structure_module.py
openfold/model/structure_module.py
+8
-9
openfold/model/template.py
openfold/model/template.py
+45
-24
openfold/model/torchscript.py
openfold/model/torchscript.py
+200
-15
openfold/model/triangular_attention.py
openfold/model/triangular_attention.py
+30
-14
openfold/model/triangular_multiplicative_update.py
openfold/model/triangular_multiplicative_update.py
+37
-42
openfold/utils/affine_utils.py
openfold/utils/affine_utils.py
+1
-1
openfold/utils/checkpointing.py
openfold/utils/checkpointing.py
+2
-1
openfold/utils/import_weights.py
openfold/utils/import_weights.py
+6
-1
openfold/utils/tensor_utils.py
openfold/utils/tensor_utils.py
+19
-19
run_pretrained_openfold.py
run_pretrained_openfold.py
+3
-3
train_openfold.py
train_openfold.py
+6
-5
No files found.
openfold/model/evoformer.py
View file @
34e9363c
...
...
@@ -71,11 +71,25 @@ class MSATransition(nn.Module):
m
=
self
.
linear_2
(
m
)
*
mask
return
m
@
torch
.
jit
.
ignore
def
_chunk
(
self
,
m
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
chunk_size
:
int
,
)
->
torch
.
Tensor
:
return
chunk_layer
(
self
.
_transition
,
{
"m"
:
m
,
"mask"
:
mask
},
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
)
def
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
=
None
,
chunk_size
:
int
=
None
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
"""
Args:
...
...
@@ -95,16 +109,10 @@ class MSATransition(nn.Module):
m
=
self
.
layer_norm
(
m
)
inp
=
{
"m"
:
m
,
"mask"
:
mask
}
if
chunk_size
is
not
None
:
m
=
chunk_layer
(
self
.
_transition
,
inp
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
)
m
=
self
.
_chunk
(
m
,
mask
,
chunk_size
)
else
:
m
=
self
.
_transition
(
**
inp
)
m
=
self
.
_transition
(
m
,
mask
)
return
m
...
...
@@ -201,9 +209,11 @@ class EvoformerBlock(nn.Module):
z
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
int
,
chunk_size
:
Optional
[
int
]
=
None
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
#print(torch.cuda.memory_summary())
# DeepMind doesn't mask these transitions in the source, so _mask_trans
# should be disabled to better approximate the exact activations of
# the original.
...
...
openfold/model/msa.py
View file @
34e9363c
...
...
@@ -16,7 +16,7 @@
import
math
import
torch
import
torch.nn
as
nn
from
typing
import
Optional
from
typing
import
Optional
,
List
from
openfold.model.primitives
import
Linear
,
Attention
,
GlobalAttention
from
openfold.utils.tensor_utils
import
(
...
...
@@ -63,6 +63,8 @@ class MSAAttention(nn.Module):
self
.
layer_norm_m
=
nn
.
LayerNorm
(
self
.
c_in
)
self
.
layer_norm_z
=
None
self
.
linear_z
=
None
if
self
.
pair_bias
:
self
.
layer_norm_z
=
nn
.
LayerNorm
(
self
.
c_z
)
self
.
linear_z
=
Linear
(
...
...
@@ -73,7 +75,25 @@ class MSAAttention(nn.Module):
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_hidden
,
self
.
no_heads
)
def
forward
(
self
,
m
,
chunk_size
,
z
=
None
,
mask
=
None
):
@
torch
.
jit
.
ignore
def
_chunk
(
self
,
m
:
torch
.
Tensor
,
biases
:
List
[
torch
.
Tensor
],
chunk_size
:
int
,
)
->
torch
.
Tensor
:
return
chunk_layer
(
self
.
mha
,
{
"q_x"
:
m
,
"k_x"
:
m
,
"v_x"
:
m
,
"biases"
:
biases
},
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
)
def
forward
(
self
,
m
:
torch
.
Tensor
,
z
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
"""
Args:
m:
...
...
@@ -83,6 +103,11 @@ class MSAAttention(nn.Module):
pair_bias is True
mask:
[*, N_seq, N_res] MSA mask
chunk_size:
Size of chunks into which the inputs are split along their
batch dimensions. A low value decreases memory overhead at the
cost of slower execution. Chunking is not performed by default.
"""
# [*, N_seq, N_res, C_m]
m
=
self
.
layer_norm_m
(
m
)
...
...
@@ -106,7 +131,11 @@ class MSAAttention(nn.Module):
biases
=
[
bias
]
if
self
.
pair_bias
:
if
(
self
.
pair_bias
and
z
is
not
None
and
# For the
self
.
layer_norm_z
is
not
None
and
# benefit of
self
.
linear_z
is
not
None
# TorchScript
):
# [*, N_res, N_res, C_z]
z
=
self
.
layer_norm_z
(
z
)
...
...
@@ -118,16 +147,10 @@ class MSAAttention(nn.Module):
biases
.
append
(
z
)
mha_inputs
=
{
"q_x"
:
m
,
"k_x"
:
m
,
"v_x"
:
m
,
"biases"
:
biases
}
if
chunk_size
is
not
None
:
m
=
chunk_layer
(
self
.
mha
,
mha_inputs
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
)
m
=
self
.
_chunk
(
m
,
biases
,
chunk_size
)
else
:
m
=
self
.
mha
(
**
mha_input
s
)
m
=
self
.
mha
(
q_x
=
m
,
k_x
=
m
,
v_x
=
m
,
biases
=
biase
s
)
return
m
...
...
@@ -161,9 +184,12 @@ class MSARowAttentionWithPairBias(MSAAttention):
)
class
MSAColumnAttention
(
MSAAttention
):
class
MSAColumnAttention
(
nn
.
Module
):
"""
Implements Algorithm 8.
By rights, this should also be a subclass of MSAAttention. Alas,
most inheritance isn't supported by TorchScript.
"""
def
__init__
(
self
,
c_m
,
c_hidden
,
no_heads
,
inf
=
1e9
):
...
...
@@ -178,7 +204,14 @@ class MSAColumnAttention(MSAAttention):
inf:
Large number used to construct attention masks
"""
super
(
MSAColumnAttention
,
self
).
__init__
(
super
(
MSAColumnAttention
,
self
).
__init__
()
self
.
c_m
=
c_m
self
.
c_hidden
=
c_hidden
self
.
no_heads
=
no_heads
self
.
inf
=
inf
self
.
_msa_att
=
MSAAttention
(
c_in
=
c_m
,
c_hidden
=
c_hidden
,
no_heads
=
no_heads
,
...
...
@@ -187,31 +220,40 @@ class MSAColumnAttention(MSAAttention):
inf
=
inf
,
)
def
forward
(
self
,
m
,
chunk_size
,
mask
=
None
):
def
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
mask:
[*, N_seq, N_res] MSA mask
chunk_size:
Size of chunks into which the inputs are split along their
batch dimensions. A low value decreases memory overhead at the
cost of slower execution. Chunking is not performed by default.
"""
# [*, N_res, N_seq, C_in]
m
=
m
.
transpose
(
-
2
,
-
3
)
if
mask
is
not
None
:
mask
=
mask
.
transpose
(
-
1
,
-
2
)
m
=
s
uper
().
forward
(
m
,
chunk_size
=
chunk_size
,
mask
=
mask
)
m
=
s
elf
.
_msa_att
(
m
,
mask
=
mask
,
chunk_size
=
chunk_size
)
# [*, N_seq, N_res, C_in]
m
=
m
.
transpose
(
-
2
,
-
3
)
if
mask
is
not
None
:
mask
=
mask
.
transpose
(
-
1
,
-
2
)
return
m
class
MSAColumnGlobalAttention
(
nn
.
Module
):
def
__init__
(
self
,
c_in
,
c_hidden
,
no_heads
,
inf
=
1e9
,
eps
=
1e-10
self
,
c_in
,
c_hidden
,
no_heads
,
inf
=
1e9
,
eps
=
1e-10
,
):
super
(
MSAColumnGlobalAttention
,
self
).
__init__
()
...
...
@@ -231,8 +273,28 @@ class MSAColumnGlobalAttention(nn.Module):
eps
=
eps
,
)
@
torch
.
jit
.
ignore
def
_chunk
(
self
,
m
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
chunk_size
:
int
,
)
->
torch
.
Tensor
:
mha_input
=
{
"m"
:
m
,
"mask"
:
mask
,
}
return
chunk_layer
(
self
.
global_attention
,
mha_input
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
)
def
forward
(
self
,
m
:
torch
.
Tensor
,
chunk_size
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
self
,
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
n_seq
,
n_res
,
c_in
=
m
.
shape
[
-
3
:]
...
...
@@ -251,19 +313,10 @@ class MSAColumnGlobalAttention(nn.Module):
# [*, N_res, N_seq, C_in]
m
=
self
.
layer_norm_m
(
m
)
mha_input
=
{
"m"
:
m
,
"mask"
:
mask
,
}
if
chunk_size
is
not
None
:
m
=
chunk_layer
(
self
.
global_attention
,
mha_input
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
)
m
=
self
.
_chunk
(
m
,
mask
,
chunk_size
)
else
:
m
=
self
.
global_attention
(
m
=
m
ha_input
[
"m"
],
mask
=
mha_input
[
"
mask
"
]
)
m
=
self
.
global_attention
(
m
=
m
,
mask
=
mask
)
# [*, N_seq, N_res, C_in]
m
=
m
.
transpose
(
-
2
,
-
3
)
...
...
openfold/model/outer_product_mean.py
View file @
34e9363c
...
...
@@ -14,6 +14,8 @@
# limitations under the License.
from
functools
import
partial
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
...
...
@@ -38,6 +40,7 @@ class OuterProductMean(nn.Module):
"""
super
(
OuterProductMean
,
self
).
__init__
()
self
.
c_m
=
c_m
self
.
c_z
=
c_z
self
.
c_hidden
=
c_hidden
self
.
eps
=
eps
...
...
@@ -52,14 +55,43 @@ class OuterProductMean(nn.Module):
outer
=
torch
.
einsum
(
"...bac,...dae->...bdce"
,
a
,
b
)
# [*, N_res, N_res, C * C]
outer
=
outer
.
reshape
(
*
outer
.
shape
[:
-
2
]
,
-
1
)
outer
=
outer
.
reshape
(
outer
.
shape
[:
-
2
]
+
(
-
1
,)
)
# [*, N_res, N_res, C_z]
outer
=
self
.
linear_out
(
outer
)
return
outer
def
forward
(
self
,
m
,
chunk_size
,
mask
=
None
):
@
torch
.
jit
.
ignore
def
_chunk
(
self
,
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
chunk_size
:
int
)
->
torch
.
Tensor
:
# Since the "batch dim" in this case is not a true batch dimension
# (in that the shape of the output depends on it), we need to
# iterate over it ourselves
a_reshape
=
a
.
reshape
((
-
1
,)
+
a
.
shape
[
-
3
:])
b_reshape
=
b
.
reshape
((
-
1
,)
+
b
.
shape
[
-
3
:])
out
=
[]
for
a_prime
,
b_prime
in
zip
(
a_reshape
,
b_reshape
):
outer
=
chunk_layer
(
partial
(
self
.
_opm
,
b
=
b_prime
),
{
"a"
:
a_prime
},
chunk_size
=
chunk_size
,
no_batch_dims
=
1
,
)
out
.
append
(
outer
)
outer
=
torch
.
stack
(
out
,
dim
=
0
)
outer
=
outer
.
reshape
(
a
.
shape
[:
-
3
]
+
outer
.
shape
[
1
:])
return
outer
def
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
"""
Args:
m:
...
...
@@ -84,22 +116,7 @@ class OuterProductMean(nn.Module):
b
=
b
.
transpose
(
-
2
,
-
3
)
if
chunk_size
is
not
None
:
# Since the "batch dim" in this case is not a true batch dimension
# (in that the shape of the output depends on it), we need to
# iterate over it ourselves
a_reshape
=
a
.
reshape
(
-
1
,
*
a
.
shape
[
-
3
:])
b_reshape
=
b
.
reshape
(
-
1
,
*
b
.
shape
[
-
3
:])
out
=
[]
for
a_prime
,
b_prime
in
zip
(
a_reshape
,
b_reshape
):
outer
=
chunk_layer
(
partial
(
self
.
_opm
,
b
=
b_prime
),
{
"a"
:
a_prime
},
chunk_size
=
chunk_size
,
no_batch_dims
=
1
,
)
out
.
append
(
outer
)
outer
=
torch
.
stack
(
out
,
dim
=
0
)
outer
=
outer
.
reshape
(
*
a
.
shape
[:
-
3
],
*
outer
.
shape
[
1
:])
outer
=
self
.
_chunk
(
a
,
b
,
chunk_size
)
else
:
outer
=
self
.
_opm
(
a
,
b
)
...
...
openfold/model/pair_transition.py
View file @
34e9363c
...
...
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
...
...
@@ -54,7 +55,25 @@ class PairTransition(nn.Module):
return
z
def
forward
(
self
,
z
,
chunk_size
,
mask
=
None
):
@
torch
.
jit
.
ignore
def
_chunk
(
self
,
z
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
chunk_size
:
int
,
)
->
torch
.
Tensor
:
return
chunk_layer
(
self
.
_transition
,
{
"z"
:
z
,
"mask"
:
mask
},
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
z
.
shape
[:
-
2
]),
)
def
forward
(
self
,
z
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
"""
Args:
z:
...
...
@@ -72,15 +91,9 @@ class PairTransition(nn.Module):
# [*, N_res, N_res, C_z]
z
=
self
.
layer_norm
(
z
)
inp
=
{
"z"
:
z
,
"mask"
:
mask
}
if
chunk_size
is
not
None
:
z
=
chunk_layer
(
self
.
_transition
,
inp
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
z
.
shape
[:
-
2
]),
)
z
=
self
.
_chunk
(
z
,
mask
,
chunk_size
)
else
:
z
=
self
.
_transition
(
**
inp
)
z
=
self
.
_transition
(
z
=
z
,
mask
=
mask
)
return
z
openfold/model/structure_module.py
View file @
34e9363c
...
...
@@ -155,17 +155,16 @@ class InvariantPointAttention(nn.Module):
"""
Implements Algorithm 22.
"""
def
__init__
(
self
,
c_s
,
c_z
,
c_hidden
,
no_heads
,
no_qk_points
,
no_v_points
,
inf
=
1e5
,
eps
=
1e-8
,
c_s
:
int
,
c_z
:
int
,
c_hidden
:
int
,
no_heads
:
int
,
no_qk_points
:
int
,
no_v_points
:
int
,
inf
:
float
=
1e5
,
eps
:
float
=
1e-8
,
):
"""
Args:
...
...
openfold/model/template.py
View file @
34e9363c
...
...
@@ -12,9 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
functools
import
partial
import
math
from
typing
import
Optional
,
List
import
torch
import
torch.nn
as
nn
...
...
@@ -71,7 +72,32 @@ class TemplatePointwiseAttention(nn.Module):
gating
=
False
,
)
def
forward
(
self
,
t
,
z
,
chunk_size
,
template_mask
=
None
):
def
_chunk
(
self
,
z
:
torch
.
Tensor
,
t
:
torch
.
Tensor
,
biases
:
List
[
torch
.
Tensor
],
chunk_size
:
int
,
)
->
torch
.
Tensor
:
mha_inputs
=
{
"q_x"
:
z
,
"k_x"
:
t
,
"v_x"
:
t
,
"biases"
:
biases
,
}
return
chunk_layer
(
self
.
mha
,
mha_inputs
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
z
.
shape
[:
-
2
]),
)
def
forward
(
self
,
t
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
template_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
"""
Args:
t:
...
...
@@ -95,21 +121,11 @@ class TemplatePointwiseAttention(nn.Module):
t
=
permute_final_dims
(
t
,
(
1
,
2
,
0
,
3
))
# [*, N_res, N_res, 1, C_z]
mha_inputs
=
{
"q_x"
:
z
,
"k_x"
:
t
,
"v_x"
:
t
,
"biases"
:
[
bias
],
}
biases
=
[
bias
]
if
chunk_size
is
not
None
:
z
=
chunk_layer
(
self
.
mha
,
mha_inputs
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
z
.
shape
[:
-
2
]),
)
z
=
self
.
_chunk
(
z
,
t
,
biases
,
chunk_size
)
else
:
z
=
self
.
mha
(
**
mha_input
s
)
z
=
self
.
mha
(
q_x
=
z
,
k_x
=
t
,
v_x
=
t
,
biases
=
biase
s
)
# [*, N_res, N_res, C_z]
z
=
z
.
squeeze
(
-
2
)
...
...
@@ -120,13 +136,13 @@ class TemplatePointwiseAttention(nn.Module):
class
TemplatePairStackBlock
(
nn
.
Module
):
def
__init__
(
self
,
c_t
,
c_hidden_tri_att
,
c_hidden_tri_mul
,
no_heads
,
pair_transition_n
,
dropout_rate
,
inf
,
c_t
:
int
,
c_hidden_tri_att
:
int
,
c_hidden_tri_mul
:
int
,
no_heads
:
int
,
pair_transition_n
:
int
,
dropout_rate
:
float
,
inf
:
float
,
**
kwargs
,
):
super
(
TemplatePairStackBlock
,
self
).
__init__
()
...
...
@@ -169,7 +185,12 @@ class TemplatePairStackBlock(nn.Module):
self
.
pair_transition_n
,
)
def
forward
(
self
,
z
,
mask
,
chunk_size
,
_mask_trans
=
True
):
def
forward
(
self
,
z
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
_mask_trans
:
bool
=
True
):
single_templates
=
[
t
.
unsqueeze
(
-
4
)
for
t
in
torch
.
unbind
(
z
,
dim
=-
4
)
]
...
...
@@ -208,8 +229,8 @@ class TemplatePairStackBlock(nn.Module):
)
single
=
single
+
self
.
pair_transition
(
single
,
mask
=
single_mask
if
_mask_trans
else
None
,
chunk_size
=
chunk_size
,
mask
=
single_mask
if
_mask_trans
else
None
)
single_templates
[
i
]
=
single
...
...
openfold/model/torchscript.py
View file @
34e9363c
from
typing
import
Optional
,
Sequence
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Optional
,
Sequence
,
Tuple
import
torch
import
torch.nn
as
nn
from
openfold.model.dropout
import
(
DropoutRowwise
,
DropoutColumnwise
,
)
from
openfold.model.evoformer
import
(
EvoformerBlock
,
EvoformerStack
,
)
from
openfold.model.outer_product_mean
import
OuterProductMean
from
openfold.model.msa
import
(
MSARowAttentionWithPairBias
,
MSAColumnAttention
,
MSAColumnGlobalAttention
,
)
from
openfold.model.pair_transition
import
PairTransition
from
openfold.model.primitives
import
Attention
,
GlobalAttention
from
openfold.model.structure_module
import
(
InvariantPointAttention
,
BackboneUpdate
,
)
from
openfold.model.template
import
TemplatePairStackBlock
from
openfold.model.triangular_attention
import
(
TriangleAttentionStartingNode
,
TriangleAttentionEndingNode
,
)
from
openfold.model.triangular_multiplicative_update
import
(
TriangleMultiplicationOutgoing
,
TriangleMultiplicationIncoming
,
)
def
script_preset_
(
model
:
torch
.
nn
.
Module
):
"""
TorchScript a handful of low-level but frequently used submodule types
that are known to be scriptable.
Args:
model:
A torch.nn.Module. It should contain at least some modules from
this repository, or this function won't do anything.
"""
script_submodules_
(
model
,
[
nn
.
Dropout
,
Attention
,
GlobalAttention
,
EvoformerBlock
,
#TemplatePairStackBlock,
],
attempt_trace
=
False
,
batch_dims
=
None
,
)
def
_get_module_device
(
module
:
torch
.
nn
.
Module
)
->
torch
.
device
:
"""
Fetches the device of a module, assuming that all of the module's
parameters reside on a single device
Args:
module: A torch.nn.Module
Returns:
The module's device
"""
return
next
(
module
.
parameters
()).
device
def
_trace_module
(
module
,
batch_dims
=
None
):
if
(
batch_dims
is
None
):
batch_dims
=
()
# Stand-in values
n_seq
=
10
n_res
=
10
device
=
_get_module_device
(
module
)
def
msa
(
channel_dim
):
return
torch
.
rand
(
(
*
batch_dims
,
n_seq
,
n_res
,
channel_dim
),
device
=
device
,
)
def
pair
(
channel_dim
):
return
torch
.
rand
(
(
*
batch_dims
,
n_res
,
n_res
,
channel_dim
),
device
=
device
,
)
if
(
isinstance
(
module
,
MSARowAttentionWithPairBias
)):
inputs
=
{
"forward"
:
(
msa
(
module
.
c_in
),
# m
pair
(
module
.
c_z
),
# z
torch
.
randint
(
0
,
2
,
(
*
batch_dims
,
n_seq
,
n_res
)
),
# mask
),
}
elif
(
isinstance
(
module
,
MSAColumnAttention
)):
inputs
=
{
"forward"
:
(
msa
(
module
.
c_in
),
# m
torch
.
randint
(
0
,
2
,
(
*
batch_dims
,
n_seq
,
n_res
)
),
# mask
),
}
elif
(
isinstance
(
module
,
OuterProductMean
)):
inputs
=
{
"forward"
:
(
msa
(
module
.
c_m
),
torch
.
randint
(
0
,
2
,
(
*
batch_dims
,
n_seq
,
n_res
)
)
)
}
module
=
OPM
(
module
)
else
:
raise
TypeError
(
f
"tracing is not supported for modules of type
{
type
(
module
)
}
"
)
return
torch
.
jit
.
trace_module
(
module
,
inputs
)
def
_script_submodules_helper_
(
model
,
types
,
attempt_trace
,
to_trace
,
):
for
name
,
child
in
model
.
named_children
():
if
(
types
is
None
or
any
(
isinstance
(
child
,
t
)
for
t
in
types
)):
try
:
scripted
=
torch
.
jit
.
script
(
child
)
setattr
(
model
,
name
,
scripted
)
continue
except
(
RuntimeError
,
torch
.
jit
.
frontend
.
NotSupportedError
)
as
e
:
if
(
attempt_trace
):
to_trace
.
add
(
type
(
child
))
else
:
raise
e
_script_submodules_helper_
(
child
,
types
,
attempt_trace
,
to_trace
)
def
_trace_submodules_
(
model
,
types
,
batch_dims
=
None
,
):
for
name
,
child
in
model
.
named_children
():
if
(
any
(
isinstance
(
child
,
t
)
for
t
in
types
)):
traced
=
_trace_module
(
child
,
batch_dims
=
batch_dims
)
setattr
(
model
,
name
,
traced
)
else
:
_trace_submodules_
(
child
,
types
,
batch_dims
=
batch_dims
)
def
script_primitives_
(
model
):
script_submodules_
(
model
,
[
Attention
,
GlobalAttention
])
def
script_submodules_
(
model
:
nn
.
Module
,
types
:
Optional
[
Sequence
[
type
]]
=
None
,
attempt_trace
:
Optional
[
bool
]
=
True
,
batch_dims
:
Optional
[
Tuple
[
int
]]
=
None
,
):
"""
Convert all submodules whose types match one of those in the input
list to recursively scripted equivalents in place. To script the entire
model, just call torch.jit.script on it directly.
Convert all submodules whose types match one of those in the input
list to recursively scripted equivalents in place. To script the entire
model, just call torch.jit.script on it directly.
When types is None, all submodules are scripted.
When types is None, all submodules are scripted.
Args:
model: A torch.nn.Module
types: A list of types of submodules to script
Args:
model:
A torch.nn.Module
types:
A list of types of submodules to script
attempt_trace:
Whether to attempt to trace specified modules if scripting
fails. Recall that tracing eliminates all conditional
logic---with great tracing comes the mild responsibility of
having to remember to ensure that the modules in question
perform the same computations no matter what.
"""
for
name
,
child
in
model
.
named_children
():
if
(
types
is
None
or
any
(
isinstance
(
child
,
t
)
for
t
in
types
)):
setattr
(
model
,
name
,
torch
.
jit
.
script
(
child
))
else
:
script_submodules_
(
child
,
types
)
to_trace
=
set
()
# Aggressively script as much as possible first...
_script_submodules_helper_
(
model
,
types
,
attempt_trace
,
to_trace
)
# ... and then trace stragglers.
if
(
attempt_trace
and
len
(
to_trace
)
>
0
):
_trace_submodules_
(
model
,
to_trace
,
batch_dims
=
batch_dims
)
openfold/model/triangular_attention.py
View file @
34e9363c
...
...
@@ -15,6 +15,8 @@
from
functools
import
partialmethod
import
math
from
typing
import
Optional
,
List
import
torch
import
torch.nn
as
nn
...
...
@@ -55,7 +57,30 @@ class TriangleAttention(nn.Module):
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_hidden
,
self
.
no_heads
)
def
forward
(
self
,
x
,
chunk_size
,
mask
=
None
):
@
torch
.
jit
.
ignore
def
_chunk
(
self
,
x
:
torch
.
Tensor
,
biases
:
List
[
torch
.
Tensor
],
chunk_size
:
int
,
)
->
torch
.
Tensor
:
mha_inputs
=
{
"q_x"
:
x
,
"k_x"
:
x
,
"v_x"
:
x
,
"biases"
:
biases
,
}
return
chunk_layer
(
self
.
mha
,
mha_inputs
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
x
.
shape
[:
-
2
]),
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
"""
Args:
x:
...
...
@@ -86,21 +111,12 @@ class TriangleAttention(nn.Module):
# [*, 1, H, I, J]
triangle_bias
=
triangle_bias
.
unsqueeze
(
-
4
)
mha_inputs
=
{
"q_x"
:
x
,
"k_x"
:
x
,
"v_x"
:
x
,
"biases"
:
[
mask_bias
,
triangle_bias
],
}
biases
=
[
mask_bias
,
triangle_bias
]
if
chunk_size
is
not
None
:
x
=
chunk_layer
(
self
.
mha
,
mha_inputs
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
x
.
shape
[:
-
2
]),
)
x
=
self
.
_chunk
(
x
,
biases
,
chunk_size
)
else
:
x
=
self
.
mha
(
**
mha_input
s
)
x
=
self
.
mha
(
q_x
=
x
,
k_x
=
x
,
v_x
=
x
,
biases
=
biase
s
)
if
not
self
.
starting
:
x
=
x
.
transpose
(
-
2
,
-
3
)
...
...
openfold/model/triangular_multiplicative_update.py
View file @
34e9363c
...
...
@@ -14,6 +14,8 @@
# limitations under the License.
from
functools
import
partialmethod
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
...
...
@@ -25,7 +27,6 @@ class TriangleMultiplicativeUpdate(nn.Module):
"""
Implements Algorithms 11 and 12.
"""
def
__init__
(
self
,
c_z
,
c_hidden
,
_outgoing
=
True
):
"""
Args:
...
...
@@ -51,39 +52,16 @@ class TriangleMultiplicativeUpdate(nn.Module):
self
.
sigmoid
=
nn
.
Sigmoid
()
cp
=
self
.
_outgoing_matmul
if
self
.
_outgoing
else
self
.
_incoming_matmul
self
.
combine_projections
=
cp
def
_outgoing_matmul
(
self
,
a
:
torch
.
Tensor
,
# [*, N_i, N_k, C]
b
:
torch
.
Tensor
,
# [*, N_j, N_k, C]
):
# [*, C, N_i, N_j]
p
=
torch
.
matmul
(
permute_final_dims
(
a
,
(
2
,
0
,
1
)),
permute_final_dims
(
b
,
(
2
,
1
,
0
)),
)
# [*, N_i, N_j, C]
return
permute_final_dims
(
p
,
(
1
,
2
,
0
))
def
_incoming_matmul
(
self
,
a
:
torch
.
Tensor
,
# [*, N_k, N_i, C]
b
:
torch
.
Tensor
,
# [*, N_k, N_j, C]
):
# [*, C, N_i, N_j]
p
=
torch
.
matmul
(
permute_final_dims
(
a
,
(
2
,
1
,
0
)),
permute_final_dims
(
b
,
(
2
,
0
,
1
)),
)
# [*, N_i, N_j, C]
return
permute_final_dims
(
p
,
(
1
,
2
,
0
))
def
_combine_projections
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
(
"This method needs to be overridden"
)
def
forward
(
self
,
z
,
mask
=
None
):
def
forward
(
self
,
z
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""
Args:
x:
...
...
@@ -103,7 +81,7 @@ class TriangleMultiplicativeUpdate(nn.Module):
a
=
a
*
mask
b
=
self
.
linear_b_p
(
z
)
*
self
.
sigmoid
(
self
.
linear_b_g
(
z
))
b
=
b
*
mask
x
=
self
.
combine_projections
(
a
,
b
)
x
=
self
.
_
combine_projections
(
a
,
b
)
x
=
self
.
layer_norm_out
(
x
)
x
=
self
.
linear_z
(
x
)
g
=
self
.
sigmoid
(
self
.
linear_g
(
z
))
...
...
@@ -116,19 +94,36 @@ class TriangleMultiplicationOutgoing(TriangleMultiplicativeUpdate):
"""
Implements Algorithm 11.
"""
def
_combine_projections
(
self
,
a
:
torch
.
Tensor
,
# [*, N_i, N_k, C]
b
:
torch
.
Tensor
,
# [*, N_j, N_k, C]
):
# [*, C, N_i, N_j]
p
=
torch
.
matmul
(
permute_final_dims
(
a
,
(
2
,
0
,
1
)),
permute_final_dims
(
b
,
(
2
,
1
,
0
)),
)
__init__
=
partialmethod
(
TriangleMultiplicativeUpdate
.
__init__
,
_outgoing
=
True
,
)
# [*, N_i, N_j, C]
return
permute_final_dims
(
p
,
(
1
,
2
,
0
))
class
TriangleMultiplicationIncoming
(
TriangleMultiplicativeUpdate
):
"""
Implements Algorithm 12.
"""
def
_combine_projections
(
self
,
a
:
torch
.
Tensor
,
# [*, N_k, N_i, C]
b
:
torch
.
Tensor
,
# [*, N_k, N_j, C]
):
# [*, C, N_i, N_j]
p
=
torch
.
matmul
(
permute_final_dims
(
a
,
(
2
,
1
,
0
)),
permute_final_dims
(
b
,
(
2
,
0
,
1
)),
)
# [*, N_i, N_j, C]
return
permute_final_dims
(
p
,
(
1
,
2
,
0
))
__init__
=
partialmethod
(
TriangleMultiplicativeUpdate
.
__init__
,
_outgoing
=
False
,
)
openfold/utils/affine_utils.py
View file @
34e9363c
...
...
@@ -631,7 +631,7 @@ _qtr_ind_dict = {key: ind for ind, key in enumerate(_qtr_keys)}
def
_to_mat
(
pairs
):
mat
=
torch
.
zeros
((
4
,
4
))
mat
=
np
.
zeros
((
4
,
4
))
for
pair
in
pairs
:
key
,
value
=
pair
ind
=
_qtr_ind_dict
[
key
]
...
...
openfold/utils/checkpointing.py
View file @
34e9363c
...
...
@@ -17,10 +17,11 @@ import torch
import
torch.utils.checkpoint
from
typing
import
Any
,
Tuple
,
List
,
Callable
BLOCK_ARG
=
Any
BLOCK_ARGS
=
List
[
BLOCK_ARG
]
@
torch
.
jit
.
ignore
def
checkpoint_blocks
(
blocks
:
List
[
Callable
],
args
:
BLOCK_ARGS
,
...
...
openfold/utils/import_weights.py
View file @
34e9363c
...
...
@@ -217,6 +217,11 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"attention"
:
AttentionGatedParams
(
matt
.
mha
),
}
MSAColAttParams
=
lambda
matt
:
{
"query_norm"
:
LayerNormParams
(
matt
.
_msa_att
.
layer_norm_m
),
"attention"
:
AttentionGatedParams
(
matt
.
_msa_att
.
mha
),
}
MSAGlobalAttParams
=
lambda
matt
:
{
"query_norm"
:
LayerNormParams
(
matt
.
layer_norm_m
),
"attention"
:
GlobalAttentionParams
(
matt
.
global_attention
),
...
...
@@ -270,7 +275,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
msa_col_att_params
=
MSAGlobalAttParams
(
b
.
msa_att_col
)
else
:
col_att_name
=
"msa_column_attention"
msa_col_att_params
=
MSAAttParams
(
b
.
msa_att_col
)
msa_col_att_params
=
MSA
Col
AttParams
(
b
.
msa_att_col
)
d
=
{
"msa_row_attention_with_pair_bias"
:
MSAAttPairBiasParams
(
...
...
openfold/utils/tensor_utils.py
View file @
34e9363c
...
...
@@ -107,6 +107,22 @@ def tree_map(fn, tree, leaf_type):
tensor_tree_map
=
partial
(
tree_map
,
leaf_type
=
torch
.
Tensor
)
def
_fetch_dims
(
tree
):
shapes
=
[]
tree_type
=
type
(
tree
)
if
tree_type
is
dict
:
for
v
in
tree
.
values
():
shapes
.
extend
(
_fetch_dims
(
v
))
elif
tree_type
is
list
or
tree_type
is
tuple
:
for
t
in
tree
:
shapes
.
extend
(
_fetch_dims
(
t
))
elif
tree_type
is
torch
.
Tensor
:
shapes
.
append
(
tree
.
shape
)
else
:
raise
ValueError
(
"Not supported"
)
return
shapes
def
chunk_layer
(
layer
:
Callable
,
...
...
@@ -141,33 +157,17 @@ def chunk_layer(
if
not
(
len
(
inputs
)
>
0
):
raise
ValueError
(
"Must provide at least one input"
)
def
fetch_dims
(
tree
):
shapes
=
[]
tree_type
=
type
(
tree
)
if
tree_type
is
dict
:
for
v
in
tree
.
values
():
shapes
.
extend
(
fetch_dims
(
v
))
elif
tree_type
is
list
or
tree_type
is
tuple
:
for
t
in
tree
:
shapes
.
extend
(
fetch_dims
(
t
))
elif
tree_type
is
torch
.
Tensor
:
shapes
.
append
(
tree
.
shape
)
else
:
raise
ValueError
(
"Not supported"
)
return
shapes
initial_dims
=
[
shape
[:
no_batch_dims
]
for
shape
in
fetch_dims
(
inputs
)]
initial_dims
=
[
shape
[:
no_batch_dims
]
for
shape
in
_fetch_dims
(
inputs
)]
orig_batch_dims
=
tuple
([
max
(
s
)
for
s
in
zip
(
*
initial_dims
)])
def
prep_inputs
(
t
):
def
_
prep_inputs
(
t
):
# TODO: make this more memory efficient. This sucks
if
not
sum
(
t
.
shape
[:
no_batch_dims
])
==
no_batch_dims
:
t
=
t
.
expand
(
*
orig_batch_dims
,
*
t
.
shape
[
no_batch_dims
:])
t
=
t
.
reshape
(
-
1
,
*
t
.
shape
[
no_batch_dims
:])
return
t
flattened_inputs
=
tensor_tree_map
(
prep_inputs
,
inputs
)
flattened_inputs
=
tensor_tree_map
(
_
prep_inputs
,
inputs
)
flat_batch_dim
=
1
for
d
in
orig_batch_dims
:
...
...
run_pretrained_openfold.py
View file @
34e9363c
...
...
@@ -31,7 +31,7 @@ import torch
from
openfold.config
import
model_config
from
openfold.data
import
templates
,
feature_pipeline
,
data_pipeline
from
openfold.model.model
import
AlphaFold
from
openfold.model.torchscript
import
script_pr
imitiv
es_
from
openfold.model.torchscript
import
script_pres
et
_
from
openfold.np
import
residue_constants
,
protein
import
openfold.np.relax.relax
as
relax
from
openfold.utils.import_weights
import
(
...
...
@@ -49,9 +49,9 @@ def main(args):
model
=
AlphaFold
(
config
)
model
=
model
.
eval
()
import_jax_weights_
(
model
,
args
.
param_path
)
script_pr
imitiv
es_
(
model
)
script_pres
et
_
(
model
)
model
=
model
.
to
(
args
.
model_device
)
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
mmcif_dir
=
args
.
template_mmcif_dir
,
max_template_date
=
args
.
max_template_date
,
...
...
train_openfold.py
View file @
34e9363c
...
...
@@ -2,7 +2,7 @@ import argparse
import
logging
import
os
#
os.environ["CUDA_VISIBLE_DEVICES"] = "6"
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"6"
#os.environ["MASTER_ADDR"]="10.119.81.14"
#os.environ["MASTER_PORT"]="42069"
#os.environ["NODE_RANK"]="0"
...
...
@@ -23,6 +23,7 @@ from openfold.data.data_modules import (
DummyDataLoader
,
)
from
openfold.model.model
import
AlphaFold
from
openfold.model.torchscript
import
script_preset_
from
openfold.utils.callbacks
import
(
EarlyStoppingVerbose
,
)
...
...
@@ -64,10 +65,6 @@ class OpenFoldWrapper(pl.LightningModule):
# Compute loss
loss
=
self
.
loss
(
outputs
,
batch
)
#if(torch.isnan(loss) or torch.isinf(loss)):
# logging.warning("loss is NaN. Skipping example...")
# loss = loss.new_tensor(0., requires_grad=True)
return
{
"loss"
:
loss
}
def
validation_step
(
self
,
batch
,
batch_idx
):
...
...
@@ -121,6 +118,10 @@ def main(args):
sd
=
{
k
[
len
(
"module."
):]:
v
for
k
,
v
in
sd
.
items
()}
model_module
.
load_state_dict
(
sd
)
logging
.
info
(
"Successfully loaded model weights..."
)
# TorchScript components of the model
script_preset_
(
model_module
)
#data_module = DummyDataLoader("batch.pickle")
data_module
=
OpenFoldDataModule
(
config
=
config
.
data
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment