Commit 6ce8cfe3 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fixes

parent 1df4991d
......@@ -4,7 +4,7 @@ import json
import logging
import os
import pickle
from typing import Optional, Sequence
from typing import Optional, Sequence, List, Any
import ml_collections as mlc
import numpy as np
......@@ -29,14 +29,15 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
max_template_date: str,
config: mlc.ConfigDict,
kalign_binary_path: str = '/usr/bin/kalign',
mapping_path: Optional[str] = None,
max_template_hits: int = 4,
obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None,
shuffle_top_k_prefiltered: Optional[int] = None,
treat_pdb_as_distillation: bool = True,
mapping_path: Optional[str] = None,
mode: str = "train",
_output_raw: bool = False,
_alignment_index: Optional[Any] = None
):
"""
Args:
......@@ -56,12 +57,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
A dataset config object. See openfold.config
kalign_binary_path:
Path to kalign binary.
mapping_path:
A json file containing a mapping from consecutive numerical
ids to sample names (matching the directories in data_dir).
Samples not in this mapping are ignored. Can be used to
implement the various training-time filters described in
the AlphaFold supplement.
max_template_hits:
An upper bound on how many templates are considered. During
training, the templates ultimately used are subsampled
......@@ -89,26 +84,28 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
self.treat_pdb_as_distillation = treat_pdb_as_distillation
self.mode = mode
self._output_raw = _output_raw
self._alignment_index = _alignment_index
valid_modes = ["train", "eval", "predict"]
if(mode not in valid_modes):
raise ValueError(f'mode must be one of {valid_modes}')
if(mapping_path is None):
self.mapping = {
str(i):os.path.splitext(name)[0]
for i, name in enumerate(os.listdir(alignment_dir))
}
else:
with open(mapping_path, 'r') as fp:
self.mapping = json.load(fp)
if(template_release_dates_cache_path is None):
logging.warning(
"Template release dates cache does not exist. Remember to run "
"scripts/generate_mmcif_cache.py before running OpenFold"
)
if(mapping_path is None):
self._chain_ids = list(os.listdir(alignment_dir))
else:
with open(mapping_path, "r") as f:
self._chain_ids = [l.strip() for l in f.readlines()]
self._chain_id_to_idx_dict = {
chain: i for i, chain in enumerate(self._chain_ids)
}
template_featurizer = templates.TemplateHitFeaturizer(
mmcif_dir=template_mmcif_dir,
max_template_date=max_template_date,
......@@ -126,7 +123,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
if(not self._output_raw):
self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def _parse_mmcif(self, path, file_id, chain_id, alignment_dir):
def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, _alignment_index):
with open(path, 'r') as f:
mmcif_string = f.read()
......@@ -145,14 +142,25 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
mmcif=mmcif_object,
alignment_dir=alignment_dir,
chain_id=chain_id,
_alignment_index=_alignment_index
)
return data
def chain_id_to_idx(self, chain_id):
return self._chain_id_to_idx_dict[chain_id]
def idx_to_chain_id(self, idx):
return self._chain_ids[idx]
def __getitem__(self, idx):
name = self.mapping[str(idx)]
name = self.idx_to_chain_id(idx)
alignment_dir = os.path.join(self.alignment_dir, name)
_alignment_index = None
if(self._alignment_index is not None):
_alignment_index = self._alignment_index[name]
if(self.mode == 'train' or self.mode == 'eval'):
spl = name.rsplit('_', 1)
if(len(spl) == 2):
......@@ -164,11 +172,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
path = os.path.join(self.data_dir, file_id)
if(os.path.exists(path + ".cif")):
data = self._parse_mmcif(
path + ".cif", file_id, chain_id, alignment_dir
path + ".cif", file_id, chain_id, alignment_dir, _alignment_index,
)
elif(os.path.exists(path + ".core")):
data = self.data_pipeline.process_core(
path + ".core", alignment_dir
path + ".core", alignment_dir, _alignment_index,
)
elif(os.path.exists(path + ".pdb")):
data = self.data_pipeline.process_pdb(
......@@ -176,6 +184,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
alignment_dir=alignment_dir,
is_distillation=self.treat_pdb_as_distillation,
chain_id=chain_id,
_alignment_index=_alignment_index,
)
else:
raise ValueError("Invalid file type")
......@@ -184,6 +193,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
data = self.data_pipeline.process_fasta(
fasta_path=path,
alignment_dir=alignment_dir,
_alignment_index=_alignment_index,
)
if(self._output_raw):
......@@ -196,56 +206,130 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
return feats
def __len__(self):
return len(self.mapping.keys())
return len(self._chain_ids)
def train_filter(
prot_data_cache_entry: Any,
generator: torch.Generator,
max_resolution: float = 9.,
max_single_aa_prop: float = 0.8,
) -> bool:
# Hard filters
resolution = prot_data_cache_entry.get("resolution", None)
if(resolution is not None and resolution > max_resolution):
return False
seq = prot_data_cache_entry["seq"]
counts = {}
for aa in seq:
counts.setdefault(aa, 0)
counts[aa] += 1
largest_aa_count = max(counts.values())
largest_single_aa_prop = largest_aa_count / len(seq)
if(largest_single_aa_prop > max_single_aa_prop):
return False
# Stochastic filters
probabilities = []
cluster_size = prot_data_cache_entry.get("cluster_size", None)
if(cluster_size is not None and cluster_size > 0):
probabilities.append(1 / cluster_size)
chain_length = len(prot_data_cache_entry["seq"])
probabilities.append((1 / 512) * (max(min(chain_length, 512), 256)))
weights = [[1 - p, p] for p in probabilities]
results = torch.multinomial(
torch.tensor(weights),
num_samples=1,
generator=generator,
)
def looped_sequence(sequence):
while True:
for x in sequence:
yield x
return torch.all(results)
class OpenFoldDataset(torch.utils.data.IterableDataset):
class OpenFoldDataset(torch.utils.data.Dataset):
"""
The Dataset is written to accommodate the requirement that proteins are
sampled from the distillation set with some probability p
and from the PDB set with probability (1 - p). Proteins are sampled
from both sets without replacement, and as soon as either set is
emptied, it is refilled. The Dataset therefore has an arbitrary length.
Nevertheless, for compatibility with various PyTorch Lightning
functionalities, it is possible to specify an epoch length. This length
has no effect on the output of the Dataset.
Implements the stochastic filters applied during AlphaFold's training.
Because samples are selected from constituent datasets randomly, the
length of an OpenFoldFilteredDataset is arbitrary. Samples are selected
and filtered once at initialization.
"""
def __init__(self,
datasets: Sequence[OpenFoldSingleDataset],
probabilities: Sequence[int],
epoch_len: int,
prot_data_cache_paths: List[str],
filter_fn: Optional[Any] = train_filter,
generator: torch.Generator = None,
_roll_at_init: bool = True,
):
self.datasets = datasets
self.samplers = [
looped_sequence(RandomSampler(d)) for d in datasets
]
self.probabilities = probabilities
self.epoch_len = epoch_len
self.generator = generator
self.filter_fn = filter_fn
def looped_shuffled_dataset_idx(dataset_len):
while True:
# Uniformly shuffle each dataset's indices
weights = [1. for _ in range(dataset_len)]
shuf = torch.multinomial(
torch.tensor(weights),
num_samples=dataset_len,
replacement=False,
generator=self.generator,
)
for idx in shuf:
yield idx
self.distr = torch.distributions.categorical.Categorical(
probs=torch.tensor(probabilities),
)
self.shuffled_idx_iters = []
for d in datasets:
self.shuffled_idx_iters.append(
looped_shuffled_dataset_idx(len(d))
)
def __iter__(self):
return self
self.prot_data_caches = []
for path in prot_data_cache_paths:
with open(path, "r") as fp:
self.prot_data_caches.append(json.load(fp))
def __next__(self):
dataset_idx = self.distr.sample()
sampler = self.samplers[dataset_idx]
element_idx = next(sampler)
return self.datasets[dataset_idx][element_idx]
if(_roll_at_init):
self.reroll()
def __getitem__(self, idx):
dataset_idx, datapoint_idx = self.datapoints[idx]
return self.datasets[dataset_idx][datapoint_idx]
def __len__(self):
return self.epoch_len
def reroll(self):
dataset_choices = torch.multinomial(
torch.tensor(self.probabilities),
num_samples=self.epoch_len,
replacement=True,
generator=self.generator,
)
self.datapoints = []
for dataset_idx in dataset_choices:
dataset = self.datasets[dataset_idx]
idx_iter = self.shuffled_idx_iters[dataset_idx]
prot_data_cache = self.prot_data_caches[dataset_idx]
datapoint_idx = None
while datapoint_idx is None:
candidate_idx = next(idx_iter)
chain_id = dataset.idx_to_chain_id(candidate_idx)
if(self.filter_fn(prot_data_cache[chain_id], self.generator)):
datapoint_idx = candidate_idx
self.datapoints.append((dataset_idx, datapoint_idx))
class OpenFoldBatchCollator:
def __init__(self, config, generator, stage="train"):
def __init__(self, config, stage="train"):
self.stage = stage
self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
......@@ -283,21 +367,20 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
keyed_probs.append(
("use_clamped_fape", [1 - clamp_prob, clamp_prob])
)
if(self.config.supervised.uniform_recycling):
recycling_probs = [
1. / (max_iters + 1) for _ in range(max_iters + 1)
]
keyed_probs.append(
("no_recycling_iters", recycling_probs)
)
if(stage_cfg.uniform_recycling):
recycling_probs = [
1. / (max_iters + 1) for _ in range(max_iters + 1)
]
else:
recycling_probs = [
0. for _ in range(max_iters + 1)
]
recycling_probs[-1] = 1.
keyed_probs.append(
("no_recycling_iters", recycling_probs)
)
keyed_probs.append(
("no_recycling_iters", recycling_probs)
)
keys, probs = zip(*keyed_probs)
max_len = max([len(p) for p in probs])
......@@ -362,8 +445,11 @@ class OpenFoldDataModule(pl.LightningDataModule):
max_template_date: str,
train_data_dir: Optional[str] = None,
train_alignment_dir: Optional[str] = None,
train_filter_fn: Optional[Any] = train_filter,
train_prot_data_cache_path: Optional[str] = None,
distillation_data_dir: Optional[str] = None,
distillation_alignment_dir: Optional[str] = None,
distillation_prot_data_cache_path: Optional[str] = None,
val_data_dir: Optional[str] = None,
val_alignment_dir: Optional[str] = None,
predict_data_dir: Optional[str] = None,
......@@ -374,6 +460,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None,
batch_seed: Optional[int] = None,
train_epoch_len: int = 50000,
_alignment_index_path: Optional[str] = None,
**kwargs
):
super(OpenFoldDataModule, self).__init__()
......@@ -383,8 +471,13 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.max_template_date = max_template_date
self.train_data_dir = train_data_dir
self.train_alignment_dir = train_alignment_dir
self.train_filter_fn = train_filter_fn
self.train_prot_data_cache_path = train_prot_data_cache_path
self.distillation_data_dir = distillation_data_dir
self.distillation_alignment_dir = distillation_alignment_dir
self.distillation_prot_data_cache_path = (
distillation_prot_data_cache_path
)
self.val_data_dir = val_data_dir
self.val_alignment_dir = val_alignment_dir
self.predict_data_dir = predict_data_dir
......@@ -397,6 +490,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
)
self.obsolete_pdbs_file_path = obsolete_pdbs_file_path
self.batch_seed = batch_seed
self.train_epoch_len = train_epoch_len
if(self.train_data_dir is None and self.predict_data_dir is None):
raise ValueError(
......@@ -406,11 +500,11 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.training_mode = self.train_data_dir is not None
if(self.training_mode and self.train_alignment_dir is None):
if(self.training_mode and train_alignment_dir is None):
raise ValueError(
'In training mode, train_alignment_dir must be specified'
)
elif(not self.training_mode and self.predict_alingment_dir is None):
elif(not self.training_mode and predict_alignment_dir is None):
raise ValueError(
'In inference mode, predict_alignment_dir must be specified'
)
......@@ -420,10 +514,28 @@ class OpenFoldDataModule(pl.LightningDataModule):
'be specified as well'
)
def setup(self, stage: Optional[str] = None):
if(stage is None):
stage = "train"
cache_missing = (
train_filter_fn and
(
train_prot_data_cache_path is None or
(
distillation_data_dir is not None and
distillation_prot_data_cache_path is None
)
)
)
if(cache_missing):
raise ValueError(
"If train_filter_fn is given, so must the protein data caches"
)
# An ad-hoc measure for our particular filesystem restrictions
self._alignment_index = None
if(_alignment_index_path is not None):
with open(_alignment_index_path, "r") as fp:
self._alignment_index = json.load(fp)
def setup(self):
# Most of the arguments are the same for the three datasets
dataset_gen = partial(OpenFoldSingleDataset,
template_mmcif_dir=self.template_mmcif_dir,
......@@ -434,10 +546,11 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.template_release_dates_cache_path,
obsolete_pdbs_file_path=
self.obsolete_pdbs_file_path,
_alignment_index=self._alignment_index,
)
if(self.training_mode):
self.train_dataset = dataset_gen(
if(self.training_mode):
train_dataset = dataset_gen(
data_dir=self.train_data_dir,
alignment_dir=self.train_alignment_dir,
mapping_path=self.train_mapping_path,
......@@ -449,6 +562,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
_output_raw=True,
)
distillation_dataset = None
if(self.distillation_data_dir is not None):
distillation_dataset = dataset_gen(
data_dir=self.distillation_data_dir,
......@@ -461,13 +575,30 @@ class OpenFoldDataModule(pl.LightningDataModule):
)
d_prob = self.config.train.distillation_prob
self.train_dataset = OpenFoldDataset(
datasets=[self.train_dataset, distillation_dataset],
probabilities=[1 - d_prob, d_prob],
epoch_len=(
self.train_dataset.len() + distillation_dataset.len()
),
)
if(distillation_dataset is not None):
datasets = [train_dataset, distillation_dataset]
d_prob = self.config.train.distillation_prob
probabilities = [1 - d_prob, d_prob]
prot_data_cache_paths = [
self.train_prot_data_cache_path,
self.distillation_prot_data_cache_path,
]
else:
datasets = [train_dataset]
probabilities = [1.]
prot_data_cache_paths = [
self.train_prot_data_cache_path,
]
self.train_dataset = OpenFoldDataset(
datasets=datasets,
probabilities=probabilities,
epoch_len=self.train_epoch_len,
prot_data_cache_paths=prot_data_cache_paths,
filter_fn=self.train_filter_fn,
_roll_at_init=False,
)
if(self.val_data_dir is not None):
self.eval_dataset = dataset_gen(
......@@ -497,6 +628,9 @@ class OpenFoldDataModule(pl.LightningDataModule):
dataset = None
if(stage == "train"):
dataset = self.train_dataset
# Filter the dataset, if necessary
dataset.reroll()
elif(stage == "eval"):
dataset = self.eval_dataset
elif(stage == "predict"):
......
......@@ -422,42 +422,86 @@ class DataPipeline:
def _parse_msa_data(
self,
alignment_dir: str,
_alignment_index: Optional[Any] = None,
) -> Mapping[str, Any]:
msa_data = {}
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]
if(ext == ".a3m"):
with open(path, "r") as fp:
msa, deletion_matrix = parsers.parse_a3m(fp.read())
data = {"msa": msa, "deletion_matrix": deletion_matrix}
elif(ext == ".sto"):
with open(path, "r") as fp:
if(_alignment_index is not None):
fp = open(_alignment_index["db"], "rb")
def read_msa(start, size):
fp.seek(start)
msa = fp.read(size).encode("utf-8")
return msa
for (name, start, size) in _alignment_index["files"]:
ext = os.path.splitext(name)[-1]
if(ext == ".a3m"):
msa, deletion_matrix = parsers.parse_a3m(
read_msa(start, size)
)
data = {"msa": msa, "deletion_matrix": deletion_matrix}
elif(ext == ".sto"):
msa, deletion_matrix, _ = parsers.parse_stockholm(
fp.read()
read_msa(start, size)
)
data = {"msa": msa, "deletion_matrix": deletion_matrix}
else:
continue
data = {"msa": msa, "deletion_matrix": deletion_matrix}
else:
continue
msa_data[f] = data
msa_data[f] = data
fp.close()
else:
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]
if(ext == ".a3m"):
with open(path, "r") as fp:
msa, deletion_matrix = parsers.parse_a3m(fp.read())
data = {"msa": msa, "deletion_matrix": deletion_matrix}
elif(ext == ".sto"):
with open(path, "r") as fp:
msa, deletion_matrix, _ = parsers.parse_stockholm(
fp.read()
)
data = {"msa": msa, "deletion_matrix": deletion_matrix}
else:
continue
msa_data[f] = data
return msa_data
def _parse_template_hits(
self,
alignment_dir: str,
_alignment_index: Optional[Any] = None
) -> Mapping[str, Any]:
all_hits = {}
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]
if(_alignment_index is not None):
fp = open(_alignment_index["db"], 'rb')
def read_template(start, size):
fp.seek(start)
return fp.read(size).encode("utf-8")
for (name, start, size) in _alignment_index["files"]:
ext = os.path.splitext(name)[-1]
if(ext == ".hhr"):
hits = parsers.parse_hhr(read_template(start, size))
all_hits[f] = hits
else:
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]
if(ext == ".hhr"):
with open(path, "r") as fp:
hits = parsers.parse_hhr(fp.read())
all_hits[f] = hits
if(ext == ".hhr"):
with open(path, "r") as fp:
hits = parsers.parse_hhr(fp.read())
all_hits[f] = hits
return all_hits
......@@ -465,8 +509,9 @@ class DataPipeline:
self,
alignment_dir: str,
input_sequence: Optional[str] = None,
_alignment_index: Optional[str] = None
) -> Mapping[str, Any]:
msa_data = self._parse_msa_data(alignment_dir)
msa_data = self._parse_msa_data(alignment_dir, _alignment_index)
if(len(msa_data) == 0):
if(input_sequence is None):
......@@ -496,6 +541,7 @@ class DataPipeline:
self,
fasta_path: str,
alignment_dir: str,
_alignment_index: Optional[str] = None,
) -> FeatureDict:
"""Assembles features for a single sequence in a FASTA file"""
with open(fasta_path) as f:
......@@ -509,7 +555,7 @@ class DataPipeline:
input_description = input_descs[0]
num_res = len(input_sequence)
hits = self._parse_template_hits(alignment_dir)
hits = self._parse_template_hits(alignment_dir, _alignment_index)
template_features = make_template_features(
input_sequence,
hits,
......@@ -535,6 +581,7 @@ class DataPipeline:
mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path
alignment_dir: str,
chain_id: Optional[str] = None,
_alignment_index: Optional[str] = None,
) -> FeatureDict:
"""
Assembles features for a specific chain in an mmCIF object.
......@@ -552,7 +599,7 @@ class DataPipeline:
mmcif_feats = make_mmcif_features(mmcif, chain_id)
input_sequence = mmcif.chain_to_seqres[chain_id]
hits = self._parse_template_hits(alignment_dir)
hits = self._parse_template_hits(alignment_dir, _alignment_index)
template_features = make_template_features(
input_sequence,
hits,
......@@ -570,6 +617,7 @@ class DataPipeline:
alignment_dir: str,
is_distillation: bool = True,
chain_id: Optional[str] = None,
_alignment_index: Optional[str] = None,
) -> FeatureDict:
"""
Assembles features for a protein in a PDB file.
......@@ -586,7 +634,7 @@ class DataPipeline:
is_distillation
)
hits = self._parse_template_hits(alignment_dir)
hits = self._parse_template_hits(alignment_dir, _alignment_index)
template_features = make_template_features(
input_sequence,
hits,
......@@ -601,6 +649,7 @@ class DataPipeline:
self,
core_path: str,
alignment_dir: str,
_alignment_index: Optional[str] = None,
) -> FeatureDict:
"""
Assembles features for a protein in a ProteinNet .core file.
......@@ -613,7 +662,7 @@ class DataPipeline:
description = os.path.splitext(os.path.basename(core_path))[0].upper()
core_feats = make_protein_features(protein_object, description)
hits = self._parse_template_hits(alignment_dir)
hits = self._parse_template_hits(alignment_dir, _alignment_index)
template_features = make_template_features(
input_sequence,
hits,
......
......@@ -360,7 +360,8 @@ class ExtraMSABlock(nn.Module):
mask=msa_mask,
chunk_size=chunk_size,
_chunk_logits=_chunk_logits,
_checkpoint_chunks=self.ckpt,
_checkpoint_chunks=
self.ckpt if torch.is_grad_enabled() else False,
)
)
......
......@@ -188,7 +188,7 @@ class LayerNorm(nn.Module):
self.bias.to(dtype=d),
self.eps
)
elif(d == torch.bfloat16):
else:
out = nn.functional.layer_norm(
x,
self.c_in,
......@@ -209,7 +209,7 @@ def softmax(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
if(d is torch.bfloat16 and not deepspeed.utils.is_initialized()):
with torch.cuda.amp.autocast(enabled=False):
s = torch.nn.functional.softmax(t, dim=dim)
elif(d == torch.bfloat16):
else:
s = torch.nn.functional.softmax(t, dim=dim)
return s
......
......@@ -65,8 +65,7 @@ class TriangleAttention(nn.Module):
) -> torch.Tensor:
mha_inputs = {
"q_x": x,
"k_x": x,
"v_x": x,
"kv_x": x,
"biases": biases,
}
return chunk_layer(
......
......@@ -282,13 +282,19 @@ def import_jax_weights_(model, npz_path, version="model_1"):
b.msa_att_row
),
col_att_name: msa_col_att_params,
"msa_transition": MSATransitionParams(b.msa_transition),
"outer_product_mean": OuterProductMeanParams(b.outer_product_mean),
"triangle_multiplication_outgoing": TriMulOutParams(b.tri_mul_out),
"triangle_multiplication_incoming": TriMulInParams(b.tri_mul_in),
"triangle_attention_starting_node": TriAttParams(b.tri_att_start),
"triangle_attention_ending_node": TriAttParams(b.tri_att_end),
"pair_transition": PairTransitionParams(b.pair_transition),
"msa_transition": MSATransitionParams(b.core.msa_transition),
"outer_product_mean":
OuterProductMeanParams(b.core.outer_product_mean),
"triangle_multiplication_outgoing":
TriMulOutParams(b.core.tri_mul_out),
"triangle_multiplication_incoming":
TriMulInParams(b.core.tri_mul_in),
"triangle_attention_starting_node":
TriAttParams(b.core.tri_att_start),
"triangle_attention_ending_node":
TriAttParams(b.core.tri_att_end),
"pair_transition":
PairTransitionParams(b.core.pair_transition),
}
return d
......@@ -323,7 +329,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
[TemplatePairBlockParams(b) for b in tps_blocks]
)
ems_blocks = model.extra_msa_stack.stack.blocks
ems_blocks = model.extra_msa_stack.blocks
ems_blocks_params = stacked([ExtraMSABlockParams(b) for b in ems_blocks])
evo_blocks = model.evoformer.blocks
......
......@@ -1520,7 +1520,6 @@ class AlphaFoldLoss(nn.Module):
weight = self.config[loss_name].weight
if weight:
loss = loss_fn()
if(torch.isnan(loss) or torch.isinf(loss)):
logging.warning(f"{loss_name} loss is NaN. Skipping...")
loss = loss.new_tensor(0., requires_grad=True)
......
import argparse
from functools import partial
import logging
from multiprocessing import Pool
import os
import sys
import json
sys.path.append(".") # an innocent hack to get this to run from the top level
from tqdm import tqdm
from openfold.data.mmcif_parsing import parse
def parse_file(f, args):
with open(os.path.join(args.mmcif_dir, f), "r") as fp:
mmcif_string = fp.read()
file_id = os.path.splitext(f)[0]
mmcif = parse(file_id=file_id, mmcif_string=mmcif_string)
if mmcif.mmcif_object is None:
logging.info(f"Could not parse {f}. Skipping...")
return {}
else:
mmcif = mmcif.mmcif_object
local_data = {}
local_data["release_date"] = mmcif.header["release_date"]
local_data["no_chains"] = len(list(mmcif.structure.get_chains()))
return {file_id: local_data}
def main(args):
files = [f for f in os.listdir(args.mmcif_dir) if ".cif" in f]
fn = partial(parse_file, args=args)
data = {}
with Pool(processes=args.no_workers) as p:
with tqdm(total=len(files)) as pbar:
for d in p.imap_unordered(fn, files, chunksize=args.chunksize):
data.update(d)
pbar.update()
with open(args.output_path, "w") as fp:
fp.write(json.dumps(data, indent=4))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"mmcif_dir", type=str, help="Directory containing mmCIF files"
)
parser.add_argument(
"output_path", type=str, help="Path for .json output"
)
parser.add_argument(
"--no_workers", type=int, default=4,
help="Number of workers to use for parsing"
)
parser.add_argument(
"--chunksize", type=int, default=10,
help="How many files should be distributed to each worker at a time"
)
args = parser.parse_args()
main(args)
......@@ -25,11 +25,12 @@ def run_seq_group_alignments(seq_groups, alignment_runner, args):
first_name = names[0]
alignment_dir = os.path.join(args.output_dir, first_name)
try:
os.makedirs(alignment_dir)
except Exception as e:
logging.warning(f"Failed to create directory for {first_name} with exception {e}...")
continue
os.makedirs(alignment_dir, exist_ok=True)
# try:
# os.makedirs(alignment_dir)
# except Exception as e:
# logging.warning(f"Failed to create directory for {first_name} with exception {e}...")
# continue
fd, fasta_path = tempfile.mkstemp(suffix=".fasta")
with os.fdopen(fd, 'w') as fp:
......@@ -48,14 +49,14 @@ def run_seq_group_alignments(seq_groups, alignment_runner, args):
os.remove(fasta_path)
for name in names[1:]:
if(name in dirs):
logging.warning(
f'{name} has already been processed. Skipping...'
)
continue
#if(name in dirs):
# logging.warning(
# f'{name} has already been processed. Skipping...'
# )
# continue
cp_dir = os.path.join(args.output_dir, name)
os.makedirs(cp_dir)
os.makedirs(cp_dir, exist_ok=True)
for f in os.listdir(alignment_dir):
copyfile(os.path.join(alignment_dir, f), os.path.join(cp_dir, f))
......@@ -136,23 +137,23 @@ def main(args):
else:
cache = None
if(cache is not None and args.filter):
dirs = set(os.listdir(args.output_dir))
def prot_is_done(f):
prot_id = os.path.splitext(f)[0]
if(prot_id in cache):
chain_ids = cache[prot_id]["chain_ids"]
for c in chain_ids:
full_name = prot_id + "_" + c
if(not full_name in dirs):
return False
else:
return False
return True
files = [f for f in files if not prot_is_done(f)]
dirs = []
#if(cache is not None and args.filter):
# dirs = set(os.listdir(args.output_dir))
# def prot_is_done(f):
# prot_id = os.path.splitext(f)[0]
# if(prot_id in cache):
# chain_ids = cache[prot_id]["chain_ids"]
# for c in chain_ids:
# full_name = prot_id + "_" + c
# if(not full_name in dirs):
# return False
# else:
# return False
# return True
# files = [f for f in files if not prot_is_done(f)]
def split_up_arglist(arglist):
# Split up the survivors
......
......@@ -4,19 +4,19 @@ from datetime import date
def add_data_args(parser: argparse.ArgumentParser):
parser.add_argument(
'uniref90_database_path', type=str,
'--uniref90_database_path', type=str, default=None,
)
parser.add_argument(
'mgnify_database_path', type=str,
'--mgnify_database_path', type=str, default=None,
)
parser.add_argument(
'pdb70_database_path', type=str,
'--pdb70_database_path', type=str, default=None,
)
parser.add_argument(
'template_mmcif_dir', type=str,
'--template_mmcif_dir', type=str, default=None,
)
parser.add_argument(
'uniclust30_database_path', type=str,
'--uniclust30_database_path', type=str, default=None,
)
parser.add_argument(
'--bfd_database_path', type=str, default=None,
......
......@@ -135,8 +135,8 @@ class TestEvoformerStack(unittest.TestCase):
out_repro_msa = out_repro_msa.cpu()
out_repro_pair = out_repro_pair.cpu()
assert torch.max(torch.abs(out_repro_msa - out_gt_msa) < consts.eps)
assert torch.max(torch.abs(out_repro_pair - out_gt_pair) < consts.eps)
assert(torch.max(torch.abs(out_repro_msa - out_gt_msa)) < consts.eps)
assert(torch.max(torch.abs(out_repro_pair - out_gt_pair)) < consts.eps)
class TestExtraMSAStack(unittest.TestCase):
......@@ -172,7 +172,7 @@ class TestExtraMSAStack(unittest.TestCase):
transition_n,
msa_dropout,
pair_stack_dropout,
blocks_per_ckpt=None,
ckpt=False,
inf=inf,
eps=eps,
).eval()
......@@ -257,16 +257,19 @@ class TestMSATransition(unittest.TestCase):
out_gt = torch.as_tensor(np.array(out_gt))
model = compare_utils.get_global_pretrained_openfold()
out_repro = (
model.evoformer.blocks[0]
.msa_transition(
model.evoformer.blocks[0].core.msa_transition(
torch.as_tensor(msa_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(msa_mask, dtype=torch.float32).cuda(),
)
.cpu()
)
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
print(out_gt)
print(out_repro)
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
if __name__ == "__main__":
......
......@@ -130,5 +130,6 @@ class TestModel(unittest.TestCase):
out_repro = out_repro["sm"]["positions"][-1]
out_repro = out_repro.squeeze(0)
print(torch.mean(torch.abs(out_gt - out_repro)))
print(torch.max(torch.abs(out_gt - out_repro)))
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < 1e-3)
......@@ -88,15 +88,13 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold()
out_repro = (
model.evoformer.blocks[0]
.msa_att_row(
model.evoformer.blocks[0].msa_att_row(
torch.as_tensor(msa_act).cuda(),
z=torch.as_tensor(pair_act).cuda(),
chunk_size=4,
mask=torch.as_tensor(msa_mask).cuda(),
)
.cpu()
)
).cpu()
self.assertTrue(torch.all(torch.abs(out_gt - out_repro) < consts.eps))
......@@ -153,14 +151,14 @@ class TestMSAColumnAttention(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold()
out_repro = (
model.evoformer.blocks[0]
.msa_att_col(
model.evoformer.blocks[0].msa_att_col(
torch.as_tensor(msa_act).cuda(),
chunk_size=4,
mask=torch.as_tensor(msa_mask).cuda(),
)
.cpu()
)
).cpu()
print(torch.mean(torch.abs(out_gt - out_repro)))
self.assertTrue(torch.all(torch.abs(out_gt - out_repro) < consts.eps))
......@@ -218,8 +216,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold()
out_repro = (
model.extra_msa_stack.stack.blocks[0]
.msa_att_col(
model.extra_msa_stack.blocks[0].msa_att_col(
torch.as_tensor(msa_act, dtype=torch.float32).cuda(),
chunk_size=4,
mask=torch.as_tensor(msa_mask, dtype=torch.float32).cuda(),
......
......@@ -85,9 +85,9 @@ class TestTriangularAttention(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold()
module = (
model.evoformer.blocks[0].tri_att_start
model.evoformer.blocks[0].core.tri_att_start
if starting
else model.evoformer.blocks[0].tri_att_end
else model.evoformer.blocks[0].core.tri_att_end
)
out_repro = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
......@@ -95,7 +95,7 @@ class TestTriangularAttention(unittest.TestCase):
chunk_size=None,
).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
@compare_utils.skip_unless_alphafold_installed()
def test_tri_att_end_compare(self):
......
......@@ -87,9 +87,9 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold()
module = (
model.evoformer.blocks[0].tri_mul_in
model.evoformer.blocks[0].core.tri_mul_in
if incoming
else model.evoformer.blocks[0].tri_mul_out
else model.evoformer.blocks[0].core.tri_mul_out
)
out_repro = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
......
......@@ -67,6 +67,8 @@ class OpenFoldWrapper(pl.LightningModule):
# Compute loss
loss = self.loss(outputs, batch)
self.log("loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
......@@ -79,6 +81,7 @@ class OpenFoldWrapper(pl.LightningModule):
outputs = self(batch)
batch = tensor_tree_map(lambda t: t[..., -1], batch)
loss = self.loss(outputs, batch)
self.log("val_loss", loss)
return {"val_loss": loss}
def validation_epoch_end(self, _):
......@@ -316,6 +319,15 @@ if __name__ == "__main__":
"--script_modules", type=bool_type, default=False,
help="Whether to TorchScript eligible components of them model"
)
parser.add_argument(
"--train_prot_data_cache_path", type=str, default=None,
)
parser.add_argument(
"--distillation_prot_data_cache_path", type=str, default=None,
)
parser.add_argument(
"--train_epoch_len", type=int, default=10000,
)
parser = pl.Trainer.add_argparse_args(parser)
# Disable the initial validation pass
......@@ -324,7 +336,14 @@ if __name__ == "__main__":
)
# Remove some buggy/redundant arguments introduced by the Trainer
remove_arguments(parser, ["--accelerator", "--resume_from_checkpoint"])
remove_arguments(
parser,
[
"--accelerator",
"--resume_from_checkpoint",
"--reload_dataloaders_every_epoch"
]
)
args = parser.parse_args()
......@@ -333,4 +352,7 @@ if __name__ == "__main__":
(args.num_nodes is not None and args.num_nodes > 1))):
raise ValueError("For distributed training, --seed must be specified")
# This re-applies the training-time filters at the beginning of every epoch
args.reload_dataloaders_every_epoch = True
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment