Commit 6298a3e6 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix PDB parsing, add distillation MSA cropping

parent 4d40ce80
...@@ -179,6 +179,7 @@ config = mlc.ConfigDict( ...@@ -179,6 +179,7 @@ config = mlc.ConfigDict(
"all_atom_positions", "all_atom_positions",
"resolution", "resolution",
"use_clamped_fape", "use_clamped_fape",
"is_distillation",
], ],
}, },
"predict": { "predict": {
...@@ -192,6 +193,7 @@ config = mlc.ConfigDict( ...@@ -192,6 +193,7 @@ config = mlc.ConfigDict(
"crop": False, "crop": False,
"crop_size": None, "crop_size": None,
"supervised": False, "supervised": False,
"subsample_recycling": False,
}, },
"eval": { "eval": {
"fixed_size": True, "fixed_size": True,
...@@ -204,6 +206,7 @@ config = mlc.ConfigDict( ...@@ -204,6 +206,7 @@ config = mlc.ConfigDict(
"crop": False, "crop": False,
"crop_size": None, "crop_size": None,
"supervised": True, "supervised": True,
"subsample_recycling": False,
}, },
"train": { "train": {
"fixed_size": True, "fixed_size": True,
...@@ -218,6 +221,7 @@ config = mlc.ConfigDict( ...@@ -218,6 +221,7 @@ config = mlc.ConfigDict(
"supervised": True, "supervised": True,
"clamp_prob": 0.9, "clamp_prob": 0.9,
"subsample_recycling": True, "subsample_recycling": True,
"max_distillation_msa_clusters": 1000,
}, },
"data_module": { "data_module": {
"use_small_bfd": False, "use_small_bfd": False,
......
import copy
from functools import partial from functools import partial
import json import json
import logging import logging
...@@ -462,9 +463,9 @@ class DummyDataset(torch.utils.data.Dataset): ...@@ -462,9 +463,9 @@ class DummyDataset(torch.utils.data.Dataset):
class DummyDataLoader(pl.LightningDataModule): class DummyDataLoader(pl.LightningDataModule):
def __init__(self): def __init__(self, batch_path):
super().__init__() super().__init__()
self.dataset = Dataset() self.dataset = DummyDataset(batch_path)
def train_dataloader(self): def train_dataloader(self):
return torch.utils.data.DataLoader(self.dataset) return torch.utils.data.DataLoader(self.dataset)
...@@ -113,6 +113,7 @@ def make_pdb_features( ...@@ -113,6 +113,7 @@ def make_pdb_features(
pdb_feats["all_atom_positions"] = all_atom_positions pdb_feats["all_atom_positions"] = all_atom_positions
pdb_feats["all_atom_mask"] = all_atom_mask pdb_feats["all_atom_mask"] = all_atom_mask
pdb_feats["resolution"] = np.array([0.]).astype(np.float32)
pdb_feats["is_distillation"] = np.array(1.).astype(np.float32) pdb_feats["is_distillation"] = np.array(1.).astype(np.float32)
return pdb_feats return pdb_feats
...@@ -412,16 +413,12 @@ class DataPipeline: ...@@ -412,16 +413,12 @@ class DataPipeline:
pdb_feats = make_pdb_features(protein_object) pdb_feats = make_pdb_features(protein_object)
mmcif_feats = make_mmcif_features(mmcif, chain_id)
alignments = self._parse_alignment_output(alignment_dir) alignments = self._parse_alignment_output(alignment_dir)
input_sequence = mmcif.chain_to_seqres[chain_id]
templates_result = self.template_featurizer.get_templates( templates_result = self.template_featurizer.get_templates(
query_sequence=input_sequence, query_sequence=protein_object.aatype,
query_pdb_code=None, query_pdb_code=None,
query_release_date=to_date(mmcif.header["release_date"]), query_release_date=None,
hits=alignments["hhsearch_hits"], hits=alignments["hhsearch_hits"],
) )
...@@ -438,4 +435,4 @@ class DataPipeline: ...@@ -438,4 +435,4 @@ class DataPipeline:
), ),
) )
return {**mmcif_feats, **templates_result.features, **msa_features} return {**pdb_feats, **templates_result.features, **msa_features}
...@@ -77,14 +77,6 @@ def curry1(f): ...@@ -77,14 +77,6 @@ def curry1(f):
return fc return fc
@curry1
def add_distillation_flag(protein, distillation):
protein["is_distillation"] = torch.tensor(
float(distillation), dtype=torch.float32
)
return protein
def make_all_atom_aatype(protein): def make_all_atom_aatype(protein):
protein["all_atom_aatype"] = protein["aatype"] protein["all_atom_aatype"] = protein["aatype"]
return protein return protein
...@@ -176,7 +168,6 @@ def randomly_replace_msa_with_unknown(protein, replace_proportion): ...@@ -176,7 +168,6 @@ def randomly_replace_msa_with_unknown(protein, replace_proportion):
) )
return protein return protein
@curry1 @curry1
def sample_msa(protein, max_seq, keep_extra): def sample_msa(protein, max_seq, keep_extra):
"""Sample MSA randomly, remaining sequences are stored are stored as `extra_*`.""" """Sample MSA randomly, remaining sequences are stored are stored as `extra_*`."""
...@@ -198,6 +189,13 @@ def sample_msa(protein, max_seq, keep_extra): ...@@ -198,6 +189,13 @@ def sample_msa(protein, max_seq, keep_extra):
return protein return protein
@curry1
def sample_msa_distillation(protein, max_seq):
if(protein["is_distillation"] == 1):
protein = sample_msa(protein, max_seq, keep_extra=False)
return protein
@curry1 @curry1
def crop_extra_msa(protein, max_extra_msa): def crop_extra_msa(protein, max_extra_msa):
num_seq = protein["extra_msa"].shape[0] num_seq = protein["extra_msa"].shape[0]
......
...@@ -25,7 +25,6 @@ def nonensembled_transform_fns(common_cfg, mode_cfg): ...@@ -25,7 +25,6 @@ def nonensembled_transform_fns(common_cfg, mode_cfg):
transforms = [ transforms = [
data_transforms.cast_to_64bit_ints, data_transforms.cast_to_64bit_ints,
data_transforms.correct_msa_restypes, data_transforms.correct_msa_restypes,
data_transforms.add_distillation_flag(False),
data_transforms.squeeze_features, data_transforms.squeeze_features,
data_transforms.randomly_replace_msa_with_unknown(0.0), data_transforms.randomly_replace_msa_with_unknown(0.0),
data_transforms.make_seq_mask, data_transforms.make_seq_mask,
...@@ -72,6 +71,13 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed): ...@@ -72,6 +71,13 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
"""Input pipeline data transformers that can be ensembled and averaged.""" """Input pipeline data transformers that can be ensembled and averaged."""
transforms = [] transforms = []
if "max_distillation_msa_clusters" in mode_cfg:
transforms.append(
data_transforms.sample_msa_distillation(
mode_cfg.max_distillation_msa_clusters
)
)
if common_cfg.reduce_msa_clusters_by_max_templates: if common_cfg.reduce_msa_clusters_by_max_templates:
pad_msa_clusters = mode_cfg.max_msa_clusters - mode_cfg.max_templates pad_msa_clusters = mode_cfg.max_msa_clusters - mode_cfg.max_templates
else: else:
......
...@@ -2,7 +2,7 @@ import argparse ...@@ -2,7 +2,7 @@ import argparse
import logging import logging
import os import os
os.environ["CUDA_VISIBLE_DEVICES"] = "6" #os.environ["CUDA_VISIBLE_DEVICES"] = "6"
#os.environ["MASTER_ADDR"]="10.119.81.14" #os.environ["MASTER_ADDR"]="10.119.81.14"
#os.environ["MASTER_PORT"]="42069" #os.environ["MASTER_PORT"]="42069"
#os.environ["NODE_RANK"]="0" #os.environ["NODE_RANK"]="0"
...@@ -223,7 +223,7 @@ if __name__ == "__main__": ...@@ -223,7 +223,7 @@ if __name__ == "__main__":
help="Path to DeepSpeed config. If not provided, DeepSpeed is disabled" help="Path to DeepSpeed config. If not provided, DeepSpeed is disabled"
) )
parser.add_argument( parser.add_argument(
"--checkpoint_best_val", type=int, default=True, "--checkpoint_best_val", type=bool, default=True,
help="""Whether to save the model parameters that perform best during help="""Whether to save the model parameters that perform best during
validation""" validation"""
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment