Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
893fe372
Commit
893fe372
authored
Oct 05, 2021
by
Gustaf Ahdritz
Browse files
Get FP16 training working
parent
dd06b323
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
203 additions
and
161 deletions
+203
-161
config.py
config.py
+6
-6
openfold/model/model.py
openfold/model/model.py
+5
-2
openfold/model/primitives.py
openfold/model/primitives.py
+0
-2
openfold/model/structure_module.py
openfold/model/structure_module.py
+16
-107
openfold/utils/affine_utils.py
openfold/utils/affine_utils.py
+28
-12
openfold/utils/deepspeed.py
openfold/utils/deepspeed.py
+2
-5
openfold/utils/feats.py
openfold/utils/feats.py
+123
-6
openfold/utils/import_weights.py
openfold/utils/import_weights.py
+1
-1
openfold/utils/loss.py
openfold/utils/loss.py
+19
-17
openfold/utils/tensor_utils.py
openfold/utils/tensor_utils.py
+3
-3
No files found.
config.py
View file @
893fe372
...
...
@@ -44,7 +44,7 @@ def model_config(name, train=False, low_prec=False):
raise
ValueError
(
"Invalid model name"
)
if
(
train
):
c
.
globals
.
model
.
blocks_per_ckpt
=
1
c
.
globals
.
blocks_per_ckpt
=
1
c
.
globals
.
chunk_size
=
None
if
(
low_prec
):
...
...
@@ -137,7 +137,7 @@ config = mlc.ConfigDict({
},
"inf"
:
1e9
,
"eps"
:
eps
,
#1e-6,
"enabled"
:
True
,
"enabled"
:
False
,
#
True,
"embed_angles"
:
True
,
},
"extra_msa"
:
{
...
...
@@ -239,7 +239,7 @@ config = mlc.ConfigDict({
"max_bin"
:
21.6875
,
"no_bins"
:
64
,
"eps"
:
eps
,
#1e-6,
"weight"
:
0.
,
#
0.3,
"weight"
:
0.3
,
},
"experimentally_resolved"
:
{
"eps"
:
eps
,
#1e-8,
...
...
@@ -267,17 +267,17 @@ config = mlc.ConfigDict({
"cutoff"
:
15.
,
"no_bins"
:
50
,
"eps"
:
eps
,
#1e-10,
"weight"
:
0.
,
#
0.01,
"weight"
:
0.01
,
},
"masked_msa"
:
{
"eps"
:
eps
,
#1e-8,
"weight"
:
0.
,
#
2.0,
"weight"
:
2.0
,
},
"supervised_chi"
:
{
"chi_weight"
:
0.5
,
"angle_norm_weight"
:
0.01
,
"eps"
:
eps
,
#1e-6,
"weight"
:
0.
,
#
1.0,
"weight"
:
1.0
,
},
"violation"
:
{
"violation_tolerance_factor"
:
12.0
,
...
...
openfold/model/model.py
View file @
893fe372
...
...
@@ -389,6 +389,7 @@ class AlphaFold(nn.Module):
m_1_prev
,
z_prev
,
x_prev
=
None
,
None
,
None
is_grad_enabled
=
torch
.
is_grad_enabled
()
self
.
_disable_activation_checkpointing
()
# Main recycling loop
for
cycle_no
in
range
(
self
.
config
.
no_cycles
):
...
...
@@ -400,8 +401,10 @@ class AlphaFold(nn.Module):
is_final_iter
=
(
cycle_no
==
(
self
.
config
.
no_cycles
-
1
))
with
torch
.
set_grad_enabled
(
is_grad_enabled
and
is_final_iter
):
# Sidestep AMP bug discussed in pytorch issue #65766
if
(
is_final_iter
and
torch
.
is_autocast_enabled
()):
torch
.
clear_autocast_cache
()
if
(
is_final_iter
):
self
.
_enable_activation_checkpointing
()
if
(
torch
.
is_autocast_enabled
()):
torch
.
clear_autocast_cache
()
# Run the next iteration of the model
outputs
,
m_1_prev
,
z_prev
,
x_prev
=
self
.
iteration
(
feats
,
m_1_prev
,
z_prev
,
x_prev
,
...
...
openfold/model/primitives.py
View file @
893fe372
...
...
@@ -257,8 +257,6 @@ class Attention(nn.Module):
a
=
a
+
b
a
=
self
.
softmax
(
a
)
#print(torch.any(torch.isnan(a)))
# [*, H, Q, C_hidden]
o
=
torch
.
matmul
(
a
,
...
...
openfold/model/structure_module.py
View file @
893fe372
...
...
@@ -26,6 +26,10 @@ from openfold.np.residue_constants import (
restype_atom14_rigid_group_positions
,
)
from
openfold.utils.affine_utils
import
T
,
quat_to_rot
from
openfold.utils.feats
import
(
frames_and_literature_positions_to_atom14_pos
,
torsion_angles_to_frames
,
)
from
openfold.utils.tensor_utils
import
(
dict_multimap
,
permute_final_dims
,
...
...
@@ -305,7 +309,7 @@ class InvariantPointAttention(nn.Module):
pt_att
=
pt_att
**
2
# [*, N_res, N_res, H, P_q]
pt_att
=
torch
.
sum
(
pt_att
,
dim
=-
1
)
pt_att
=
sum
(
torch
.
unbind
(
pt_att
,
dim
=-
1
)
)
head_weights
=
self
.
softplus
(
self
.
head_weights
).
view
(
*
((
1
,)
*
len
(
pt_att
.
shape
[:
-
2
])
+
(
-
1
,
1
))
)
...
...
@@ -358,7 +362,7 @@ class InvariantPointAttention(nn.Module):
)
# [*, N_res, H * P_v, 3]
o_pt
=
o_pt
.
view
(
*
o_pt
.
shape
[:
-
3
],
-
1
,
3
)
o_pt
=
o_pt
.
reshape
(
*
o_pt
.
shape
[:
-
3
],
-
1
,
3
)
# [*, N_res, H, C_z]
o_pair
=
torch
.
matmul
(
a
.
transpose
(
-
2
,
-
3
),
z
)
...
...
@@ -409,14 +413,17 @@ class BackboneUpdate(nn.Module):
quats
,
trans
=
params
[...,:
3
],
params
[...,
3
:]
# [*]
#norm_denom = torch.sqrt(sum(torch.unbind(quats ** 2, dim=-1)) + 1)
norm_denom
=
torch
.
sqrt
(
torch
.
sum
(
quats
**
2
,
dim
=-
1
)
+
1
)
# As many ones as there are dimensions in quats
ones
=
s
.
new_ones
((
1
,)
*
len
(
quats
.
shape
))
# [*, 3]
ones
=
(
s
.
new_ones
((
1
,)
*
len
(
quats
.
shape
)).
expand
(
quats
.
shape
[:
-
1
]
+
(
1
,))
)
# [*, 4]
quats
=
torch
.
cat
(
(
ones
.
expand
(
*
quats
.
shape
[:
-
1
],
1
)
,
quats
)
,
dim
=-
1
)
quats
=
quats
/
norm_denom
.
unsqueeze
(
-
1
)
quats
=
torch
.
cat
(
[
ones
,
quats
]
,
dim
=-
1
)
quats
=
quats
/
norm_denom
[...,
None
]
# [*, 3, 3]
rots
=
quat_to_rot
(
quats
)
...
...
@@ -424,105 +431,6 @@ class BackboneUpdate(nn.Module):
return
T
(
rots
,
trans
)
def
_torsion_angles_to_frames
(
t
,
alpha
,
f
,
rrgdf
):
# [*, N, 8, 4, 4]
default_4x4
=
rrgdf
[
f
,...]
# [*, N, 8] transformations, i.e.
# One [*, N, 8, 3, 3] rotation matrix and
# One [*, N, 8, 3] translation matrix
default_t
=
T
.
from_4x4
(
default_4x4
)
bb_rot
=
alpha
.
new_zeros
((
*
((
1
,)
*
len
(
alpha
.
shape
[:
-
1
])),
2
))
bb_rot
[...,
1
]
=
1
# [*, N, 8, 2]
alpha
=
torch
.
cat
(
[
bb_rot
.
expand
(
*
alpha
.
shape
[:
-
2
],
-
1
,
-
1
),
alpha
],
dim
=-
2
)
# [*, N, 8, 3, 3]
# Produces rotation matrices of the form:
# [
# [1, 0 , 0 ],
# [0, a_2,-a_1],
# [0, a_1, a_2]
# ]
# This follows the original code rather than the supplement, which uses
# different indices.
all_rots
=
alpha
.
new_zeros
(
default_t
.
rots
.
shape
)
all_rots
[...,
0
,
0
]
=
1
all_rots
[...,
1
,
1
]
=
alpha
[...,
1
]
all_rots
[...,
1
,
2
]
=
-
alpha
[...,
0
]
all_rots
[...,
2
,
1
:]
=
alpha
all_rots
=
T
(
all_rots
,
None
)
all_frames
=
default_t
.
compose
(
all_rots
)
chi2_frame_to_frame
=
all_frames
[...,
5
]
chi3_frame_to_frame
=
all_frames
[...,
6
]
chi4_frame_to_frame
=
all_frames
[...,
7
]
chi1_frame_to_bb
=
all_frames
[...,
4
]
chi2_frame_to_bb
=
chi1_frame_to_bb
.
compose
(
chi2_frame_to_frame
)
chi3_frame_to_bb
=
chi2_frame_to_bb
.
compose
(
chi3_frame_to_frame
)
chi4_frame_to_bb
=
chi3_frame_to_bb
.
compose
(
chi4_frame_to_frame
)
all_frames_to_bb
=
T
.
concat
([
all_frames
[...,
:
5
],
chi2_frame_to_bb
.
unsqueeze
(
-
1
),
chi3_frame_to_bb
.
unsqueeze
(
-
1
),
chi4_frame_to_bb
.
unsqueeze
(
-
1
),
],
dim
=-
1
,
)
all_frames_to_global
=
t
[...,
None
].
compose
(
all_frames_to_bb
)
return
all_frames_to_global
def
_frames_and_literature_positions_to_atom14_pos
(
t
,
f
,
default_frames
,
group_idx
,
atom_mask
,
lit_positions
,
):
# [*, N, 14, 4, 4]
default_4x4
=
default_frames
[
f
,
...]
# [*, N, 14]
group_mask
=
group_idx
[
f
,
...]
# [*, N, 14, 8]
group_mask
=
nn
.
functional
.
one_hot
(
group_mask
,
num_classes
=
default_frames
.
shape
[
-
3
],
)
# [*, N, 14, 8]
t_atoms_to_global
=
t
[...,
None
,
:]
*
group_mask
# [*, N, 14]
t_atoms_to_global
=
t_atoms_to_global
.
map_tensor_fn
(
lambda
x
:
torch
.
sum
(
x
,
dim
=-
1
)
)
# [*, N, 14, 1]
atom_mask
=
atom_mask
[
f
,...].
unsqueeze
(
-
1
)
# [*, N, 14, 3]
lit_positions
=
lit_positions
[
f
,
...]
pred_positions
=
t_atoms_to_global
.
apply
(
lit_positions
)
pred_positions
=
pred_positions
*
atom_mask
return
pred_positions
class
StructureModuleTransitionLayer
(
nn
.
Module
):
def
__init__
(
self
,
c
):
super
(
StructureModuleTransitionLayer
,
self
).
__init__
()
...
...
@@ -664,6 +572,7 @@ class StructureModule(nn.Module):
self
.
no_qk_points
,
self
.
no_v_points
,
inf
=
self
.
inf
,
eps
=
self
.
epsilon
,
)
self
.
ipa_dropout
=
nn
.
Dropout
(
self
.
dropout_rate
)
...
...
@@ -791,7 +700,7 @@ class StructureModule(nn.Module):
# Lazily initialize the residue constants on the correct device
self
.
_init_residue_constants
(
alpha
.
dtype
,
alpha
.
device
)
# Separated purely to make testing less annoying
return
_
torsion_angles_to_frames
(
t
,
alpha
,
f
,
self
.
default_frames
)
return
torsion_angles_to_frames
(
t
,
alpha
,
f
,
self
.
default_frames
)
def
frames_and_literature_positions_to_atom14_pos
(
self
,
t
,
# [*, N, 8]
...
...
@@ -799,7 +708,7 @@ class StructureModule(nn.Module):
):
# Lazily initialize the residue constants on the correct device
self
.
_init_residue_constants
(
t
.
rots
.
dtype
,
t
.
rots
.
device
)
return
_
frames_and_literature_positions_to_atom14_pos
(
return
frames_and_literature_positions_to_atom14_pos
(
t
,
f
,
self
.
default_frames
,
...
...
openfold/utils/affine_utils.py
View file @
893fe372
...
...
@@ -188,17 +188,29 @@ class T:
@
staticmethod
def
from_3_points
(
p_neg_x_axis
,
origin
,
p_xy_plane
,
eps
=
1e-8
):
e0
=
origin
-
p_neg_x_axis
e1
=
p_xy_plane
-
origin
e0
=
e0
/
torch
.
sqrt
(
torch
.
sum
(
e0
**
2
,
dim
=-
1
,
keepdims
=
True
)
+
eps
)
e1
=
e1
-
e0
*
torch
.
sum
(
e0
*
e1
,
dim
=-
1
,
keepdims
=
True
)
e1
=
e1
/
torch
.
sqrt
(
torch
.
sum
(
e1
**
2
,
dim
=-
1
,
keepdims
=
True
)
+
eps
)
e2
=
torch
.
cross
(
e0
,
e1
)
rots
=
torch
.
stack
([
e0
,
e1
,
e2
],
dim
=-
1
)
return
T
(
rots
,
origin
)
p_neg_x_axis
=
torch
.
unbind
(
p_neg_x_axis
,
dim
=-
1
)
origin
=
torch
.
unbind
(
origin
,
dim
=-
1
)
p_xy_plane
=
torch
.
unbind
(
p_xy_plane
,
dim
=-
1
)
e0
=
[
c1
-
c2
for
c1
,
c2
in
zip
(
origin
,
p_neg_x_axis
)]
e1
=
[
c1
-
c2
for
c1
,
c2
in
zip
(
p_xy_plane
,
origin
)]
denom
=
torch
.
sqrt
(
sum
((
c
*
c
for
c
in
e0
))
+
eps
)
e0
=
[
c
/
denom
for
c
in
e0
]
dot
=
sum
((
c1
*
c2
for
c1
,
c2
in
zip
(
e0
,
e1
)))
e1
=
[
c1
-
c2
*
dot
for
c1
,
c2
in
zip
(
e1
,
e0
)]
denom
=
torch
.
sqrt
(
sum
((
c
*
c
for
c
in
e1
))
+
eps
)
e1
=
[
c
/
denom
for
c
in
e1
]
e2
=
[
e0
[
1
]
*
e1
[
2
]
-
e0
[
2
]
*
e1
[
1
],
e0
[
2
]
*
e1
[
0
]
-
e0
[
0
]
*
e1
[
2
],
e0
[
0
]
*
e1
[
1
]
-
e0
[
1
]
*
e1
[
0
],
]
rots
=
torch
.
stack
([
c
for
tup
in
zip
(
e0
,
e1
,
e2
)
for
c
in
tup
],
dim
=-
1
)
rots
=
rots
.
reshape
(
rots
.
shape
[:
-
1
]
+
(
3
,
3
))
return
T
(
rots
,
torch
.
stack
(
origin
,
dim
=-
1
))
@
staticmethod
def
concat
(
ts
,
dim
):
...
...
@@ -294,6 +306,9 @@ class T:
return
T
(
rots
,
translation
)
def
cuda
(
self
):
return
T
(
self
.
rots
.
cuda
(),
self
.
trans
.
cuda
())
_quat_elements
=
[
'a'
,
'b'
,
'c'
,
'd'
]
_qtr_keys
=
[
l1
+
l2
for
l1
in
_quat_elements
for
l2
in
_quat_elements
]
...
...
@@ -325,10 +340,11 @@ def quat_to_rot(
# [*, 4, 4]
quat
=
quat
[...,
None
]
*
quat
[...,
None
,
:]
# [4, 4, 3, 3]
mat
=
quat
.
new_tensor
(
_qtr_mat
)
# [*, 4, 4, 3, 3]
shaped_qtr_mat
=
mat
.
view
((
1
,)
*
len
(
quat
.
shape
[:
-
2
])
+
(
4
,
4
,
3
,
3
)
)
shaped_qtr_mat
=
mat
.
view
((
1
,)
*
len
(
quat
.
shape
[:
-
2
])
+
mat
.
shape
)
quat
=
quat
[...,
None
,
None
]
*
shaped_qtr_mat
# [*, 3, 3]
...
...
openfold/utils/deepspeed.py
View file @
893fe372
...
...
@@ -70,11 +70,8 @@ def checkpoint_blocks(
for
s
in
range
(
0
,
len
(
blocks
),
blocks_per_ckpt
):
e
=
s
+
blocks_per_ckpt
#print(len(args))
#for a in args:
# print(a.requires_grad)
args
=
checkpoint
(
chunker
(
s
,
e
),
*
args
)
#args = deepspeed.checkpointing.checkpoint(chunker(s, e), *args)
#args = checkpoint(chunker(s, e), *args)
args
=
deepspeed
.
checkpointing
.
checkpoint
(
chunker
(
s
,
e
),
*
args
)
args
=
wrap
(
args
)
return
args
openfold/utils/feats.py
View file @
893fe372
...
...
@@ -173,6 +173,11 @@ def atom37_to_torsion_angles(
**
kwargs
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""
Convert coordinates to torsion angles.
This function is extremely sensitive to floating point imprecisions
and should be run with double precision whenever possible.
Args:
aatype:
[*, N_res] residue indices
...
...
@@ -228,10 +233,10 @@ def atom37_to_torsion_angles(
)
phi_mask
=
(
prev_all_atom_mask
[...,
2
]
*
torch
.
prod
(
all_atom_mask
[...,
:
3
],
dim
=-
1
)
torch
.
prod
(
all_atom_mask
[...,
:
3
],
dim
=-
1
,
dtype
=
all_atom_mask
.
dtype
)
)
psi_mask
=
(
torch
.
prod
(
all_atom_mask
[...,
:
3
],
dim
=-
1
)
*
torch
.
prod
(
all_atom_mask
[...,
:
3
],
dim
=-
1
,
dtype
=
all_atom_mask
.
dtype
)
*
all_atom_mask
[...,
4
]
)
...
...
@@ -256,7 +261,9 @@ def atom37_to_torsion_angles(
dim
=-
1
,
no_batch_dims
=
len
(
atom_indices
.
shape
[:
-
2
])
)
chi_angle_atoms_mask
=
torch
.
prod
(
chi_angle_atoms_mask
,
dim
=-
1
)
chi_angle_atoms_mask
=
torch
.
prod
(
chi_angle_atoms_mask
,
dim
=-
1
,
dtype
=
chi_angle_atoms_mask
.
dtype
)
chis_mask
=
chis_mask
*
chi_angle_atoms_mask
torsions_atom_pos
=
torch
.
cat
(
...
...
@@ -281,6 +288,7 @@ def atom37_to_torsion_angles(
torsions_atom_pos
[...,
1
,
:],
torsions_atom_pos
[...,
2
,
:],
torsions_atom_pos
[...,
0
,
:],
eps
=
eps
,
)
fourth_atom_rel_pos
=
torsion_frames
.
invert
().
apply
(
...
...
@@ -290,15 +298,19 @@ def atom37_to_torsion_angles(
torsion_angles_sin_cos
=
torch
.
stack
(
[
fourth_atom_rel_pos
[...,
2
],
fourth_atom_rel_pos
[...,
1
]],
dim
=-
1
)
denom
=
torch
.
sqrt
(
torch
.
sum
(
torch
.
square
(
torsion_angles_sin_cos
),
dim
=-
1
,
keepdims
=
True
torch
.
square
(
torsion_angles_sin_cos
),
dim
=-
1
,
dtype
=
torsion_angles_sin_cos
.
dtype
,
keepdims
=
True
)
+
eps
)
torsion_angles_sin_cos
=
torsion_angles_sin_cos
/
denom
torsion_angles_sin_cos
=
torsion_angles_sin_cos
*
torch
.
tensor
(
[
1.
,
1.
,
-
1.
,
1.
,
1.
,
1.
,
1.
],
device
=
aatype
.
device
,
torsion_angles_sin_cos
=
torsion_angles_sin_cos
*
all_atom_mask
.
new_
tensor
(
[
1.
,
1.
,
-
1.
,
1.
,
1.
,
1.
,
1.
],
)[((
None
,)
*
len
(
torsion_angles_sin_cos
.
shape
[:
-
2
]))
+
(
slice
(
None
),
None
)]
chi_is_ambiguous
=
torsion_angles_sin_cos
.
new_tensor
(
...
...
@@ -327,6 +339,7 @@ def atom37_to_frames(
aatype
:
torch
.
Tensor
,
all_atom_positions
:
torch
.
Tensor
,
all_atom_mask
:
torch
.
Tensor
,
eps
:
float
,
**
kwargs
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
batch_dims
=
len
(
aatype
.
shape
[:
-
1
])
...
...
@@ -387,6 +400,7 @@ def atom37_to_frames(
p_neg_x_axis
=
base_atom_pos
[...,
0
,
:],
origin
=
base_atom_pos
[...,
1
,
:],
p_xy_plane
=
base_atom_pos
[...,
2
,
:],
eps
=
eps
,
)
group_exists
=
batched_gather
(
...
...
@@ -638,3 +652,106 @@ def build_ambiguity_feats(batch: Dict[str, torch.Tensor]) -> None:
"atom14_alt_gt_positions"
:
atom14_alt_gt_positions
,
"atom14_alt_gt_exists"
:
atom14_alt_gt_exists
,
}
def
torsion_angles_to_frames
(
t
:
T
,
alpha
:
torch
.
Tensor
,
aatype
:
torch
.
Tensor
,
rrgdf
:
torch
.
Tensor
,
):
# [*, N, 8, 4, 4]
default_4x4
=
rrgdf
[
aatype
,
...]
# [*, N, 8] transformations, i.e.
# One [*, N, 8, 3, 3] rotation matrix and
# One [*, N, 8, 3] translation matrix
default_t
=
T
.
from_4x4
(
default_4x4
)
bb_rot
=
alpha
.
new_zeros
((
*
((
1
,)
*
len
(
alpha
.
shape
[:
-
1
])),
2
))
bb_rot
[...,
1
]
=
1
# [*, N, 8, 2]
alpha
=
torch
.
cat
(
[
bb_rot
.
expand
(
*
alpha
.
shape
[:
-
2
],
-
1
,
-
1
),
alpha
],
dim
=-
2
)
# [*, N, 8, 3, 3]
# Produces rotation matrices of the form:
# [
# [1, 0 , 0 ],
# [0, a_2,-a_1],
# [0, a_1, a_2]
# ]
# This follows the original code rather than the supplement, which uses
# different indices.
all_rots
=
alpha
.
new_zeros
(
default_t
.
rots
.
shape
)
all_rots
[...,
0
,
0
]
=
1
all_rots
[...,
1
,
1
]
=
alpha
[...,
1
]
all_rots
[...,
1
,
2
]
=
-
alpha
[...,
0
]
all_rots
[...,
2
,
1
:]
=
alpha
all_rots
=
T
(
all_rots
,
None
)
all_frames
=
default_t
.
compose
(
all_rots
)
chi2_frame_to_frame
=
all_frames
[...,
5
]
chi3_frame_to_frame
=
all_frames
[...,
6
]
chi4_frame_to_frame
=
all_frames
[...,
7
]
chi1_frame_to_bb
=
all_frames
[...,
4
]
chi2_frame_to_bb
=
chi1_frame_to_bb
.
compose
(
chi2_frame_to_frame
)
chi3_frame_to_bb
=
chi2_frame_to_bb
.
compose
(
chi3_frame_to_frame
)
chi4_frame_to_bb
=
chi3_frame_to_bb
.
compose
(
chi4_frame_to_frame
)
all_frames_to_bb
=
T
.
concat
([
all_frames
[...,
:
5
],
chi2_frame_to_bb
.
unsqueeze
(
-
1
),
chi3_frame_to_bb
.
unsqueeze
(
-
1
),
chi4_frame_to_bb
.
unsqueeze
(
-
1
),
],
dim
=-
1
,
)
all_frames_to_global
=
t
[...,
None
].
compose
(
all_frames_to_bb
)
return
all_frames_to_global
def
frames_and_literature_positions_to_atom14_pos
(
t
:
T
,
aatype
:
torch
.
Tensor
,
default_frames
,
group_idx
,
atom_mask
,
lit_positions
,
):
# [*, N, 14, 4, 4]
default_4x4
=
default_frames
[
aatype
,
...]
# [*, N, 14]
group_mask
=
group_idx
[
aatype
,
...]
# [*, N, 14, 8]
group_mask
=
nn
.
functional
.
one_hot
(
group_mask
,
num_classes
=
default_frames
.
shape
[
-
3
],
)
# [*, N, 14, 8]
t_atoms_to_global
=
t
[...,
None
,
:]
*
group_mask
# [*, N, 14]
t_atoms_to_global
=
t_atoms_to_global
.
map_tensor_fn
(
lambda
x
:
torch
.
sum
(
x
,
dim
=-
1
)
)
# [*, N, 14, 1]
atom_mask
=
atom_mask
[
aatype
,
...].
unsqueeze
(
-
1
)
# [*, N, 14, 3]
lit_positions
=
lit_positions
[
aatype
,
...]
pred_positions
=
t_atoms_to_global
.
apply
(
lit_positions
)
pred_positions
=
pred_positions
*
atom_mask
return
pred_positions
openfold/utils/import_weights.py
View file @
893fe372
...
...
@@ -417,7 +417,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
if
(
"_ptm"
in
version
):
translations
[
"predicted_aligned_error_head"
]
=
{
"logits"
:
LinearParams
(
model
.
aux_heads
.
tm
_score
.
linear
)
LinearParams
(
model
.
aux_heads
.
tm
.
linear
)
}
# Flatten keys and insert missing key prefixes
...
...
openfold/utils/loss.py
View file @
893fe372
...
...
@@ -273,7 +273,6 @@ def supervised_chi_loss(
shifted_mask
=
(
1
-
2
*
chi_pi_periodic
).
unsqueeze
(
-
1
)
true_chi_shifted
=
shifted_mask
*
true_chi
sq_chi_error
=
torch
.
sum
(
(
true_chi
-
pred_angles
)
**
2
,
dim
=-
1
)
...
...
@@ -498,11 +497,11 @@ def tm_loss(
)
loss
=
torch
.
sum
(
errors
*
square_mask
,
dim
=-
1
)
scale
=
0.
1
# hack to help FP16 training along
scale
=
0.
5
# hack to help FP16 training along
denom
=
eps
+
torch
.
sum
(
scale
*
square_mask
,
dim
=
(
-
1
,
-
2
))
loss
=
loss
/
denom
[...,
None
]
loss
=
torch
.
sum
(
loss
,
dim
=-
1
)
loss
=
loss
/
scale
loss
=
loss
*
scale
loss
=
loss
*
(
(
resolution
>=
min_resolution
)
&
...
...
@@ -744,7 +743,7 @@ def between_residue_clash_loss(
# Backbone C--N bond between subsequent residues is no clash.
c_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
residue_index
.
new_tensor
(
2
.
),
num_classes
=
14
residue_index
.
new_tensor
(
2
),
num_classes
=
14
)
c_one_hot
=
c_one_hot
.
reshape
(
*
((
1
,)
*
len
(
residue_index
.
shape
[:
-
1
])),
*
c_one_hot
.
shape
...
...
@@ -1319,11 +1318,11 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
# )
loss
=
errors
*
bert_mask
loss
=
torch
.
sum
(
loss
,
dim
=-
1
)
scale
=
0.
1
scale
=
0.
5
denom
=
eps
+
torch
.
sum
(
scale
*
bert_mask
,
dim
=
(
-
1
,
-
2
))
loss
=
loss
/
denom
[...,
None
]
loss
=
torch
.
sum
(
loss
,
dim
=-
1
)
loss
=
loss
/
scale
loss
=
loss
*
scale
return
loss
...
...
@@ -1352,7 +1351,7 @@ class AlphaFoldLoss(nn.Module):
))
if
(
"backbone_affine_tensor"
not
in
batch
.
keys
()):
batch
.
update
(
feats
.
atom37_to_frames
(
**
batch
))
batch
.
update
(
feats
.
atom37_to_frames
(
eps
=
self
.
config
.
eps
,
**
batch
))
# TODO: Verify that this is correct
batch
[
"backbone_affine_tensor"
]
=
(
...
...
@@ -1363,16 +1362,19 @@ class AlphaFoldLoss(nn.Module):
)
if
(
"chi_angles_sin_cos"
not
in
batch
.
keys
()):
batch
.
update
(
feats
.
atom37_to_torsion_angles
(
**
batch
,
eps
=
self
.
config
.
eps
,
))
# TODO: Verify that this is correct
batch
[
"chi_angles_sin_cos"
]
=
(
batch
[
"torsion_angles_sin_cos"
][...,
3
:,
:]
)
batch
[
"chi_mask"
]
=
batch
[
"torsion_angles_mask"
][...,
3
:]
with
torch
.
no_grad
():
batch
.
update
(
feats
.
atom37_to_torsion_angles
(
aatype
=
batch
[
"aatype"
],
all_atom_positions
=
batch
[
"all_atom_positions"
].
double
(),
all_atom_mask
=
batch
[
"all_atom_mask"
].
double
(),
eps
=
self
.
config
.
eps
,
))
# TODO: Verify that this is correct
batch
[
"chi_angles_sin_cos"
]
=
(
batch
[
"torsion_angles_sin_cos"
][...,
3
:,
:]
).
to
(
batch
[
"all_atom_mask"
].
dtype
)
batch
[
"chi_mask"
]
=
batch
[
"torsion_angles_mask"
][...,
3
:].
to
(
batch
[
"all_atom_mask"
].
dtype
)
loss_fns
=
{
"distogram"
:
...
...
openfold/utils/tensor_utils.py
View file @
893fe372
...
...
@@ -25,11 +25,11 @@ def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
return
tensor
.
permute
(
first_inds
+
[
zero_index
+
i
for
i
in
inds
])
def
flatten_final_dims
(
t
ensor
:
torch
.
Tensor
,
no_dims
:
int
):
return
t
ensor
.
reshape
(
t
ensor
.
shape
[:
-
no_dims
]
+
(
-
1
,))
def
flatten_final_dims
(
t
:
torch
.
Tensor
,
no_dims
:
int
):
return
t
.
reshape
(
t
.
shape
[:
-
no_dims
]
+
(
-
1
,))
def
masked_mean
(
mask
,
value
,
dim
,
eps
=
1e-
10
):
def
masked_mean
(
mask
,
value
,
dim
,
eps
=
1e-
4
):
mask
=
mask
.
expand
(
*
value
.
shape
)
return
torch
.
sum
(
mask
*
value
,
dim
=
dim
)
/
(
eps
+
torch
.
sum
(
mask
,
dim
=
dim
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment