Commit 6298a3e6 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix PDB parsing, add distillation MSA cropping

parent 4d40ce80
......@@ -179,6 +179,7 @@ config = mlc.ConfigDict(
"all_atom_positions",
"resolution",
"use_clamped_fape",
"is_distillation",
],
},
"predict": {
......@@ -192,6 +193,7 @@ config = mlc.ConfigDict(
"crop": False,
"crop_size": None,
"supervised": False,
"subsample_recycling": False,
},
"eval": {
"fixed_size": True,
......@@ -204,6 +206,7 @@ config = mlc.ConfigDict(
"crop": False,
"crop_size": None,
"supervised": True,
"subsample_recycling": False,
},
"train": {
"fixed_size": True,
......@@ -218,6 +221,7 @@ config = mlc.ConfigDict(
"supervised": True,
"clamp_prob": 0.9,
"subsample_recycling": True,
"max_distillation_msa_clusters": 1000,
},
"data_module": {
"use_small_bfd": False,
......
import copy
from functools import partial
import json
import logging
......@@ -462,9 +463,9 @@ class DummyDataset(torch.utils.data.Dataset):
class DummyDataLoader(pl.LightningDataModule):
def __init__(self):
def __init__(self, batch_path):
super().__init__()
self.dataset = Dataset()
self.dataset = DummyDataset(batch_path)
def train_dataloader(self):
return torch.utils.data.DataLoader(self.dataset)
......@@ -113,6 +113,7 @@ def make_pdb_features(
pdb_feats["all_atom_positions"] = all_atom_positions
pdb_feats["all_atom_mask"] = all_atom_mask
pdb_feats["resolution"] = np.array([0.]).astype(np.float32)
pdb_feats["is_distillation"] = np.array(1.).astype(np.float32)
return pdb_feats
......@@ -412,16 +413,12 @@ class DataPipeline:
pdb_feats = make_pdb_features(protein_object)
mmcif_feats = make_mmcif_features(mmcif, chain_id)
alignments = self._parse_alignment_output(alignment_dir)
input_sequence = mmcif.chain_to_seqres[chain_id]
templates_result = self.template_featurizer.get_templates(
query_sequence=input_sequence,
query_sequence=protein_object.aatype,
query_pdb_code=None,
query_release_date=to_date(mmcif.header["release_date"]),
query_release_date=None,
hits=alignments["hhsearch_hits"],
)
......@@ -438,4 +435,4 @@ class DataPipeline:
),
)
return {**mmcif_feats, **templates_result.features, **msa_features}
return {**pdb_feats, **templates_result.features, **msa_features}
......@@ -77,14 +77,6 @@ def curry1(f):
return fc
@curry1
def add_distillation_flag(protein, distillation):
protein["is_distillation"] = torch.tensor(
float(distillation), dtype=torch.float32
)
return protein
def make_all_atom_aatype(protein):
protein["all_atom_aatype"] = protein["aatype"]
return protein
......@@ -176,7 +168,6 @@ def randomly_replace_msa_with_unknown(protein, replace_proportion):
)
return protein
@curry1
def sample_msa(protein, max_seq, keep_extra):
"""Sample MSA randomly, remaining sequences are stored are stored as `extra_*`."""
......@@ -198,6 +189,13 @@ def sample_msa(protein, max_seq, keep_extra):
return protein
@curry1
def sample_msa_distillation(protein, max_seq):
if(protein["is_distillation"] == 1):
protein = sample_msa(protein, max_seq, keep_extra=False)
return protein
@curry1
def crop_extra_msa(protein, max_extra_msa):
num_seq = protein["extra_msa"].shape[0]
......
......@@ -25,7 +25,6 @@ def nonensembled_transform_fns(common_cfg, mode_cfg):
transforms = [
data_transforms.cast_to_64bit_ints,
data_transforms.correct_msa_restypes,
data_transforms.add_distillation_flag(False),
data_transforms.squeeze_features,
data_transforms.randomly_replace_msa_with_unknown(0.0),
data_transforms.make_seq_mask,
......@@ -72,6 +71,13 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
"""Input pipeline data transformers that can be ensembled and averaged."""
transforms = []
if "max_distillation_msa_clusters" in mode_cfg:
transforms.append(
data_transforms.sample_msa_distillation(
mode_cfg.max_distillation_msa_clusters
)
)
if common_cfg.reduce_msa_clusters_by_max_templates:
pad_msa_clusters = mode_cfg.max_msa_clusters - mode_cfg.max_templates
else:
......
......@@ -2,7 +2,7 @@ import argparse
import logging
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "6"
#os.environ["CUDA_VISIBLE_DEVICES"] = "6"
#os.environ["MASTER_ADDR"]="10.119.81.14"
#os.environ["MASTER_PORT"]="42069"
#os.environ["NODE_RANK"]="0"
......@@ -223,7 +223,7 @@ if __name__ == "__main__":
help="Path to DeepSpeed config. If not provided, DeepSpeed is disabled"
)
parser.add_argument(
"--checkpoint_best_val", type=int, default=True,
"--checkpoint_best_val", type=bool, default=True,
help="""Whether to save the model parameters that perform best during
validation"""
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment