Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
78567a86
Commit
78567a86
authored
Apr 27, 2022
by
Gustaf Ahdritz
Browse files
Fix validation metric bugs
parent
e1796607
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
104 additions
and
51 deletions
+104
-51
openfold/data/data_pipeline.py
openfold/data/data_pipeline.py
+16
-7
openfold/data/data_transforms.py
openfold/data/data_transforms.py
+1
-0
openfold/utils/loss.py
openfold/utils/loss.py
+6
-33
openfold/utils/superimposition.py
openfold/utils/superimposition.py
+29
-5
openfold/utils/validation_metrics.py
openfold/utils/validation_metrics.py
+32
-1
train_openfold.py
train_openfold.py
+20
-5
No files found.
openfold/data/data_pipeline.py
View file @
78567a86
...
...
@@ -163,8 +163,8 @@ def make_protein_features(
def
make_pdb_features
(
protein_object
:
protein
.
Protein
,
description
:
str
,
confidence_threshold
:
float
=
0.5
,
is_distillation
:
bool
=
True
,
confidence_threshold
:
float
=
50.
,
)
->
FeatureDict
:
pdb_feats
=
make_protein_features
(
protein_object
,
description
,
_is_distillation
=
True
...
...
@@ -173,9 +173,7 @@ def make_pdb_features(
if
(
is_distillation
):
high_confidence
=
protein_object
.
b_factors
>
confidence_threshold
high_confidence
=
np
.
any
(
high_confidence
,
axis
=-
1
)
for
i
,
confident
in
enumerate
(
high_confidence
):
if
(
not
confident
):
pdb_feats
[
"all_atom_mask"
][
i
]
=
0
pdb_feats
[
"all_atom_mask"
]
*=
high_confidence
[...,
None
]
return
pdb_feats
...
...
@@ -620,13 +618,24 @@ class DataPipeline:
alignment_dir
:
str
,
is_distillation
:
bool
=
True
,
chain_id
:
Optional
[
str
]
=
None
,
_structure_index
:
Optional
[
str
]
=
None
,
_alignment_index
:
Optional
[
str
]
=
None
,
)
->
FeatureDict
:
"""
Assembles features for a protein in a PDB file.
"""
with
open
(
pdb_path
,
'r'
)
as
f
:
pdb_str
=
f
.
read
()
if
(
_structure_index
is
not
None
):
db_dir
=
os
.
path
.
dirname
(
pdb_path
)
db
=
_structure_index
[
"db"
]
db_path
=
os
.
path
.
join
(
db_dir
,
db
)
fp
=
open
(
db_path
,
"rb"
)
_
,
offset
,
length
=
_structure_index
[
"files"
][
0
]
fp
.
seek
(
offset
)
pdb_str
=
fp
.
read
(
length
).
decode
(
"utf-8"
)
fp
.
close
()
else
:
with
open
(
pdb_path
,
'r'
)
as
f
:
pdb_str
=
f
.
read
()
protein_object
=
protein
.
from_pdb_string
(
pdb_str
,
chain_id
)
input_sequence
=
_aatype_to_str_sequence
(
protein_object
.
aatype
)
...
...
@@ -634,7 +643,7 @@ class DataPipeline:
pdb_feats
=
make_pdb_features
(
protein_object
,
description
,
is_distillation
is_distillation
=
is_distillation
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
,
_alignment_index
)
...
...
openfold/data/data_transforms.py
View file @
78567a86
...
...
@@ -463,6 +463,7 @@ def make_masked_msa(protein, config, replace_fraction):
1.0
-
config
.
profile_prob
-
config
.
same_prob
-
config
.
uniform_prob
)
assert
mask_prob
>=
0.0
categorical_probs
=
torch
.
nn
.
functional
.
pad
(
categorical_probs
,
pad_shapes
,
value
=
mask_prob
)
...
...
openfold/utils/loss.py
View file @
78567a86
...
...
@@ -334,10 +334,12 @@ def supervised_chi_loss(
(
true_chi_shifted
-
pred_angles
)
**
2
,
dim
=-
1
)
sq_chi_error
=
torch
.
minimum
(
sq_chi_error
,
sq_chi_error_shifted
)
# The ol' switcheroo
sq_chi_error
=
sq_chi_error
.
permute
(
*
range
(
len
(
sq_chi_error
.
shape
))[
1
:
-
2
],
0
,
-
2
,
-
1
)
sq_chi_loss
=
masked_mean
(
chi_mask
[...,
None
,
:,
:],
sq_chi_error
,
dim
=
(
-
1
,
-
2
,
-
3
)
)
...
...
@@ -1513,39 +1515,6 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
return
loss
def
compute_drmsd
(
structure_1
,
structure_2
,
mask
=
None
):
if
(
mask
is
not
None
):
structure_1
=
structure_1
*
mask
[...,
None
]
structure_2
=
structure_2
*
mask
[...,
None
]
d1
=
structure_1
[...,
:,
None
,
:]
-
structure_1
[...,
None
,
:,
:]
d2
=
structure_2
[...,
:,
None
,
:]
-
structure_2
[...,
None
,
:,
:]
d1
=
d1
**
2
d2
=
d2
**
2
d1
=
torch
.
sqrt
(
torch
.
sum
(
d1
,
dim
=-
1
))
d2
=
torch
.
sqrt
(
torch
.
sum
(
d2
,
dim
=-
1
))
drmsd
=
d1
-
d2
drmsd
=
drmsd
**
2
drmsd
=
torch
.
sum
(
drmsd
,
dim
=
(
-
1
,
-
2
))
n
=
d1
.
shape
[
-
1
]
if
mask
is
None
else
torch
.
sum
(
mask
,
dim
=-
1
)
drmsd
=
drmsd
*
(
1
/
(
n
*
(
n
-
1
)))
if
n
>
1
else
(
drmsd
*
0.
)
drmsd
=
torch
.
sqrt
(
drmsd
)
return
drmsd
def
compute_drmsd_np
(
structure_1
,
structure_2
,
mask
=
None
):
structure_1
=
torch
.
tensor
(
structure_1
)
structure_2
=
torch
.
tensor
(
structure_2
)
if
(
mask
is
not
None
):
mask
=
torch
.
tensor
(
mask
)
return
compute_drmsd
(
structure_1
,
structure_2
,
mask
)
class
AlphaFoldLoss
(
nn
.
Module
):
"""Aggregation of the various losses described in the supplement"""
def
__init__
(
self
,
config
):
...
...
@@ -1614,6 +1583,10 @@ class AlphaFoldLoss(nn.Module):
weight
=
self
.
config
[
loss_name
].
weight
loss
=
loss_fn
()
if
(
torch
.
isnan
(
loss
)
or
torch
.
isinf
(
loss
)):
#for k,v in batch.items():
# if(torch.any(torch.isnan(v)) or torch.any(torch.isinf(v))):
# logging.warning(f"{k}: is nan")
#logging.warning(f"{loss_name}: {loss}")
logging
.
warning
(
f
"
{
loss_name
}
loss is NaN. Skipping..."
)
loss
=
loss
.
new_tensor
(
0.
,
requires_grad
=
True
)
cum_loss
=
cum_loss
+
weight
*
loss
...
...
openfold/utils/superimposition.py
View file @
78567a86
...
...
@@ -42,7 +42,7 @@ def _superimpose_single(reference, coords):
return
coords
.
new_tensor
(
superimposed
),
coords
.
new_tensor
(
rmsd
)
def
superimpose
(
reference
,
coords
):
def
superimpose
(
reference
,
coords
,
mask
):
"""
Superimposes coordinates onto a reference by minimizing RMSD using SVD.
...
...
@@ -51,18 +51,42 @@ def superimpose(reference, coords):
[*, N, 3] reference tensor
coords:
[*, N, 3] tensor
mask:
[*, N] tensor
Returns:
A tuple of [*, N, 3] superimposed coords and [*] final RMSDs.
"""
def
select_unmasked_coords
(
coords
,
mask
):
return
torch
.
masked_select
(
coords
,
(
mask
>
0.
)[...,
None
],
).
reshape
(
-
1
,
3
)
batch_dims
=
reference
.
shape
[:
-
2
]
flat_reference
=
reference
.
reshape
((
-
1
,)
+
reference
.
shape
[
-
2
:])
flat_coords
=
coords
.
reshape
((
-
1
,)
+
reference
.
shape
[
-
2
:])
flat_mask
=
mask
.
reshape
((
-
1
,)
+
mask
.
shape
[
-
1
:])
superimposed_list
=
[]
rmsds
=
[]
for
r
,
c
in
zip
(
flat_reference
,
flat_coords
):
superimposed
,
rmsd
=
_superimpose_single
(
r
,
c
)
superimposed_list
.
append
(
superimposed
)
rmsds
.
append
(
rmsd
)
for
r
,
c
,
m
in
zip
(
flat_reference
,
flat_coords
,
flat_mask
):
r_unmasked_coords
=
select_unmasked_coords
(
r
,
m
)
c_unmasked_coords
=
select_unmasked_coords
(
c
,
m
)
superimposed
,
rmsd
=
_superimpose_single
(
r_unmasked_coords
,
c_unmasked_coords
)
# This is very inelegant, but idk how else to invert the masking
# procedure.
count
=
0
superimposed_full_size
=
torch
.
zeros_like
(
r
)
for
i
,
unmasked
in
enumerate
(
m
):
if
(
unmasked
):
superimposed_full_size
[
i
]
=
superimposed
[
count
]
count
+=
1
superimposed_list
.
append
(
superimposed_full_size
)
rmsds
.
append
(
rmsd
)
superimposed_stacked
=
torch
.
stack
(
superimposed_list
,
dim
=
0
)
rmsds_stacked
=
torch
.
stack
(
rmsds
,
dim
=
0
)
...
...
openfold/utils/validation_metrics.py
View file @
78567a86
...
...
@@ -14,16 +14,47 @@
import
torch
def
drmsd
(
structure_1
,
structure_2
,
mask
=
None
):
def
prep_d
(
structure
):
d
=
structure
[...,
:,
None
,
:]
-
structure
[...,
None
,
:,
:]
d
=
d
**
2
d
=
torch
.
sqrt
(
torch
.
sum
(
d
,
dim
=-
1
))
return
d
d1
=
prep_d
(
structure_1
)
d2
=
prep_d
(
structure_2
)
drmsd
=
d1
-
d2
drmsd
=
drmsd
**
2
if
(
mask
is
not
None
):
drmsd
=
drmsd
*
(
mask
[...,
None
]
*
mask
[...,
None
,
:])
drmsd
=
torch
.
sum
(
drmsd
,
dim
=
(
-
1
,
-
2
))
n
=
d1
.
shape
[
-
1
]
if
mask
is
None
else
torch
.
sum
(
mask
,
dim
=-
1
)
drmsd
=
drmsd
*
(
1
/
(
n
*
(
n
-
1
)))
if
n
>
1
else
(
drmsd
*
0.
)
drmsd
=
torch
.
sqrt
(
drmsd
)
return
drmsd
def
drmsd_np
(
structure_1
,
structure_2
,
mask
=
None
):
structure_1
=
torch
.
tensor
(
structure_1
)
structure_2
=
torch
.
tensor
(
structure_2
)
if
(
mask
is
not
None
):
mask
=
torch
.
tensor
(
mask
)
return
drmsd
(
structure_1
,
structure_2
,
mask
)
def
gdt
(
p1
,
p2
,
mask
,
cutoffs
):
n
=
torch
.
sum
(
mask
,
dim
=-
1
)
p1
=
p1
.
float
()
p2
=
p2
.
float
()
distances
=
torch
.
sqrt
(
torch
.
sum
((
p1
-
p2
)
**
2
,
dim
=-
1
))
scores
=
[]
for
c
in
cutoffs
:
score
=
torch
.
sum
((
distances
<=
c
)
*
mask
,
dim
=-
1
)
/
n
score
=
torch
.
mean
(
score
)
scores
.
append
(
score
)
return
sum
(
scores
)
/
len
(
scores
)
...
...
train_openfold.py
View file @
78567a86
...
...
@@ -8,6 +8,7 @@ import os
#os.environ["NODE_RANK"]="0"
import
random
import
sys
import
time
import
numpy
as
np
...
...
@@ -32,12 +33,13 @@ from openfold.utils.callbacks import (
EarlyStoppingVerbose
,
)
from
openfold.utils.exponential_moving_average
import
ExponentialMovingAverage
from
openfold.utils.loss
import
AlphaFoldLoss
,
lddt_ca
,
compute_drmsd
from
openfold.utils.loss
import
AlphaFoldLoss
,
lddt_ca
from
openfold.utils.lr_schedulers
import
AlphaFoldLRScheduler
from
openfold.utils.seed
import
seed_everything
from
openfold.utils.superimposition
import
superimpose
from
openfold.utils.tensor_utils
import
tensor_tree_map
from
openfold.utils.validation_metrics
import
(
drmsd
,
gdt_ts
,
gdt_ha
,
)
...
...
@@ -59,6 +61,7 @@ class OpenFoldWrapper(pl.LightningModule):
)
self
.
cached_weights
=
None
self
.
last_lr_step
=
0
def
forward
(
self
,
batch
):
return
self
.
model
(
batch
)
...
...
@@ -172,7 +175,7 @@ class OpenFoldWrapper(pl.LightningModule):
metrics
[
"lddt_ca"
]
=
lddt_ca_score
drmsd_ca_score
=
compute_
drmsd
(
drmsd_ca_score
=
drmsd
(
pred_coords_masked_ca
,
gt_coords_masked_ca
,
mask
=
all_atom_mask_ca
,
# still required here to compute n
...
...
@@ -181,8 +184,8 @@ class OpenFoldWrapper(pl.LightningModule):
metrics
[
"drmsd_ca"
]
=
drmsd_ca_score
if
(
superimposition_metrics
):
superimposed_pred
,
_
=
superimpose
(
gt_coords_masked_ca
,
pred_coords_masked_ca
superimposed_pred
,
alignment_rmsd
=
superimpose
(
gt_coords_masked_ca
,
pred_coords_masked_ca
,
all_atom_mask_ca
,
)
gdt_ts_score
=
gdt_ts
(
superimposed_pred
,
gt_coords_masked_ca
,
all_atom_mask_ca
...
...
@@ -191,6 +194,7 @@ class OpenFoldWrapper(pl.LightningModule):
superimposed_pred
,
gt_coords_masked_ca
,
all_atom_mask_ca
)
metrics
[
"alignment_rmsd"
]
=
alignment_rmsd
metrics
[
"gdt_ts"
]
=
gdt_ts_score
metrics
[
"gdt_ha"
]
=
gdt_ha_score
...
...
@@ -312,7 +316,12 @@ def main(args):
strategy
=
DDPPlugin
(
find_unused_parameters
=
False
)
else
:
strategy
=
None
if
(
args
.
wandb
):
freeze_path
=
f
"
{
wdb_logger
.
experiment
.
dir
}
/package_versions.txt"
os
.
system
(
f
"
{
sys
.
executable
}
-m pip freeze >
{
freeze_path
}
"
)
wdb_logger
.
experiment
.
save
(
f
"
{
freeze_path
}
"
)
trainer
=
pl
.
Trainer
.
from_argparse_args
(
args
,
default_root_dir
=
args
.
output_dir
,
...
...
@@ -499,9 +508,15 @@ if __name__ == "__main__":
'used.'
)
)
parser
.
add_argument
(
"--_distillation_structure_index_path"
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
"--_alignment_index_path"
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
"--_distillation_alignment_index_path"
,
type
=
str
,
default
=
None
,
)
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
# Disable the initial validation pass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment