"segmentation/git@developer.sourcefind.cn:OpenDAS/dcnv3.git" did not exist on "b1238b9da254025fab301c4e5a0cdc298795e419"
Commit 78567a86 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix validation metric bugs

parent e1796607
......@@ -163,8 +163,8 @@ def make_protein_features(
def make_pdb_features(
protein_object: protein.Protein,
description: str,
confidence_threshold: float = 0.5,
is_distillation: bool = True,
confidence_threshold: float = 50.,
) -> FeatureDict:
pdb_feats = make_protein_features(
protein_object, description, _is_distillation=True
......@@ -173,9 +173,7 @@ def make_pdb_features(
if(is_distillation):
high_confidence = protein_object.b_factors > confidence_threshold
high_confidence = np.any(high_confidence, axis=-1)
for i, confident in enumerate(high_confidence):
if(not confident):
pdb_feats["all_atom_mask"][i] = 0
pdb_feats["all_atom_mask"] *= high_confidence[..., None]
return pdb_feats
......@@ -620,13 +618,24 @@ class DataPipeline:
alignment_dir: str,
is_distillation: bool = True,
chain_id: Optional[str] = None,
_structure_index: Optional[str] = None,
_alignment_index: Optional[str] = None,
) -> FeatureDict:
"""
Assembles features for a protein in a PDB file.
"""
with open(pdb_path, 'r') as f:
pdb_str = f.read()
if(_structure_index is not None):
db_dir = os.path.dirname(pdb_path)
db = _structure_index["db"]
db_path = os.path.join(db_dir, db)
fp = open(db_path, "rb")
_, offset, length = _structure_index["files"][0]
fp.seek(offset)
pdb_str = fp.read(length).decode("utf-8")
fp.close()
else:
with open(pdb_path, 'r') as f:
pdb_str = f.read()
protein_object = protein.from_pdb_string(pdb_str, chain_id)
input_sequence = _aatype_to_str_sequence(protein_object.aatype)
......@@ -634,7 +643,7 @@ class DataPipeline:
pdb_feats = make_pdb_features(
protein_object,
description,
is_distillation
is_distillation=is_distillation
)
hits = self._parse_template_hits(alignment_dir, _alignment_index)
......
......@@ -463,6 +463,7 @@ def make_masked_msa(protein, config, replace_fraction):
1.0 - config.profile_prob - config.same_prob - config.uniform_prob
)
assert mask_prob >= 0.0
categorical_probs = torch.nn.functional.pad(
categorical_probs, pad_shapes, value=mask_prob
)
......
......@@ -334,10 +334,12 @@ def supervised_chi_loss(
(true_chi_shifted - pred_angles) ** 2, dim=-1
)
sq_chi_error = torch.minimum(sq_chi_error, sq_chi_error_shifted)
# The ol' switcheroo
sq_chi_error = sq_chi_error.permute(
*range(len(sq_chi_error.shape))[1:-2], 0, -2, -1
)
sq_chi_loss = masked_mean(
chi_mask[..., None, :, :], sq_chi_error, dim=(-1, -2, -3)
)
......@@ -1513,39 +1515,6 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
return loss
def compute_drmsd(structure_1, structure_2, mask=None):
if(mask is not None):
structure_1 = structure_1 * mask[..., None]
structure_2 = structure_2 * mask[..., None]
d1 = structure_1[..., :, None, :] - structure_1[..., None, :, :]
d2 = structure_2[..., :, None, :] - structure_2[..., None, :, :]
d1 = d1 ** 2
d2 = d2 ** 2
d1 = torch.sqrt(torch.sum(d1, dim=-1))
d2 = torch.sqrt(torch.sum(d2, dim=-1))
drmsd = d1 - d2
drmsd = drmsd ** 2
drmsd = torch.sum(drmsd, dim=(-1, -2))
n = d1.shape[-1] if mask is None else torch.sum(mask, dim=-1)
drmsd = drmsd * (1 / (n * (n - 1))) if n > 1 else (drmsd * 0.)
drmsd = torch.sqrt(drmsd)
return drmsd
def compute_drmsd_np(structure_1, structure_2, mask=None):
structure_1 = torch.tensor(structure_1)
structure_2 = torch.tensor(structure_2)
if(mask is not None):
mask = torch.tensor(mask)
return compute_drmsd(structure_1, structure_2, mask)
class AlphaFoldLoss(nn.Module):
"""Aggregation of the various losses described in the supplement"""
def __init__(self, config):
......@@ -1614,6 +1583,10 @@ class AlphaFoldLoss(nn.Module):
weight = self.config[loss_name].weight
loss = loss_fn()
if(torch.isnan(loss) or torch.isinf(loss)):
#for k,v in batch.items():
# if(torch.any(torch.isnan(v)) or torch.any(torch.isinf(v))):
# logging.warning(f"{k}: is nan")
#logging.warning(f"{loss_name}: {loss}")
logging.warning(f"{loss_name} loss is NaN. Skipping...")
loss = loss.new_tensor(0., requires_grad=True)
cum_loss = cum_loss + weight * loss
......
......@@ -42,7 +42,7 @@ def _superimpose_single(reference, coords):
return coords.new_tensor(superimposed), coords.new_tensor(rmsd)
def superimpose(reference, coords):
def superimpose(reference, coords, mask):
"""
Superimposes coordinates onto a reference by minimizing RMSD using SVD.
......@@ -51,18 +51,42 @@ def superimpose(reference, coords):
[*, N, 3] reference tensor
coords:
[*, N, 3] tensor
mask:
[*, N] tensor
Returns:
A tuple of [*, N, 3] superimposed coords and [*] final RMSDs.
"""
def select_unmasked_coords(coords, mask):
return torch.masked_select(
coords,
(mask > 0.)[..., None],
).reshape(-1, 3)
batch_dims = reference.shape[:-2]
flat_reference = reference.reshape((-1,) + reference.shape[-2:])
flat_coords = coords.reshape((-1,) + reference.shape[-2:])
flat_mask = mask.reshape((-1,) + mask.shape[-1:])
superimposed_list = []
rmsds = []
for r, c in zip(flat_reference, flat_coords):
superimposed, rmsd = _superimpose_single(r, c)
superimposed_list.append(superimposed)
rmsds.append(rmsd)
for r, c, m in zip(flat_reference, flat_coords, flat_mask):
r_unmasked_coords = select_unmasked_coords(r, m)
c_unmasked_coords = select_unmasked_coords(c, m)
superimposed, rmsd = _superimpose_single(
r_unmasked_coords,
c_unmasked_coords
)
# This is very inelegant, but idk how else to invert the masking
# procedure.
count = 0
superimposed_full_size = torch.zeros_like(r)
for i, unmasked in enumerate(m):
if(unmasked):
superimposed_full_size[i] = superimposed[count]
count += 1
superimposed_list.append(superimposed_full_size)
rmsds.append(rmsd)
superimposed_stacked = torch.stack(superimposed_list, dim=0)
rmsds_stacked = torch.stack(rmsds, dim=0)
......
......@@ -14,16 +14,47 @@
import torch
def drmsd(structure_1, structure_2, mask=None):
def prep_d(structure):
d = structure[..., :, None, :] - structure[..., None, :, :]
d = d ** 2
d = torch.sqrt(torch.sum(d, dim=-1))
return d
d1 = prep_d(structure_1)
d2 = prep_d(structure_2)
drmsd = d1 - d2
drmsd = drmsd ** 2
if(mask is not None):
drmsd = drmsd * (mask[..., None] * mask[..., None, :])
drmsd = torch.sum(drmsd, dim=(-1, -2))
n = d1.shape[-1] if mask is None else torch.sum(mask, dim=-1)
drmsd = drmsd * (1 / (n * (n - 1))) if n > 1 else (drmsd * 0.)
drmsd = torch.sqrt(drmsd)
return drmsd
def drmsd_np(structure_1, structure_2, mask=None):
structure_1 = torch.tensor(structure_1)
structure_2 = torch.tensor(structure_2)
if(mask is not None):
mask = torch.tensor(mask)
return drmsd(structure_1, structure_2, mask)
def gdt(p1, p2, mask, cutoffs):
n = torch.sum(mask, dim=-1)
p1 = p1.float()
p2 = p2.float()
distances = torch.sqrt(torch.sum((p1 - p2)**2, dim=-1))
scores = []
for c in cutoffs:
score = torch.sum((distances <= c) * mask, dim=-1) / n
score = torch.mean(score)
scores.append(score)
return sum(scores) / len(scores)
......
......@@ -8,6 +8,7 @@ import os
#os.environ["NODE_RANK"]="0"
import random
import sys
import time
import numpy as np
......@@ -32,12 +33,13 @@ from openfold.utils.callbacks import (
EarlyStoppingVerbose,
)
from openfold.utils.exponential_moving_average import ExponentialMovingAverage
from openfold.utils.loss import AlphaFoldLoss, lddt_ca, compute_drmsd
from openfold.utils.loss import AlphaFoldLoss, lddt_ca
from openfold.utils.lr_schedulers import AlphaFoldLRScheduler
from openfold.utils.seed import seed_everything
from openfold.utils.superimposition import superimpose
from openfold.utils.tensor_utils import tensor_tree_map
from openfold.utils.validation_metrics import (
drmsd,
gdt_ts,
gdt_ha,
)
......@@ -59,6 +61,7 @@ class OpenFoldWrapper(pl.LightningModule):
)
self.cached_weights = None
self.last_lr_step = 0
def forward(self, batch):
return self.model(batch)
......@@ -172,7 +175,7 @@ class OpenFoldWrapper(pl.LightningModule):
metrics["lddt_ca"] = lddt_ca_score
drmsd_ca_score = compute_drmsd(
drmsd_ca_score = drmsd(
pred_coords_masked_ca,
gt_coords_masked_ca,
mask=all_atom_mask_ca, # still required here to compute n
......@@ -181,8 +184,8 @@ class OpenFoldWrapper(pl.LightningModule):
metrics["drmsd_ca"] = drmsd_ca_score
if(superimposition_metrics):
superimposed_pred, _ = superimpose(
gt_coords_masked_ca, pred_coords_masked_ca
superimposed_pred, alignment_rmsd = superimpose(
gt_coords_masked_ca, pred_coords_masked_ca, all_atom_mask_ca,
)
gdt_ts_score = gdt_ts(
superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
......@@ -191,6 +194,7 @@ class OpenFoldWrapper(pl.LightningModule):
superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
)
metrics["alignment_rmsd"] = alignment_rmsd
metrics["gdt_ts"] = gdt_ts_score
metrics["gdt_ha"] = gdt_ha_score
......@@ -312,7 +316,12 @@ def main(args):
strategy = DDPPlugin(find_unused_parameters=False)
else:
strategy = None
if(args.wandb):
freeze_path = f"{wdb_logger.experiment.dir}/package_versions.txt"
os.system(f"{sys.executable} -m pip freeze > {freeze_path}")
wdb_logger.experiment.save(f"{freeze_path}")
trainer = pl.Trainer.from_argparse_args(
args,
default_root_dir=args.output_dir,
......@@ -499,9 +508,15 @@ if __name__ == "__main__":
'used.'
)
)
parser.add_argument(
"--_distillation_structure_index_path", type=str, default=None,
)
parser.add_argument(
"--_alignment_index_path", type=str, default=None,
)
parser.add_argument(
"--_distillation_alignment_index_path", type=str, default=None,
)
parser = pl.Trainer.add_argparse_args(parser)
# Disable the initial validation pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment