Commit 6ce8cfe3 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fixes

parent 1df4991d
...@@ -4,7 +4,7 @@ import json ...@@ -4,7 +4,7 @@ import json
import logging import logging
import os import os
import pickle import pickle
from typing import Optional, Sequence from typing import Optional, Sequence, List, Any
import ml_collections as mlc import ml_collections as mlc
import numpy as np import numpy as np
...@@ -29,14 +29,15 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -29,14 +29,15 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
max_template_date: str, max_template_date: str,
config: mlc.ConfigDict, config: mlc.ConfigDict,
kalign_binary_path: str = '/usr/bin/kalign', kalign_binary_path: str = '/usr/bin/kalign',
mapping_path: Optional[str] = None,
max_template_hits: int = 4, max_template_hits: int = 4,
obsolete_pdbs_file_path: Optional[str] = None, obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None, template_release_dates_cache_path: Optional[str] = None,
shuffle_top_k_prefiltered: Optional[int] = None, shuffle_top_k_prefiltered: Optional[int] = None,
treat_pdb_as_distillation: bool = True, treat_pdb_as_distillation: bool = True,
mapping_path: Optional[str] = None,
mode: str = "train", mode: str = "train",
_output_raw: bool = False, _output_raw: bool = False,
_alignment_index: Optional[Any] = None
): ):
""" """
Args: Args:
...@@ -56,12 +57,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -56,12 +57,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
A dataset config object. See openfold.config A dataset config object. See openfold.config
kalign_binary_path: kalign_binary_path:
Path to kalign binary. Path to kalign binary.
mapping_path:
A json file containing a mapping from consecutive numerical
ids to sample names (matching the directories in data_dir).
Samples not in this mapping are ignored. Can be used to
implement the various training-time filters described in
the AlphaFold supplement.
max_template_hits: max_template_hits:
An upper bound on how many templates are considered. During An upper bound on how many templates are considered. During
training, the templates ultimately used are subsampled training, the templates ultimately used are subsampled
...@@ -89,26 +84,28 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -89,26 +84,28 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
self.treat_pdb_as_distillation = treat_pdb_as_distillation self.treat_pdb_as_distillation = treat_pdb_as_distillation
self.mode = mode self.mode = mode
self._output_raw = _output_raw self._output_raw = _output_raw
self._alignment_index = _alignment_index
valid_modes = ["train", "eval", "predict"] valid_modes = ["train", "eval", "predict"]
if(mode not in valid_modes): if(mode not in valid_modes):
raise ValueError(f'mode must be one of {valid_modes}') raise ValueError(f'mode must be one of {valid_modes}')
if(mapping_path is None):
self.mapping = {
str(i):os.path.splitext(name)[0]
for i, name in enumerate(os.listdir(alignment_dir))
}
else:
with open(mapping_path, 'r') as fp:
self.mapping = json.load(fp)
if(template_release_dates_cache_path is None): if(template_release_dates_cache_path is None):
logging.warning( logging.warning(
"Template release dates cache does not exist. Remember to run " "Template release dates cache does not exist. Remember to run "
"scripts/generate_mmcif_cache.py before running OpenFold" "scripts/generate_mmcif_cache.py before running OpenFold"
) )
if(mapping_path is None):
self._chain_ids = list(os.listdir(alignment_dir))
else:
with open(mapping_path, "r") as f:
self._chain_ids = [l.strip() for l in f.readlines()]
self._chain_id_to_idx_dict = {
chain: i for i, chain in enumerate(self._chain_ids)
}
template_featurizer = templates.TemplateHitFeaturizer( template_featurizer = templates.TemplateHitFeaturizer(
mmcif_dir=template_mmcif_dir, mmcif_dir=template_mmcif_dir,
max_template_date=max_template_date, max_template_date=max_template_date,
...@@ -126,7 +123,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -126,7 +123,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
if(not self._output_raw): if(not self._output_raw):
self.feature_pipeline = feature_pipeline.FeaturePipeline(config) self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def _parse_mmcif(self, path, file_id, chain_id, alignment_dir): def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, _alignment_index):
with open(path, 'r') as f: with open(path, 'r') as f:
mmcif_string = f.read() mmcif_string = f.read()
...@@ -145,14 +142,25 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -145,14 +142,25 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
mmcif=mmcif_object, mmcif=mmcif_object,
alignment_dir=alignment_dir, alignment_dir=alignment_dir,
chain_id=chain_id, chain_id=chain_id,
_alignment_index=_alignment_index
) )
return data return data
def chain_id_to_idx(self, chain_id):
return self._chain_id_to_idx_dict[chain_id]
def idx_to_chain_id(self, idx):
return self._chain_ids[idx]
def __getitem__(self, idx): def __getitem__(self, idx):
name = self.mapping[str(idx)] name = self.idx_to_chain_id(idx)
alignment_dir = os.path.join(self.alignment_dir, name) alignment_dir = os.path.join(self.alignment_dir, name)
_alignment_index = None
if(self._alignment_index is not None):
_alignment_index = self._alignment_index[name]
if(self.mode == 'train' or self.mode == 'eval'): if(self.mode == 'train' or self.mode == 'eval'):
spl = name.rsplit('_', 1) spl = name.rsplit('_', 1)
if(len(spl) == 2): if(len(spl) == 2):
...@@ -164,11 +172,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -164,11 +172,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
path = os.path.join(self.data_dir, file_id) path = os.path.join(self.data_dir, file_id)
if(os.path.exists(path + ".cif")): if(os.path.exists(path + ".cif")):
data = self._parse_mmcif( data = self._parse_mmcif(
path + ".cif", file_id, chain_id, alignment_dir path + ".cif", file_id, chain_id, alignment_dir, _alignment_index,
) )
elif(os.path.exists(path + ".core")): elif(os.path.exists(path + ".core")):
data = self.data_pipeline.process_core( data = self.data_pipeline.process_core(
path + ".core", alignment_dir path + ".core", alignment_dir, _alignment_index,
) )
elif(os.path.exists(path + ".pdb")): elif(os.path.exists(path + ".pdb")):
data = self.data_pipeline.process_pdb( data = self.data_pipeline.process_pdb(
...@@ -176,6 +184,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -176,6 +184,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
alignment_dir=alignment_dir, alignment_dir=alignment_dir,
is_distillation=self.treat_pdb_as_distillation, is_distillation=self.treat_pdb_as_distillation,
chain_id=chain_id, chain_id=chain_id,
_alignment_index=_alignment_index,
) )
else: else:
raise ValueError("Invalid file type") raise ValueError("Invalid file type")
...@@ -184,6 +193,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -184,6 +193,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
data = self.data_pipeline.process_fasta( data = self.data_pipeline.process_fasta(
fasta_path=path, fasta_path=path,
alignment_dir=alignment_dir, alignment_dir=alignment_dir,
_alignment_index=_alignment_index,
) )
if(self._output_raw): if(self._output_raw):
...@@ -196,56 +206,130 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -196,56 +206,130 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
return feats return feats
def __len__(self): def __len__(self):
return len(self.mapping.keys()) return len(self._chain_ids)
def train_filter(
prot_data_cache_entry: Any,
generator: torch.Generator,
max_resolution: float = 9.,
max_single_aa_prop: float = 0.8,
) -> bool:
# Hard filters
resolution = prot_data_cache_entry.get("resolution", None)
if(resolution is not None and resolution > max_resolution):
return False
seq = prot_data_cache_entry["seq"]
counts = {}
for aa in seq:
counts.setdefault(aa, 0)
counts[aa] += 1
largest_aa_count = max(counts.values())
largest_single_aa_prop = largest_aa_count / len(seq)
if(largest_single_aa_prop > max_single_aa_prop):
return False
# Stochastic filters
probabilities = []
cluster_size = prot_data_cache_entry.get("cluster_size", None)
if(cluster_size is not None and cluster_size > 0):
probabilities.append(1 / cluster_size)
chain_length = len(prot_data_cache_entry["seq"])
probabilities.append((1 / 512) * (max(min(chain_length, 512), 256)))
weights = [[1 - p, p] for p in probabilities]
results = torch.multinomial(
torch.tensor(weights),
num_samples=1,
generator=generator,
)
def looped_sequence(sequence): return torch.all(results)
while True:
for x in sequence:
yield x
class OpenFoldDataset(torch.utils.data.IterableDataset): class OpenFoldDataset(torch.utils.data.Dataset):
""" """
The Dataset is written to accommodate the requirement that proteins are Implements the stochastic filters applied during AlphaFold's training.
sampled from the distillation set with some probability p Because samples are selected from constituent datasets randomly, the
and from the PDB set with probability (1 - p). Proteins are sampled length of an OpenFoldFilteredDataset is arbitrary. Samples are selected
from both sets without replacement, and as soon as either set is and filtered once at initialization.
emptied, it is refilled. The Dataset therefore has an arbitrary length.
Nevertheless, for compatibility with various PyTorch Lightning
functionalities, it is possible to specify an epoch length. This length
has no effect on the output of the Dataset.
""" """
def __init__(self, def __init__(self,
datasets: Sequence[OpenFoldSingleDataset], datasets: Sequence[OpenFoldSingleDataset],
probabilities: Sequence[int], probabilities: Sequence[int],
epoch_len: int, epoch_len: int,
prot_data_cache_paths: List[str],
filter_fn: Optional[Any] = train_filter,
generator: torch.Generator = None,
_roll_at_init: bool = True,
): ):
self.datasets = datasets self.datasets = datasets
self.samplers = [ self.probabilities = probabilities
looped_sequence(RandomSampler(d)) for d in datasets
]
self.epoch_len = epoch_len self.epoch_len = epoch_len
self.generator = generator
self.filter_fn = filter_fn
def looped_shuffled_dataset_idx(dataset_len):
while True:
# Uniformly shuffle each dataset's indices
weights = [1. for _ in range(dataset_len)]
shuf = torch.multinomial(
torch.tensor(weights),
num_samples=dataset_len,
replacement=False,
generator=self.generator,
)
for idx in shuf:
yield idx
self.distr = torch.distributions.categorical.Categorical( self.shuffled_idx_iters = []
probs=torch.tensor(probabilities), for d in datasets:
) self.shuffled_idx_iters.append(
looped_shuffled_dataset_idx(len(d))
)
def __iter__(self): self.prot_data_caches = []
return self for path in prot_data_cache_paths:
with open(path, "r") as fp:
self.prot_data_caches.append(json.load(fp))
def __next__(self): if(_roll_at_init):
dataset_idx = self.distr.sample() self.reroll()
sampler = self.samplers[dataset_idx]
element_idx = next(sampler) def __getitem__(self, idx):
return self.datasets[dataset_idx][element_idx] dataset_idx, datapoint_idx = self.datapoints[idx]
return self.datasets[dataset_idx][datapoint_idx]
def __len__(self): def __len__(self):
return self.epoch_len return self.epoch_len
def reroll(self):
dataset_choices = torch.multinomial(
torch.tensor(self.probabilities),
num_samples=self.epoch_len,
replacement=True,
generator=self.generator,
)
self.datapoints = []
for dataset_idx in dataset_choices:
dataset = self.datasets[dataset_idx]
idx_iter = self.shuffled_idx_iters[dataset_idx]
prot_data_cache = self.prot_data_caches[dataset_idx]
datapoint_idx = None
while datapoint_idx is None:
candidate_idx = next(idx_iter)
chain_id = dataset.idx_to_chain_id(candidate_idx)
if(self.filter_fn(prot_data_cache[chain_id], self.generator)):
datapoint_idx = candidate_idx
self.datapoints.append((dataset_idx, datapoint_idx))
class OpenFoldBatchCollator: class OpenFoldBatchCollator:
def __init__(self, config, generator, stage="train"): def __init__(self, config, stage="train"):
self.stage = stage self.stage = stage
self.feature_pipeline = feature_pipeline.FeaturePipeline(config) self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
...@@ -283,21 +367,20 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader): ...@@ -283,21 +367,20 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
keyed_probs.append( keyed_probs.append(
("use_clamped_fape", [1 - clamp_prob, clamp_prob]) ("use_clamped_fape", [1 - clamp_prob, clamp_prob])
) )
if(self.config.supervised.uniform_recycling):
recycling_probs = [ if(stage_cfg.uniform_recycling):
1. / (max_iters + 1) for _ in range(max_iters + 1) recycling_probs = [
] 1. / (max_iters + 1) for _ in range(max_iters + 1)
keyed_probs.append( ]
("no_recycling_iters", recycling_probs)
)
else: else:
recycling_probs = [ recycling_probs = [
0. for _ in range(max_iters + 1) 0. for _ in range(max_iters + 1)
] ]
recycling_probs[-1] = 1. recycling_probs[-1] = 1.
keyed_probs.append(
("no_recycling_iters", recycling_probs) keyed_probs.append(
) ("no_recycling_iters", recycling_probs)
)
keys, probs = zip(*keyed_probs) keys, probs = zip(*keyed_probs)
max_len = max([len(p) for p in probs]) max_len = max([len(p) for p in probs])
...@@ -362,8 +445,11 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -362,8 +445,11 @@ class OpenFoldDataModule(pl.LightningDataModule):
max_template_date: str, max_template_date: str,
train_data_dir: Optional[str] = None, train_data_dir: Optional[str] = None,
train_alignment_dir: Optional[str] = None, train_alignment_dir: Optional[str] = None,
train_filter_fn: Optional[Any] = train_filter,
train_prot_data_cache_path: Optional[str] = None,
distillation_data_dir: Optional[str] = None, distillation_data_dir: Optional[str] = None,
distillation_alignment_dir: Optional[str] = None, distillation_alignment_dir: Optional[str] = None,
distillation_prot_data_cache_path: Optional[str] = None,
val_data_dir: Optional[str] = None, val_data_dir: Optional[str] = None,
val_alignment_dir: Optional[str] = None, val_alignment_dir: Optional[str] = None,
predict_data_dir: Optional[str] = None, predict_data_dir: Optional[str] = None,
...@@ -374,6 +460,8 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -374,6 +460,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
obsolete_pdbs_file_path: Optional[str] = None, obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None, template_release_dates_cache_path: Optional[str] = None,
batch_seed: Optional[int] = None, batch_seed: Optional[int] = None,
train_epoch_len: int = 50000,
_alignment_index_path: Optional[str] = None,
**kwargs **kwargs
): ):
super(OpenFoldDataModule, self).__init__() super(OpenFoldDataModule, self).__init__()
...@@ -383,8 +471,13 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -383,8 +471,13 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.max_template_date = max_template_date self.max_template_date = max_template_date
self.train_data_dir = train_data_dir self.train_data_dir = train_data_dir
self.train_alignment_dir = train_alignment_dir self.train_alignment_dir = train_alignment_dir
self.train_filter_fn = train_filter_fn
self.train_prot_data_cache_path = train_prot_data_cache_path
self.distillation_data_dir = distillation_data_dir self.distillation_data_dir = distillation_data_dir
self.distillation_alignment_dir = distillation_alignment_dir self.distillation_alignment_dir = distillation_alignment_dir
self.distillation_prot_data_cache_path = (
distillation_prot_data_cache_path
)
self.val_data_dir = val_data_dir self.val_data_dir = val_data_dir
self.val_alignment_dir = val_alignment_dir self.val_alignment_dir = val_alignment_dir
self.predict_data_dir = predict_data_dir self.predict_data_dir = predict_data_dir
...@@ -397,6 +490,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -397,6 +490,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
) )
self.obsolete_pdbs_file_path = obsolete_pdbs_file_path self.obsolete_pdbs_file_path = obsolete_pdbs_file_path
self.batch_seed = batch_seed self.batch_seed = batch_seed
self.train_epoch_len = train_epoch_len
if(self.train_data_dir is None and self.predict_data_dir is None): if(self.train_data_dir is None and self.predict_data_dir is None):
raise ValueError( raise ValueError(
...@@ -406,11 +500,11 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -406,11 +500,11 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.training_mode = self.train_data_dir is not None self.training_mode = self.train_data_dir is not None
if(self.training_mode and self.train_alignment_dir is None): if(self.training_mode and train_alignment_dir is None):
raise ValueError( raise ValueError(
'In training mode, train_alignment_dir must be specified' 'In training mode, train_alignment_dir must be specified'
) )
elif(not self.training_mode and self.predict_alingment_dir is None): elif(not self.training_mode and predict_alignment_dir is None):
raise ValueError( raise ValueError(
'In inference mode, predict_alignment_dir must be specified' 'In inference mode, predict_alignment_dir must be specified'
) )
...@@ -420,10 +514,28 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -420,10 +514,28 @@ class OpenFoldDataModule(pl.LightningDataModule):
'be specified as well' 'be specified as well'
) )
def setup(self, stage: Optional[str] = None): cache_missing = (
if(stage is None): train_filter_fn and
stage = "train" (
train_prot_data_cache_path is None or
(
distillation_data_dir is not None and
distillation_prot_data_cache_path is None
)
)
)
if(cache_missing):
raise ValueError(
"If train_filter_fn is given, so must the protein data caches"
)
# An ad-hoc measure for our particular filesystem restrictions
self._alignment_index = None
if(_alignment_index_path is not None):
with open(_alignment_index_path, "r") as fp:
self._alignment_index = json.load(fp)
def setup(self):
# Most of the arguments are the same for the three datasets # Most of the arguments are the same for the three datasets
dataset_gen = partial(OpenFoldSingleDataset, dataset_gen = partial(OpenFoldSingleDataset,
template_mmcif_dir=self.template_mmcif_dir, template_mmcif_dir=self.template_mmcif_dir,
...@@ -434,10 +546,11 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -434,10 +546,11 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.template_release_dates_cache_path, self.template_release_dates_cache_path,
obsolete_pdbs_file_path= obsolete_pdbs_file_path=
self.obsolete_pdbs_file_path, self.obsolete_pdbs_file_path,
_alignment_index=self._alignment_index,
) )
if(self.training_mode): if(self.training_mode):
self.train_dataset = dataset_gen( train_dataset = dataset_gen(
data_dir=self.train_data_dir, data_dir=self.train_data_dir,
alignment_dir=self.train_alignment_dir, alignment_dir=self.train_alignment_dir,
mapping_path=self.train_mapping_path, mapping_path=self.train_mapping_path,
...@@ -449,6 +562,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -449,6 +562,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
_output_raw=True, _output_raw=True,
) )
distillation_dataset = None
if(self.distillation_data_dir is not None): if(self.distillation_data_dir is not None):
distillation_dataset = dataset_gen( distillation_dataset = dataset_gen(
data_dir=self.distillation_data_dir, data_dir=self.distillation_data_dir,
...@@ -461,13 +575,30 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -461,13 +575,30 @@ class OpenFoldDataModule(pl.LightningDataModule):
) )
d_prob = self.config.train.distillation_prob d_prob = self.config.train.distillation_prob
self.train_dataset = OpenFoldDataset(
datasets=[self.train_dataset, distillation_dataset], if(distillation_dataset is not None):
probabilities=[1 - d_prob, d_prob], datasets = [train_dataset, distillation_dataset]
epoch_len=( d_prob = self.config.train.distillation_prob
self.train_dataset.len() + distillation_dataset.len() probabilities = [1 - d_prob, d_prob]
), prot_data_cache_paths = [
) self.train_prot_data_cache_path,
self.distillation_prot_data_cache_path,
]
else:
datasets = [train_dataset]
probabilities = [1.]
prot_data_cache_paths = [
self.train_prot_data_cache_path,
]
self.train_dataset = OpenFoldDataset(
datasets=datasets,
probabilities=probabilities,
epoch_len=self.train_epoch_len,
prot_data_cache_paths=prot_data_cache_paths,
filter_fn=self.train_filter_fn,
_roll_at_init=False,
)
if(self.val_data_dir is not None): if(self.val_data_dir is not None):
self.eval_dataset = dataset_gen( self.eval_dataset = dataset_gen(
...@@ -497,6 +628,9 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -497,6 +628,9 @@ class OpenFoldDataModule(pl.LightningDataModule):
dataset = None dataset = None
if(stage == "train"): if(stage == "train"):
dataset = self.train_dataset dataset = self.train_dataset
# Filter the dataset, if necessary
dataset.reroll()
elif(stage == "eval"): elif(stage == "eval"):
dataset = self.eval_dataset dataset = self.eval_dataset
elif(stage == "predict"): elif(stage == "predict"):
......
...@@ -422,42 +422,86 @@ class DataPipeline: ...@@ -422,42 +422,86 @@ class DataPipeline:
def _parse_msa_data( def _parse_msa_data(
self, self,
alignment_dir: str, alignment_dir: str,
_alignment_index: Optional[Any] = None,
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
msa_data = {} msa_data = {}
for f in os.listdir(alignment_dir): if(_alignment_index is not None):
path = os.path.join(alignment_dir, f) fp = open(_alignment_index["db"], "rb")
ext = os.path.splitext(f)[-1]
def read_msa(start, size):
if(ext == ".a3m"): fp.seek(start)
with open(path, "r") as fp: msa = fp.read(size).encode("utf-8")
msa, deletion_matrix = parsers.parse_a3m(fp.read()) return msa
data = {"msa": msa, "deletion_matrix": deletion_matrix}
elif(ext == ".sto"): for (name, start, size) in _alignment_index["files"]:
with open(path, "r") as fp: ext = os.path.splitext(name)[-1]
if(ext == ".a3m"):
msa, deletion_matrix = parsers.parse_a3m(
read_msa(start, size)
)
data = {"msa": msa, "deletion_matrix": deletion_matrix}
elif(ext == ".sto"):
msa, deletion_matrix, _ = parsers.parse_stockholm( msa, deletion_matrix, _ = parsers.parse_stockholm(
fp.read() read_msa(start, size)
) )
data = {"msa": msa, "deletion_matrix": deletion_matrix} data = {"msa": msa, "deletion_matrix": deletion_matrix}
else: else:
continue continue
msa_data[f] = data
msa_data[f] = data fp.close()
else:
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]
if(ext == ".a3m"):
with open(path, "r") as fp:
msa, deletion_matrix = parsers.parse_a3m(fp.read())
data = {"msa": msa, "deletion_matrix": deletion_matrix}
elif(ext == ".sto"):
with open(path, "r") as fp:
msa, deletion_matrix, _ = parsers.parse_stockholm(
fp.read()
)
data = {"msa": msa, "deletion_matrix": deletion_matrix}
else:
continue
msa_data[f] = data
return msa_data return msa_data
def _parse_template_hits( def _parse_template_hits(
self, self,
alignment_dir: str, alignment_dir: str,
_alignment_index: Optional[Any] = None
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
all_hits = {} all_hits = {}
for f in os.listdir(alignment_dir): if(_alignment_index is not None):
path = os.path.join(alignment_dir, f) fp = open(_alignment_index["db"], 'rb')
ext = os.path.splitext(f)[-1]
def read_template(start, size):
fp.seek(start)
return fp.read(size).encode("utf-8")
for (name, start, size) in _alignment_index["files"]:
ext = os.path.splitext(name)[-1]
if(ext == ".hhr"):
hits = parsers.parse_hhr(read_template(start, size))
all_hits[f] = hits
else:
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]
if(ext == ".hhr"): if(ext == ".hhr"):
with open(path, "r") as fp: with open(path, "r") as fp:
hits = parsers.parse_hhr(fp.read()) hits = parsers.parse_hhr(fp.read())
all_hits[f] = hits all_hits[f] = hits
return all_hits return all_hits
...@@ -465,8 +509,9 @@ class DataPipeline: ...@@ -465,8 +509,9 @@ class DataPipeline:
self, self,
alignment_dir: str, alignment_dir: str,
input_sequence: Optional[str] = None, input_sequence: Optional[str] = None,
_alignment_index: Optional[str] = None
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
msa_data = self._parse_msa_data(alignment_dir) msa_data = self._parse_msa_data(alignment_dir, _alignment_index)
if(len(msa_data) == 0): if(len(msa_data) == 0):
if(input_sequence is None): if(input_sequence is None):
...@@ -496,6 +541,7 @@ class DataPipeline: ...@@ -496,6 +541,7 @@ class DataPipeline:
self, self,
fasta_path: str, fasta_path: str,
alignment_dir: str, alignment_dir: str,
_alignment_index: Optional[str] = None,
) -> FeatureDict: ) -> FeatureDict:
"""Assembles features for a single sequence in a FASTA file""" """Assembles features for a single sequence in a FASTA file"""
with open(fasta_path) as f: with open(fasta_path) as f:
...@@ -509,7 +555,7 @@ class DataPipeline: ...@@ -509,7 +555,7 @@ class DataPipeline:
input_description = input_descs[0] input_description = input_descs[0]
num_res = len(input_sequence) num_res = len(input_sequence)
hits = self._parse_template_hits(alignment_dir) hits = self._parse_template_hits(alignment_dir, _alignment_index)
template_features = make_template_features( template_features = make_template_features(
input_sequence, input_sequence,
hits, hits,
...@@ -535,6 +581,7 @@ class DataPipeline: ...@@ -535,6 +581,7 @@ class DataPipeline:
mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path
alignment_dir: str, alignment_dir: str,
chain_id: Optional[str] = None, chain_id: Optional[str] = None,
_alignment_index: Optional[str] = None,
) -> FeatureDict: ) -> FeatureDict:
""" """
Assembles features for a specific chain in an mmCIF object. Assembles features for a specific chain in an mmCIF object.
...@@ -552,7 +599,7 @@ class DataPipeline: ...@@ -552,7 +599,7 @@ class DataPipeline:
mmcif_feats = make_mmcif_features(mmcif, chain_id) mmcif_feats = make_mmcif_features(mmcif, chain_id)
input_sequence = mmcif.chain_to_seqres[chain_id] input_sequence = mmcif.chain_to_seqres[chain_id]
hits = self._parse_template_hits(alignment_dir) hits = self._parse_template_hits(alignment_dir, _alignment_index)
template_features = make_template_features( template_features = make_template_features(
input_sequence, input_sequence,
hits, hits,
...@@ -570,6 +617,7 @@ class DataPipeline: ...@@ -570,6 +617,7 @@ class DataPipeline:
alignment_dir: str, alignment_dir: str,
is_distillation: bool = True, is_distillation: bool = True,
chain_id: Optional[str] = None, chain_id: Optional[str] = None,
_alignment_index: Optional[str] = None,
) -> FeatureDict: ) -> FeatureDict:
""" """
Assembles features for a protein in a PDB file. Assembles features for a protein in a PDB file.
...@@ -586,7 +634,7 @@ class DataPipeline: ...@@ -586,7 +634,7 @@ class DataPipeline:
is_distillation is_distillation
) )
hits = self._parse_template_hits(alignment_dir) hits = self._parse_template_hits(alignment_dir, _alignment_index)
template_features = make_template_features( template_features = make_template_features(
input_sequence, input_sequence,
hits, hits,
...@@ -601,6 +649,7 @@ class DataPipeline: ...@@ -601,6 +649,7 @@ class DataPipeline:
self, self,
core_path: str, core_path: str,
alignment_dir: str, alignment_dir: str,
_alignment_index: Optional[str] = None,
) -> FeatureDict: ) -> FeatureDict:
""" """
Assembles features for a protein in a ProteinNet .core file. Assembles features for a protein in a ProteinNet .core file.
...@@ -613,7 +662,7 @@ class DataPipeline: ...@@ -613,7 +662,7 @@ class DataPipeline:
description = os.path.splitext(os.path.basename(core_path))[0].upper() description = os.path.splitext(os.path.basename(core_path))[0].upper()
core_feats = make_protein_features(protein_object, description) core_feats = make_protein_features(protein_object, description)
hits = self._parse_template_hits(alignment_dir) hits = self._parse_template_hits(alignment_dir, _alignment_index)
template_features = make_template_features( template_features = make_template_features(
input_sequence, input_sequence,
hits, hits,
......
...@@ -360,7 +360,8 @@ class ExtraMSABlock(nn.Module): ...@@ -360,7 +360,8 @@ class ExtraMSABlock(nn.Module):
mask=msa_mask, mask=msa_mask,
chunk_size=chunk_size, chunk_size=chunk_size,
_chunk_logits=_chunk_logits, _chunk_logits=_chunk_logits,
_checkpoint_chunks=self.ckpt, _checkpoint_chunks=
self.ckpt if torch.is_grad_enabled() else False,
) )
) )
......
...@@ -188,7 +188,7 @@ class LayerNorm(nn.Module): ...@@ -188,7 +188,7 @@ class LayerNorm(nn.Module):
self.bias.to(dtype=d), self.bias.to(dtype=d),
self.eps self.eps
) )
elif(d == torch.bfloat16): else:
out = nn.functional.layer_norm( out = nn.functional.layer_norm(
x, x,
self.c_in, self.c_in,
...@@ -209,7 +209,7 @@ def softmax(t: torch.Tensor, dim: int = -1) -> torch.Tensor: ...@@ -209,7 +209,7 @@ def softmax(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
if(d is torch.bfloat16 and not deepspeed.utils.is_initialized()): if(d is torch.bfloat16 and not deepspeed.utils.is_initialized()):
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
s = torch.nn.functional.softmax(t, dim=dim) s = torch.nn.functional.softmax(t, dim=dim)
elif(d == torch.bfloat16): else:
s = torch.nn.functional.softmax(t, dim=dim) s = torch.nn.functional.softmax(t, dim=dim)
return s return s
......
...@@ -65,8 +65,7 @@ class TriangleAttention(nn.Module): ...@@ -65,8 +65,7 @@ class TriangleAttention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
mha_inputs = { mha_inputs = {
"q_x": x, "q_x": x,
"k_x": x, "kv_x": x,
"v_x": x,
"biases": biases, "biases": biases,
} }
return chunk_layer( return chunk_layer(
......
...@@ -282,13 +282,19 @@ def import_jax_weights_(model, npz_path, version="model_1"): ...@@ -282,13 +282,19 @@ def import_jax_weights_(model, npz_path, version="model_1"):
b.msa_att_row b.msa_att_row
), ),
col_att_name: msa_col_att_params, col_att_name: msa_col_att_params,
"msa_transition": MSATransitionParams(b.msa_transition), "msa_transition": MSATransitionParams(b.core.msa_transition),
"outer_product_mean": OuterProductMeanParams(b.outer_product_mean), "outer_product_mean":
"triangle_multiplication_outgoing": TriMulOutParams(b.tri_mul_out), OuterProductMeanParams(b.core.outer_product_mean),
"triangle_multiplication_incoming": TriMulInParams(b.tri_mul_in), "triangle_multiplication_outgoing":
"triangle_attention_starting_node": TriAttParams(b.tri_att_start), TriMulOutParams(b.core.tri_mul_out),
"triangle_attention_ending_node": TriAttParams(b.tri_att_end), "triangle_multiplication_incoming":
"pair_transition": PairTransitionParams(b.pair_transition), TriMulInParams(b.core.tri_mul_in),
"triangle_attention_starting_node":
TriAttParams(b.core.tri_att_start),
"triangle_attention_ending_node":
TriAttParams(b.core.tri_att_end),
"pair_transition":
PairTransitionParams(b.core.pair_transition),
} }
return d return d
...@@ -323,7 +329,7 @@ def import_jax_weights_(model, npz_path, version="model_1"): ...@@ -323,7 +329,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
[TemplatePairBlockParams(b) for b in tps_blocks] [TemplatePairBlockParams(b) for b in tps_blocks]
) )
ems_blocks = model.extra_msa_stack.stack.blocks ems_blocks = model.extra_msa_stack.blocks
ems_blocks_params = stacked([ExtraMSABlockParams(b) for b in ems_blocks]) ems_blocks_params = stacked([ExtraMSABlockParams(b) for b in ems_blocks])
evo_blocks = model.evoformer.blocks evo_blocks = model.evoformer.blocks
......
...@@ -1520,7 +1520,6 @@ class AlphaFoldLoss(nn.Module): ...@@ -1520,7 +1520,6 @@ class AlphaFoldLoss(nn.Module):
weight = self.config[loss_name].weight weight = self.config[loss_name].weight
if weight: if weight:
loss = loss_fn() loss = loss_fn()
if(torch.isnan(loss) or torch.isinf(loss)): if(torch.isnan(loss) or torch.isinf(loss)):
logging.warning(f"{loss_name} loss is NaN. Skipping...") logging.warning(f"{loss_name} loss is NaN. Skipping...")
loss = loss.new_tensor(0., requires_grad=True) loss = loss.new_tensor(0., requires_grad=True)
......
import argparse
from functools import partial
import logging
from multiprocessing import Pool
import os
import sys
import json
sys.path.append(".") # an innocent hack to get this to run from the top level
from tqdm import tqdm
from openfold.data.mmcif_parsing import parse
def parse_file(f, args):
with open(os.path.join(args.mmcif_dir, f), "r") as fp:
mmcif_string = fp.read()
file_id = os.path.splitext(f)[0]
mmcif = parse(file_id=file_id, mmcif_string=mmcif_string)
if mmcif.mmcif_object is None:
logging.info(f"Could not parse {f}. Skipping...")
return {}
else:
mmcif = mmcif.mmcif_object
local_data = {}
local_data["release_date"] = mmcif.header["release_date"]
local_data["no_chains"] = len(list(mmcif.structure.get_chains()))
return {file_id: local_data}
def main(args):
files = [f for f in os.listdir(args.mmcif_dir) if ".cif" in f]
fn = partial(parse_file, args=args)
data = {}
with Pool(processes=args.no_workers) as p:
with tqdm(total=len(files)) as pbar:
for d in p.imap_unordered(fn, files, chunksize=args.chunksize):
data.update(d)
pbar.update()
with open(args.output_path, "w") as fp:
fp.write(json.dumps(data, indent=4))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"mmcif_dir", type=str, help="Directory containing mmCIF files"
)
parser.add_argument(
"output_path", type=str, help="Path for .json output"
)
parser.add_argument(
"--no_workers", type=int, default=4,
help="Number of workers to use for parsing"
)
parser.add_argument(
"--chunksize", type=int, default=10,
help="How many files should be distributed to each worker at a time"
)
args = parser.parse_args()
main(args)
...@@ -25,11 +25,12 @@ def run_seq_group_alignments(seq_groups, alignment_runner, args): ...@@ -25,11 +25,12 @@ def run_seq_group_alignments(seq_groups, alignment_runner, args):
first_name = names[0] first_name = names[0]
alignment_dir = os.path.join(args.output_dir, first_name) alignment_dir = os.path.join(args.output_dir, first_name)
try: os.makedirs(alignment_dir, exist_ok=True)
os.makedirs(alignment_dir) # try:
except Exception as e: # os.makedirs(alignment_dir)
logging.warning(f"Failed to create directory for {first_name} with exception {e}...") # except Exception as e:
continue # logging.warning(f"Failed to create directory for {first_name} with exception {e}...")
# continue
fd, fasta_path = tempfile.mkstemp(suffix=".fasta") fd, fasta_path = tempfile.mkstemp(suffix=".fasta")
with os.fdopen(fd, 'w') as fp: with os.fdopen(fd, 'w') as fp:
...@@ -48,14 +49,14 @@ def run_seq_group_alignments(seq_groups, alignment_runner, args): ...@@ -48,14 +49,14 @@ def run_seq_group_alignments(seq_groups, alignment_runner, args):
os.remove(fasta_path) os.remove(fasta_path)
for name in names[1:]: for name in names[1:]:
if(name in dirs): #if(name in dirs):
logging.warning( # logging.warning(
f'{name} has already been processed. Skipping...' # f'{name} has already been processed. Skipping...'
) # )
continue # continue
cp_dir = os.path.join(args.output_dir, name) cp_dir = os.path.join(args.output_dir, name)
os.makedirs(cp_dir) os.makedirs(cp_dir, exist_ok=True)
for f in os.listdir(alignment_dir): for f in os.listdir(alignment_dir):
copyfile(os.path.join(alignment_dir, f), os.path.join(cp_dir, f)) copyfile(os.path.join(alignment_dir, f), os.path.join(cp_dir, f))
...@@ -136,23 +137,23 @@ def main(args): ...@@ -136,23 +137,23 @@ def main(args):
else: else:
cache = None cache = None
if(cache is not None and args.filter): dirs = []
dirs = set(os.listdir(args.output_dir)) #if(cache is not None and args.filter):
# dirs = set(os.listdir(args.output_dir))
def prot_is_done(f): # def prot_is_done(f):
prot_id = os.path.splitext(f)[0] # prot_id = os.path.splitext(f)[0]
if(prot_id in cache): # if(prot_id in cache):
chain_ids = cache[prot_id]["chain_ids"] # chain_ids = cache[prot_id]["chain_ids"]
for c in chain_ids: # for c in chain_ids:
full_name = prot_id + "_" + c # full_name = prot_id + "_" + c
if(not full_name in dirs): # if(not full_name in dirs):
return False # return False
else: # else:
return False # return False
return True # return True
files = [f for f in files if not prot_is_done(f)] # files = [f for f in files if not prot_is_done(f)]
def split_up_arglist(arglist): def split_up_arglist(arglist):
# Split up the survivors # Split up the survivors
......
...@@ -4,19 +4,19 @@ from datetime import date ...@@ -4,19 +4,19 @@ from datetime import date
def add_data_args(parser: argparse.ArgumentParser): def add_data_args(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
'uniref90_database_path', type=str, '--uniref90_database_path', type=str, default=None,
) )
parser.add_argument( parser.add_argument(
'mgnify_database_path', type=str, '--mgnify_database_path', type=str, default=None,
) )
parser.add_argument( parser.add_argument(
'pdb70_database_path', type=str, '--pdb70_database_path', type=str, default=None,
) )
parser.add_argument( parser.add_argument(
'template_mmcif_dir', type=str, '--template_mmcif_dir', type=str, default=None,
) )
parser.add_argument( parser.add_argument(
'uniclust30_database_path', type=str, '--uniclust30_database_path', type=str, default=None,
) )
parser.add_argument( parser.add_argument(
'--bfd_database_path', type=str, default=None, '--bfd_database_path', type=str, default=None,
......
...@@ -135,8 +135,8 @@ class TestEvoformerStack(unittest.TestCase): ...@@ -135,8 +135,8 @@ class TestEvoformerStack(unittest.TestCase):
out_repro_msa = out_repro_msa.cpu() out_repro_msa = out_repro_msa.cpu()
out_repro_pair = out_repro_pair.cpu() out_repro_pair = out_repro_pair.cpu()
assert torch.max(torch.abs(out_repro_msa - out_gt_msa) < consts.eps) assert(torch.max(torch.abs(out_repro_msa - out_gt_msa)) < consts.eps)
assert torch.max(torch.abs(out_repro_pair - out_gt_pair) < consts.eps) assert(torch.max(torch.abs(out_repro_pair - out_gt_pair)) < consts.eps)
class TestExtraMSAStack(unittest.TestCase): class TestExtraMSAStack(unittest.TestCase):
...@@ -172,7 +172,7 @@ class TestExtraMSAStack(unittest.TestCase): ...@@ -172,7 +172,7 @@ class TestExtraMSAStack(unittest.TestCase):
transition_n, transition_n,
msa_dropout, msa_dropout,
pair_stack_dropout, pair_stack_dropout,
blocks_per_ckpt=None, ckpt=False,
inf=inf, inf=inf,
eps=eps, eps=eps,
).eval() ).eval()
...@@ -257,16 +257,19 @@ class TestMSATransition(unittest.TestCase): ...@@ -257,16 +257,19 @@ class TestMSATransition(unittest.TestCase):
out_gt = torch.as_tensor(np.array(out_gt)) out_gt = torch.as_tensor(np.array(out_gt))
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
out_repro = ( out_repro = (
model.evoformer.blocks[0] model.evoformer.blocks[0].core.msa_transition(
.msa_transition(
torch.as_tensor(msa_act, dtype=torch.float32).cuda(), torch.as_tensor(msa_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(msa_mask, dtype=torch.float32).cuda(), mask=torch.as_tensor(msa_mask, dtype=torch.float32).cuda(),
) )
.cpu() .cpu()
) )
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps)) print(out_gt)
print(out_repro)
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -130,5 +130,6 @@ class TestModel(unittest.TestCase): ...@@ -130,5 +130,6 @@ class TestModel(unittest.TestCase):
out_repro = out_repro["sm"]["positions"][-1] out_repro = out_repro["sm"]["positions"][-1]
out_repro = out_repro.squeeze(0) out_repro = out_repro.squeeze(0)
print(torch.mean(torch.abs(out_gt - out_repro)))
print(torch.max(torch.abs(out_gt - out_repro))) print(torch.max(torch.abs(out_gt - out_repro)))
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < 1e-3) self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < 1e-3)
...@@ -88,15 +88,13 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase): ...@@ -88,15 +88,13 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
out_repro = ( out_repro = (
model.evoformer.blocks[0] model.evoformer.blocks[0].msa_att_row(
.msa_att_row(
torch.as_tensor(msa_act).cuda(), torch.as_tensor(msa_act).cuda(),
z=torch.as_tensor(pair_act).cuda(), z=torch.as_tensor(pair_act).cuda(),
chunk_size=4, chunk_size=4,
mask=torch.as_tensor(msa_mask).cuda(), mask=torch.as_tensor(msa_mask).cuda(),
) )
.cpu() ).cpu()
)
self.assertTrue(torch.all(torch.abs(out_gt - out_repro) < consts.eps)) self.assertTrue(torch.all(torch.abs(out_gt - out_repro) < consts.eps))
...@@ -153,14 +151,14 @@ class TestMSAColumnAttention(unittest.TestCase): ...@@ -153,14 +151,14 @@ class TestMSAColumnAttention(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
out_repro = ( out_repro = (
model.evoformer.blocks[0] model.evoformer.blocks[0].msa_att_col(
.msa_att_col(
torch.as_tensor(msa_act).cuda(), torch.as_tensor(msa_act).cuda(),
chunk_size=4, chunk_size=4,
mask=torch.as_tensor(msa_mask).cuda(), mask=torch.as_tensor(msa_mask).cuda(),
) )
.cpu() ).cpu()
)
print(torch.mean(torch.abs(out_gt - out_repro)))
self.assertTrue(torch.all(torch.abs(out_gt - out_repro) < consts.eps)) self.assertTrue(torch.all(torch.abs(out_gt - out_repro) < consts.eps))
...@@ -218,8 +216,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase): ...@@ -218,8 +216,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
out_repro = ( out_repro = (
model.extra_msa_stack.stack.blocks[0] model.extra_msa_stack.blocks[0].msa_att_col(
.msa_att_col(
torch.as_tensor(msa_act, dtype=torch.float32).cuda(), torch.as_tensor(msa_act, dtype=torch.float32).cuda(),
chunk_size=4, chunk_size=4,
mask=torch.as_tensor(msa_mask, dtype=torch.float32).cuda(), mask=torch.as_tensor(msa_mask, dtype=torch.float32).cuda(),
......
...@@ -85,9 +85,9 @@ class TestTriangularAttention(unittest.TestCase): ...@@ -85,9 +85,9 @@ class TestTriangularAttention(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
module = ( module = (
model.evoformer.blocks[0].tri_att_start model.evoformer.blocks[0].core.tri_att_start
if starting if starting
else model.evoformer.blocks[0].tri_att_end else model.evoformer.blocks[0].core.tri_att_end
) )
out_repro = module( out_repro = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(), torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
...@@ -95,7 +95,7 @@ class TestTriangularAttention(unittest.TestCase): ...@@ -95,7 +95,7 @@ class TestTriangularAttention(unittest.TestCase):
chunk_size=None, chunk_size=None,
).cpu() ).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps)) self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
@compare_utils.skip_unless_alphafold_installed() @compare_utils.skip_unless_alphafold_installed()
def test_tri_att_end_compare(self): def test_tri_att_end_compare(self):
......
...@@ -87,9 +87,9 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase): ...@@ -87,9 +87,9 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
module = ( module = (
model.evoformer.blocks[0].tri_mul_in model.evoformer.blocks[0].core.tri_mul_in
if incoming if incoming
else model.evoformer.blocks[0].tri_mul_out else model.evoformer.blocks[0].core.tri_mul_out
) )
out_repro = module( out_repro = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(), torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
......
...@@ -67,6 +67,8 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -67,6 +67,8 @@ class OpenFoldWrapper(pl.LightningModule):
# Compute loss # Compute loss
loss = self.loss(outputs, batch) loss = self.loss(outputs, batch)
self.log("loss", loss)
return {"loss": loss} return {"loss": loss}
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
...@@ -79,6 +81,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -79,6 +81,7 @@ class OpenFoldWrapper(pl.LightningModule):
outputs = self(batch) outputs = self(batch)
batch = tensor_tree_map(lambda t: t[..., -1], batch) batch = tensor_tree_map(lambda t: t[..., -1], batch)
loss = self.loss(outputs, batch) loss = self.loss(outputs, batch)
self.log("val_loss", loss)
return {"val_loss": loss} return {"val_loss": loss}
def validation_epoch_end(self, _): def validation_epoch_end(self, _):
...@@ -316,6 +319,15 @@ if __name__ == "__main__": ...@@ -316,6 +319,15 @@ if __name__ == "__main__":
"--script_modules", type=bool_type, default=False, "--script_modules", type=bool_type, default=False,
help="Whether to TorchScript eligible components of them model" help="Whether to TorchScript eligible components of them model"
) )
parser.add_argument(
"--train_prot_data_cache_path", type=str, default=None,
)
parser.add_argument(
"--distillation_prot_data_cache_path", type=str, default=None,
)
parser.add_argument(
"--train_epoch_len", type=int, default=10000,
)
parser = pl.Trainer.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser)
# Disable the initial validation pass # Disable the initial validation pass
...@@ -324,7 +336,14 @@ if __name__ == "__main__": ...@@ -324,7 +336,14 @@ if __name__ == "__main__":
) )
# Remove some buggy/redundant arguments introduced by the Trainer # Remove some buggy/redundant arguments introduced by the Trainer
remove_arguments(parser, ["--accelerator", "--resume_from_checkpoint"]) remove_arguments(
parser,
[
"--accelerator",
"--resume_from_checkpoint",
"--reload_dataloaders_every_epoch"
]
)
args = parser.parse_args() args = parser.parse_args()
...@@ -333,4 +352,7 @@ if __name__ == "__main__": ...@@ -333,4 +352,7 @@ if __name__ == "__main__":
(args.num_nodes is not None and args.num_nodes > 1))): (args.num_nodes is not None and args.num_nodes > 1))):
raise ValueError("For distributed training, --seed must be specified") raise ValueError("For distributed training, --seed must be specified")
# This re-applies the training-time filters at the beginning of every epoch
args.reload_dataloaders_every_epoch = True
main(args) main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment