Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
b2d102cb
Commit
b2d102cb
authored
Sep 27, 2021
by
Gustaf Ahdritz
Browse files
Refactor certain modules for TorchScript, fix recycling bug
parent
4bd4ad93
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
349 additions
and
318 deletions
+349
-318
openfold/model/embedders.py
openfold/model/embedders.py
+8
-6
openfold/model/model.py
openfold/model/model.py
+155
-148
openfold/model/msa.py
openfold/model/msa.py
+24
-72
openfold/model/primitives.py
openfold/model/primitives.py
+79
-16
openfold/model/structure_module.py
openfold/model/structure_module.py
+32
-27
openfold/model/template.py
openfold/model/template.py
+4
-4
openfold/model/triangular_attention.py
openfold/model/triangular_attention.py
+3
-3
openfold/model/triangular_multiplicative_update.py
openfold/model/triangular_multiplicative_update.py
+6
-6
openfold/utils/affine_utils.py
openfold/utils/affine_utils.py
+17
-22
openfold/utils/deepspeed.py
openfold/utils/deepspeed.py
+3
-2
openfold/utils/import_weights.py
openfold/utils/import_weights.py
+1
-1
openfold/utils/loss.py
openfold/utils/loss.py
+1
-1
openfold/utils/tensor_utils.py
openfold/utils/tensor_utils.py
+16
-10
No files found.
openfold/model/embedders.py
View file @
b2d102cb
...
@@ -83,7 +83,7 @@ class InputEmbedder(nn.Module):
...
@@ -83,7 +83,7 @@ class InputEmbedder(nn.Module):
boundaries
=
torch
.
arange
(
boundaries
=
torch
.
arange
(
start
=-
self
.
relpos_k
,
end
=
self
.
relpos_k
+
1
,
device
=
d
.
device
start
=-
self
.
relpos_k
,
end
=
self
.
relpos_k
+
1
,
device
=
d
.
device
)
)
oh
=
one_hot
(
d
,
boundaries
)
oh
=
one_hot
(
d
,
boundaries
)
.
type
(
ri
.
dtype
)
return
self
.
linear_relpos
(
oh
)
return
self
.
linear_relpos
(
oh
)
def
forward
(
self
,
def
forward
(
self
,
...
@@ -112,14 +112,15 @@ class InputEmbedder(nn.Module):
...
@@ -112,14 +112,15 @@ class InputEmbedder(nn.Module):
# [*, N_res, N_res, c_z]
# [*, N_res, N_res, c_z]
pair_emb
=
tf_emb_i
[...,
None
,
:]
+
tf_emb_j
[...,
None
,
:,
:]
pair_emb
=
tf_emb_i
[...,
None
,
:]
+
tf_emb_j
[...,
None
,
:,
:]
pair_emb
+=
self
.
relpos
(
ri
)
pair_emb
=
pair_emb
+
self
.
relpos
(
ri
.
type
(
pair_emb
.
dtype
))
#pair_emb = pair_emb + self.relpos(ri)
# [*, N_clust, N_res, c_m]
# [*, N_clust, N_res, c_m]
n_clust
=
msa
.
shape
[
-
3
]
n_clust
=
msa
.
shape
[
-
3
]
tf_m
=
(
self
.
linear_tf_m
(
tf
)
tf_m
=
(
self
.
linear_tf_m
(
tf
)
.
unsqueeze
(
-
3
)
.
unsqueeze
(
-
3
)
.
expand
((
*
(
-
1
,)
*
len
(
tf
.
shape
[:
-
2
]),
n_clust
,
-
1
,
-
1
)))
.
expand
(((
-
1
,)
*
len
(
tf
.
shape
[:
-
2
])
+
(
n_clust
,
-
1
,
-
1
)))
)
msa_emb
=
self
.
linear_msa_m
(
msa
)
+
tf_m
msa_emb
=
self
.
linear_msa_m
(
msa
)
+
tf_m
return
msa_emb
,
pair_emb
return
msa_emb
,
pair_emb
...
@@ -192,6 +193,7 @@ class RecyclingEmbedder(nn.Module):
...
@@ -192,6 +193,7 @@ class RecyclingEmbedder(nn.Module):
self
.
min_bin
,
self
.
min_bin
,
self
.
max_bin
,
self
.
max_bin
,
self
.
no_bins
,
self
.
no_bins
,
dtype
=
x
.
dtype
,
requires_grad
=
False
,
requires_grad
=
False
,
device
=
x
.
device
device
=
x
.
device
)
)
...
...
openfold/model/model.py
View file @
b2d102cb
...
@@ -170,68 +170,10 @@ class AlphaFold(nn.Module):
...
@@ -170,68 +170,10 @@ class AlphaFold(nn.Module):
"torsion_angles_mask"
:
angle_feats
[
"torsion_angles_mask"
],
"torsion_angles_mask"
:
angle_feats
[
"torsion_angles_mask"
],
}
}
def
forward
(
self
,
batch
):
def
iteration
(
self
,
feats
,
m_1_prev
,
z_prev
,
x_prev
):
"""
Args:
batch:
Dictionary of arguments outlined in Algorithm 2. Keys must
include the official names of the features in the
supplement subsection 1.2.9.
The final dimension of each input must have length equal to
the number of recycling iterations.
Features (without the recycling dimension):
"aatype" ([*, N_res]):
Contrary to the supplement, this tensor of residue
indices is not one-hot.
"target_feat" ([*, N_res, C_tf])
One-hot encoding of the target sequence. C_tf is
config.model.input_embedder.tf_dim.
"residue_index" ([*, N_res])
Tensor whose final dimension consists of
consecutive indices from 0 to N_res.
"msa_feat" ([*, N_seq, N_res, C_msa])
MSA features, constructed as in the supplement.
C_msa is config.model.input_embedder.msa_dim.
"seq_mask" ([*, N_res])
1-D sequence mask
"msa_mask" ([*, N_seq, N_res])
MSA mask
"pair_mask" ([*, N_res, N_res])
2-D pair mask
"extra_msa_mask" ([*, N_extra, N_res])
Extra MSA mask
"template_mask" ([*, N_templ])
Template mask (on the level of templates, not
residues)
"template_aatype" ([*, N_templ, N_res])
Tensor of template residue indices (indices greater
than 19 are clamped to 20 (Unknown))
"template_all_atom_pos" ([*, N_templ, N_res, 37, 3])
Template atom coordinates in atom37 format
"template_all_atom_mask" ([*, N_templ, N_res, 37])
Template atom coordinate mask
"template_pseudo_beta" ([*, N_templ, N_res, 3])
Positions of template carbon "pseudo-beta" atoms
(i.e. C_beta for all residues but glycine, for
for which C_alpha is used instead)
"template_pseudo_beta_mask" ([*, N_templ, N_res])
Pseudo-beta mask
"""
# Recycling embeddings
m_1_prev
,
z_prev
,
x_prev
=
None
,
None
,
None
# Primary output dictionary
# Primary output dictionary
outputs
=
{}
outputs
=
{}
# Main recycling loop
for
cycle_no
in
range
(
self
.
config
.
no_cycles
):
# Select the features for the current recycling cycle
fetch_cur_batch
=
lambda
t
:
t
[...,
cycle_no
]
feats
=
tensor_tree_map
(
fetch_cur_batch
,
batch
)
# Grab some data about the input
# Grab some data about the input
batch_dims
=
feats
[
"target_feat"
].
shape
[:
-
2
]
batch_dims
=
feats
[
"target_feat"
].
shape
[:
-
2
]
no_batch_dims
=
len
(
batch_dims
)
no_batch_dims
=
len
(
batch_dims
)
...
@@ -259,24 +201,21 @@ class AlphaFold(nn.Module):
...
@@ -259,24 +201,21 @@ class AlphaFold(nn.Module):
# Initialize the recycling embeddings, if needs be
# Initialize the recycling embeddings, if needs be
if
(
None
in
[
m_1_prev
,
z_prev
,
x_prev
]):
if
(
None
in
[
m_1_prev
,
z_prev
,
x_prev
]):
# [*, N, C_m]
# [*, N, C_m]
m_1_prev
=
torch
.
zeros
(
m_1_prev
=
m
.
new_
zeros
(
(
*
batch_dims
,
n
,
self
.
config
.
c_m
),
(
*
batch_dims
,
n
,
self
.
config
.
c_m
),
requires_grad
=
False
,
requires_grad
=
False
,
device
=
device
,
)
)
# [*, N, N, C_z]
# [*, N, N, C_z]
z_prev
=
torch
.
zeros
(
z_prev
=
z
.
new_
zeros
(
(
*
batch_dims
,
n
,
n
,
self
.
config
.
c_z
),
(
*
batch_dims
,
n
,
n
,
self
.
config
.
c_z
),
requires_grad
=
False
,
requires_grad
=
False
,
device
=
device
,
)
)
# [*, N, 3]
# [*, N, 3]
x_prev
=
torch
.
zeros
(
x_prev
=
z
.
new_
zeros
(
(
*
batch_dims
,
n
,
residue_constants
.
atom_type_num
,
3
),
(
*
batch_dims
,
n
,
residue_constants
.
atom_type_num
,
3
),
requires_grad
=
False
,
requires_grad
=
False
,
device
=
device
,
)
)
x_prev
=
pseudo_beta_fn
(
x_prev
=
pseudo_beta_fn
(
...
@@ -377,6 +316,74 @@ class AlphaFold(nn.Module):
...
@@ -377,6 +316,74 @@ class AlphaFold(nn.Module):
# [*, N, 3]
# [*, N, 3]
x_prev
=
outputs
[
"final_atom_positions"
]
x_prev
=
outputs
[
"final_atom_positions"
]
return
outputs
,
m_1_prev
,
z_prev
,
x_prev
def
forward
(
self
,
batch
):
"""
Args:
batch:
Dictionary of arguments outlined in Algorithm 2. Keys must
include the official names of the features in the
supplement subsection 1.2.9.
The final dimension of each input must have length equal to
the number of recycling iterations.
Features (without the recycling dimension):
"aatype" ([*, N_res]):
Contrary to the supplement, this tensor of residue
indices is not one-hot.
"target_feat" ([*, N_res, C_tf])
One-hot encoding of the target sequence. C_tf is
config.model.input_embedder.tf_dim.
"residue_index" ([*, N_res])
Tensor whose final dimension consists of
consecutive indices from 0 to N_res.
"msa_feat" ([*, N_seq, N_res, C_msa])
MSA features, constructed as in the supplement.
C_msa is config.model.input_embedder.msa_dim.
"seq_mask" ([*, N_res])
1-D sequence mask
"msa_mask" ([*, N_seq, N_res])
MSA mask
"pair_mask" ([*, N_res, N_res])
2-D pair mask
"extra_msa_mask" ([*, N_extra, N_res])
Extra MSA mask
"template_mask" ([*, N_templ])
Template mask (on the level of templates, not
residues)
"template_aatype" ([*, N_templ, N_res])
Tensor of template residue indices (indices greater
than 19 are clamped to 20 (Unknown))
"template_all_atom_pos" ([*, N_templ, N_res, 37, 3])
Template atom coordinates in atom37 format
"template_all_atom_mask" ([*, N_templ, N_res, 37])
Template atom coordinate mask
"template_pseudo_beta" ([*, N_templ, N_res, 3])
Positions of template carbon "pseudo-beta" atoms
(i.e. C_beta for all residues but glycine, for
for which C_alpha is used instead)
"template_pseudo_beta_mask" ([*, N_templ, N_res])
Pseudo-beta mask
"""
# Recycling embeddings
m_1_prev
,
z_prev
,
x_prev
=
None
,
None
,
None
# Main recycling loop
for
cycle_no
in
range
(
self
.
config
.
no_cycles
):
# Select the features for the current recycling cycle
fetch_cur_batch
=
lambda
t
:
t
[...,
cycle_no
]
feats
=
tensor_tree_map
(
fetch_cur_batch
,
batch
)
# Enable grad iff we're training and it's the final recycling layer
is_final_iter
=
(
cycle_no
==
self
.
config
.
no_cycles
-
1
)
with
torch
.
set_grad_enabled
(
self
.
training
and
is_final_iter
):
outputs
,
m_1_prev
,
z_prev
,
x_prev
=
self
.
iteration
(
feats
,
m_1_prev
,
z_prev
,
x_prev
,
)
outputs
.
update
(
self
.
aux_heads
(
outputs
))
outputs
.
update
(
self
.
aux_heads
(
outputs
))
return
outputs
return
outputs
openfold/model/msa.py
View file @
b2d102cb
...
@@ -16,8 +16,9 @@
...
@@ -16,8 +16,9 @@
import
math
import
math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
typing
import
Optional
from
openfold.model.primitives
import
Linear
,
scripted_a
ttention
from
openfold.model.primitives
import
Linear
,
Attention
,
GlobalA
ttention
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
chunk_layer
,
chunk_layer
,
permute_final_dims
,
permute_final_dims
,
...
@@ -69,7 +70,8 @@ class MSAAttention(nn.Module):
...
@@ -69,7 +70,8 @@ class MSAAttention(nn.Module):
self
.
c_z
,
self
.
no_heads
,
bias
=
False
,
init
=
"normal"
self
.
c_z
,
self
.
no_heads
,
bias
=
False
,
init
=
"normal"
)
)
self
.
mha
=
scripted_attention
(
self
.
mha
=
Attention
(
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_hidden
,
self
.
c_hidden
,
self
.
no_heads
self
.
no_heads
...
@@ -93,7 +95,7 @@ class MSAAttention(nn.Module):
...
@@ -93,7 +95,7 @@ class MSAAttention(nn.Module):
if
(
mask
is
None
):
if
(
mask
is
None
):
# [*, N_seq, N_res]
# [*, N_seq, N_res]
mask
=
torch
.
ones
(
mask
=
torch
.
ones
(
(
*
m
.
shape
[:
-
3
]
,
n_seq
,
n_res
),
m
.
shape
[:
-
3
]
+
(
n_seq
,
n_res
),
device
=
m
.
device
,
device
=
m
.
device
,
requires_grad
=
False
requires_grad
=
False
)
)
...
@@ -103,7 +105,7 @@ class MSAAttention(nn.Module):
...
@@ -103,7 +105,7 @@ class MSAAttention(nn.Module):
# [*, N_seq, no_heads, N_res, N_res]
# [*, N_seq, no_heads, N_res, N_res]
bias
=
bias
.
expand
(
bias
=
bias
.
expand
(
(
*
((
-
1
,)
*
len
(
bias
.
shape
[:
-
4
]))
,
-
1
,
self
.
no_heads
,
n_res
,
-
1
)
((
-
1
,)
*
len
(
bias
.
shape
[:
-
4
]))
+
(
-
1
,
self
.
no_heads
,
n_res
,
-
1
)
)
)
biases
=
[
bias
]
biases
=
[
bias
]
...
@@ -115,7 +117,7 @@ class MSAAttention(nn.Module):
...
@@ -115,7 +117,7 @@ class MSAAttention(nn.Module):
z
=
self
.
linear_z
(
z
)
z
=
self
.
linear_z
(
z
)
# [*, 1, no_heads, N_res, N_res]
# [*, 1, no_heads, N_res, N_res]
z
=
permute_final_dims
(
z
,
2
,
0
,
1
).
unsqueeze
(
-
4
)
z
=
permute_final_dims
(
z
,
(
2
,
0
,
1
)
)
.
unsqueeze
(
-
4
)
biases
.
append
(
z
)
biases
.
append
(
z
)
...
@@ -234,79 +236,29 @@ class MSAColumnGlobalAttention(nn.Module):
...
@@ -234,79 +236,29 @@ class MSAColumnGlobalAttention(nn.Module):
self
.
inf
=
inf
self
.
inf
=
inf
self
.
eps
=
eps
self
.
eps
=
eps
self
.
layer_norm_m
=
nn
.
LayerNorm
(
self
.
c_in
)
self
.
layer_norm_m
=
nn
.
LayerNorm
(
c_in
)
self
.
linear_q
=
Linear
(
self
.
c_in
,
self
.
c_hidden
*
self
.
no_heads
,
bias
=
False
,
init
=
"glorot"
)
C_hidden
=
self
.
c_hidden
self
.
linear_k
=
Linear
(
self
.
c_in
,
C_hidden
,
bias
=
False
,
init
=
"glorot"
,
)
self
.
linear_v
=
Linear
(
self
.
c_in
,
C_hidden
,
bias
=
False
,
init
=
"glorot"
,
)
self
.
linear_g
=
Linear
(
self
.
c_in
,
self
.
c_hidden
*
self
.
no_heads
,
init
=
"gating"
)
self
.
linear_o
=
Linear
(
self
.
c_hidden
*
self
.
no_heads
,
self
.
c_in
,
init
=
"final"
)
self
.
sigmoid
=
nn
.
Sigmoid
()
self
.
softmax
=
nn
.
Softmax
(
dim
=-
1
)
def
global_attention
(
self
,
m
,
mask
):
# [*, N_res, C_in]
q
=
(
torch
.
sum
(
m
*
mask
.
unsqueeze
(
-
1
),
dim
=-
2
)
/
(
torch
.
sum
(
mask
,
dim
=-
1
)[...,
None
]
+
self
.
eps
))
# [*, N_res, H * C_hidden]
q
=
self
.
linear_q
(
q
)
q
=
q
*
self
.
c_hidden
**
(
-
0.5
)
# [*, N_res, H, C_hidden]
q
=
q
.
view
(
*
q
.
shape
[:
-
1
],
self
.
no_heads
,
-
1
)
# [*, N_res, N_seq, C_hidden]
self
.
global_attention
=
GlobalAttention
(
k
=
self
.
linear_k
(
m
)
c_in
=
c_in
,
v
=
self
.
linear_v
(
m
)
c_hidden
=
c_hidden
,
no_heads
=
no_heads
,
# [*, N_res, H, N_seq]
inf
=
inf
,
a
=
torch
.
matmul
(
eps
=
eps
,
q
,
k
.
transpose
(
-
1
,
-
2
),
# [*, N_res, C_hidden, N_seq]
)
bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
:]
a
=
a
+
bias
a
=
self
.
softmax
(
a
)
# [*, N_res, H, C_hidden]
o
=
torch
.
matmul
(
a
,
v
,
)
)
# [*, N_res, N_seq, C_hidden]
def
forward
(
self
,
g
=
self
.
sigmoid
(
self
.
linear_g
(
m
))
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
# [*, N_res, N_seq, H, C_hidden]
)
->
torch
.
Tensor
:
g
=
g
.
view
(
*
g
.
shape
[:
-
1
],
self
.
no_heads
,
-
1
)
# [*, N_res, N_seq, H, C_hidden]
o
=
o
.
unsqueeze
(
-
3
)
*
g
# [*, N_res, N_seq, H * C_hidden]
o
=
o
.
reshape
(
*
o
.
shape
[:
-
2
],
-
1
)
# [*, N_res, N_seq, C_in]
m
=
self
.
linear_o
(
o
)
return
m
def
forward
(
self
,
m
,
mask
=
None
):
n_seq
,
n_res
,
c_in
=
m
.
shape
[
-
3
:]
n_seq
,
n_res
,
c_in
=
m
.
shape
[
-
3
:]
if
(
mask
is
None
):
if
(
mask
is
None
):
# [*, N_seq, N_res]
# [*, N_seq, N_res]
mask
=
m
.
new_ones
(
m
.
shape
[:
-
1
],
requires_grad
=
False
)
mask
=
torch
.
ones
(
m
.
shape
[:
-
1
],
dtype
=
m
.
dtype
,
device
=
m
.
device
,
).
detach
()
# [*, N_res, N_seq, C_in]
# [*, N_res, N_seq, C_in]
m
=
m
.
transpose
(
-
2
,
-
3
)
m
=
m
.
transpose
(
-
2
,
-
3
)
...
@@ -327,7 +279,7 @@ class MSAColumnGlobalAttention(nn.Module):
...
@@ -327,7 +279,7 @@ class MSAColumnGlobalAttention(nn.Module):
no_batch_dims
=
len
(
m
.
shape
[:
-
2
])
no_batch_dims
=
len
(
m
.
shape
[:
-
2
])
)
)
else
:
else
:
m
=
self
.
global_attention
(
**
mha_input
)
m
=
self
.
global_attention
(
m
=
mha_input
[
"m"
],
mask
=
mha_input
[
"mask"
]
)
# [*, N_seq, N_res, C_in]
# [*, N_seq, N_res, C_in]
m
=
m
.
transpose
(
-
2
,
-
3
)
m
=
m
.
transpose
(
-
2
,
-
3
)
...
...
openfold/model/primitives.py
View file @
b2d102cb
...
@@ -235,12 +235,6 @@ class Attention(nn.Module):
...
@@ -235,12 +235,6 @@ class Attention(nn.Module):
Returns
Returns
[*, Q, C_q] attention update
[*, Q, C_q] attention update
"""
"""
# Flatten batch dims
batch_dims
=
q_x
.
shape
[:
-
2
]
q_x
=
q_x
.
view
((
-
1
,)
+
q_x
.
shape
[
-
2
:])
k_x
=
k_x
.
view
((
-
1
,)
+
k_x
.
shape
[
-
2
:])
v_x
=
v_x
.
view
((
-
1
,)
+
v_x
.
shape
[
-
2
:])
# [*, Q/K/V, H * C_hidden]
# [*, Q/K/V, H * C_hidden]
q
=
self
.
linear_q
(
q_x
)
q
=
self
.
linear_q
(
q_x
)
k
=
self
.
linear_k
(
k_x
)
k
=
self
.
linear_k
(
k_x
)
...
@@ -253,20 +247,20 @@ class Attention(nn.Module):
...
@@ -253,20 +247,20 @@ class Attention(nn.Module):
# [*, H, Q, K]
# [*, H, Q, K]
a
=
torch
.
matmul
(
a
=
torch
.
matmul
(
q
.
permute
(
0
,
2
,
1
,
3
),
# [*, H, Q, C_hidden]
permute
_final_dims
(
q
,
(
0
,
2
,
1
,
3
)
)
,
# [*, H, Q, C_hidden]
k
.
permute
(
0
,
2
,
3
,
1
),
# [*, H, C_hidden, K]
permute
_final_dims
(
k
,
(
0
,
2
,
3
,
1
)
)
,
# [*, H, C_hidden, K]
)
)
norm
=
1
/
math
.
sqrt
(
self
.
c_hidden
)
# [1]
norm
=
1
/
math
.
sqrt
(
self
.
c_hidden
)
# [1]
a
=
a
*
norm
a
*
=
norm
if
(
biases
is
not
None
):
if
(
biases
is
not
None
):
for
b
in
biases
:
for
b
in
biases
:
a
=
a
+
b
a
+
=
b
a
=
self
.
softmax
(
a
)
a
=
self
.
softmax
(
a
)
# [*, H, Q, C_hidden]
# [*, H, Q, C_hidden]
o
=
torch
.
matmul
(
o
=
torch
.
matmul
(
a
,
a
,
v
.
permute
(
0
,
2
,
1
,
3
),
# [*, H, V, C_hidden]
permute
_final_dims
(
v
,
(
0
,
2
,
1
,
3
)
)
,
# [*, H, V, C_hidden]
)
)
# [*, Q, H, C_hidden]
# [*, Q, H, C_hidden]
...
@@ -282,11 +276,80 @@ class Attention(nn.Module):
...
@@ -282,11 +276,80 @@ class Attention(nn.Module):
# [*, Q, C_q]
# [*, Q, C_q]
o
=
self
.
linear_o
(
o
)
o
=
self
.
linear_o
(
o
)
# Restore the batch dims
o
=
o
.
reshape
(
batch_dims
+
o
.
shape
[
1
:])
return
o
return
o
def
scripted_attention
(
*
args
,
**
kwargs
):
class
GlobalAttention
(
nn
.
Module
):
return
torch
.
jit
.
script
(
Attention
(
*
args
,
**
kwargs
))
def
__init__
(
self
,
c_in
,
c_hidden
,
no_heads
,
inf
,
eps
):
super
(
GlobalAttention
,
self
).
__init__
()
self
.
c_in
=
c_in
self
.
c_hidden
=
c_hidden
self
.
no_heads
=
no_heads
self
.
inf
=
inf
self
.
eps
=
eps
self
.
linear_q
=
Linear
(
c_in
,
c_hidden
*
no_heads
,
bias
=
False
,
init
=
"glorot"
)
self
.
linear_k
=
Linear
(
c_in
,
c_hidden
,
bias
=
False
,
init
=
"glorot"
,
)
self
.
linear_v
=
Linear
(
c_in
,
c_hidden
,
bias
=
False
,
init
=
"glorot"
,
)
self
.
linear_g
=
Linear
(
c_in
,
c_hidden
*
no_heads
,
init
=
"gating"
)
self
.
linear_o
=
Linear
(
c_hidden
*
no_heads
,
c_in
,
init
=
"final"
)
self
.
sigmoid
=
nn
.
Sigmoid
()
self
.
softmax
=
nn
.
Softmax
(
dim
=-
1
)
def
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# [*, N_res, C_in]
q
=
(
torch
.
sum
(
m
*
mask
.
unsqueeze
(
-
1
),
dim
=-
2
)
/
(
torch
.
sum
(
mask
,
dim
=-
1
)[...,
None
]
+
self
.
eps
))
# [*, N_res, H * C_hidden]
q
=
self
.
linear_q
(
q
)
q
=
q
*
self
.
c_hidden
**
(
-
0.5
)
# [*, N_res, H, C_hidden]
q
=
q
.
view
(
q
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
# [*, N_res, N_seq, C_hidden]
k
=
self
.
linear_k
(
m
)
v
=
self
.
linear_v
(
m
)
# [*, N_res, H, N_seq]
a
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
),
# [*, N_res, C_hidden, N_seq]
)
bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
:]
a
+=
bias
a
=
self
.
softmax
(
a
)
# [*, N_res, H, C_hidden]
o
=
torch
.
matmul
(
a
,
v
,
)
# [*, N_res, N_seq, C_hidden]
g
=
self
.
sigmoid
(
self
.
linear_g
(
m
))
# [*, N_res, N_seq, H, C_hidden]
g
=
g
.
view
(
g
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
# [*, N_res, N_seq, H, C_hidden]
o
=
o
.
unsqueeze
(
-
3
)
*
g
# [*, N_res, N_seq, H * C_hidden]
o
=
o
.
reshape
(
o
.
shape
[:
-
2
]
+
(
-
1
,))
# [*, N_res, N_seq, C_in]
m
=
self
.
linear_o
(
o
)
return
m
openfold/model/structure_module.py
View file @
b2d102cb
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
import
math
import
math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
typing
import
Optional
from
typing
import
Optional
,
Tuple
from
openfold.model.primitives
import
Linear
,
ipa_point_weights_init_
from
openfold.model.primitives
import
Linear
,
ipa_point_weights_init_
from
openfold.np.residue_constants
import
(
from
openfold.np.residue_constants
import
(
...
@@ -49,7 +49,7 @@ class AngleResnetBlock(nn.Module):
...
@@ -49,7 +49,7 @@ class AngleResnetBlock(nn.Module):
self
.
relu
=
nn
.
ReLU
()
self
.
relu
=
nn
.
ReLU
()
def
forward
(
self
,
a
)
:
def
forward
(
self
,
a
:
torch
.
Tensor
)
->
torch
.
Tensor
:
s_initial
=
a
s_initial
=
a
...
@@ -85,7 +85,7 @@ class AngleResnet(nn.Module):
...
@@ -85,7 +85,7 @@ class AngleResnet(nn.Module):
self
.
c_hidden
=
c_hidden
self
.
c_hidden
=
c_hidden
self
.
no_blocks
=
no_blocks
self
.
no_blocks
=
no_blocks
self
.
no_angles
=
no_angles
self
.
no_angles
=
no_angles
self
.
eps
ilon
=
epsilon
self
.
eps
=
epsilon
self
.
linear_in
=
Linear
(
self
.
c_in
,
self
.
c_hidden
)
self
.
linear_in
=
Linear
(
self
.
c_in
,
self
.
c_hidden
)
self
.
linear_initial
=
Linear
(
self
.
c_in
,
self
.
c_hidden
)
self
.
linear_initial
=
Linear
(
self
.
c_in
,
self
.
c_hidden
)
...
@@ -99,7 +99,10 @@ class AngleResnet(nn.Module):
...
@@ -99,7 +99,10 @@ class AngleResnet(nn.Module):
self
.
relu
=
nn
.
ReLU
()
self
.
relu
=
nn
.
ReLU
()
def
forward
(
self
,
s
,
s_initial
):
def
forward
(
self
,
s
:
torch
.
Tensor
,
s_initial
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
"""
Args:
Args:
s:
s:
...
@@ -130,14 +133,13 @@ class AngleResnet(nn.Module):
...
@@ -130,14 +133,13 @@ class AngleResnet(nn.Module):
s
=
self
.
linear_out
(
s
)
s
=
self
.
linear_out
(
s
)
# [*, no_angles, 2]
# [*, no_angles, 2]
s
=
s
.
view
(
*
s
.
shape
[:
-
1
]
,
-
1
,
2
)
s
=
s
.
view
(
s
.
shape
[:
-
1
]
+
(
-
1
,
2
)
)
unnormalized_s
=
s
unnormalized_s
=
s
norm_denom
=
torch
.
sqrt
(
norm_denom
=
torch
.
sqrt
(
torch
.
clamp
(
torch
.
clamp
(
torch
.
sum
(
s
**
2
,
dim
=-
1
,
keepdim
s
=
True
),
torch
.
sum
(
s
**
2
,
dim
=-
1
,
keepdim
=
True
),
min
=
self
.
eps
ilon
,
min
=
self
.
eps
,
)
)
)
)
s
=
s
/
norm_denom
s
=
s
/
norm_denom
...
@@ -219,7 +221,7 @@ class InvariantPointAttention(nn.Module):
...
@@ -219,7 +221,7 @@ class InvariantPointAttention(nn.Module):
z
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
t
:
T
,
t
:
T
,
mask
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
):
)
->
torch
.
Tensor
:
"""
"""
Args:
Args:
s:
s:
...
@@ -236,16 +238,15 @@ class InvariantPointAttention(nn.Module):
...
@@ -236,16 +238,15 @@ class InvariantPointAttention(nn.Module):
#######################################
#######################################
# Generate scalar and point activations
# Generate scalar and point activations
#######################################
#######################################
# [*, N_res, H * C_hidden]
# [*, N_res, H * C_hidden]
q
=
self
.
linear_q
(
s
)
q
=
self
.
linear_q
(
s
)
kv
=
self
.
linear_kv
(
s
)
kv
=
self
.
linear_kv
(
s
)
# [*, N_res, H, C_hidden]
# [*, N_res, H, C_hidden]
q
=
q
.
view
(
*
q
.
shape
[:
-
1
]
,
self
.
no_heads
,
-
1
)
q
=
q
.
view
(
q
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
)
)
# [*, N_res, H, 2 * C_hidden]
# [*, N_res, H, 2 * C_hidden]
kv
=
kv
.
view
(
*
kv
.
shape
[:
-
1
]
,
self
.
no_heads
,
-
1
)
kv
=
kv
.
view
(
kv
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
)
)
# [*, N_res, H, C_hidden]
# [*, N_res, H, C_hidden]
k
,
v
=
torch
.
split
(
kv
,
self
.
c_hidden
,
dim
=-
1
)
k
,
v
=
torch
.
split
(
kv
,
self
.
c_hidden
,
dim
=-
1
)
...
@@ -261,7 +262,7 @@ class InvariantPointAttention(nn.Module):
...
@@ -261,7 +262,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H, P_q, 3]
# [*, N_res, H, P_q, 3]
q_pts
=
q_pts
.
view
(
q_pts
=
q_pts
.
view
(
*
q_pts
.
shape
[:
-
2
]
,
self
.
no_heads
,
self
.
no_qk_points
,
3
q_pts
.
shape
[:
-
2
]
+
(
self
.
no_heads
,
self
.
no_qk_points
,
3
)
)
)
# [*, N_res, H * (P_q + P_v) * 3]
# [*, N_res, H * (P_q + P_v) * 3]
...
@@ -274,7 +275,7 @@ class InvariantPointAttention(nn.Module):
...
@@ -274,7 +275,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H, (P_q + P_v), 3]
# [*, N_res, H, (P_q + P_v), 3]
kv_pts
=
kv_pts
.
view
(
kv_pts
=
kv_pts
.
view
(
*
kv_pts
.
shape
[:
-
2
]
,
self
.
no_heads
,
-
1
,
3
kv_pts
.
shape
[:
-
2
]
+
(
self
.
no_heads
,
-
1
,
3
)
)
)
# [*, N_res, H, P_q/P_v, 3]
# [*, N_res, H, P_q/P_v, 3]
...
@@ -293,11 +294,11 @@ class InvariantPointAttention(nn.Module):
...
@@ -293,11 +294,11 @@ class InvariantPointAttention(nn.Module):
# [*, H, N_res, N_res]
# [*, H, N_res, N_res]
a
=
torch
.
matmul
(
a
=
torch
.
matmul
(
permute_final_dims
(
q
,
1
,
0
,
2
),
# [*, H, N_res, C_hidden]
permute_final_dims
(
q
,
(
1
,
0
,
2
)
)
,
# [*, H, N_res, C_hidden]
permute_final_dims
(
k
,
1
,
2
,
0
),
# [*, H, C_hidden, N_res]
permute_final_dims
(
k
,
(
1
,
2
,
0
)
)
,
# [*, H, C_hidden, N_res]
)
)
a
=
a
+
math
.
sqrt
(
1.
/
(
3
*
self
.
c_hidden
))
a
=
a
+
math
.
sqrt
(
1.
/
(
3
*
self
.
c_hidden
))
a
=
a
+
math
.
sqrt
(
1.
/
3
)
*
permute_final_dims
(
b
,
2
,
0
,
1
)
a
=
a
+
math
.
sqrt
(
1.
/
3
)
*
permute_final_dims
(
b
,
(
2
,
0
,
1
)
)
# [*, N_res, N_res, H, P_q, 3]
# [*, N_res, N_res, H, P_q, 3]
pt_att
=
q_pts
.
unsqueeze
(
-
4
)
-
k_pts
.
unsqueeze
(
-
5
)
pt_att
=
q_pts
.
unsqueeze
(
-
4
)
-
k_pts
.
unsqueeze
(
-
5
)
...
@@ -321,7 +322,7 @@ class InvariantPointAttention(nn.Module):
...
@@ -321,7 +322,7 @@ class InvariantPointAttention(nn.Module):
square_mask
=
self
.
inf
*
(
square_mask
-
1
)
square_mask
=
self
.
inf
*
(
square_mask
-
1
)
# [*, H, N_res, N_res]
# [*, H, N_res, N_res]
pt_att
=
permute_final_dims
(
pt_att
,
2
,
0
,
1
)
pt_att
=
permute_final_dims
(
pt_att
,
(
2
,
0
,
1
)
)
a
=
a
+
pt_att
a
=
a
+
pt_att
a
=
a
+
square_mask
.
unsqueeze
(
-
3
)
a
=
a
+
square_mask
.
unsqueeze
(
-
3
)
a
=
self
.
softmax
(
a
)
a
=
self
.
softmax
(
a
)
...
@@ -339,11 +340,11 @@ class InvariantPointAttention(nn.Module):
...
@@ -339,11 +340,11 @@ class InvariantPointAttention(nn.Module):
# [*, H, 3, N_res, P_v]
# [*, H, 3, N_res, P_v]
o_pt
=
torch
.
matmul
(
o_pt
=
torch
.
matmul
(
a
.
unsqueeze
(
-
3
),
# [*, H, 1, N_res, N_res]
a
.
unsqueeze
(
-
3
),
# [*, H, 1, N_res, N_res]
permute_final_dims
(
v_pts
,
1
,
3
,
0
,
2
),
# [*, H, 3, N_res, P_v]
permute_final_dims
(
v_pts
,
(
1
,
3
,
0
,
2
)
)
,
# [*, H, 3, N_res, P_v]
)
)
# [*, N_res, H, P_v, 3]
# [*, N_res, H, P_v, 3]
o_pt
=
permute_final_dims
(
o_pt
,
2
,
0
,
3
,
1
)
o_pt
=
permute_final_dims
(
o_pt
,
(
2
,
0
,
3
,
1
)
)
o_pt
=
t
[...,
None
,
None
].
invert_apply
(
o_pt
)
o_pt
=
t
[...,
None
,
None
].
invert_apply
(
o_pt
)
# [*, N_res, H * P_v]
# [*, N_res, H * P_v]
...
@@ -758,35 +759,39 @@ class StructureModule(nn.Module):
...
@@ -758,35 +759,39 @@ class StructureModule(nn.Module):
return
outputs
return
outputs
def
_init_residue_constants
(
self
,
device
):
def
_init_residue_constants
(
self
,
dtype
,
device
):
if
(
self
.
default_frames
is
None
):
if
(
self
.
default_frames
is
None
):
self
.
default_frames
=
torch
.
tensor
(
self
.
default_frames
=
torch
.
tensor
(
restype_rigid_group_default_frame
,
restype_rigid_group_default_frame
,
dtype
=
dtype
,
device
=
device
,
device
=
device
,
requires_grad
=
False
,
requires_grad
=
False
,
)
)
if
(
self
.
group_idx
is
None
):
if
(
self
.
group_idx
is
None
):
self
.
group_idx
=
torch
.
tensor
(
self
.
group_idx
=
torch
.
tensor
(
restype_atom14_to_rigid_group
,
restype_atom14_to_rigid_group
,
dtype
=
dtype
,
device
=
device
,
device
=
device
,
requires_grad
=
False
,
requires_grad
=
False
,
)
)
if
(
self
.
atom_mask
is
None
):
if
(
self
.
atom_mask
is
None
):
self
.
atom_mask
=
torch
.
tensor
(
self
.
atom_mask
=
torch
.
tensor
(
restype_atom14_mask
,
restype_atom14_mask
,
dtype
=
dtype
,
device
=
device
,
device
=
device
,
requires_grad
=
False
,
requires_grad
=
False
,
)
)
if
(
self
.
lit_positions
is
None
):
if
(
self
.
lit_positions
is
None
):
self
.
lit_positions
=
torch
.
tensor
(
self
.
lit_positions
=
torch
.
tensor
(
restype_atom14_rigid_group_positions
,
restype_atom14_rigid_group_positions
,
dtype
=
dtype
,
device
=
device
,
device
=
device
,
requires_grad
=
False
,
requires_grad
=
False
,
)
)
def
torsion_angles_to_frames
(
self
,
t
,
alpha
,
f
):
def
torsion_angles_to_frames
(
self
,
t
,
alpha
,
f
):
# Lazily initialize the residue constants on the correct device
# Lazily initialize the residue constants on the correct device
self
.
_init_residue_constants
(
f
.
device
)
self
.
_init_residue_constants
(
f
.
dtype
,
f
.
device
)
# Separated purely to make testing less annoying
# Separated purely to make testing less annoying
return
_torsion_angles_to_frames
(
t
,
alpha
,
f
,
self
.
default_frames
)
return
_torsion_angles_to_frames
(
t
,
alpha
,
f
,
self
.
default_frames
)
...
@@ -797,7 +802,7 @@ class StructureModule(nn.Module):
...
@@ -797,7 +802,7 @@ class StructureModule(nn.Module):
# Lazily initialize the residue constants on the correct device
# Lazily initialize the residue constants on the correct device
# TODO: Maybe this stuff should be done on CPU instead (so these
# TODO: Maybe this stuff should be done on CPU instead (so these
# arrays
# arrays
self
.
_init_residue_constants
(
f
.
device
)
self
.
_init_residue_constants
(
f
.
dtype
,
f
.
device
)
return
_frames_and_literature_positions_to_atom14_pos
(
return
_frames_and_literature_positions_to_atom14_pos
(
t
,
t
,
...
...
openfold/model/template.py
View file @
b2d102cb
...
@@ -18,7 +18,7 @@ import math
...
@@ -18,7 +18,7 @@ import math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
openfold.model.primitives
import
Linear
,
scripted_a
ttention
from
openfold.model.primitives
import
Linear
,
A
ttention
from
openfold.utils.deepspeed
import
checkpoint_blocks
from
openfold.utils.deepspeed
import
checkpoint_blocks
from
openfold.model.dropout
import
(
from
openfold.model.dropout
import
(
DropoutRowwise
,
DropoutRowwise
,
...
@@ -69,7 +69,7 @@ class TemplatePointwiseAttention(nn.Module):
...
@@ -69,7 +69,7 @@ class TemplatePointwiseAttention(nn.Module):
self
.
no_heads
=
no_heads
self
.
no_heads
=
no_heads
self
.
chunk_size
=
chunk_size
self
.
chunk_size
=
chunk_size
self
.
mha
=
scripted_a
ttention
(
self
.
mha
=
A
ttention
(
self
.
c_z
,
self
.
c_t
,
self
.
c_t
,
self
.
c_z
,
self
.
c_t
,
self
.
c_t
,
self
.
c_hidden
,
self
.
no_heads
,
self
.
c_hidden
,
self
.
no_heads
,
gating
=
False
,
gating
=
False
,
...
@@ -91,7 +91,7 @@ class TemplatePointwiseAttention(nn.Module):
...
@@ -91,7 +91,7 @@ class TemplatePointwiseAttention(nn.Module):
# NOTE: This is not the "template_mask" from the supplement, but a
# NOTE: This is not the "template_mask" from the supplement, but a
# [*, N_templ] mask from the code. I'm pretty sure it's always just 1,
# [*, N_templ] mask from the code. I'm pretty sure it's always just 1,
# but not sure enough to remove it. It's nice to have, I guess.
# but not sure enough to remove it. It's nice to have, I guess.
template_mask
=
t
orch
.
ones
(
t
.
shape
[:
-
3
]
,
device
=
t
.
device
)
template_mask
=
t
.
new_
ones
(
t
.
shape
[:
-
3
])
bias
=
(
1e9
*
(
template_mask
[...,
None
,
None
,
None
,
None
,
:]
-
1
))
bias
=
(
1e9
*
(
template_mask
[...,
None
,
None
,
None
,
None
,
:]
-
1
))
...
@@ -99,7 +99,7 @@ class TemplatePointwiseAttention(nn.Module):
...
@@ -99,7 +99,7 @@ class TemplatePointwiseAttention(nn.Module):
z
=
z
.
unsqueeze
(
-
2
)
z
=
z
.
unsqueeze
(
-
2
)
# [*, N_res, N_res, N_temp, C_t]
# [*, N_res, N_res, N_temp, C_t]
t
=
permute_final_dims
(
t
,
1
,
2
,
0
,
3
)
t
=
permute_final_dims
(
t
,
(
1
,
2
,
0
,
3
)
)
# [*, N_res, N_res, 1, C_z]
# [*, N_res, N_res, 1, C_z]
mha_inputs
=
{
mha_inputs
=
{
...
...
openfold/model/triangular_attention.py
View file @
b2d102cb
...
@@ -18,7 +18,7 @@ import math
...
@@ -18,7 +18,7 @@ import math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
openfold.model.primitives
import
Linear
,
scripted_a
ttention
from
openfold.model.primitives
import
Linear
,
A
ttention
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
chunk_layer
,
chunk_layer
,
permute_final_dims
,
permute_final_dims
,
...
@@ -57,7 +57,7 @@ class TriangleAttention(nn.Module):
...
@@ -57,7 +57,7 @@ class TriangleAttention(nn.Module):
self
.
linear
=
Linear
(
c_in
,
self
.
no_heads
,
bias
=
False
,
init
=
"normal"
)
self
.
linear
=
Linear
(
c_in
,
self
.
no_heads
,
bias
=
False
,
init
=
"normal"
)
self
.
mha
=
scripted_a
ttention
(
self
.
mha
=
A
ttention
(
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_hidden
,
self
.
c_hidden
,
self
.
no_heads
self
.
no_heads
...
@@ -91,7 +91,7 @@ class TriangleAttention(nn.Module):
...
@@ -91,7 +91,7 @@ class TriangleAttention(nn.Module):
mask_bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
None
,
:]
mask_bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
None
,
:]
# [*, H, I, J]
# [*, H, I, J]
triangle_bias
=
permute_final_dims
(
self
.
linear
(
x
),
2
,
0
,
1
)
triangle_bias
=
permute_final_dims
(
self
.
linear
(
x
),
(
2
,
0
,
1
)
)
# [*, 1, H, I, J]
# [*, 1, H, I, J]
triangle_bias
=
triangle_bias
.
unsqueeze
(
-
4
)
triangle_bias
=
triangle_bias
.
unsqueeze
(
-
4
)
...
...
openfold/model/triangular_multiplicative_update.py
View file @
b2d102cb
...
@@ -59,12 +59,12 @@ class TriangleMultiplicativeUpdate(nn.Module):
...
@@ -59,12 +59,12 @@ class TriangleMultiplicativeUpdate(nn.Module):
):
):
# [*, C, N_i, N_j]
# [*, C, N_i, N_j]
p
=
torch
.
matmul
(
p
=
torch
.
matmul
(
permute_final_dims
(
a
,
2
,
0
,
1
),
permute_final_dims
(
a
,
(
2
,
0
,
1
)
)
,
permute_final_dims
(
b
,
2
,
1
,
0
),
permute_final_dims
(
b
,
(
2
,
1
,
0
)
)
,
)
)
# [*, N_i, N_j, C]
# [*, N_i, N_j, C]
return
permute_final_dims
(
p
,
1
,
2
,
0
)
return
permute_final_dims
(
p
,
(
1
,
2
,
0
)
)
def
_incoming_matmul
(
self
,
def
_incoming_matmul
(
self
,
a
:
torch
.
Tensor
,
# [*, N_k, N_i, C]
a
:
torch
.
Tensor
,
# [*, N_k, N_i, C]
...
@@ -73,12 +73,12 @@ class TriangleMultiplicativeUpdate(nn.Module):
...
@@ -73,12 +73,12 @@ class TriangleMultiplicativeUpdate(nn.Module):
# [*, C, N_i, N_j]
# [*, C, N_i, N_j]
p
=
torch
.
matmul
(
p
=
torch
.
matmul
(
permute_final_dims
(
a
,
2
,
1
,
0
),
permute_final_dims
(
a
,
(
2
,
1
,
0
)
)
,
permute_final_dims
(
b
,
2
,
0
,
1
),
permute_final_dims
(
b
,
(
2
,
0
,
1
)
)
,
)
)
# [*, N_i, N_j, C]
# [*, N_i, N_j, C]
return
permute_final_dims
(
p
,
1
,
2
,
0
)
return
permute_final_dims
(
p
,
(
1
,
2
,
0
)
)
def
forward
(
self
,
z
,
mask
=
None
):
def
forward
(
self
,
z
,
mask
=
None
):
"""
"""
...
...
openfold/utils/affine_utils.py
View file @
b2d102cb
...
@@ -13,30 +13,25 @@
...
@@ -13,30 +13,25 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
numpy
as
np
import
torch
import
torch
# According to DeepMind, this prevents rotation compositions from being
# computed on low-precision tensor cores. I'm personally skeptical that it
# makes a difference, but to get as close as possible to their outputs, I'm
# adding it.
def
rot_matmul
(
a
,
b
):
def
rot_matmul
(
a
,
b
):
e
=
...
row_1
=
torch
.
stack
([
row_1
=
torch
.
stack
([
a
[
e
,
0
,
0
]
*
b
[
e
,
0
,
0
]
+
a
[
e
,
0
,
1
]
*
b
[
e
,
1
,
0
]
+
a
[
e
,
0
,
2
]
*
b
[
e
,
2
,
0
],
a
[
...
,
0
,
0
]
*
b
[
...
,
0
,
0
]
+
a
[
...
,
0
,
1
]
*
b
[
...
,
1
,
0
]
+
a
[
...
,
0
,
2
]
*
b
[
...
,
2
,
0
],
a
[
e
,
0
,
0
]
*
b
[
e
,
0
,
1
]
+
a
[
e
,
0
,
1
]
*
b
[
e
,
1
,
1
]
+
a
[
e
,
0
,
2
]
*
b
[
e
,
2
,
1
],
a
[
...
,
0
,
0
]
*
b
[
...
,
0
,
1
]
+
a
[
...
,
0
,
1
]
*
b
[
...
,
1
,
1
]
+
a
[
...
,
0
,
2
]
*
b
[
...
,
2
,
1
],
a
[
e
,
0
,
0
]
*
b
[
e
,
0
,
2
]
+
a
[
e
,
0
,
1
]
*
b
[
e
,
1
,
2
]
+
a
[
e
,
0
,
2
]
*
b
[
e
,
2
,
2
],
a
[
...
,
0
,
0
]
*
b
[
...
,
0
,
2
]
+
a
[
...
,
0
,
1
]
*
b
[
...
,
1
,
2
]
+
a
[
...
,
0
,
2
]
*
b
[
...
,
2
,
2
],
],
dim
=-
1
)
],
dim
=-
1
)
row_2
=
torch
.
stack
([
row_2
=
torch
.
stack
([
a
[
e
,
1
,
0
]
*
b
[
e
,
0
,
0
]
+
a
[
e
,
1
,
1
]
*
b
[
e
,
1
,
0
]
+
a
[
e
,
1
,
2
]
*
b
[
e
,
2
,
0
],
a
[
...
,
1
,
0
]
*
b
[
...
,
0
,
0
]
+
a
[
...
,
1
,
1
]
*
b
[
...
,
1
,
0
]
+
a
[
...
,
1
,
2
]
*
b
[
...
,
2
,
0
],
a
[
e
,
1
,
0
]
*
b
[
e
,
0
,
1
]
+
a
[
e
,
1
,
1
]
*
b
[
e
,
1
,
1
]
+
a
[
e
,
1
,
2
]
*
b
[
e
,
2
,
1
],
a
[
...
,
1
,
0
]
*
b
[
...
,
0
,
1
]
+
a
[
...
,
1
,
1
]
*
b
[
...
,
1
,
1
]
+
a
[
...
,
1
,
2
]
*
b
[
...
,
2
,
1
],
a
[
e
,
1
,
0
]
*
b
[
e
,
0
,
2
]
+
a
[
e
,
1
,
1
]
*
b
[
e
,
1
,
2
]
+
a
[
e
,
1
,
2
]
*
b
[
e
,
2
,
2
],
a
[
...
,
1
,
0
]
*
b
[
...
,
0
,
2
]
+
a
[
...
,
1
,
1
]
*
b
[
...
,
1
,
2
]
+
a
[
...
,
1
,
2
]
*
b
[
...
,
2
,
2
],
],
dim
=-
1
)
],
dim
=-
1
)
row_3
=
torch
.
stack
([
row_3
=
torch
.
stack
([
a
[
e
,
2
,
0
]
*
b
[
e
,
0
,
0
]
+
a
[
e
,
2
,
1
]
*
b
[
e
,
1
,
0
]
+
a
[
e
,
2
,
2
]
*
b
[
e
,
2
,
0
],
a
[
...
,
2
,
0
]
*
b
[
...
,
0
,
0
]
+
a
[
...
,
2
,
1
]
*
b
[
...
,
1
,
0
]
+
a
[
...
,
2
,
2
]
*
b
[
...
,
2
,
0
],
a
[
e
,
2
,
0
]
*
b
[
e
,
0
,
1
]
+
a
[
e
,
2
,
1
]
*
b
[
e
,
1
,
1
]
+
a
[
e
,
2
,
2
]
*
b
[
e
,
2
,
1
],
a
[
...
,
2
,
0
]
*
b
[
...
,
0
,
1
]
+
a
[
...
,
2
,
1
]
*
b
[
...
,
1
,
1
]
+
a
[
...
,
2
,
2
]
*
b
[
...
,
2
,
1
],
a
[
e
,
2
,
0
]
*
b
[
e
,
0
,
2
]
+
a
[
e
,
2
,
1
]
*
b
[
e
,
1
,
2
]
+
a
[
e
,
2
,
2
]
*
b
[
e
,
2
,
2
],
a
[
...
,
2
,
0
]
*
b
[
...
,
0
,
2
]
+
a
[
...
,
2
,
1
]
*
b
[
...
,
1
,
2
]
+
a
[
...
,
2
,
2
]
*
b
[
...
,
2
,
2
],
],
dim
=-
1
)
],
dim
=-
1
)
return
torch
.
stack
([
row_1
,
row_2
,
row_3
],
dim
=-
2
)
return
torch
.
stack
([
row_1
,
row_2
,
row_3
],
dim
=-
2
)
...
@@ -175,7 +170,7 @@ class T:
...
@@ -175,7 +170,7 @@ class T:
return
T
(
rots
,
trans
)
return
T
(
rots
,
trans
)
def
to_4x4
(
self
):
def
to_4x4
(
self
):
tensor
=
torch
.
zeros
((
*
self
.
shape
,
4
,
4
)
,
device
=
self
.
rots
.
device
)
tensor
=
self
.
rots
.
new_
zeros
((
*
self
.
shape
,
4
,
4
))
tensor
[...,
:
3
,
:
3
]
=
self
.
rots
tensor
[...,
:
3
,
:
3
]
=
self
.
rots
tensor
[...,
:
3
,
3
]
=
self
.
trans
tensor
[...,
:
3
,
3
]
=
self
.
trans
tensor
[...,
3
,
3
]
=
1
tensor
[...,
3
,
3
]
=
1
...
@@ -311,7 +306,7 @@ def _to_mat(pairs):
...
@@ -311,7 +306,7 @@ def _to_mat(pairs):
return
mat
return
mat
_qtr_mat
=
torch
.
zeros
((
4
,
4
,
3
,
3
))
_qtr_mat
=
np
.
zeros
((
4
,
4
,
3
,
3
))
_qtr_mat
[...,
0
,
0
]
=
_to_mat
([(
'aa'
,
1
),
(
'bb'
,
1
),
(
'cc'
,
-
1
),
(
'dd'
,
-
1
)])
_qtr_mat
[...,
0
,
0
]
=
_to_mat
([(
'aa'
,
1
),
(
'bb'
,
1
),
(
'cc'
,
-
1
),
(
'dd'
,
-
1
)])
_qtr_mat
[...,
0
,
1
]
=
_to_mat
([(
'bc'
,
2
),
(
'ad'
,
-
2
)])
_qtr_mat
[...,
0
,
1
]
=
_to_mat
([(
'bc'
,
2
),
(
'ad'
,
-
2
)])
_qtr_mat
[...,
0
,
2
]
=
_to_mat
([(
'bd'
,
2
),
(
'ac'
,
2
)])
_qtr_mat
[...,
0
,
2
]
=
_to_mat
([(
'bd'
,
2
),
(
'ac'
,
2
)])
...
@@ -328,9 +323,11 @@ def quat_to_rot(
...
@@ -328,9 +323,11 @@ def quat_to_rot(
# [*, 4, 4]
# [*, 4, 4]
quat
=
quat
[...,
None
]
*
quat
[...,
None
,
:]
quat
=
quat
[...,
None
]
*
quat
[...,
None
,
:]
mat
=
quat
.
new_tensor
(
_qtr_mat
)
# [*, 4, 4, 3, 3]
# [*, 4, 4, 3, 3]
shaped_qtr_mat
=
_qtr_
mat
.
view
((
1
,)
*
len
(
quat
.
shape
[:
-
2
])
+
(
4
,
4
,
3
,
3
))
shaped_qtr_mat
=
mat
.
view
((
1
,)
*
len
(
quat
.
shape
[:
-
2
])
+
(
4
,
4
,
3
,
3
))
quat
=
quat
[...,
None
,
None
]
*
shaped_qtr_mat
.
to
(
quat
.
device
)
quat
=
quat
[...,
None
,
None
]
*
shaped_qtr_mat
# [*, 3, 3]
# [*, 3, 3]
return
torch
.
sum
(
quat
,
dim
=
(
-
3
,
-
4
))
return
torch
.
sum
(
quat
,
dim
=
(
-
3
,
-
4
))
...
@@ -339,9 +336,7 @@ def affine_vector_to_4x4(vector):
...
@@ -339,9 +336,7 @@ def affine_vector_to_4x4(vector):
quats
=
vector
[...,
:
4
]
quats
=
vector
[...,
:
4
]
trans
=
vector
[...,
4
:]
trans
=
vector
[...,
4
:]
four_by_four
=
torch
.
zeros
(
four_by_four
=
vector
.
new_zeros
((
*
vector
.
shape
[:
-
1
],
4
,
4
))
(
*
vector
.
shape
[:
-
1
],
4
,
4
),
device
=
vector
.
device
)
four_by_four
[...,
:
3
,
:
3
]
=
quat_to_rot
(
quats
)
four_by_four
[...,
:
3
,
:
3
]
=
quat_to_rot
(
quats
)
four_by_four
[...,
:
3
,
3
]
=
trans
four_by_four
[...,
:
3
,
3
]
=
trans
four_by_four
[...,
3
,
3
]
=
1
four_by_four
[...,
3
,
3
]
=
1
...
...
openfold/utils/deepspeed.py
View file @
b2d102cb
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
import
deepspeed
import
deepspeed
import
torch
import
torch
from
torch.utils.checkpoint
import
checkpoint
from
typing
import
Any
,
Tuple
,
List
,
Callable
from
typing
import
Any
,
Tuple
,
List
,
Callable
BLOCK_ARG
=
Any
BLOCK_ARG
=
Any
...
@@ -55,7 +56,7 @@ def checkpoint_blocks(
...
@@ -55,7 +56,7 @@ def checkpoint_blocks(
return
a
return
a
def
chunker
(
s
,
e
):
def
chunker
(
s
,
e
):
def
exec_sliced
(
a
):
def
exec_sliced
(
*
a
):
return
exec
(
blocks
[
s
:
e
],
a
)
return
exec
(
blocks
[
s
:
e
],
a
)
return
exec_sliced
return
exec_sliced
...
@@ -69,7 +70,7 @@ def checkpoint_blocks(
...
@@ -69,7 +70,7 @@ def checkpoint_blocks(
for
s
in
range
(
0
,
len
(
blocks
),
blocks_per_ckpt
):
for
s
in
range
(
0
,
len
(
blocks
),
blocks_per_ckpt
):
e
=
s
+
blocks_per_ckpt
e
=
s
+
blocks_per_ckpt
args
=
deepspeed
.
checkpointing
.
checkpoint
(
chunker
(
s
,
e
),
args
)
args
=
deepspeed
.
checkpointing
.
checkpoint
(
chunker
(
s
,
e
),
*
args
)
args
=
wrap
(
args
)
args
=
wrap
(
args
)
return
args
return
args
openfold/utils/import_weights.py
View file @
b2d102cb
...
@@ -231,7 +231,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
...
@@ -231,7 +231,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
MSAGlobalAttParams
=
lambda
matt
:
{
MSAGlobalAttParams
=
lambda
matt
:
{
"query_norm"
:
LayerNormParams
(
matt
.
layer_norm_m
),
"query_norm"
:
LayerNormParams
(
matt
.
layer_norm_m
),
"attention"
:
GlobalAttentionParams
(
matt
)
"attention"
:
GlobalAttentionParams
(
matt
.
global_attention
)
}
}
MSAAttPairBiasParams
=
lambda
matt
:
dict
(
MSAAttPairBiasParams
=
lambda
matt
:
dict
(
...
...
openfold/utils/loss.py
View file @
b2d102cb
...
@@ -356,7 +356,7 @@ def lddt_loss(
...
@@ -356,7 +356,7 @@ def lddt_loss(
)
)
dists_to_score
=
(
dists_to_score
=
(
(
dmat_true
<
cutoff
)
*
all_atom_mask
*
(
dmat_true
<
cutoff
)
*
all_atom_mask
*
permute_final_dims
(
all_atom_mask
,
1
,
0
)
*
permute_final_dims
(
all_atom_mask
,
(
1
,
0
)
)
*
(
1.
-
torch
.
eye
(
n
,
device
=
all_atom_mask
.
device
))
(
1.
-
torch
.
eye
(
n
,
device
=
all_atom_mask
.
device
))
)
)
...
...
openfold/utils/tensor_utils.py
View file @
b2d102cb
...
@@ -16,12 +16,13 @@
...
@@ -16,12 +16,13 @@
from
functools
import
partial
from
functools
import
partial
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
typing
import
Tuple
,
List
,
Callable
,
Any
,
Dict
def
permute_final_dims
(
tensor
,
*
inds
):
def
permute_final_dims
(
tensor
:
torch
.
Tensor
,
inds
:
List
[
int
]
):
zero_index
=
-
1
*
len
(
inds
)
zero_index
=
-
1
*
len
(
inds
)
first_inds
=
range
(
len
(
tensor
.
shape
[:
zero_index
]))
first_inds
=
list
(
range
(
len
(
tensor
.
shape
[:
zero_index
]))
)
return
tensor
.
permute
(
*
first_inds
,
*
[
zero_index
+
i
for
i
in
inds
])
return
tensor
.
permute
(
first_inds
+
[
zero_index
+
i
for
i
in
inds
])
def
flatten_final_dims
(
tensor
:
torch
.
Tensor
,
no_dims
:
int
):
def
flatten_final_dims
(
tensor
:
torch
.
Tensor
,
no_dims
:
int
):
...
@@ -70,7 +71,7 @@ def stack_tensor_dicts(dicts):
...
@@ -70,7 +71,7 @@ def stack_tensor_dicts(dicts):
def
one_hot
(
x
,
v_bins
):
def
one_hot
(
x
,
v_bins
):
reshaped_bins
=
v_bins
.
view
(
*
((
1
,)
*
len
(
x
.
shape
)
+
(
len
(
v_bins
),))
)
reshaped_bins
=
v_bins
.
view
(((
1
,)
*
len
(
x
.
shape
)
)
+
(
len
(
v_bins
),))
diffs
=
x
[...,
None
]
-
reshaped_bins
diffs
=
x
[...,
None
]
-
reshaped_bins
am
=
torch
.
argmin
(
torch
.
abs
(
diffs
),
dim
=-
1
)
am
=
torch
.
argmin
(
torch
.
abs
(
diffs
),
dim
=-
1
)
return
nn
.
functional
.
one_hot
(
am
,
num_classes
=
len
(
v_bins
)).
float
()
return
nn
.
functional
.
one_hot
(
am
,
num_classes
=
len
(
v_bins
)).
float
()
...
@@ -118,7 +119,12 @@ def tree_map(fn, tree, leaf_type):
...
@@ -118,7 +119,12 @@ def tree_map(fn, tree, leaf_type):
tensor_tree_map
=
partial
(
tree_map
,
leaf_type
=
torch
.
Tensor
)
tensor_tree_map
=
partial
(
tree_map
,
leaf_type
=
torch
.
Tensor
)
def
chunk_layer
(
layer
,
inputs
,
chunk_size
,
no_batch_dims
):
def
chunk_layer
(
layer
:
Callable
,
inputs
:
Dict
[
str
,
Any
],
chunk_size
:
int
,
no_batch_dims
:
int
,
)
->
Any
:
"""
"""
Implements the "chunking" procedure described in section 1.11.8.
Implements the "chunking" procedure described in section 1.11.8.
...
@@ -130,8 +136,8 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
...
@@ -130,8 +136,8 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
layer:
layer:
The layer to be applied chunk-wise
The layer to be applied chunk-wise
inputs:
inputs:
A (nested) dictionary of keyworded inputs. All leaves must
be
A (
non-
nested) dictionary of keyworded inputs. All leaves must
tensors and must share the same batch dimensions.
be
tensors and must share the same batch dimensions.
chunk_size:
chunk_size:
The number of sub-batches per chunk. If multiple batch
The number of sub-batches per chunk. If multiple batch
dimensions are specified, a "sub-batch" is defined as a single
dimensions are specified, a "sub-batch" is defined as a single
...
@@ -163,7 +169,7 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
...
@@ -163,7 +169,7 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
return
shapes
return
shapes
initial_dims
=
[
shape
[:
no_batch_dims
]
for
shape
in
fetch_dims
(
inputs
)]
initial_dims
=
[
shape
[:
no_batch_dims
]
for
shape
in
fetch_dims
(
inputs
)]
orig_batch_dims
=
[
max
(
s
)
for
s
in
zip
(
*
initial_dims
)]
orig_batch_dims
=
tuple
(
[
max
(
s
)
for
s
in
zip
(
*
initial_dims
)]
)
def
prep_inputs
(
t
):
def
prep_inputs
(
t
):
# TODO: make this more memory efficient. This sucks
# TODO: make this more memory efficient. This sucks
...
@@ -194,7 +200,7 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
...
@@ -194,7 +200,7 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
# Allocate space for the output
# Allocate space for the output
if
(
out
is
None
):
if
(
out
is
None
):
allocate
=
lambda
t
:
t
.
new_zeros
(
flat_batch_dim
,
*
t
.
shape
[
1
:])
allocate
=
lambda
t
:
t
.
new_zeros
(
(
flat_batch_dim
,
)
+
t
.
shape
[
1
:])
out
=
tensor_tree_map
(
allocate
,
output_chunk
)
out
=
tensor_tree_map
(
allocate
,
output_chunk
)
# Put the chunk in its pre-allocated space
# Put the chunk in its pre-allocated space
...
@@ -217,7 +223,7 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
...
@@ -217,7 +223,7 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
i
+=
chunk_size
i
+=
chunk_size
reshape
=
lambda
t
:
t
.
reshape
(
*
orig_batch_dims
,
*
t
.
shape
[
1
:])
reshape
=
lambda
t
:
t
.
reshape
(
orig_batch_dims
+
t
.
shape
[
1
:])
out
=
tensor_tree_map
(
reshape
,
out
)
out
=
tensor_tree_map
(
reshape
,
out
)
return
out
return
out
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment