Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
893fe372
"examples/python_rs/vscode:/vscode.git/clone" did not exist on "861c50982b702abbece18ad4cdc46dbe9a10cbd6"
Commit
893fe372
authored
Oct 05, 2021
by
Gustaf Ahdritz
Browse files
Get FP16 training working
parent
dd06b323
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
203 additions
and
161 deletions
+203
-161
config.py
config.py
+6
-6
openfold/model/model.py
openfold/model/model.py
+5
-2
openfold/model/primitives.py
openfold/model/primitives.py
+0
-2
openfold/model/structure_module.py
openfold/model/structure_module.py
+16
-107
openfold/utils/affine_utils.py
openfold/utils/affine_utils.py
+28
-12
openfold/utils/deepspeed.py
openfold/utils/deepspeed.py
+2
-5
openfold/utils/feats.py
openfold/utils/feats.py
+123
-6
openfold/utils/import_weights.py
openfold/utils/import_weights.py
+1
-1
openfold/utils/loss.py
openfold/utils/loss.py
+19
-17
openfold/utils/tensor_utils.py
openfold/utils/tensor_utils.py
+3
-3
No files found.
config.py
View file @
893fe372
...
@@ -44,7 +44,7 @@ def model_config(name, train=False, low_prec=False):
...
@@ -44,7 +44,7 @@ def model_config(name, train=False, low_prec=False):
raise
ValueError
(
"Invalid model name"
)
raise
ValueError
(
"Invalid model name"
)
if
(
train
):
if
(
train
):
c
.
globals
.
model
.
blocks_per_ckpt
=
1
c
.
globals
.
blocks_per_ckpt
=
1
c
.
globals
.
chunk_size
=
None
c
.
globals
.
chunk_size
=
None
if
(
low_prec
):
if
(
low_prec
):
...
@@ -137,7 +137,7 @@ config = mlc.ConfigDict({
...
@@ -137,7 +137,7 @@ config = mlc.ConfigDict({
},
},
"inf"
:
1e9
,
"inf"
:
1e9
,
"eps"
:
eps
,
#1e-6,
"eps"
:
eps
,
#1e-6,
"enabled"
:
True
,
"enabled"
:
False
,
#
True,
"embed_angles"
:
True
,
"embed_angles"
:
True
,
},
},
"extra_msa"
:
{
"extra_msa"
:
{
...
@@ -239,7 +239,7 @@ config = mlc.ConfigDict({
...
@@ -239,7 +239,7 @@ config = mlc.ConfigDict({
"max_bin"
:
21.6875
,
"max_bin"
:
21.6875
,
"no_bins"
:
64
,
"no_bins"
:
64
,
"eps"
:
eps
,
#1e-6,
"eps"
:
eps
,
#1e-6,
"weight"
:
0.
,
#
0.3,
"weight"
:
0.3
,
},
},
"experimentally_resolved"
:
{
"experimentally_resolved"
:
{
"eps"
:
eps
,
#1e-8,
"eps"
:
eps
,
#1e-8,
...
@@ -267,17 +267,17 @@ config = mlc.ConfigDict({
...
@@ -267,17 +267,17 @@ config = mlc.ConfigDict({
"cutoff"
:
15.
,
"cutoff"
:
15.
,
"no_bins"
:
50
,
"no_bins"
:
50
,
"eps"
:
eps
,
#1e-10,
"eps"
:
eps
,
#1e-10,
"weight"
:
0.
,
#
0.01,
"weight"
:
0.01
,
},
},
"masked_msa"
:
{
"masked_msa"
:
{
"eps"
:
eps
,
#1e-8,
"eps"
:
eps
,
#1e-8,
"weight"
:
0.
,
#
2.0,
"weight"
:
2.0
,
},
},
"supervised_chi"
:
{
"supervised_chi"
:
{
"chi_weight"
:
0.5
,
"chi_weight"
:
0.5
,
"angle_norm_weight"
:
0.01
,
"angle_norm_weight"
:
0.01
,
"eps"
:
eps
,
#1e-6,
"eps"
:
eps
,
#1e-6,
"weight"
:
0.
,
#
1.0,
"weight"
:
1.0
,
},
},
"violation"
:
{
"violation"
:
{
"violation_tolerance_factor"
:
12.0
,
"violation_tolerance_factor"
:
12.0
,
...
...
openfold/model/model.py
View file @
893fe372
...
@@ -389,6 +389,7 @@ class AlphaFold(nn.Module):
...
@@ -389,6 +389,7 @@ class AlphaFold(nn.Module):
m_1_prev
,
z_prev
,
x_prev
=
None
,
None
,
None
m_1_prev
,
z_prev
,
x_prev
=
None
,
None
,
None
is_grad_enabled
=
torch
.
is_grad_enabled
()
is_grad_enabled
=
torch
.
is_grad_enabled
()
self
.
_disable_activation_checkpointing
()
# Main recycling loop
# Main recycling loop
for
cycle_no
in
range
(
self
.
config
.
no_cycles
):
for
cycle_no
in
range
(
self
.
config
.
no_cycles
):
...
@@ -400,7 +401,9 @@ class AlphaFold(nn.Module):
...
@@ -400,7 +401,9 @@ class AlphaFold(nn.Module):
is_final_iter
=
(
cycle_no
==
(
self
.
config
.
no_cycles
-
1
))
is_final_iter
=
(
cycle_no
==
(
self
.
config
.
no_cycles
-
1
))
with
torch
.
set_grad_enabled
(
is_grad_enabled
and
is_final_iter
):
with
torch
.
set_grad_enabled
(
is_grad_enabled
and
is_final_iter
):
# Sidestep AMP bug discussed in pytorch issue #65766
# Sidestep AMP bug discussed in pytorch issue #65766
if
(
is_final_iter
and
torch
.
is_autocast_enabled
()):
if
(
is_final_iter
):
self
.
_enable_activation_checkpointing
()
if
(
torch
.
is_autocast_enabled
()):
torch
.
clear_autocast_cache
()
torch
.
clear_autocast_cache
()
# Run the next iteration of the model
# Run the next iteration of the model
outputs
,
m_1_prev
,
z_prev
,
x_prev
=
self
.
iteration
(
outputs
,
m_1_prev
,
z_prev
,
x_prev
=
self
.
iteration
(
...
...
openfold/model/primitives.py
View file @
893fe372
...
@@ -257,8 +257,6 @@ class Attention(nn.Module):
...
@@ -257,8 +257,6 @@ class Attention(nn.Module):
a
=
a
+
b
a
=
a
+
b
a
=
self
.
softmax
(
a
)
a
=
self
.
softmax
(
a
)
#print(torch.any(torch.isnan(a)))
# [*, H, Q, C_hidden]
# [*, H, Q, C_hidden]
o
=
torch
.
matmul
(
o
=
torch
.
matmul
(
a
,
a
,
...
...
openfold/model/structure_module.py
View file @
893fe372
...
@@ -26,6 +26,10 @@ from openfold.np.residue_constants import (
...
@@ -26,6 +26,10 @@ from openfold.np.residue_constants import (
restype_atom14_rigid_group_positions
,
restype_atom14_rigid_group_positions
,
)
)
from
openfold.utils.affine_utils
import
T
,
quat_to_rot
from
openfold.utils.affine_utils
import
T
,
quat_to_rot
from
openfold.utils.feats
import
(
frames_and_literature_positions_to_atom14_pos
,
torsion_angles_to_frames
,
)
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
dict_multimap
,
dict_multimap
,
permute_final_dims
,
permute_final_dims
,
...
@@ -305,7 +309,7 @@ class InvariantPointAttention(nn.Module):
...
@@ -305,7 +309,7 @@ class InvariantPointAttention(nn.Module):
pt_att
=
pt_att
**
2
pt_att
=
pt_att
**
2
# [*, N_res, N_res, H, P_q]
# [*, N_res, N_res, H, P_q]
pt_att
=
torch
.
sum
(
pt_att
,
dim
=-
1
)
pt_att
=
sum
(
torch
.
unbind
(
pt_att
,
dim
=-
1
)
)
head_weights
=
self
.
softplus
(
self
.
head_weights
).
view
(
head_weights
=
self
.
softplus
(
self
.
head_weights
).
view
(
*
((
1
,)
*
len
(
pt_att
.
shape
[:
-
2
])
+
(
-
1
,
1
))
*
((
1
,)
*
len
(
pt_att
.
shape
[:
-
2
])
+
(
-
1
,
1
))
)
)
...
@@ -358,7 +362,7 @@ class InvariantPointAttention(nn.Module):
...
@@ -358,7 +362,7 @@ class InvariantPointAttention(nn.Module):
)
)
# [*, N_res, H * P_v, 3]
# [*, N_res, H * P_v, 3]
o_pt
=
o_pt
.
view
(
*
o_pt
.
shape
[:
-
3
],
-
1
,
3
)
o_pt
=
o_pt
.
reshape
(
*
o_pt
.
shape
[:
-
3
],
-
1
,
3
)
# [*, N_res, H, C_z]
# [*, N_res, H, C_z]
o_pair
=
torch
.
matmul
(
a
.
transpose
(
-
2
,
-
3
),
z
)
o_pair
=
torch
.
matmul
(
a
.
transpose
(
-
2
,
-
3
),
z
)
...
@@ -409,14 +413,17 @@ class BackboneUpdate(nn.Module):
...
@@ -409,14 +413,17 @@ class BackboneUpdate(nn.Module):
quats
,
trans
=
params
[...,:
3
],
params
[...,
3
:]
quats
,
trans
=
params
[...,:
3
],
params
[...,
3
:]
# [*]
# [*]
#norm_denom = torch.sqrt(sum(torch.unbind(quats ** 2, dim=-1)) + 1)
norm_denom
=
torch
.
sqrt
(
torch
.
sum
(
quats
**
2
,
dim
=-
1
)
+
1
)
norm_denom
=
torch
.
sqrt
(
torch
.
sum
(
quats
**
2
,
dim
=-
1
)
+
1
)
# As many ones as there are dimensions in quats
# [*, 3]
ones
=
s
.
new_ones
((
1
,)
*
len
(
quats
.
shape
))
ones
=
(
s
.
new_ones
((
1
,)
*
len
(
quats
.
shape
)).
expand
(
quats
.
shape
[:
-
1
]
+
(
1
,))
)
# [*, 4]
# [*, 4]
quats
=
torch
.
cat
(
(
ones
.
expand
(
*
quats
.
shape
[:
-
1
],
1
)
,
quats
)
,
dim
=-
1
)
quats
=
torch
.
cat
(
[
ones
,
quats
]
,
dim
=-
1
)
quats
=
quats
/
norm_denom
.
unsqueeze
(
-
1
)
quats
=
quats
/
norm_denom
[...,
None
]
# [*, 3, 3]
# [*, 3, 3]
rots
=
quat_to_rot
(
quats
)
rots
=
quat_to_rot
(
quats
)
...
@@ -424,105 +431,6 @@ class BackboneUpdate(nn.Module):
...
@@ -424,105 +431,6 @@ class BackboneUpdate(nn.Module):
return
T
(
rots
,
trans
)
return
T
(
rots
,
trans
)
def
_torsion_angles_to_frames
(
t
,
alpha
,
f
,
rrgdf
):
# [*, N, 8, 4, 4]
default_4x4
=
rrgdf
[
f
,...]
# [*, N, 8] transformations, i.e.
# One [*, N, 8, 3, 3] rotation matrix and
# One [*, N, 8, 3] translation matrix
default_t
=
T
.
from_4x4
(
default_4x4
)
bb_rot
=
alpha
.
new_zeros
((
*
((
1
,)
*
len
(
alpha
.
shape
[:
-
1
])),
2
))
bb_rot
[...,
1
]
=
1
# [*, N, 8, 2]
alpha
=
torch
.
cat
(
[
bb_rot
.
expand
(
*
alpha
.
shape
[:
-
2
],
-
1
,
-
1
),
alpha
],
dim
=-
2
)
# [*, N, 8, 3, 3]
# Produces rotation matrices of the form:
# [
# [1, 0 , 0 ],
# [0, a_2,-a_1],
# [0, a_1, a_2]
# ]
# This follows the original code rather than the supplement, which uses
# different indices.
all_rots
=
alpha
.
new_zeros
(
default_t
.
rots
.
shape
)
all_rots
[...,
0
,
0
]
=
1
all_rots
[...,
1
,
1
]
=
alpha
[...,
1
]
all_rots
[...,
1
,
2
]
=
-
alpha
[...,
0
]
all_rots
[...,
2
,
1
:]
=
alpha
all_rots
=
T
(
all_rots
,
None
)
all_frames
=
default_t
.
compose
(
all_rots
)
chi2_frame_to_frame
=
all_frames
[...,
5
]
chi3_frame_to_frame
=
all_frames
[...,
6
]
chi4_frame_to_frame
=
all_frames
[...,
7
]
chi1_frame_to_bb
=
all_frames
[...,
4
]
chi2_frame_to_bb
=
chi1_frame_to_bb
.
compose
(
chi2_frame_to_frame
)
chi3_frame_to_bb
=
chi2_frame_to_bb
.
compose
(
chi3_frame_to_frame
)
chi4_frame_to_bb
=
chi3_frame_to_bb
.
compose
(
chi4_frame_to_frame
)
all_frames_to_bb
=
T
.
concat
([
all_frames
[...,
:
5
],
chi2_frame_to_bb
.
unsqueeze
(
-
1
),
chi3_frame_to_bb
.
unsqueeze
(
-
1
),
chi4_frame_to_bb
.
unsqueeze
(
-
1
),
],
dim
=-
1
,
)
all_frames_to_global
=
t
[...,
None
].
compose
(
all_frames_to_bb
)
return
all_frames_to_global
def
_frames_and_literature_positions_to_atom14_pos
(
t
,
f
,
default_frames
,
group_idx
,
atom_mask
,
lit_positions
,
):
# [*, N, 14, 4, 4]
default_4x4
=
default_frames
[
f
,
...]
# [*, N, 14]
group_mask
=
group_idx
[
f
,
...]
# [*, N, 14, 8]
group_mask
=
nn
.
functional
.
one_hot
(
group_mask
,
num_classes
=
default_frames
.
shape
[
-
3
],
)
# [*, N, 14, 8]
t_atoms_to_global
=
t
[...,
None
,
:]
*
group_mask
# [*, N, 14]
t_atoms_to_global
=
t_atoms_to_global
.
map_tensor_fn
(
lambda
x
:
torch
.
sum
(
x
,
dim
=-
1
)
)
# [*, N, 14, 1]
atom_mask
=
atom_mask
[
f
,...].
unsqueeze
(
-
1
)
# [*, N, 14, 3]
lit_positions
=
lit_positions
[
f
,
...]
pred_positions
=
t_atoms_to_global
.
apply
(
lit_positions
)
pred_positions
=
pred_positions
*
atom_mask
return
pred_positions
class
StructureModuleTransitionLayer
(
nn
.
Module
):
class
StructureModuleTransitionLayer
(
nn
.
Module
):
def
__init__
(
self
,
c
):
def
__init__
(
self
,
c
):
super
(
StructureModuleTransitionLayer
,
self
).
__init__
()
super
(
StructureModuleTransitionLayer
,
self
).
__init__
()
...
@@ -664,6 +572,7 @@ class StructureModule(nn.Module):
...
@@ -664,6 +572,7 @@ class StructureModule(nn.Module):
self
.
no_qk_points
,
self
.
no_qk_points
,
self
.
no_v_points
,
self
.
no_v_points
,
inf
=
self
.
inf
,
inf
=
self
.
inf
,
eps
=
self
.
epsilon
,
)
)
self
.
ipa_dropout
=
nn
.
Dropout
(
self
.
dropout_rate
)
self
.
ipa_dropout
=
nn
.
Dropout
(
self
.
dropout_rate
)
...
@@ -791,7 +700,7 @@ class StructureModule(nn.Module):
...
@@ -791,7 +700,7 @@ class StructureModule(nn.Module):
# Lazily initialize the residue constants on the correct device
# Lazily initialize the residue constants on the correct device
self
.
_init_residue_constants
(
alpha
.
dtype
,
alpha
.
device
)
self
.
_init_residue_constants
(
alpha
.
dtype
,
alpha
.
device
)
# Separated purely to make testing less annoying
# Separated purely to make testing less annoying
return
_
torsion_angles_to_frames
(
t
,
alpha
,
f
,
self
.
default_frames
)
return
torsion_angles_to_frames
(
t
,
alpha
,
f
,
self
.
default_frames
)
def
frames_and_literature_positions_to_atom14_pos
(
self
,
def
frames_and_literature_positions_to_atom14_pos
(
self
,
t
,
# [*, N, 8]
t
,
# [*, N, 8]
...
@@ -799,7 +708,7 @@ class StructureModule(nn.Module):
...
@@ -799,7 +708,7 @@ class StructureModule(nn.Module):
):
):
# Lazily initialize the residue constants on the correct device
# Lazily initialize the residue constants on the correct device
self
.
_init_residue_constants
(
t
.
rots
.
dtype
,
t
.
rots
.
device
)
self
.
_init_residue_constants
(
t
.
rots
.
dtype
,
t
.
rots
.
device
)
return
_
frames_and_literature_positions_to_atom14_pos
(
return
frames_and_literature_positions_to_atom14_pos
(
t
,
t
,
f
,
f
,
self
.
default_frames
,
self
.
default_frames
,
...
...
openfold/utils/affine_utils.py
View file @
893fe372
...
@@ -188,17 +188,29 @@ class T:
...
@@ -188,17 +188,29 @@ class T:
@
staticmethod
@
staticmethod
def
from_3_points
(
p_neg_x_axis
,
origin
,
p_xy_plane
,
eps
=
1e-8
):
def
from_3_points
(
p_neg_x_axis
,
origin
,
p_xy_plane
,
eps
=
1e-8
):
e0
=
origin
-
p_neg_x_axis
p_neg_x_axis
=
torch
.
unbind
(
p_neg_x_axis
,
dim
=-
1
)
e1
=
p_xy_plane
-
origin
origin
=
torch
.
unbind
(
origin
,
dim
=-
1
)
p_xy_plane
=
torch
.
unbind
(
p_xy_plane
,
dim
=-
1
)
e0
=
e0
/
torch
.
sqrt
(
torch
.
sum
(
e0
**
2
,
dim
=-
1
,
keepdims
=
True
)
+
eps
)
e1
=
e1
-
e0
*
torch
.
sum
(
e0
*
e1
,
dim
=-
1
,
keepdims
=
True
)
e0
=
[
c1
-
c2
for
c1
,
c2
in
zip
(
origin
,
p_neg_x_axis
)]
e1
=
e1
/
torch
.
sqrt
(
torch
.
sum
(
e1
**
2
,
dim
=-
1
,
keepdims
=
True
)
+
eps
)
e1
=
[
c1
-
c2
for
c1
,
c2
in
zip
(
p_xy_plane
,
origin
)]
e2
=
torch
.
cross
(
e0
,
e1
)
denom
=
torch
.
sqrt
(
sum
((
c
*
c
for
c
in
e0
))
+
eps
)
rots
=
torch
.
stack
([
e0
,
e1
,
e2
],
dim
=-
1
)
e0
=
[
c
/
denom
for
c
in
e0
]
dot
=
sum
((
c1
*
c2
for
c1
,
c2
in
zip
(
e0
,
e1
)))
return
T
(
rots
,
origin
)
e1
=
[
c1
-
c2
*
dot
for
c1
,
c2
in
zip
(
e1
,
e0
)]
denom
=
torch
.
sqrt
(
sum
((
c
*
c
for
c
in
e1
))
+
eps
)
e1
=
[
c
/
denom
for
c
in
e1
]
e2
=
[
e0
[
1
]
*
e1
[
2
]
-
e0
[
2
]
*
e1
[
1
],
e0
[
2
]
*
e1
[
0
]
-
e0
[
0
]
*
e1
[
2
],
e0
[
0
]
*
e1
[
1
]
-
e0
[
1
]
*
e1
[
0
],
]
rots
=
torch
.
stack
([
c
for
tup
in
zip
(
e0
,
e1
,
e2
)
for
c
in
tup
],
dim
=-
1
)
rots
=
rots
.
reshape
(
rots
.
shape
[:
-
1
]
+
(
3
,
3
))
return
T
(
rots
,
torch
.
stack
(
origin
,
dim
=-
1
))
@
staticmethod
@
staticmethod
def
concat
(
ts
,
dim
):
def
concat
(
ts
,
dim
):
...
@@ -294,6 +306,9 @@ class T:
...
@@ -294,6 +306,9 @@ class T:
return
T
(
rots
,
translation
)
return
T
(
rots
,
translation
)
def
cuda
(
self
):
return
T
(
self
.
rots
.
cuda
(),
self
.
trans
.
cuda
())
_quat_elements
=
[
'a'
,
'b'
,
'c'
,
'd'
]
_quat_elements
=
[
'a'
,
'b'
,
'c'
,
'd'
]
_qtr_keys
=
[
l1
+
l2
for
l1
in
_quat_elements
for
l2
in
_quat_elements
]
_qtr_keys
=
[
l1
+
l2
for
l1
in
_quat_elements
for
l2
in
_quat_elements
]
...
@@ -325,10 +340,11 @@ def quat_to_rot(
...
@@ -325,10 +340,11 @@ def quat_to_rot(
# [*, 4, 4]
# [*, 4, 4]
quat
=
quat
[...,
None
]
*
quat
[...,
None
,
:]
quat
=
quat
[...,
None
]
*
quat
[...,
None
,
:]
# [4, 4, 3, 3]
mat
=
quat
.
new_tensor
(
_qtr_mat
)
mat
=
quat
.
new_tensor
(
_qtr_mat
)
# [*, 4, 4, 3, 3]
# [*, 4, 4, 3, 3]
shaped_qtr_mat
=
mat
.
view
((
1
,)
*
len
(
quat
.
shape
[:
-
2
])
+
(
4
,
4
,
3
,
3
)
)
shaped_qtr_mat
=
mat
.
view
((
1
,)
*
len
(
quat
.
shape
[:
-
2
])
+
mat
.
shape
)
quat
=
quat
[...,
None
,
None
]
*
shaped_qtr_mat
quat
=
quat
[...,
None
,
None
]
*
shaped_qtr_mat
# [*, 3, 3]
# [*, 3, 3]
...
...
openfold/utils/deepspeed.py
View file @
893fe372
...
@@ -70,11 +70,8 @@ def checkpoint_blocks(
...
@@ -70,11 +70,8 @@ def checkpoint_blocks(
for
s
in
range
(
0
,
len
(
blocks
),
blocks_per_ckpt
):
for
s
in
range
(
0
,
len
(
blocks
),
blocks_per_ckpt
):
e
=
s
+
blocks_per_ckpt
e
=
s
+
blocks_per_ckpt
#print(len(args))
#args = checkpoint(chunker(s, e), *args)
#for a in args:
args
=
deepspeed
.
checkpointing
.
checkpoint
(
chunker
(
s
,
e
),
*
args
)
# print(a.requires_grad)
args
=
checkpoint
(
chunker
(
s
,
e
),
*
args
)
#args = deepspeed.checkpointing.checkpoint(chunker(s, e), *args)
args
=
wrap
(
args
)
args
=
wrap
(
args
)
return
args
return
args
openfold/utils/feats.py
View file @
893fe372
...
@@ -173,6 +173,11 @@ def atom37_to_torsion_angles(
...
@@ -173,6 +173,11 @@ def atom37_to_torsion_angles(
**
kwargs
,
**
kwargs
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""
"""
Convert coordinates to torsion angles.
This function is extremely sensitive to floating point imprecisions
and should be run with double precision whenever possible.
Args:
Args:
aatype:
aatype:
[*, N_res] residue indices
[*, N_res] residue indices
...
@@ -228,10 +233,10 @@ def atom37_to_torsion_angles(
...
@@ -228,10 +233,10 @@ def atom37_to_torsion_angles(
)
)
phi_mask
=
(
phi_mask
=
(
prev_all_atom_mask
[...,
2
]
*
prev_all_atom_mask
[...,
2
]
*
torch
.
prod
(
all_atom_mask
[...,
:
3
],
dim
=-
1
)
torch
.
prod
(
all_atom_mask
[...,
:
3
],
dim
=-
1
,
dtype
=
all_atom_mask
.
dtype
)
)
)
psi_mask
=
(
psi_mask
=
(
torch
.
prod
(
all_atom_mask
[...,
:
3
],
dim
=-
1
)
*
torch
.
prod
(
all_atom_mask
[...,
:
3
],
dim
=-
1
,
dtype
=
all_atom_mask
.
dtype
)
*
all_atom_mask
[...,
4
]
all_atom_mask
[...,
4
]
)
)
...
@@ -256,7 +261,9 @@ def atom37_to_torsion_angles(
...
@@ -256,7 +261,9 @@ def atom37_to_torsion_angles(
dim
=-
1
,
dim
=-
1
,
no_batch_dims
=
len
(
atom_indices
.
shape
[:
-
2
])
no_batch_dims
=
len
(
atom_indices
.
shape
[:
-
2
])
)
)
chi_angle_atoms_mask
=
torch
.
prod
(
chi_angle_atoms_mask
,
dim
=-
1
)
chi_angle_atoms_mask
=
torch
.
prod
(
chi_angle_atoms_mask
,
dim
=-
1
,
dtype
=
chi_angle_atoms_mask
.
dtype
)
chis_mask
=
chis_mask
*
chi_angle_atoms_mask
chis_mask
=
chis_mask
*
chi_angle_atoms_mask
torsions_atom_pos
=
torch
.
cat
(
torsions_atom_pos
=
torch
.
cat
(
...
@@ -281,6 +288,7 @@ def atom37_to_torsion_angles(
...
@@ -281,6 +288,7 @@ def atom37_to_torsion_angles(
torsions_atom_pos
[...,
1
,
:],
torsions_atom_pos
[...,
1
,
:],
torsions_atom_pos
[...,
2
,
:],
torsions_atom_pos
[...,
2
,
:],
torsions_atom_pos
[...,
0
,
:],
torsions_atom_pos
[...,
0
,
:],
eps
=
eps
,
)
)
fourth_atom_rel_pos
=
torsion_frames
.
invert
().
apply
(
fourth_atom_rel_pos
=
torsion_frames
.
invert
().
apply
(
...
@@ -290,15 +298,19 @@ def atom37_to_torsion_angles(
...
@@ -290,15 +298,19 @@ def atom37_to_torsion_angles(
torsion_angles_sin_cos
=
torch
.
stack
(
torsion_angles_sin_cos
=
torch
.
stack
(
[
fourth_atom_rel_pos
[...,
2
],
fourth_atom_rel_pos
[...,
1
]],
dim
=-
1
[
fourth_atom_rel_pos
[...,
2
],
fourth_atom_rel_pos
[...,
1
]],
dim
=-
1
)
)
denom
=
torch
.
sqrt
(
denom
=
torch
.
sqrt
(
torch
.
sum
(
torch
.
sum
(
torch
.
square
(
torsion_angles_sin_cos
),
dim
=-
1
,
keepdims
=
True
torch
.
square
(
torsion_angles_sin_cos
),
dim
=-
1
,
dtype
=
torsion_angles_sin_cos
.
dtype
,
keepdims
=
True
)
+
eps
)
+
eps
)
)
torsion_angles_sin_cos
=
torsion_angles_sin_cos
/
denom
torsion_angles_sin_cos
=
torsion_angles_sin_cos
/
denom
torsion_angles_sin_cos
=
torsion_angles_sin_cos
*
torch
.
tensor
(
torsion_angles_sin_cos
=
torsion_angles_sin_cos
*
all_atom_mask
.
new_
tensor
(
[
1.
,
1.
,
-
1.
,
1.
,
1.
,
1.
,
1.
],
device
=
aatype
.
device
,
[
1.
,
1.
,
-
1.
,
1.
,
1.
,
1.
,
1.
],
)[((
None
,)
*
len
(
torsion_angles_sin_cos
.
shape
[:
-
2
]))
+
(
slice
(
None
),
None
)]
)[((
None
,)
*
len
(
torsion_angles_sin_cos
.
shape
[:
-
2
]))
+
(
slice
(
None
),
None
)]
chi_is_ambiguous
=
torsion_angles_sin_cos
.
new_tensor
(
chi_is_ambiguous
=
torsion_angles_sin_cos
.
new_tensor
(
...
@@ -327,6 +339,7 @@ def atom37_to_frames(
...
@@ -327,6 +339,7 @@ def atom37_to_frames(
aatype
:
torch
.
Tensor
,
aatype
:
torch
.
Tensor
,
all_atom_positions
:
torch
.
Tensor
,
all_atom_positions
:
torch
.
Tensor
,
all_atom_mask
:
torch
.
Tensor
,
all_atom_mask
:
torch
.
Tensor
,
eps
:
float
,
**
kwargs
,
**
kwargs
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
)
->
Dict
[
str
,
torch
.
Tensor
]:
batch_dims
=
len
(
aatype
.
shape
[:
-
1
])
batch_dims
=
len
(
aatype
.
shape
[:
-
1
])
...
@@ -387,6 +400,7 @@ def atom37_to_frames(
...
@@ -387,6 +400,7 @@ def atom37_to_frames(
p_neg_x_axis
=
base_atom_pos
[...,
0
,
:],
p_neg_x_axis
=
base_atom_pos
[...,
0
,
:],
origin
=
base_atom_pos
[...,
1
,
:],
origin
=
base_atom_pos
[...,
1
,
:],
p_xy_plane
=
base_atom_pos
[...,
2
,
:],
p_xy_plane
=
base_atom_pos
[...,
2
,
:],
eps
=
eps
,
)
)
group_exists
=
batched_gather
(
group_exists
=
batched_gather
(
...
@@ -638,3 +652,106 @@ def build_ambiguity_feats(batch: Dict[str, torch.Tensor]) -> None:
...
@@ -638,3 +652,106 @@ def build_ambiguity_feats(batch: Dict[str, torch.Tensor]) -> None:
"atom14_alt_gt_positions"
:
atom14_alt_gt_positions
,
"atom14_alt_gt_positions"
:
atom14_alt_gt_positions
,
"atom14_alt_gt_exists"
:
atom14_alt_gt_exists
,
"atom14_alt_gt_exists"
:
atom14_alt_gt_exists
,
}
}
def
torsion_angles_to_frames
(
t
:
T
,
alpha
:
torch
.
Tensor
,
aatype
:
torch
.
Tensor
,
rrgdf
:
torch
.
Tensor
,
):
# [*, N, 8, 4, 4]
default_4x4
=
rrgdf
[
aatype
,
...]
# [*, N, 8] transformations, i.e.
# One [*, N, 8, 3, 3] rotation matrix and
# One [*, N, 8, 3] translation matrix
default_t
=
T
.
from_4x4
(
default_4x4
)
bb_rot
=
alpha
.
new_zeros
((
*
((
1
,)
*
len
(
alpha
.
shape
[:
-
1
])),
2
))
bb_rot
[...,
1
]
=
1
# [*, N, 8, 2]
alpha
=
torch
.
cat
(
[
bb_rot
.
expand
(
*
alpha
.
shape
[:
-
2
],
-
1
,
-
1
),
alpha
],
dim
=-
2
)
# [*, N, 8, 3, 3]
# Produces rotation matrices of the form:
# [
# [1, 0 , 0 ],
# [0, a_2,-a_1],
# [0, a_1, a_2]
# ]
# This follows the original code rather than the supplement, which uses
# different indices.
all_rots
=
alpha
.
new_zeros
(
default_t
.
rots
.
shape
)
all_rots
[...,
0
,
0
]
=
1
all_rots
[...,
1
,
1
]
=
alpha
[...,
1
]
all_rots
[...,
1
,
2
]
=
-
alpha
[...,
0
]
all_rots
[...,
2
,
1
:]
=
alpha
all_rots
=
T
(
all_rots
,
None
)
all_frames
=
default_t
.
compose
(
all_rots
)
chi2_frame_to_frame
=
all_frames
[...,
5
]
chi3_frame_to_frame
=
all_frames
[...,
6
]
chi4_frame_to_frame
=
all_frames
[...,
7
]
chi1_frame_to_bb
=
all_frames
[...,
4
]
chi2_frame_to_bb
=
chi1_frame_to_bb
.
compose
(
chi2_frame_to_frame
)
chi3_frame_to_bb
=
chi2_frame_to_bb
.
compose
(
chi3_frame_to_frame
)
chi4_frame_to_bb
=
chi3_frame_to_bb
.
compose
(
chi4_frame_to_frame
)
all_frames_to_bb
=
T
.
concat
([
all_frames
[...,
:
5
],
chi2_frame_to_bb
.
unsqueeze
(
-
1
),
chi3_frame_to_bb
.
unsqueeze
(
-
1
),
chi4_frame_to_bb
.
unsqueeze
(
-
1
),
],
dim
=-
1
,
)
all_frames_to_global
=
t
[...,
None
].
compose
(
all_frames_to_bb
)
return
all_frames_to_global
def
frames_and_literature_positions_to_atom14_pos
(
t
:
T
,
aatype
:
torch
.
Tensor
,
default_frames
,
group_idx
,
atom_mask
,
lit_positions
,
):
# [*, N, 14, 4, 4]
default_4x4
=
default_frames
[
aatype
,
...]
# [*, N, 14]
group_mask
=
group_idx
[
aatype
,
...]
# [*, N, 14, 8]
group_mask
=
nn
.
functional
.
one_hot
(
group_mask
,
num_classes
=
default_frames
.
shape
[
-
3
],
)
# [*, N, 14, 8]
t_atoms_to_global
=
t
[...,
None
,
:]
*
group_mask
# [*, N, 14]
t_atoms_to_global
=
t_atoms_to_global
.
map_tensor_fn
(
lambda
x
:
torch
.
sum
(
x
,
dim
=-
1
)
)
# [*, N, 14, 1]
atom_mask
=
atom_mask
[
aatype
,
...].
unsqueeze
(
-
1
)
# [*, N, 14, 3]
lit_positions
=
lit_positions
[
aatype
,
...]
pred_positions
=
t_atoms_to_global
.
apply
(
lit_positions
)
pred_positions
=
pred_positions
*
atom_mask
return
pred_positions
openfold/utils/import_weights.py
View file @
893fe372
...
@@ -417,7 +417,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
...
@@ -417,7 +417,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
if
(
"_ptm"
in
version
):
if
(
"_ptm"
in
version
):
translations
[
"predicted_aligned_error_head"
]
=
{
translations
[
"predicted_aligned_error_head"
]
=
{
"logits"
:
"logits"
:
LinearParams
(
model
.
aux_heads
.
tm
_score
.
linear
)
LinearParams
(
model
.
aux_heads
.
tm
.
linear
)
}
}
# Flatten keys and insert missing key prefixes
# Flatten keys and insert missing key prefixes
...
...
openfold/utils/loss.py
View file @
893fe372
...
@@ -273,7 +273,6 @@ def supervised_chi_loss(
...
@@ -273,7 +273,6 @@ def supervised_chi_loss(
shifted_mask
=
(
1
-
2
*
chi_pi_periodic
).
unsqueeze
(
-
1
)
shifted_mask
=
(
1
-
2
*
chi_pi_periodic
).
unsqueeze
(
-
1
)
true_chi_shifted
=
shifted_mask
*
true_chi
true_chi_shifted
=
shifted_mask
*
true_chi
sq_chi_error
=
torch
.
sum
(
sq_chi_error
=
torch
.
sum
(
(
true_chi
-
pred_angles
)
**
2
,
dim
=-
1
(
true_chi
-
pred_angles
)
**
2
,
dim
=-
1
)
)
...
@@ -498,11 +497,11 @@ def tm_loss(
...
@@ -498,11 +497,11 @@ def tm_loss(
)
)
loss
=
torch
.
sum
(
errors
*
square_mask
,
dim
=-
1
)
loss
=
torch
.
sum
(
errors
*
square_mask
,
dim
=-
1
)
scale
=
0.
1
# hack to help FP16 training along
scale
=
0.
5
# hack to help FP16 training along
denom
=
eps
+
torch
.
sum
(
scale
*
square_mask
,
dim
=
(
-
1
,
-
2
))
denom
=
eps
+
torch
.
sum
(
scale
*
square_mask
,
dim
=
(
-
1
,
-
2
))
loss
=
loss
/
denom
[...,
None
]
loss
=
loss
/
denom
[...,
None
]
loss
=
torch
.
sum
(
loss
,
dim
=-
1
)
loss
=
torch
.
sum
(
loss
,
dim
=-
1
)
loss
=
loss
/
scale
loss
=
loss
*
scale
loss
=
loss
*
(
loss
=
loss
*
(
(
resolution
>=
min_resolution
)
&
(
resolution
>=
min_resolution
)
&
...
@@ -744,7 +743,7 @@ def between_residue_clash_loss(
...
@@ -744,7 +743,7 @@ def between_residue_clash_loss(
# Backbone C--N bond between subsequent residues is no clash.
# Backbone C--N bond between subsequent residues is no clash.
c_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
c_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
residue_index
.
new_tensor
(
2
.
),
num_classes
=
14
residue_index
.
new_tensor
(
2
),
num_classes
=
14
)
)
c_one_hot
=
c_one_hot
.
reshape
(
c_one_hot
=
c_one_hot
.
reshape
(
*
((
1
,)
*
len
(
residue_index
.
shape
[:
-
1
])),
*
c_one_hot
.
shape
*
((
1
,)
*
len
(
residue_index
.
shape
[:
-
1
])),
*
c_one_hot
.
shape
...
@@ -1319,11 +1318,11 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
...
@@ -1319,11 +1318,11 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
# )
# )
loss
=
errors
*
bert_mask
loss
=
errors
*
bert_mask
loss
=
torch
.
sum
(
loss
,
dim
=-
1
)
loss
=
torch
.
sum
(
loss
,
dim
=-
1
)
scale
=
0.
1
scale
=
0.
5
denom
=
eps
+
torch
.
sum
(
scale
*
bert_mask
,
dim
=
(
-
1
,
-
2
))
denom
=
eps
+
torch
.
sum
(
scale
*
bert_mask
,
dim
=
(
-
1
,
-
2
))
loss
=
loss
/
denom
[...,
None
]
loss
=
loss
/
denom
[...,
None
]
loss
=
torch
.
sum
(
loss
,
dim
=-
1
)
loss
=
torch
.
sum
(
loss
,
dim
=-
1
)
loss
=
loss
/
scale
loss
=
loss
*
scale
return
loss
return
loss
...
@@ -1352,7 +1351,7 @@ class AlphaFoldLoss(nn.Module):
...
@@ -1352,7 +1351,7 @@ class AlphaFoldLoss(nn.Module):
))
))
if
(
"backbone_affine_tensor"
not
in
batch
.
keys
()):
if
(
"backbone_affine_tensor"
not
in
batch
.
keys
()):
batch
.
update
(
feats
.
atom37_to_frames
(
**
batch
))
batch
.
update
(
feats
.
atom37_to_frames
(
eps
=
self
.
config
.
eps
,
**
batch
))
# TODO: Verify that this is correct
# TODO: Verify that this is correct
batch
[
"backbone_affine_tensor"
]
=
(
batch
[
"backbone_affine_tensor"
]
=
(
...
@@ -1363,16 +1362,19 @@ class AlphaFoldLoss(nn.Module):
...
@@ -1363,16 +1362,19 @@ class AlphaFoldLoss(nn.Module):
)
)
if
(
"chi_angles_sin_cos"
not
in
batch
.
keys
()):
if
(
"chi_angles_sin_cos"
not
in
batch
.
keys
()):
with
torch
.
no_grad
():
batch
.
update
(
feats
.
atom37_to_torsion_angles
(
batch
.
update
(
feats
.
atom37_to_torsion_angles
(
**
batch
,
aatype
=
batch
[
"aatype"
],
all_atom_positions
=
batch
[
"all_atom_positions"
].
double
(),
all_atom_mask
=
batch
[
"all_atom_mask"
].
double
(),
eps
=
self
.
config
.
eps
,
eps
=
self
.
config
.
eps
,
))
))
# TODO: Verify that this is correct
# TODO: Verify that this is correct
batch
[
"chi_angles_sin_cos"
]
=
(
batch
[
"chi_angles_sin_cos"
]
=
(
batch
[
"torsion_angles_sin_cos"
][...,
3
:,
:]
batch
[
"torsion_angles_sin_cos"
][...,
3
:,
:]
)
).
to
(
batch
[
"all_atom_mask"
].
dtype
)
batch
[
"chi_mask"
]
=
batch
[
"torsion_angles_mask"
][...,
3
:]
batch
[
"chi_mask"
]
=
batch
[
"torsion_angles_mask"
][...,
3
:]
.
to
(
batch
[
"all_atom_mask"
].
dtype
)
loss_fns
=
{
loss_fns
=
{
"distogram"
:
"distogram"
:
...
...
openfold/utils/tensor_utils.py
View file @
893fe372
...
@@ -25,11 +25,11 @@ def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
...
@@ -25,11 +25,11 @@ def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
return
tensor
.
permute
(
first_inds
+
[
zero_index
+
i
for
i
in
inds
])
return
tensor
.
permute
(
first_inds
+
[
zero_index
+
i
for
i
in
inds
])
def
flatten_final_dims
(
t
ensor
:
torch
.
Tensor
,
no_dims
:
int
):
def
flatten_final_dims
(
t
:
torch
.
Tensor
,
no_dims
:
int
):
return
t
ensor
.
reshape
(
t
ensor
.
shape
[:
-
no_dims
]
+
(
-
1
,))
return
t
.
reshape
(
t
.
shape
[:
-
no_dims
]
+
(
-
1
,))
def
masked_mean
(
mask
,
value
,
dim
,
eps
=
1e-
10
):
def
masked_mean
(
mask
,
value
,
dim
,
eps
=
1e-
4
):
mask
=
mask
.
expand
(
*
value
.
shape
)
mask
=
mask
.
expand
(
*
value
.
shape
)
return
torch
.
sum
(
mask
*
value
,
dim
=
dim
)
/
(
eps
+
torch
.
sum
(
mask
,
dim
=
dim
))
return
torch
.
sum
(
mask
*
value
,
dim
=
dim
)
/
(
eps
+
torch
.
sum
(
mask
,
dim
=
dim
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment