Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
893fe372
Commit
893fe372
authored
Oct 05, 2021
by
Gustaf Ahdritz
Browse files
Get FP16 training working
parent
dd06b323
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
203 additions
and
161 deletions
+203
-161
config.py
config.py
+6
-6
openfold/model/model.py
openfold/model/model.py
+5
-2
openfold/model/primitives.py
openfold/model/primitives.py
+0
-2
openfold/model/structure_module.py
openfold/model/structure_module.py
+16
-107
openfold/utils/affine_utils.py
openfold/utils/affine_utils.py
+28
-12
openfold/utils/deepspeed.py
openfold/utils/deepspeed.py
+2
-5
openfold/utils/feats.py
openfold/utils/feats.py
+123
-6
openfold/utils/import_weights.py
openfold/utils/import_weights.py
+1
-1
openfold/utils/loss.py
openfold/utils/loss.py
+19
-17
openfold/utils/tensor_utils.py
openfold/utils/tensor_utils.py
+3
-3
No files found.
config.py
View file @
893fe372
...
@@ -44,7 +44,7 @@ def model_config(name, train=False, low_prec=False):
...
@@ -44,7 +44,7 @@ def model_config(name, train=False, low_prec=False):
raise
ValueError
(
"Invalid model name"
)
raise
ValueError
(
"Invalid model name"
)
if
(
train
):
if
(
train
):
c
.
globals
.
model
.
blocks_per_ckpt
=
1
c
.
globals
.
blocks_per_ckpt
=
1
c
.
globals
.
chunk_size
=
None
c
.
globals
.
chunk_size
=
None
if
(
low_prec
):
if
(
low_prec
):
...
@@ -137,7 +137,7 @@ config = mlc.ConfigDict({
...
@@ -137,7 +137,7 @@ config = mlc.ConfigDict({
},
},
"inf"
:
1e9
,
"inf"
:
1e9
,
"eps"
:
eps
,
#1e-6,
"eps"
:
eps
,
#1e-6,
"enabled"
:
True
,
"enabled"
:
False
,
#
True,
"embed_angles"
:
True
,
"embed_angles"
:
True
,
},
},
"extra_msa"
:
{
"extra_msa"
:
{
...
@@ -239,7 +239,7 @@ config = mlc.ConfigDict({
...
@@ -239,7 +239,7 @@ config = mlc.ConfigDict({
"max_bin"
:
21.6875
,
"max_bin"
:
21.6875
,
"no_bins"
:
64
,
"no_bins"
:
64
,
"eps"
:
eps
,
#1e-6,
"eps"
:
eps
,
#1e-6,
"weight"
:
0.
,
#
0.3,
"weight"
:
0.3
,
},
},
"experimentally_resolved"
:
{
"experimentally_resolved"
:
{
"eps"
:
eps
,
#1e-8,
"eps"
:
eps
,
#1e-8,
...
@@ -267,17 +267,17 @@ config = mlc.ConfigDict({
...
@@ -267,17 +267,17 @@ config = mlc.ConfigDict({
"cutoff"
:
15.
,
"cutoff"
:
15.
,
"no_bins"
:
50
,
"no_bins"
:
50
,
"eps"
:
eps
,
#1e-10,
"eps"
:
eps
,
#1e-10,
"weight"
:
0.
,
#
0.01,
"weight"
:
0.01
,
},
},
"masked_msa"
:
{
"masked_msa"
:
{
"eps"
:
eps
,
#1e-8,
"eps"
:
eps
,
#1e-8,
"weight"
:
0.
,
#
2.0,
"weight"
:
2.0
,
},
},
"supervised_chi"
:
{
"supervised_chi"
:
{
"chi_weight"
:
0.5
,
"chi_weight"
:
0.5
,
"angle_norm_weight"
:
0.01
,
"angle_norm_weight"
:
0.01
,
"eps"
:
eps
,
#1e-6,
"eps"
:
eps
,
#1e-6,
"weight"
:
0.
,
#
1.0,
"weight"
:
1.0
,
},
},
"violation"
:
{
"violation"
:
{
"violation_tolerance_factor"
:
12.0
,
"violation_tolerance_factor"
:
12.0
,
...
...
openfold/model/model.py
View file @
893fe372
...
@@ -389,6 +389,7 @@ class AlphaFold(nn.Module):
...
@@ -389,6 +389,7 @@ class AlphaFold(nn.Module):
m_1_prev
,
z_prev
,
x_prev
=
None
,
None
,
None
m_1_prev
,
z_prev
,
x_prev
=
None
,
None
,
None
is_grad_enabled
=
torch
.
is_grad_enabled
()
is_grad_enabled
=
torch
.
is_grad_enabled
()
self
.
_disable_activation_checkpointing
()
# Main recycling loop
# Main recycling loop
for
cycle_no
in
range
(
self
.
config
.
no_cycles
):
for
cycle_no
in
range
(
self
.
config
.
no_cycles
):
...
@@ -400,8 +401,10 @@ class AlphaFold(nn.Module):
...
@@ -400,8 +401,10 @@ class AlphaFold(nn.Module):
is_final_iter
=
(
cycle_no
==
(
self
.
config
.
no_cycles
-
1
))
is_final_iter
=
(
cycle_no
==
(
self
.
config
.
no_cycles
-
1
))
with
torch
.
set_grad_enabled
(
is_grad_enabled
and
is_final_iter
):
with
torch
.
set_grad_enabled
(
is_grad_enabled
and
is_final_iter
):
# Sidestep AMP bug discussed in pytorch issue #65766
# Sidestep AMP bug discussed in pytorch issue #65766
if
(
is_final_iter
and
torch
.
is_autocast_enabled
()):
if
(
is_final_iter
):
torch
.
clear_autocast_cache
()
self
.
_enable_activation_checkpointing
()
if
(
torch
.
is_autocast_enabled
()):
torch
.
clear_autocast_cache
()
# Run the next iteration of the model
# Run the next iteration of the model
outputs
,
m_1_prev
,
z_prev
,
x_prev
=
self
.
iteration
(
outputs
,
m_1_prev
,
z_prev
,
x_prev
=
self
.
iteration
(
feats
,
m_1_prev
,
z_prev
,
x_prev
,
feats
,
m_1_prev
,
z_prev
,
x_prev
,
...
...
openfold/model/primitives.py
View file @
893fe372
...
@@ -257,8 +257,6 @@ class Attention(nn.Module):
...
@@ -257,8 +257,6 @@ class Attention(nn.Module):
a
=
a
+
b
a
=
a
+
b
a
=
self
.
softmax
(
a
)
a
=
self
.
softmax
(
a
)
#print(torch.any(torch.isnan(a)))
# [*, H, Q, C_hidden]
# [*, H, Q, C_hidden]
o
=
torch
.
matmul
(
o
=
torch
.
matmul
(
a
,
a
,
...
...
openfold/model/structure_module.py
View file @
893fe372
...
@@ -26,6 +26,10 @@ from openfold.np.residue_constants import (
...
@@ -26,6 +26,10 @@ from openfold.np.residue_constants import (
restype_atom14_rigid_group_positions
,
restype_atom14_rigid_group_positions
,
)
)
from
openfold.utils.affine_utils
import
T
,
quat_to_rot
from
openfold.utils.affine_utils
import
T
,
quat_to_rot
from
openfold.utils.feats
import
(
frames_and_literature_positions_to_atom14_pos
,
torsion_angles_to_frames
,
)
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
dict_multimap
,
dict_multimap
,
permute_final_dims
,
permute_final_dims
,
...
@@ -305,7 +309,7 @@ class InvariantPointAttention(nn.Module):
...
@@ -305,7 +309,7 @@ class InvariantPointAttention(nn.Module):
pt_att
=
pt_att
**
2
pt_att
=
pt_att
**
2
# [*, N_res, N_res, H, P_q]
# [*, N_res, N_res, H, P_q]
pt_att
=
torch
.
sum
(
pt_att
,
dim
=-
1
)
pt_att
=
sum
(
torch
.
unbind
(
pt_att
,
dim
=-
1
)
)
head_weights
=
self
.
softplus
(
self
.
head_weights
).
view
(
head_weights
=
self
.
softplus
(
self
.
head_weights
).
view
(
*
((
1
,)
*
len
(
pt_att
.
shape
[:
-
2
])
+
(
-
1
,
1
))
*
((
1
,)
*
len
(
pt_att
.
shape
[:
-
2
])
+
(
-
1
,
1
))
)
)
...
@@ -358,7 +362,7 @@ class InvariantPointAttention(nn.Module):
...
@@ -358,7 +362,7 @@ class InvariantPointAttention(nn.Module):
)
)
# [*, N_res, H * P_v, 3]
# [*, N_res, H * P_v, 3]
o_pt
=
o_pt
.
view
(
*
o_pt
.
shape
[:
-
3
],
-
1
,
3
)
o_pt
=
o_pt
.
reshape
(
*
o_pt
.
shape
[:
-
3
],
-
1
,
3
)
# [*, N_res, H, C_z]
# [*, N_res, H, C_z]
o_pair
=
torch
.
matmul
(
a
.
transpose
(
-
2
,
-
3
),
z
)
o_pair
=
torch
.
matmul
(
a
.
transpose
(
-
2
,
-
3
),
z
)
...
@@ -409,14 +413,17 @@ class BackboneUpdate(nn.Module):
...
@@ -409,14 +413,17 @@ class BackboneUpdate(nn.Module):
quats
,
trans
=
params
[...,:
3
],
params
[...,
3
:]
quats
,
trans
=
params
[...,:
3
],
params
[...,
3
:]
# [*]
# [*]
#norm_denom = torch.sqrt(sum(torch.unbind(quats ** 2, dim=-1)) + 1)
norm_denom
=
torch
.
sqrt
(
torch
.
sum
(
quats
**
2
,
dim
=-
1
)
+
1
)
norm_denom
=
torch
.
sqrt
(
torch
.
sum
(
quats
**
2
,
dim
=-
1
)
+
1
)
# As many ones as there are dimensions in quats
# [*, 3]
ones
=
s
.
new_ones
((
1
,)
*
len
(
quats
.
shape
))
ones
=
(
s
.
new_ones
((
1
,)
*
len
(
quats
.
shape
)).
expand
(
quats
.
shape
[:
-
1
]
+
(
1
,))
)
# [*, 4]
# [*, 4]
quats
=
torch
.
cat
(
(
ones
.
expand
(
*
quats
.
shape
[:
-
1
],
1
)
,
quats
)
,
dim
=-
1
)
quats
=
torch
.
cat
(
[
ones
,
quats
]
,
dim
=-
1
)
quats
=
quats
/
norm_denom
.
unsqueeze
(
-
1
)
quats
=
quats
/
norm_denom
[...,
None
]
# [*, 3, 3]
# [*, 3, 3]
rots
=
quat_to_rot
(
quats
)
rots
=
quat_to_rot
(
quats
)
...
@@ -424,105 +431,6 @@ class BackboneUpdate(nn.Module):
...
@@ -424,105 +431,6 @@ class BackboneUpdate(nn.Module):
return
T
(
rots
,
trans
)
return
T
(
rots
,
trans
)
def
_torsion_angles_to_frames
(
t
,
alpha
,
f
,
rrgdf
):
# [*, N, 8, 4, 4]
default_4x4
=
rrgdf
[
f
,...]
# [*, N, 8] transformations, i.e.
# One [*, N, 8, 3, 3] rotation matrix and
# One [*, N, 8, 3] translation matrix
default_t
=
T
.
from_4x4
(
default_4x4
)
bb_rot
=
alpha
.
new_zeros
((
*
((
1
,)
*
len
(
alpha
.
shape
[:
-
1
])),
2
))
bb_rot
[...,
1
]
=
1
# [*, N, 8, 2]
alpha
=
torch
.
cat
(
[
bb_rot
.
expand
(
*
alpha
.
shape
[:
-
2
],
-
1
,
-
1
),
alpha
],
dim
=-
2
)
# [*, N, 8, 3, 3]
# Produces rotation matrices of the form:
# [
# [1, 0 , 0 ],
# [0, a_2,-a_1],
# [0, a_1, a_2]
# ]
# This follows the original code rather than the supplement, which uses
# different indices.
all_rots
=
alpha
.
new_zeros
(
default_t
.
rots
.
shape
)
all_rots
[...,
0
,
0
]
=
1
all_rots
[...,
1
,
1
]
=
alpha
[...,
1
]
all_rots
[...,
1
,
2
]
=
-
alpha
[...,
0
]
all_rots
[...,
2
,
1
:]
=
alpha
all_rots
=
T
(
all_rots
,
None
)
all_frames
=
default_t
.
compose
(
all_rots
)
chi2_frame_to_frame
=
all_frames
[...,
5
]
chi3_frame_to_frame
=
all_frames
[...,
6
]
chi4_frame_to_frame
=
all_frames
[...,
7
]
chi1_frame_to_bb
=
all_frames
[...,
4
]
chi2_frame_to_bb
=
chi1_frame_to_bb
.
compose
(
chi2_frame_to_frame
)
chi3_frame_to_bb
=
chi2_frame_to_bb
.
compose
(
chi3_frame_to_frame
)
chi4_frame_to_bb
=
chi3_frame_to_bb
.
compose
(
chi4_frame_to_frame
)
all_frames_to_bb
=
T
.
concat
([
all_frames
[...,
:
5
],
chi2_frame_to_bb
.
unsqueeze
(
-
1
),
chi3_frame_to_bb
.
unsqueeze
(
-
1
),
chi4_frame_to_bb
.
unsqueeze
(
-
1
),
],
dim
=-
1
,
)
all_frames_to_global
=
t
[...,
None
].
compose
(
all_frames_to_bb
)
return
all_frames_to_global
def
_frames_and_literature_positions_to_atom14_pos
(
t
,
f
,
default_frames
,
group_idx
,
atom_mask
,
lit_positions
,
):
# [*, N, 14, 4, 4]
default_4x4
=
default_frames
[
f
,
...]
# [*, N, 14]
group_mask
=
group_idx
[
f
,
...]
# [*, N, 14, 8]
group_mask
=
nn
.
functional
.
one_hot
(
group_mask
,
num_classes
=
default_frames
.
shape
[
-
3
],
)
# [*, N, 14, 8]
t_atoms_to_global
=
t
[...,
None
,
:]
*
group_mask
# [*, N, 14]
t_atoms_to_global
=
t_atoms_to_global
.
map_tensor_fn
(
lambda
x
:
torch
.
sum
(
x
,
dim
=-
1
)
)
# [*, N, 14, 1]
atom_mask
=
atom_mask
[
f
,...].
unsqueeze
(
-
1
)
# [*, N, 14, 3]
lit_positions
=
lit_positions
[
f
,
...]
pred_positions
=
t_atoms_to_global
.
apply
(
lit_positions
)
pred_positions
=
pred_positions
*
atom_mask
return
pred_positions
class
StructureModuleTransitionLayer
(
nn
.
Module
):
class
StructureModuleTransitionLayer
(
nn
.
Module
):
def
__init__
(
self
,
c
):
def
__init__
(
self
,
c
):
super
(
StructureModuleTransitionLayer
,
self
).
__init__
()
super
(
StructureModuleTransitionLayer
,
self
).
__init__
()
...
@@ -664,6 +572,7 @@ class StructureModule(nn.Module):
...
@@ -664,6 +572,7 @@ class StructureModule(nn.Module):
self
.
no_qk_points
,
self
.
no_qk_points
,
self
.
no_v_points
,
self
.
no_v_points
,
inf
=
self
.
inf
,
inf
=
self
.
inf
,
eps
=
self
.
epsilon
,
)
)
self
.
ipa_dropout
=
nn
.
Dropout
(
self
.
dropout_rate
)
self
.
ipa_dropout
=
nn
.
Dropout
(
self
.
dropout_rate
)
...
@@ -791,7 +700,7 @@ class StructureModule(nn.Module):
...
@@ -791,7 +700,7 @@ class StructureModule(nn.Module):
# Lazily initialize the residue constants on the correct device
# Lazily initialize the residue constants on the correct device
self
.
_init_residue_constants
(
alpha
.
dtype
,
alpha
.
device
)
self
.
_init_residue_constants
(
alpha
.
dtype
,
alpha
.
device
)
# Separated purely to make testing less annoying
# Separated purely to make testing less annoying
return
_
torsion_angles_to_frames
(
t
,
alpha
,
f
,
self
.
default_frames
)
return
torsion_angles_to_frames
(
t
,
alpha
,
f
,
self
.
default_frames
)
def
frames_and_literature_positions_to_atom14_pos
(
self
,
def
frames_and_literature_positions_to_atom14_pos
(
self
,
t
,
# [*, N, 8]
t
,
# [*, N, 8]
...
@@ -799,7 +708,7 @@ class StructureModule(nn.Module):
...
@@ -799,7 +708,7 @@ class StructureModule(nn.Module):
):
):
# Lazily initialize the residue constants on the correct device
# Lazily initialize the residue constants on the correct device
self
.
_init_residue_constants
(
t
.
rots
.
dtype
,
t
.
rots
.
device
)
self
.
_init_residue_constants
(
t
.
rots
.
dtype
,
t
.
rots
.
device
)
return
_
frames_and_literature_positions_to_atom14_pos
(
return
frames_and_literature_positions_to_atom14_pos
(
t
,
t
,
f
,
f
,
self
.
default_frames
,
self
.
default_frames
,
...
...
openfold/utils/affine_utils.py
View file @
893fe372
...
@@ -188,17 +188,29 @@ class T:
...
@@ -188,17 +188,29 @@ class T:
@
staticmethod
@
staticmethod
def
from_3_points
(
p_neg_x_axis
,
origin
,
p_xy_plane
,
eps
=
1e-8
):
def
from_3_points
(
p_neg_x_axis
,
origin
,
p_xy_plane
,
eps
=
1e-8
):
e0
=
origin
-
p_neg_x_axis
p_neg_x_axis
=
torch
.
unbind
(
p_neg_x_axis
,
dim
=-
1
)
e1
=
p_xy_plane
-
origin
origin
=
torch
.
unbind
(
origin
,
dim
=-
1
)
p_xy_plane
=
torch
.
unbind
(
p_xy_plane
,
dim
=-
1
)
e0
=
e0
/
torch
.
sqrt
(
torch
.
sum
(
e0
**
2
,
dim
=-
1
,
keepdims
=
True
)
+
eps
)
e1
=
e1
-
e0
*
torch
.
sum
(
e0
*
e1
,
dim
=-
1
,
keepdims
=
True
)
e0
=
[
c1
-
c2
for
c1
,
c2
in
zip
(
origin
,
p_neg_x_axis
)]
e1
=
e1
/
torch
.
sqrt
(
torch
.
sum
(
e1
**
2
,
dim
=-
1
,
keepdims
=
True
)
+
eps
)
e1
=
[
c1
-
c2
for
c1
,
c2
in
zip
(
p_xy_plane
,
origin
)]
e2
=
torch
.
cross
(
e0
,
e1
)
denom
=
torch
.
sqrt
(
sum
((
c
*
c
for
c
in
e0
))
+
eps
)
rots
=
torch
.
stack
([
e0
,
e1
,
e2
],
dim
=-
1
)
e0
=
[
c
/
denom
for
c
in
e0
]
dot
=
sum
((
c1
*
c2
for
c1
,
c2
in
zip
(
e0
,
e1
)))
return
T
(
rots
,
origin
)
e1
=
[
c1
-
c2
*
dot
for
c1
,
c2
in
zip
(
e1
,
e0
)]
denom
=
torch
.
sqrt
(
sum
((
c
*
c
for
c
in
e1
))
+
eps
)
e1
=
[
c
/
denom
for
c
in
e1
]
e2
=
[
e0
[
1
]
*
e1
[
2
]
-
e0
[
2
]
*
e1
[
1
],
e0
[
2
]
*
e1
[
0
]
-
e0
[
0
]
*
e1
[
2
],
e0
[
0
]
*
e1
[
1
]
-
e0
[
1
]
*
e1
[
0
],
]
rots
=
torch
.
stack
([
c
for
tup
in
zip
(
e0
,
e1
,
e2
)
for
c
in
tup
],
dim
=-
1
)
rots
=
rots
.
reshape
(
rots
.
shape
[:
-
1
]
+
(
3
,
3
))
return
T
(
rots
,
torch
.
stack
(
origin
,
dim
=-
1
))
@
staticmethod
@
staticmethod
def
concat
(
ts
,
dim
):
def
concat
(
ts
,
dim
):
...
@@ -294,6 +306,9 @@ class T:
...
@@ -294,6 +306,9 @@ class T:
return
T
(
rots
,
translation
)
return
T
(
rots
,
translation
)
def
cuda
(
self
):
return
T
(
self
.
rots
.
cuda
(),
self
.
trans
.
cuda
())
_quat_elements
=
[
'a'
,
'b'
,
'c'
,
'd'
]
_quat_elements
=
[
'a'
,
'b'
,
'c'
,
'd'
]
_qtr_keys
=
[
l1
+
l2
for
l1
in
_quat_elements
for
l2
in
_quat_elements
]
_qtr_keys
=
[
l1
+
l2
for
l1
in
_quat_elements
for
l2
in
_quat_elements
]
...
@@ -325,10 +340,11 @@ def quat_to_rot(
...
@@ -325,10 +340,11 @@ def quat_to_rot(
# [*, 4, 4]
# [*, 4, 4]
quat
=
quat
[...,
None
]
*
quat
[...,
None
,
:]
quat
=
quat
[...,
None
]
*
quat
[...,
None
,
:]
# [4, 4, 3, 3]
mat
=
quat
.
new_tensor
(
_qtr_mat
)
mat
=
quat
.
new_tensor
(
_qtr_mat
)
# [*, 4, 4, 3, 3]
# [*, 4, 4, 3, 3]
shaped_qtr_mat
=
mat
.
view
((
1
,)
*
len
(
quat
.
shape
[:
-
2
])
+
(
4
,
4
,
3
,
3
)
)
shaped_qtr_mat
=
mat
.
view
((
1
,)
*
len
(
quat
.
shape
[:
-
2
])
+
mat
.
shape
)
quat
=
quat
[...,
None
,
None
]
*
shaped_qtr_mat
quat
=
quat
[...,
None
,
None
]
*
shaped_qtr_mat
# [*, 3, 3]
# [*, 3, 3]
...
...
openfold/utils/deepspeed.py
View file @
893fe372
...
@@ -70,11 +70,8 @@ def checkpoint_blocks(
...
@@ -70,11 +70,8 @@ def checkpoint_blocks(
for
s
in
range
(
0
,
len
(
blocks
),
blocks_per_ckpt
):
for
s
in
range
(
0
,
len
(
blocks
),
blocks_per_ckpt
):
e
=
s
+
blocks_per_ckpt
e
=
s
+
blocks_per_ckpt
#print(len(args))
#args = checkpoint(chunker(s, e), *args)
#for a in args:
args
=
deepspeed
.
checkpointing
.
checkpoint
(
chunker
(
s
,
e
),
*
args
)
# print(a.requires_grad)
args
=
checkpoint
(
chunker
(
s
,
e
),
*
args
)
#args = deepspeed.checkpointing.checkpoint(chunker(s, e), *args)
args
=
wrap
(
args
)
args
=
wrap
(
args
)
return
args
return
args
openfold/utils/feats.py
View file @
893fe372
...
@@ -173,6 +173,11 @@ def atom37_to_torsion_angles(
...
@@ -173,6 +173,11 @@ def atom37_to_torsion_angles(
**
kwargs
,
**
kwargs
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""
"""
Convert coordinates to torsion angles.
This function is extremely sensitive to floating point imprecisions
and should be run with double precision whenever possible.
Args:
Args:
aatype:
aatype:
[*, N_res] residue indices
[*, N_res] residue indices
...
@@ -228,10 +233,10 @@ def atom37_to_torsion_angles(
...
@@ -228,10 +233,10 @@ def atom37_to_torsion_angles(
)
)
phi_mask
=
(
phi_mask
=
(
prev_all_atom_mask
[...,
2
]
*
prev_all_atom_mask
[...,
2
]
*
torch
.
prod
(
all_atom_mask
[...,
:
3
],
dim
=-
1
)
torch
.
prod
(
all_atom_mask
[...,
:
3
],
dim
=-
1
,
dtype
=
all_atom_mask
.
dtype
)
)
)
psi_mask
=
(
psi_mask
=
(
torch
.
prod
(
all_atom_mask
[...,
:
3
],
dim
=-
1
)
*
torch
.
prod
(
all_atom_mask
[...,
:
3
],
dim
=-
1
,
dtype
=
all_atom_mask
.
dtype
)
*
all_atom_mask
[...,
4
]
all_atom_mask
[...,
4
]
)
)
...
@@ -256,7 +261,9 @@ def atom37_to_torsion_angles(
...
@@ -256,7 +261,9 @@ def atom37_to_torsion_angles(
dim
=-
1
,
dim
=-
1
,
no_batch_dims
=
len
(
atom_indices
.
shape
[:
-
2
])
no_batch_dims
=
len
(
atom_indices
.
shape
[:
-
2
])
)
)
chi_angle_atoms_mask
=
torch
.
prod
(
chi_angle_atoms_mask
,
dim
=-
1
)
chi_angle_atoms_mask
=
torch
.
prod
(
chi_angle_atoms_mask
,
dim
=-
1
,
dtype
=
chi_angle_atoms_mask
.
dtype
)
chis_mask
=
chis_mask
*
chi_angle_atoms_mask
chis_mask
=
chis_mask
*
chi_angle_atoms_mask
torsions_atom_pos
=
torch
.
cat
(
torsions_atom_pos
=
torch
.
cat
(
...
@@ -281,6 +288,7 @@ def atom37_to_torsion_angles(
...
@@ -281,6 +288,7 @@ def atom37_to_torsion_angles(
torsions_atom_pos
[...,
1
,
:],
torsions_atom_pos
[...,
1
,
:],
torsions_atom_pos
[...,
2
,
:],
torsions_atom_pos
[...,
2
,
:],
torsions_atom_pos
[...,
0
,
:],
torsions_atom_pos
[...,
0
,
:],
eps
=
eps
,
)
)
fourth_atom_rel_pos
=
torsion_frames
.
invert
().
apply
(
fourth_atom_rel_pos
=
torsion_frames
.
invert
().
apply
(
...
@@ -290,15 +298,19 @@ def atom37_to_torsion_angles(
...
@@ -290,15 +298,19 @@ def atom37_to_torsion_angles(
torsion_angles_sin_cos
=
torch
.
stack
(
torsion_angles_sin_cos
=
torch
.
stack
(
[
fourth_atom_rel_pos
[...,
2
],
fourth_atom_rel_pos
[...,
1
]],
dim
=-
1
[
fourth_atom_rel_pos
[...,
2
],
fourth_atom_rel_pos
[...,
1
]],
dim
=-
1
)
)
denom
=
torch
.
sqrt
(
denom
=
torch
.
sqrt
(
torch
.
sum
(
torch
.
sum
(
torch
.
square
(
torsion_angles_sin_cos
),
dim
=-
1
,
keepdims
=
True
torch
.
square
(
torsion_angles_sin_cos
),
dim
=-
1
,
dtype
=
torsion_angles_sin_cos
.
dtype
,
keepdims
=
True
)
+
eps
)
+
eps
)
)
torsion_angles_sin_cos
=
torsion_angles_sin_cos
/
denom
torsion_angles_sin_cos
=
torsion_angles_sin_cos
/
denom
torsion_angles_sin_cos
=
torsion_angles_sin_cos
*
torch
.
tensor
(
torsion_angles_sin_cos
=
torsion_angles_sin_cos
*
all_atom_mask
.
new_
tensor
(
[
1.
,
1.
,
-
1.
,
1.
,
1.
,
1.
,
1.
],
device
=
aatype
.
device
,
[
1.
,
1.
,
-
1.
,
1.
,
1.
,
1.
,
1.
],
)[((
None
,)
*
len
(
torsion_angles_sin_cos
.
shape
[:
-
2
]))
+
(
slice
(
None
),
None
)]
)[((
None
,)
*
len
(
torsion_angles_sin_cos
.
shape
[:
-
2
]))
+
(
slice
(
None
),
None
)]
chi_is_ambiguous
=
torsion_angles_sin_cos
.
new_tensor
(
chi_is_ambiguous
=
torsion_angles_sin_cos
.
new_tensor
(
...
@@ -327,6 +339,7 @@ def atom37_to_frames(
...
@@ -327,6 +339,7 @@ def atom37_to_frames(
aatype
:
torch
.
Tensor
,
aatype
:
torch
.
Tensor
,
all_atom_positions
:
torch
.
Tensor
,
all_atom_positions
:
torch
.
Tensor
,
all_atom_mask
:
torch
.
Tensor
,
all_atom_mask
:
torch
.
Tensor
,
eps
:
float
,
**
kwargs
,
**
kwargs
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
)
->
Dict
[
str
,
torch
.
Tensor
]:
batch_dims
=
len
(
aatype
.
shape
[:
-
1
])
batch_dims
=
len
(
aatype
.
shape
[:
-
1
])
...
@@ -387,6 +400,7 @@ def atom37_to_frames(
...
@@ -387,6 +400,7 @@ def atom37_to_frames(
p_neg_x_axis
=
base_atom_pos
[...,
0
,
:],
p_neg_x_axis
=
base_atom_pos
[...,
0
,
:],
origin
=
base_atom_pos
[...,
1
,
:],
origin
=
base_atom_pos
[...,
1
,
:],
p_xy_plane
=
base_atom_pos
[...,
2
,
:],
p_xy_plane
=
base_atom_pos
[...,
2
,
:],
eps
=
eps
,
)
)
group_exists
=
batched_gather
(
group_exists
=
batched_gather
(
...
@@ -638,3 +652,106 @@ def build_ambiguity_feats(batch: Dict[str, torch.Tensor]) -> None:
...
@@ -638,3 +652,106 @@ def build_ambiguity_feats(batch: Dict[str, torch.Tensor]) -> None:
"atom14_alt_gt_positions"
:
atom14_alt_gt_positions
,
"atom14_alt_gt_positions"
:
atom14_alt_gt_positions
,
"atom14_alt_gt_exists"
:
atom14_alt_gt_exists
,
"atom14_alt_gt_exists"
:
atom14_alt_gt_exists
,
}
}
def
torsion_angles_to_frames
(
t
:
T
,
alpha
:
torch
.
Tensor
,
aatype
:
torch
.
Tensor
,
rrgdf
:
torch
.
Tensor
,
):
# [*, N, 8, 4, 4]
default_4x4
=
rrgdf
[
aatype
,
...]
# [*, N, 8] transformations, i.e.
# One [*, N, 8, 3, 3] rotation matrix and
# One [*, N, 8, 3] translation matrix
default_t
=
T
.
from_4x4
(
default_4x4
)
bb_rot
=
alpha
.
new_zeros
((
*
((
1
,)
*
len
(
alpha
.
shape
[:
-
1
])),
2
))
bb_rot
[...,
1
]
=
1
# [*, N, 8, 2]
alpha
=
torch
.
cat
(
[
bb_rot
.
expand
(
*
alpha
.
shape
[:
-
2
],
-
1
,
-
1
),
alpha
],
dim
=-
2
)
# [*, N, 8, 3, 3]
# Produces rotation matrices of the form:
# [
# [1, 0 , 0 ],
# [0, a_2,-a_1],
# [0, a_1, a_2]
# ]
# This follows the original code rather than the supplement, which uses
# different indices.
all_rots
=
alpha
.
new_zeros
(
default_t
.
rots
.
shape
)
all_rots
[...,
0
,
0
]
=
1
all_rots
[...,
1
,
1
]
=
alpha
[...,
1
]
all_rots
[...,
1
,
2
]
=
-
alpha
[...,
0
]
all_rots
[...,
2
,
1
:]
=
alpha
all_rots
=
T
(
all_rots
,
None
)
all_frames
=
default_t
.
compose
(
all_rots
)
chi2_frame_to_frame
=
all_frames
[...,
5
]
chi3_frame_to_frame
=
all_frames
[...,
6
]
chi4_frame_to_frame
=
all_frames
[...,
7
]
chi1_frame_to_bb
=
all_frames
[...,
4
]
chi2_frame_to_bb
=
chi1_frame_to_bb
.
compose
(
chi2_frame_to_frame
)
chi3_frame_to_bb
=
chi2_frame_to_bb
.
compose
(
chi3_frame_to_frame
)
chi4_frame_to_bb
=
chi3_frame_to_bb
.
compose
(
chi4_frame_to_frame
)
all_frames_to_bb
=
T
.
concat
([
all_frames
[...,
:
5
],
chi2_frame_to_bb
.
unsqueeze
(
-
1
),
chi3_frame_to_bb
.
unsqueeze
(
-
1
),
chi4_frame_to_bb
.
unsqueeze
(
-
1
),
],
dim
=-
1
,
)
all_frames_to_global
=
t
[...,
None
].
compose
(
all_frames_to_bb
)
return
all_frames_to_global
def
frames_and_literature_positions_to_atom14_pos
(
t
:
T
,
aatype
:
torch
.
Tensor
,
default_frames
,
group_idx
,
atom_mask
,
lit_positions
,
):
# [*, N, 14, 4, 4]
default_4x4
=
default_frames
[
aatype
,
...]
# [*, N, 14]
group_mask
=
group_idx
[
aatype
,
...]
# [*, N, 14, 8]
group_mask
=
nn
.
functional
.
one_hot
(
group_mask
,
num_classes
=
default_frames
.
shape
[
-
3
],
)
# [*, N, 14, 8]
t_atoms_to_global
=
t
[...,
None
,
:]
*
group_mask
# [*, N, 14]
t_atoms_to_global
=
t_atoms_to_global
.
map_tensor_fn
(
lambda
x
:
torch
.
sum
(
x
,
dim
=-
1
)
)
# [*, N, 14, 1]
atom_mask
=
atom_mask
[
aatype
,
...].
unsqueeze
(
-
1
)
# [*, N, 14, 3]
lit_positions
=
lit_positions
[
aatype
,
...]
pred_positions
=
t_atoms_to_global
.
apply
(
lit_positions
)
pred_positions
=
pred_positions
*
atom_mask
return
pred_positions
openfold/utils/import_weights.py
View file @
893fe372
...
@@ -417,7 +417,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
...
@@ -417,7 +417,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
if
(
"_ptm"
in
version
):
if
(
"_ptm"
in
version
):
translations
[
"predicted_aligned_error_head"
]
=
{
translations
[
"predicted_aligned_error_head"
]
=
{
"logits"
:
"logits"
:
LinearParams
(
model
.
aux_heads
.
tm
_score
.
linear
)
LinearParams
(
model
.
aux_heads
.
tm
.
linear
)
}
}
# Flatten keys and insert missing key prefixes
# Flatten keys and insert missing key prefixes
...
...
openfold/utils/loss.py
View file @
893fe372
...
@@ -273,7 +273,6 @@ def supervised_chi_loss(
...
@@ -273,7 +273,6 @@ def supervised_chi_loss(
shifted_mask
=
(
1
-
2
*
chi_pi_periodic
).
unsqueeze
(
-
1
)
shifted_mask
=
(
1
-
2
*
chi_pi_periodic
).
unsqueeze
(
-
1
)
true_chi_shifted
=
shifted_mask
*
true_chi
true_chi_shifted
=
shifted_mask
*
true_chi
sq_chi_error
=
torch
.
sum
(
sq_chi_error
=
torch
.
sum
(
(
true_chi
-
pred_angles
)
**
2
,
dim
=-
1
(
true_chi
-
pred_angles
)
**
2
,
dim
=-
1
)
)
...
@@ -498,11 +497,11 @@ def tm_loss(
...
@@ -498,11 +497,11 @@ def tm_loss(
)
)
loss
=
torch
.
sum
(
errors
*
square_mask
,
dim
=-
1
)
loss
=
torch
.
sum
(
errors
*
square_mask
,
dim
=-
1
)
scale
=
0.
1
# hack to help FP16 training along
scale
=
0.
5
# hack to help FP16 training along
denom
=
eps
+
torch
.
sum
(
scale
*
square_mask
,
dim
=
(
-
1
,
-
2
))
denom
=
eps
+
torch
.
sum
(
scale
*
square_mask
,
dim
=
(
-
1
,
-
2
))
loss
=
loss
/
denom
[...,
None
]
loss
=
loss
/
denom
[...,
None
]
loss
=
torch
.
sum
(
loss
,
dim
=-
1
)
loss
=
torch
.
sum
(
loss
,
dim
=-
1
)
loss
=
loss
/
scale
loss
=
loss
*
scale
loss
=
loss
*
(
loss
=
loss
*
(
(
resolution
>=
min_resolution
)
&
(
resolution
>=
min_resolution
)
&
...
@@ -744,7 +743,7 @@ def between_residue_clash_loss(
...
@@ -744,7 +743,7 @@ def between_residue_clash_loss(
# Backbone C--N bond between subsequent residues is no clash.
# Backbone C--N bond between subsequent residues is no clash.
c_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
c_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
residue_index
.
new_tensor
(
2
.
),
num_classes
=
14
residue_index
.
new_tensor
(
2
),
num_classes
=
14
)
)
c_one_hot
=
c_one_hot
.
reshape
(
c_one_hot
=
c_one_hot
.
reshape
(
*
((
1
,)
*
len
(
residue_index
.
shape
[:
-
1
])),
*
c_one_hot
.
shape
*
((
1
,)
*
len
(
residue_index
.
shape
[:
-
1
])),
*
c_one_hot
.
shape
...
@@ -1319,11 +1318,11 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
...
@@ -1319,11 +1318,11 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
# )
# )
loss
=
errors
*
bert_mask
loss
=
errors
*
bert_mask
loss
=
torch
.
sum
(
loss
,
dim
=-
1
)
loss
=
torch
.
sum
(
loss
,
dim
=-
1
)
scale
=
0.
1
scale
=
0.
5
denom
=
eps
+
torch
.
sum
(
scale
*
bert_mask
,
dim
=
(
-
1
,
-
2
))
denom
=
eps
+
torch
.
sum
(
scale
*
bert_mask
,
dim
=
(
-
1
,
-
2
))
loss
=
loss
/
denom
[...,
None
]
loss
=
loss
/
denom
[...,
None
]
loss
=
torch
.
sum
(
loss
,
dim
=-
1
)
loss
=
torch
.
sum
(
loss
,
dim
=-
1
)
loss
=
loss
/
scale
loss
=
loss
*
scale
return
loss
return
loss
...
@@ -1352,7 +1351,7 @@ class AlphaFoldLoss(nn.Module):
...
@@ -1352,7 +1351,7 @@ class AlphaFoldLoss(nn.Module):
))
))
if
(
"backbone_affine_tensor"
not
in
batch
.
keys
()):
if
(
"backbone_affine_tensor"
not
in
batch
.
keys
()):
batch
.
update
(
feats
.
atom37_to_frames
(
**
batch
))
batch
.
update
(
feats
.
atom37_to_frames
(
eps
=
self
.
config
.
eps
,
**
batch
))
# TODO: Verify that this is correct
# TODO: Verify that this is correct
batch
[
"backbone_affine_tensor"
]
=
(
batch
[
"backbone_affine_tensor"
]
=
(
...
@@ -1363,16 +1362,19 @@ class AlphaFoldLoss(nn.Module):
...
@@ -1363,16 +1362,19 @@ class AlphaFoldLoss(nn.Module):
)
)
if
(
"chi_angles_sin_cos"
not
in
batch
.
keys
()):
if
(
"chi_angles_sin_cos"
not
in
batch
.
keys
()):
batch
.
update
(
feats
.
atom37_to_torsion_angles
(
with
torch
.
no_grad
():
**
batch
,
batch
.
update
(
feats
.
atom37_to_torsion_angles
(
eps
=
self
.
config
.
eps
,
aatype
=
batch
[
"aatype"
],
))
all_atom_positions
=
batch
[
"all_atom_positions"
].
double
(),
all_atom_mask
=
batch
[
"all_atom_mask"
].
double
(),
# TODO: Verify that this is correct
eps
=
self
.
config
.
eps
,
batch
[
"chi_angles_sin_cos"
]
=
(
))
batch
[
"torsion_angles_sin_cos"
][...,
3
:,
:]
)
# TODO: Verify that this is correct
batch
[
"chi_mask"
]
=
batch
[
"torsion_angles_mask"
][...,
3
:]
batch
[
"chi_angles_sin_cos"
]
=
(
batch
[
"torsion_angles_sin_cos"
][...,
3
:,
:]
).
to
(
batch
[
"all_atom_mask"
].
dtype
)
batch
[
"chi_mask"
]
=
batch
[
"torsion_angles_mask"
][...,
3
:].
to
(
batch
[
"all_atom_mask"
].
dtype
)
loss_fns
=
{
loss_fns
=
{
"distogram"
:
"distogram"
:
...
...
openfold/utils/tensor_utils.py
View file @
893fe372
...
@@ -25,11 +25,11 @@ def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
...
@@ -25,11 +25,11 @@ def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
return
tensor
.
permute
(
first_inds
+
[
zero_index
+
i
for
i
in
inds
])
return
tensor
.
permute
(
first_inds
+
[
zero_index
+
i
for
i
in
inds
])
def
flatten_final_dims
(
t
ensor
:
torch
.
Tensor
,
no_dims
:
int
):
def
flatten_final_dims
(
t
:
torch
.
Tensor
,
no_dims
:
int
):
return
t
ensor
.
reshape
(
t
ensor
.
shape
[:
-
no_dims
]
+
(
-
1
,))
return
t
.
reshape
(
t
.
shape
[:
-
no_dims
]
+
(
-
1
,))
def
masked_mean
(
mask
,
value
,
dim
,
eps
=
1e-
10
):
def
masked_mean
(
mask
,
value
,
dim
,
eps
=
1e-
4
):
mask
=
mask
.
expand
(
*
value
.
shape
)
mask
=
mask
.
expand
(
*
value
.
shape
)
return
torch
.
sum
(
mask
*
value
,
dim
=
dim
)
/
(
eps
+
torch
.
sum
(
mask
,
dim
=
dim
))
return
torch
.
sum
(
mask
*
value
,
dim
=
dim
)
/
(
eps
+
torch
.
sum
(
mask
,
dim
=
dim
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment