Unverified Commit cfd0fc6e authored by Gustaf Ahdritz's avatar Gustaf Ahdritz Committed by GitHub
Browse files

Merge pull request #76 from aqlaboratory/chunking_experiment_rebased

parents c9e0f894 2726892a
......@@ -64,6 +64,7 @@ c_s = mlc.FieldReference(384, field_type=int)
blocks_per_ckpt = mlc.FieldReference(None, field_type=int)
chunk_size = mlc.FieldReference(4, field_type=int)
aux_distogram_bins = mlc.FieldReference(64, field_type=int)
tm_enabled = mlc.FieldReference(False, field_type=bool)
eps = mlc.FieldReference(1e-8, field_type=float)
templates_enabled = mlc.FieldReference(True, field_type=bool)
embed_template_torsion_angles = mlc.FieldReference(True, field_type=bool)
......@@ -228,7 +229,7 @@ config = mlc.ConfigDict(
"use_small_bfd": False,
"data_loaders": {
"batch_size": 1,
"num_workers": 8,
"num_workers": 16,
},
},
},
......@@ -320,10 +321,10 @@ config = mlc.ConfigDict(
"transition_n": 4,
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"clear_cache_between_blocks": True,
"inf": 1e9,
"eps": eps, # 1e-10,
"ckpt": blocks_per_ckpt is not None,
},
"enabled": True,
},
......@@ -376,7 +377,7 @@ config = mlc.ConfigDict(
"tm": {
"c_z": c_z,
"no_bins": aux_distogram_bins,
"enabled": False,
"enabled": tm_enabled,
},
"masked_msa": {
"c_m": c_m,
......@@ -454,6 +455,7 @@ config = mlc.ConfigDict(
"max_resolution": 3.0,
"eps": eps, # 1e-8,
"weight": 0.0,
"enabled": tm_enabled,
},
"eps": eps,
},
......
......@@ -4,7 +4,7 @@ import json
import logging
import os
import pickle
from typing import Optional, Sequence
from typing import Optional, Sequence, List, Any
import ml_collections as mlc
import numpy as np
......@@ -29,14 +29,15 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
max_template_date: str,
config: mlc.ConfigDict,
kalign_binary_path: str = '/usr/bin/kalign',
mapping_path: Optional[str] = None,
max_template_hits: int = 4,
obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None,
shuffle_top_k_prefiltered: Optional[int] = None,
treat_pdb_as_distillation: bool = True,
mapping_path: Optional[str] = None,
mode: str = "train",
_output_raw: bool = False,
_alignment_index: Optional[Any] = None
):
"""
Args:
......@@ -56,12 +57,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
A dataset config object. See openfold.config
kalign_binary_path:
Path to kalign binary.
mapping_path:
A json file containing a mapping from consecutive numerical
ids to sample names (matching the directories in data_dir).
Samples not in this mapping are ignored. Can be used to
implement the various training-time filters described in
the AlphaFold supplement.
max_template_hits:
An upper bound on how many templates are considered. During
training, the templates ultimately used are subsampled
......@@ -89,26 +84,30 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
self.treat_pdb_as_distillation = treat_pdb_as_distillation
self.mode = mode
self._output_raw = _output_raw
self._alignment_index = _alignment_index
valid_modes = ["train", "eval", "predict"]
if(mode not in valid_modes):
raise ValueError(f'mode must be one of {valid_modes}')
if(mapping_path is None):
self.mapping = {
str(i):os.path.splitext(name)[0]
for i, name in enumerate(os.listdir(alignment_dir))
}
else:
with open(mapping_path, 'r') as fp:
self.mapping = json.load(fp)
if(template_release_dates_cache_path is None):
logging.warning(
"Template release dates cache does not exist. Remember to run "
"scripts/generate_mmcif_cache.py before running OpenFold"
)
if(_alignment_index is not None):
self._chain_ids = list(_alignment_index.keys())
elif(mapping_path is None):
self._chain_ids = list(os.listdir(alignment_dir))
else:
with open(mapping_path, "r") as f:
self._chain_ids = [l.strip() for l in f.readlines()]
self._chain_id_to_idx_dict = {
chain: i for i, chain in enumerate(self._chain_ids)
}
template_featurizer = templates.TemplateHitFeaturizer(
mmcif_dir=template_mmcif_dir,
max_template_date=max_template_date,
......@@ -126,7 +125,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
if(not self._output_raw):
self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def _parse_mmcif(self, path, file_id, chain_id, alignment_dir):
def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, _alignment_index):
with open(path, 'r') as f:
mmcif_string = f.read()
......@@ -145,14 +144,26 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
mmcif=mmcif_object,
alignment_dir=alignment_dir,
chain_id=chain_id,
_alignment_index=_alignment_index
)
return data
def chain_id_to_idx(self, chain_id):
return self._chain_id_to_idx_dict[chain_id]
def idx_to_chain_id(self, idx):
return self._chain_ids[idx]
def __getitem__(self, idx):
name = self.mapping[str(idx)]
name = self.idx_to_chain_id(idx)
alignment_dir = os.path.join(self.alignment_dir, name)
_alignment_index = None
if(self._alignment_index is not None):
alignment_dir = self.alignment_dir
_alignment_index = self._alignment_index[name]
if(self.mode == 'train' or self.mode == 'eval'):
spl = name.rsplit('_', 1)
if(len(spl) == 2):
......@@ -164,11 +175,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
path = os.path.join(self.data_dir, file_id)
if(os.path.exists(path + ".cif")):
data = self._parse_mmcif(
path + ".cif", file_id, chain_id, alignment_dir
path + ".cif", file_id, chain_id, alignment_dir, _alignment_index,
)
elif(os.path.exists(path + ".core")):
data = self.data_pipeline.process_core(
path + ".core", alignment_dir
path + ".core", alignment_dir, _alignment_index,
)
elif(os.path.exists(path + ".pdb")):
data = self.data_pipeline.process_pdb(
......@@ -176,6 +187,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
alignment_dir=alignment_dir,
is_distillation=self.treat_pdb_as_distillation,
chain_id=chain_id,
_alignment_index=_alignment_index,
)
else:
raise ValueError("Invalid file type")
......@@ -184,6 +196,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
data = self.data_pipeline.process_fasta(
fasta_path=path,
alignment_dir=alignment_dir,
_alignment_index=_alignment_index,
)
if(self._output_raw):
......@@ -196,53 +209,150 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
return feats
def __len__(self):
return len(self.mapping.keys())
return len(self._chain_ids)
def looped_sequence(sequence):
while True:
for x in sequence:
yield x
def deterministic_train_filter(
prot_data_cache_entry: Any,
max_resolution: float = 9.,
max_single_aa_prop: float = 0.8,
) -> bool:
# Hard filters
resolution = prot_data_cache_entry.get("resolution", None)
if(resolution is not None and resolution > max_resolution):
return False
seq = prot_data_cache_entry["seq"]
counts = {}
for aa in seq:
counts.setdefault(aa, 0)
counts[aa] += 1
largest_aa_count = max(counts.values())
largest_single_aa_prop = largest_aa_count / len(seq)
if(largest_single_aa_prop > max_single_aa_prop):
return False
return True
def get_stochastic_train_filter_prob(
prot_data_cache_entry: Any,
) -> List[float]:
# Stochastic filters
probabilities = []
cluster_size = prot_data_cache_entry.get("cluster_size", None)
if(cluster_size is not None and cluster_size > 0):
probabilities.append(1 / cluster_size)
chain_length = len(prot_data_cache_entry["seq"])
probabilities.append((1 / 512) * (max(min(chain_length, 512), 256)))
class OpenFoldDataset(torch.utils.data.IterableDataset):
# Risk of underflow here?
out = 1
for p in probabilities:
out *= p
return out
class OpenFoldDataset(torch.utils.data.Dataset):
"""
The Dataset is written to accommodate the requirement that proteins are
sampled from the distillation set with some probability p
and from the PDB set with probability (1 - p). Proteins are sampled
from both sets without replacement, and as soon as either set is
emptied, it is refilled. The Dataset therefore has an arbitrary length.
Nevertheless, for compatibility with various PyTorch Lightning
functionalities, it is possible to specify an epoch length. This length
has no effect on the output of the Dataset.
Implements the stochastic filters applied during AlphaFold's training.
Because samples are selected from constituent datasets randomly, the
length of an OpenFoldFilteredDataset is arbitrary. Samples are selected
and filtered once at initialization.
"""
def __init__(self,
datasets: Sequence[OpenFoldSingleDataset],
probabilities: Sequence[int],
epoch_len: int,
prot_data_cache_paths: List[str],
generator: torch.Generator = None,
_roll_at_init: bool = True,
):
self.datasets = datasets
self.samplers = [
looped_sequence(RandomSampler(d)) for d in datasets
]
self.probabilities = probabilities
self.epoch_len = epoch_len
self.generator = generator
self.prot_data_caches = []
for path in prot_data_cache_paths:
with open(path, "r") as fp:
self.prot_data_caches.append(json.load(fp))
self.distr = torch.distributions.categorical.Categorical(
probs=torch.tensor(probabilities),
def looped_shuffled_dataset_idx(dataset_len):
while True:
# Uniformly shuffle each dataset's indices
weights = [1. for _ in range(dataset_len)]
shuf = torch.multinomial(
torch.tensor(weights),
num_samples=dataset_len,
replacement=False,
generator=self.generator,
)
for idx in shuf:
yield idx
def looped_samples(dataset_idx):
max_cache_len = int(epoch_len * probabilities[dataset_idx])
dataset = self.datasets[dataset_idx]
idx_iter = looped_shuffled_dataset_idx(len(dataset))
prot_data_cache = self.prot_data_caches[dataset_idx]
while True:
weights = []
idx = []
for _ in range(max_cache_len):
candidate_idx = next(idx_iter)
chain_id = dataset.idx_to_chain_id(candidate_idx)
prot_data_cache_entry = prot_data_cache[chain_id]
if(not deterministic_train_filter(prot_data_cache_entry)):
continue
p = get_stochastic_train_filter_prob(
prot_data_cache_entry,
)
weights.append([1. - p, p])
idx.append(candidate_idx)
def __iter__(self):
return self
samples = torch.multinomial(
torch.tensor(weights),
num_samples=1,
generator=self.generator,
)
samples = samples.squeeze()
def __next__(self):
dataset_idx = self.distr.sample()
sampler = self.samplers[dataset_idx]
element_idx = next(sampler)
return self.datasets[dataset_idx][element_idx]
cache = [i for i, s in zip(idx, samples) if s]
for datapoint_idx in cache:
yield datapoint_idx
self._samples = [looped_samples(i) for i in range(len(self.datasets))]
if(_roll_at_init):
self.reroll()
def __getitem__(self, idx):
dataset_idx, datapoint_idx = self.datapoints[idx]
return self.datasets[dataset_idx][datapoint_idx]
def __len__(self):
return self.epoch_len
def reroll(self):
dataset_choices = torch.multinomial(
torch.tensor(self.probabilities),
num_samples=self.epoch_len,
replacement=True,
generator=self.generator,
)
self.datapoints = []
for dataset_idx in dataset_choices:
samples = self._samples[dataset_idx]
datapoint_idx = next(samples)
self.datapoints.append((dataset_idx, datapoint_idx))
class OpenFoldBatchCollator:
def __init__(self, config, stage="train"):
......@@ -361,8 +471,10 @@ class OpenFoldDataModule(pl.LightningDataModule):
max_template_date: str,
train_data_dir: Optional[str] = None,
train_alignment_dir: Optional[str] = None,
train_prot_data_cache_path: Optional[str] = None,
distillation_data_dir: Optional[str] = None,
distillation_alignment_dir: Optional[str] = None,
distillation_prot_data_cache_path: Optional[str] = None,
val_data_dir: Optional[str] = None,
val_alignment_dir: Optional[str] = None,
predict_data_dir: Optional[str] = None,
......@@ -373,6 +485,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None,
batch_seed: Optional[int] = None,
train_epoch_len: int = 50000,
_alignment_index_path: Optional[str] = None,
**kwargs
):
super(OpenFoldDataModule, self).__init__()
......@@ -382,8 +496,12 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.max_template_date = max_template_date
self.train_data_dir = train_data_dir
self.train_alignment_dir = train_alignment_dir
self.train_prot_data_cache_path = train_prot_data_cache_path
self.distillation_data_dir = distillation_data_dir
self.distillation_alignment_dir = distillation_alignment_dir
self.distillation_prot_data_cache_path = (
distillation_prot_data_cache_path
)
self.val_data_dir = val_data_dir
self.val_alignment_dir = val_alignment_dir
self.predict_data_dir = predict_data_dir
......@@ -396,6 +514,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
)
self.obsolete_pdbs_file_path = obsolete_pdbs_file_path
self.batch_seed = batch_seed
self.train_epoch_len = train_epoch_len
if(self.train_data_dir is None and self.predict_data_dir is None):
raise ValueError(
......@@ -405,11 +524,11 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.training_mode = self.train_data_dir is not None
if(self.training_mode and self.train_alignment_dir is None):
if(self.training_mode and train_alignment_dir is None):
raise ValueError(
'In training mode, train_alignment_dir must be specified'
)
elif(not self.training_mode and self.predict_alingment_dir is None):
elif(not self.training_mode and predict_alignment_dir is None):
raise ValueError(
'In inference mode, predict_alignment_dir must be specified'
)
......@@ -419,10 +538,13 @@ class OpenFoldDataModule(pl.LightningDataModule):
'be specified as well'
)
def setup(self, stage: Optional[str] = None):
if(stage is None):
stage = "train"
# An ad-hoc measure for our particular filesystem restrictions
self._alignment_index = None
if(_alignment_index_path is not None):
with open(_alignment_index_path, "r") as fp:
self._alignment_index = json.load(fp)
def setup(self):
# Most of the arguments are the same for the three datasets
dataset_gen = partial(OpenFoldSingleDataset,
template_mmcif_dir=self.template_mmcif_dir,
......@@ -436,7 +558,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
)
if(self.training_mode):
self.train_dataset = dataset_gen(
train_dataset = dataset_gen(
data_dir=self.train_data_dir,
alignment_dir=self.train_alignment_dir,
mapping_path=self.train_mapping_path,
......@@ -446,8 +568,10 @@ class OpenFoldDataModule(pl.LightningDataModule):
treat_pdb_as_distillation=False,
mode="train",
_output_raw=True,
_alignment_index=self._alignment_index,
)
distillation_dataset = None
if(self.distillation_data_dir is not None):
distillation_dataset = dataset_gen(
data_dir=self.distillation_data_dir,
......@@ -460,12 +584,28 @@ class OpenFoldDataModule(pl.LightningDataModule):
)
d_prob = self.config.train.distillation_prob
if(distillation_dataset is not None):
datasets = [train_dataset, distillation_dataset]
d_prob = self.config.train.distillation_prob
probabilities = [1 - d_prob, d_prob]
prot_data_cache_paths = [
self.train_prot_data_cache_path,
self.distillation_prot_data_cache_path,
]
else:
datasets = [train_dataset]
probabilities = [1.]
prot_data_cache_paths = [
self.train_prot_data_cache_path,
]
self.train_dataset = OpenFoldDataset(
datasets=[self.train_dataset, distillation_dataset],
probabilities=[1 - d_prob, d_prob],
epoch_len=(
self.train_dataset.len() + distillation_dataset.len()
),
datasets=datasets,
probabilities=probabilities,
epoch_len=self.train_epoch_len,
prot_data_cache_paths=prot_data_cache_paths,
_roll_at_init=False,
)
if(self.val_data_dir is not None):
......@@ -496,6 +636,9 @@ class OpenFoldDataModule(pl.LightningDataModule):
dataset = None
if(stage == "train"):
dataset = self.train_dataset
# Filter the dataset, if necessary
dataset.reroll()
elif(stage == "eval"):
dataset = self.eval_dataset
elif(stage == "predict"):
......
......@@ -422,8 +422,38 @@ class DataPipeline:
def _parse_msa_data(
self,
alignment_dir: str,
_alignment_index: Optional[Any] = None,
) -> Mapping[str, Any]:
msa_data = {}
if(_alignment_index is not None):
fp = open(os.path.join(alignment_dir, _alignment_index["db"]), "rb")
def read_msa(start, size):
fp.seek(start)
msa = fp.read(size).decode("utf-8")
return msa
for (name, start, size) in _alignment_index["files"]:
ext = os.path.splitext(name)[-1]
if(ext == ".a3m"):
msa, deletion_matrix = parsers.parse_a3m(
read_msa(start, size)
)
data = {"msa": msa, "deletion_matrix": deletion_matrix}
elif(ext == ".sto"):
msa, deletion_matrix, _ = parsers.parse_stockholm(
read_msa(start, size)
)
data = {"msa": msa, "deletion_matrix": deletion_matrix}
else:
continue
msa_data[name] = data
fp.close()
else:
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]
......@@ -448,8 +478,25 @@ class DataPipeline:
def _parse_template_hits(
self,
alignment_dir: str,
_alignment_index: Optional[Any] = None
) -> Mapping[str, Any]:
all_hits = {}
if(_alignment_index is not None):
fp = open(os.path.join(alignment_dir, _alignment_index["db"]), 'rb')
def read_template(start, size):
fp.seek(start)
return fp.read(size).decode("utf-8")
for (name, start, size) in _alignment_index["files"]:
ext = os.path.splitext(name)[-1]
if(ext == ".hhr"):
hits = parsers.parse_hhr(read_template(start, size))
all_hits[name] = hits
fp.close()
else:
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]
......@@ -465,8 +512,9 @@ class DataPipeline:
self,
alignment_dir: str,
input_sequence: Optional[str] = None,
_alignment_index: Optional[str] = None
) -> Mapping[str, Any]:
msa_data = self._parse_msa_data(alignment_dir)
msa_data = self._parse_msa_data(alignment_dir, _alignment_index)
if(len(msa_data) == 0):
if(input_sequence is None):
......@@ -496,6 +544,7 @@ class DataPipeline:
self,
fasta_path: str,
alignment_dir: str,
_alignment_index: Optional[str] = None,
) -> FeatureDict:
"""Assembles features for a single sequence in a FASTA file"""
with open(fasta_path) as f:
......@@ -509,7 +558,7 @@ class DataPipeline:
input_description = input_descs[0]
num_res = len(input_sequence)
hits = self._parse_template_hits(alignment_dir)
hits = self._parse_template_hits(alignment_dir, _alignment_index)
template_features = make_template_features(
input_sequence,
hits,
......@@ -522,7 +571,7 @@ class DataPipeline:
num_res=num_res,
)
msa_features = self._process_msa_feats(alignment_dir, input_sequence)
msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index)
return {
**sequence_features,
......@@ -535,6 +584,7 @@ class DataPipeline:
mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path
alignment_dir: str,
chain_id: Optional[str] = None,
_alignment_index: Optional[str] = None,
) -> FeatureDict:
"""
Assembles features for a specific chain in an mmCIF object.
......@@ -552,7 +602,7 @@ class DataPipeline:
mmcif_feats = make_mmcif_features(mmcif, chain_id)
input_sequence = mmcif.chain_to_seqres[chain_id]
hits = self._parse_template_hits(alignment_dir)
hits = self._parse_template_hits(alignment_dir, _alignment_index)
template_features = make_template_features(
input_sequence,
hits,
......@@ -560,7 +610,7 @@ class DataPipeline:
query_release_date=to_date(mmcif.header["release_date"])
)
msa_features = self._process_msa_feats(alignment_dir, input_sequence)
msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index)
return {**mmcif_feats, **template_features, **msa_features}
......@@ -570,6 +620,7 @@ class DataPipeline:
alignment_dir: str,
is_distillation: bool = True,
chain_id: Optional[str] = None,
_alignment_index: Optional[str] = None,
) -> FeatureDict:
"""
Assembles features for a protein in a PDB file.
......@@ -586,14 +637,14 @@ class DataPipeline:
is_distillation
)
hits = self._parse_template_hits(alignment_dir)
hits = self._parse_template_hits(alignment_dir, _alignment_index)
template_features = make_template_features(
input_sequence,
hits,
self.template_featurizer,
)
msa_features = self._process_msa_feats(alignment_dir, input_sequence)
msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index)
return {**pdb_feats, **template_features, **msa_features}
......@@ -601,6 +652,7 @@ class DataPipeline:
self,
core_path: str,
alignment_dir: str,
_alignment_index: Optional[str] = None,
) -> FeatureDict:
"""
Assembles features for a protein in a ProteinNet .core file.
......@@ -613,7 +665,7 @@ class DataPipeline:
description = os.path.splitext(os.path.basename(core_path))[0].upper()
core_feats = make_protein_features(protein_object, description)
hits = self._parse_template_hits(alignment_dir)
hits = self._parse_template_hits(alignment_dir, _alignment_index)
template_features = make_template_features(
input_sequence,
hits,
......
......@@ -17,7 +17,7 @@ import torch
import torch.nn as nn
from typing import Tuple
from openfold.model.primitives import Linear
from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.tensor_utils import one_hot
......@@ -168,8 +168,8 @@ class RecyclingEmbedder(nn.Module):
self.bins = None
self.linear = Linear(self.no_bins, self.c_z)
self.layer_norm_m = nn.LayerNorm(self.c_m)
self.layer_norm_z = nn.LayerNorm(self.c_z)
self.layer_norm_m = LayerNorm(self.c_m)
self.layer_norm_z = LayerNorm(self.c_z)
def forward(
self,
......
......@@ -13,12 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
import torch.nn as nn
from typing import Tuple, Optional
from functools import partial
from openfold.model.primitives import Linear
from openfold.model.primitives import Linear, LayerNorm
from openfold.model.dropout import DropoutRowwise, DropoutColumnwise
from openfold.model.msa import (
MSARowAttentionWithPairBias,
......@@ -35,7 +36,7 @@ from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationOutgoing,
TriangleMultiplicationIncoming,
)
from openfold.utils.checkpointing import checkpoint_blocks
from openfold.utils.checkpointing import checkpoint_blocks, get_checkpoint_fn
from openfold.utils.tensor_utils import chunk_layer
......@@ -60,7 +61,7 @@ class MSATransition(nn.Module):
self.c_m = c_m
self.n = n
self.layer_norm = nn.LayerNorm(self.c_m)
self.layer_norm = LayerNorm(self.c_m)
self.linear_1 = Linear(self.c_m, self.n * self.c_m, init="relu")
self.relu = nn.ReLU()
self.linear_2 = Linear(self.n * self.c_m, self.c_m, init="final")
......@@ -117,51 +118,23 @@ class MSATransition(nn.Module):
return m
class EvoformerBlock(nn.Module):
class EvoformerBlockCore(nn.Module):
def __init__(
self,
c_m: int,
c_z: int,
c_hidden_msa_att: int,
c_hidden_opm: int,
c_hidden_mul: int,
c_hidden_pair_att: int,
no_heads_msa: int,
no_heads_pair: int,
transition_n: int,
msa_dropout: float,
pair_dropout: float,
inf: float,
eps: float,
_is_extra_msa_stack: bool = False,
):
super(EvoformerBlock, self).__init__()
self._is_extra_msa_stack = _is_extra_msa_stack
self.msa_att_row = MSARowAttentionWithPairBias(
c_m=c_m,
c_z=c_z,
c_hidden=c_hidden_msa_att,
no_heads=no_heads_msa,
inf=inf,
)
if _is_extra_msa_stack:
self.msa_att_col = MSAColumnGlobalAttention(
c_in=c_m,
c_hidden=c_hidden_msa_att,
no_heads=no_heads_msa,
inf=inf,
eps=eps,
)
else:
self.msa_att_col = MSAColumnAttention(
c_m,
c_hidden_msa_att,
no_heads_msa,
inf=inf,
)
super(EvoformerBlockCore, self).__init__()
self.msa_transition = MSATransition(
c_m=c_m,
......@@ -201,7 +174,6 @@ class EvoformerBlock(nn.Module):
transition_n,
)
self.msa_dropout_layer = DropoutRowwise(msa_dropout)
self.ps_dropout_row_layer = DropoutRowwise(pair_dropout)
self.ps_dropout_col_layer = DropoutColumnwise(pair_dropout)
......@@ -220,10 +192,6 @@ class EvoformerBlock(nn.Module):
msa_trans_mask = msa_mask if _mask_trans else None
pair_trans_mask = pair_mask if _mask_trans else None
m = m + self.msa_dropout_layer(
self.msa_att_row(m, z=z, mask=msa_mask, chunk_size=chunk_size)
)
m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size)
m = m + self.msa_transition(
m, mask=msa_trans_mask, chunk_size=chunk_size
)
......@@ -245,6 +213,175 @@ class EvoformerBlock(nn.Module):
return m, z
class EvoformerBlock(nn.Module):
def __init__(self,
c_m: int,
c_z: int,
c_hidden_msa_att: int,
c_hidden_opm: int,
c_hidden_mul: int,
c_hidden_pair_att: int,
no_heads_msa: int,
no_heads_pair: int,
transition_n: int,
msa_dropout: float,
pair_dropout: float,
inf: float,
eps: float,
):
super(EvoformerBlock, self).__init__()
self.msa_att_row = MSARowAttentionWithPairBias(
c_m=c_m,
c_z=c_z,
c_hidden=c_hidden_msa_att,
no_heads=no_heads_msa,
inf=inf,
)
self.msa_att_col = MSAColumnAttention(
c_m,
c_hidden_msa_att,
no_heads_msa,
inf=inf,
)
self.msa_dropout_layer = DropoutRowwise(msa_dropout)
self.core = EvoformerBlockCore(
c_m=c_m,
c_z=c_z,
c_hidden_opm=c_hidden_opm,
c_hidden_mul=c_hidden_mul,
c_hidden_pair_att=c_hidden_pair_att,
no_heads_msa=no_heads_msa,
no_heads_pair=no_heads_pair,
transition_n=transition_n,
pair_dropout=pair_dropout,
inf=inf,
eps=eps,
)
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
m = m + self.msa_dropout_layer(
self.msa_att_row(m, z=z, mask=msa_mask, chunk_size=chunk_size)
)
m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size)
m, z = self.core(
m,
z,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
return m, z
class ExtraMSABlock(nn.Module):
"""
Almost identical to the standard EvoformerBlock, except in that the
ExtraMSABlock uses GlobalAttention for MSA column attention and
requires more fine-grained control over checkpointing. Separated from
its twin to preserve the TorchScript-ability of the latter.
"""
def __init__(self,
c_m: int,
c_z: int,
c_hidden_msa_att: int,
c_hidden_opm: int,
c_hidden_mul: int,
c_hidden_pair_att: int,
no_heads_msa: int,
no_heads_pair: int,
transition_n: int,
msa_dropout: float,
pair_dropout: float,
inf: float,
eps: float,
ckpt: bool,
):
super(ExtraMSABlock, self).__init__()
self.ckpt = ckpt
self.msa_att_row = MSARowAttentionWithPairBias(
c_m=c_m,
c_z=c_z,
c_hidden=c_hidden_msa_att,
no_heads=no_heads_msa,
inf=inf,
)
self.msa_att_col = MSAColumnGlobalAttention(
c_in=c_m,
c_hidden=c_hidden_msa_att,
no_heads=no_heads_msa,
inf=inf,
eps=eps,
)
self.msa_dropout_layer = DropoutRowwise(msa_dropout)
self.core = EvoformerBlockCore(
c_m=c_m,
c_z=c_z,
c_hidden_opm=c_hidden_opm,
c_hidden_mul=c_hidden_mul,
c_hidden_pair_att=c_hidden_pair_att,
no_heads_msa=no_heads_msa,
no_heads_pair=no_heads_pair,
transition_n=transition_n,
pair_dropout=pair_dropout,
inf=inf,
eps=eps,
)
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
_chunk_logits: Optional[int] = 1024,
) -> Tuple[torch.Tensor, torch.Tensor]:
m = m + self.msa_dropout_layer(
self.msa_att_row(
m.clone(),
z=z.clone(),
mask=msa_mask,
chunk_size=chunk_size,
_chunk_logits=_chunk_logits,
_checkpoint_chunks=
self.ckpt if torch.is_grad_enabled() else False,
)
)
def fn(m, z):
m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size)
m, z = self.core(
m, z, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size
)
return m, z
if(torch.is_grad_enabled() and self.ckpt):
checkpoint_fn = get_checkpoint_fn()
m, z = checkpoint_fn(fn, m, z)
else:
m, z = fn(m, z)
return m, z
class EvoformerStack(nn.Module):
"""
Main Evoformer trunk.
......@@ -271,7 +408,6 @@ class EvoformerStack(nn.Module):
inf: float,
eps: float,
clear_cache_between_blocks: bool = False,
_is_extra_msa_stack: bool = False,
**kwargs,
):
"""
......@@ -313,7 +449,6 @@ class EvoformerStack(nn.Module):
self.blocks_per_ckpt = blocks_per_ckpt
self.clear_cache_between_blocks = clear_cache_between_blocks
self._is_extra_msa_stack = _is_extra_msa_stack
self.blocks = nn.ModuleList()
......@@ -332,15 +467,12 @@ class EvoformerStack(nn.Module):
pair_dropout=pair_dropout,
inf=inf,
eps=eps,
_is_extra_msa_stack=_is_extra_msa_stack,
)
self.blocks.append(block)
if not self._is_extra_msa_stack:
self.linear = Linear(c_m, c_s)
def forward(
self,
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
......@@ -390,12 +522,7 @@ class EvoformerStack(nn.Module):
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
)
s = None
if not self._is_extra_msa_stack:
seq_dim = -3
index = torch.tensor([0], device=m.device)
s = self.linear(torch.index_select(m, dim=seq_dim, index=index))
s = s.squeeze(seq_dim)
s = self.linear(m[..., 0, :, :])
return m, z, s
......@@ -405,8 +532,7 @@ class ExtraMSAStack(nn.Module):
Implements Algorithm 18.
"""
def __init__(
self,
def __init__(self,
c_m: int,
c_z: int,
c_hidden_msa_att: int,
......@@ -419,38 +545,38 @@ class ExtraMSAStack(nn.Module):
transition_n: int,
msa_dropout: float,
pair_dropout: float,
blocks_per_ckpt: int,
inf: float,
eps: float,
ckpt: bool,
clear_cache_between_blocks: bool = False,
**kwargs,
):
super(ExtraMSAStack, self).__init__()
c_s = None
self.stack = EvoformerStack(
self.clear_cache_between_blocks = clear_cache_between_blocks
self.blocks = nn.ModuleList()
for _ in range(no_blocks):
block = ExtraMSABlock(
c_m=c_m,
c_z=c_z,
c_hidden_msa_att=c_hidden_msa_att,
c_hidden_opm=c_hidden_opm,
c_hidden_mul=c_hidden_mul,
c_hidden_pair_att=c_hidden_pair_att,
c_s=c_s,
no_heads_msa=no_heads_msa,
no_heads_pair=no_heads_pair,
no_blocks=no_blocks,
transition_n=transition_n,
msa_dropout=msa_dropout,
pair_dropout=pair_dropout,
blocks_per_ckpt=blocks_per_ckpt,
inf=inf,
eps=eps,
clear_cache_between_blocks=clear_cache_between_blocks,
_is_extra_msa_stack=True,
ckpt=False,
)
self.blocks.append(block)
def forward(
self,
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
chunk_size: int,
......@@ -471,12 +597,27 @@ class ExtraMSAStack(nn.Module):
Returns:
[*, N_res, N_res, C_z] pair update
"""
_, z, _ = self.stack(
m,
z,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
#checkpoint_fn = get_checkpoint_fn()
#blocks = [
# partial(b, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size, _chunk_logits=None) for b in self.blocks
#]
#def dodo(b, *args):
# torch.cuda.empty_cache()
# return b(*args)
#blocks = [partial(dodo, b) for b in blocks]
#for b in blocks:
# if(torch.is_grad_enabled()):
# m, z = checkpoint_fn(b, *(m, z))
# else:
# m, z = b(m, z)
for b in self.blocks:
m, z = b(m, z, msa_mask, pair_mask, chunk_size=chunk_size)
if(self.clear_cache_between_blocks):
torch.cuda.empty_cache()
return z
......@@ -16,7 +16,7 @@
import torch
import torch.nn as nn
from openfold.model.primitives import Linear
from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.loss import (
compute_plddt,
compute_tm,
......@@ -96,7 +96,7 @@ class PerResidueLDDTCaPredictor(nn.Module):
self.c_in = c_in
self.c_hidden = c_hidden
self.layer_norm = nn.LayerNorm(self.c_in)
self.layer_norm = LayerNorm(self.c_in)
self.linear_1 = Linear(self.c_in, self.c_hidden, init="relu")
self.linear_2 = Linear(self.c_hidden, self.c_hidden, init="relu")
......
......@@ -134,7 +134,7 @@ class AlphaFold(nn.Module):
inf=self.config.template.inf,
eps=self.config.template.eps,
**self.config.template.distogram,
)
).to(z.dtype)
t = self.template_pair_embedder(t)
single_template_embeds.update({"pair": t})
......@@ -149,7 +149,7 @@ class AlphaFold(nn.Module):
# [*, S_t, N, N, C_z]
t = self.template_pair_stack(
template_embeds["pair"],
pair_mask.unsqueeze(-3),
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
_mask_trans=self.config._mask_trans,
)
......@@ -158,7 +158,7 @@ class AlphaFold(nn.Module):
t = self.template_pointwise_att(
t,
z,
template_mask=batch["template_mask"],
template_mask=batch["template_mask"].to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
)
t = t * (torch.sum(batch["template_mask"]) > 0)
......@@ -175,6 +175,12 @@ class AlphaFold(nn.Module):
# Primary output dictionary
outputs = {}
# This needs to be done manually for DeepSpeed's sake
dtype = next(self.parameters()).dtype
for k in feats:
if(feats[k].dtype == torch.float32):
feats[k] = feats[k].to(dtype=dtype)
# Grab some data about the input
batch_dims = feats["target_feat"].shape[:-2]
no_batch_dims = len(batch_dims)
......@@ -217,7 +223,9 @@ class AlphaFold(nn.Module):
requires_grad=False,
)
x_prev = pseudo_beta_fn(feats["aatype"], x_prev, None)
x_prev = pseudo_beta_fn(
feats["aatype"], x_prev, None
).to(dtype=z.dtype)
# m_1_prev_emb: [*, N, C_m]
# z_prev_emb: [*, N, N, C_z]
......@@ -246,15 +254,13 @@ class AlphaFold(nn.Module):
# Embed the templates + merge with MSA/pair embeddings
if self.config.template.enabled:
template_mask = feats["template_mask"]
if(torch.any(template_mask)):
template_feats = {
k: v for k, v in feats.items() if k.startswith("template_")
}
template_embeds = self.embed_templates(
template_feats,
z,
pair_mask,
pair_mask.to(dtype=z.dtype),
no_batch_dims,
)
......@@ -284,9 +290,9 @@ class AlphaFold(nn.Module):
z = self.extra_msa_stack(
a,
z,
msa_mask=feats["extra_msa_mask"],
msa_mask=feats["extra_msa_mask"].to(dtype=a.dtype),
chunk_size=self.globals.chunk_size,
pair_mask=pair_mask,
pair_mask=pair_mask.to(dtype=z.dtype),
_mask_trans=self.config._mask_trans,
)
......@@ -297,8 +303,8 @@ class AlphaFold(nn.Module):
m, z, s = self.evoformer(
m,
z,
msa_mask=msa_mask,
pair_mask=pair_mask,
msa_mask=msa_mask.to(dtype=m.dtype),
pair_mask=pair_mask.to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
_mask_trans=self.config._mask_trans,
)
......@@ -312,7 +318,7 @@ class AlphaFold(nn.Module):
s,
z,
feats["aatype"],
mask=feats["seq_mask"],
mask=feats["seq_mask"].to(dtype=s.dtype),
)
outputs["final_atom_positions"] = atom14_to_atom37(
outputs["sm"]["positions"][-1], feats
......@@ -336,7 +342,9 @@ class AlphaFold(nn.Module):
def _disable_activation_checkpointing(self):
self.template_pair_stack.blocks_per_ckpt = None
self.evoformer.blocks_per_ckpt = None
self.extra_msa_stack.stack.blocks_per_ckpt = None
for b in self.extra_msa_stack.blocks:
b.ckpt = False
def _enable_activation_checkpointing(self):
self.template_pair_stack.blocks_per_ckpt = (
......@@ -345,9 +353,9 @@ class AlphaFold(nn.Module):
self.evoformer.blocks_per_ckpt = (
self.config.evoformer_stack.blocks_per_ckpt
)
self.extra_msa_stack.stack.blocks_per_ckpt = (
self.config.extra_msa.extra_msa_stack.blocks_per_ckpt
)
for b in self.extra_msa_stack.blocks:
b.ckpt = self.config.extra_msa.extra_msa_stack.ckpt
def forward(self, batch):
"""
......
......@@ -16,9 +16,16 @@
import math
import torch
import torch.nn as nn
from typing import Optional, List
from openfold.model.primitives import Linear, Attention, GlobalAttention
from typing import Optional, List, Tuple
from openfold.model.primitives import (
Linear,
LayerNorm,
Attention,
GlobalAttention,
_attention_chunked_trainable,
)
from openfold.utils.checkpointing import get_checkpoint_fn
from openfold.utils.tensor_utils import (
chunk_layer,
permute_final_dims,
......@@ -61,12 +68,12 @@ class MSAAttention(nn.Module):
self.c_z = c_z
self.inf = inf
self.layer_norm_m = nn.LayerNorm(self.c_in)
self.layer_norm_m = LayerNorm(self.c_in)
self.layer_norm_z = None
self.linear_z = None
if self.pair_bias:
self.layer_norm_z = nn.LayerNorm(self.c_z)
self.layer_norm_z = LayerNorm(self.c_z)
self.linear_z = Linear(
self.c_z, self.no_heads, bias=False, init="normal"
)
......@@ -83,32 +90,16 @@ class MSAAttention(nn.Module):
) -> torch.Tensor:
return chunk_layer(
self.mha,
{"q_x": m, "k_x": m, "v_x": m, "biases": biases},
{"q_x": m, "kv_x": m, "biases": biases},
chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2]),
)
def forward(self,
def _prep_inputs(self,
m: torch.Tensor,
z: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
) -> torch.Tensor:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding. Required only if
pair_bias is True
mask:
[*, N_seq, N_res] MSA mask
chunk_size:
Size of chunks into which the inputs are split along their
batch dimensions. A low value decreases memory overhead at the
cost of slower execution. Chunking is not performed by default.
"""
z: Optional[torch.Tensor],
mask: Optional[torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# [*, N_seq, N_res, C_m]
m = self.layer_norm_m(m)
......@@ -120,16 +111,14 @@ class MSAAttention(nn.Module):
)
# [*, N_seq, 1, 1, N_res]
bias = (self.inf * (mask - 1))[..., :, None, None, :]
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
# This step simply returns a larger view of the bias, and does not
# consume additional memory.
# [*, N_seq, no_heads, N_res, N_res]
bias = bias.expand(
((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1)
)
biases = [bias]
#bias = bias.expand(
# ((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1)
#)
if (self.pair_bias and
z is not None and # For the
......@@ -145,12 +134,91 @@ class MSAAttention(nn.Module):
# [*, 1, no_heads, N_res, N_res]
z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4)
return m, mask_bias, z
@torch.jit.ignore
def _chunked_msa_attn(self,
m: torch.Tensor,
z: Optional[torch.Tensor],
mask: Optional[torch.Tensor],
chunk_logits: int,
checkpoint: bool,
) -> torch.Tensor:
MSA_DIM = -4
def _get_qkv(m, z):
m, mask_bias, z = self._prep_inputs(m, z, mask)
q, k, v = self.mha._prep_qkv(m, m)
return m, q, k, v, mask_bias, z
checkpoint_fn = get_checkpoint_fn()
if(torch.is_grad_enabled() and checkpoint):
m, q, k, v, mask_bias, z = checkpoint_fn(_get_qkv, m, z)
else:
m, q, k, v, mask_bias, z = _get_qkv(m, z)
o = _attention_chunked_trainable(
query=q,
key=k,
value=v,
biases=[mask_bias, z],
chunk_size=chunk_logits,
chunk_dim=MSA_DIM,
checkpoint=checkpoint,
)
if(torch.is_grad_enabled() and checkpoint):
# Storing an additional m here is far from ideal
m = checkpoint_fn(self.mha._wrap_up, o, m)
else:
m = self.mha._wrap_up(o, m)
return m
def forward(self,
m: torch.Tensor,
z: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
_chunk_logits: Optional[int] = None,
_checkpoint_chunks: Optional[bool] = None,
) -> torch.Tensor:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding. Required only if
pair_bias is True
mask:
[*, N_seq, N_res] MSA mask
chunk_size:
Size of chunks into which the inputs are split along their
batch dimensions. A low value decreases memory overhead at the
cost of slower execution. Chunking is not performed by default.
"""
if(_chunk_logits is not None):
return self._chunked_msa_attn(
m=m, z=z, mask=mask,
chunk_logits=_chunk_logits, checkpoint=_checkpoint_chunks
)
m, mask_bias, z = self._prep_inputs(m, z, mask)
biases = [mask_bias]
if(z is not None):
biases.append(z)
if chunk_size is not None:
m = self._chunk(m, biases, chunk_size)
else:
m = self.mha(q_x=m, k_x=m, v_x=m, biases=biases)
m = self.mha(
q_x=m,
kv_x=m,
biases=biases
)
return m
......
......@@ -17,7 +17,7 @@ from typing import Optional
import torch
import torch.nn as nn
from openfold.model.primitives import Linear
from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.tensor_utils import chunk_layer
......@@ -40,7 +40,7 @@ class PairTransition(nn.Module):
self.c_z = c_z
self.n = n
self.layer_norm = nn.LayerNorm(self.c_z)
self.layer_norm = LayerNorm(self.c_z)
self.linear_1 = Linear(self.c_z, self.n * self.c_z, init="relu")
self.relu = nn.ReLU()
self.linear_2 = Linear(self.n * self.c_z, c_z, init="final")
......
......@@ -13,14 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import math
from typing import Optional, Callable, List, Tuple, Sequence
import numpy as np
import deepspeed
import torch
import torch.nn as nn
from scipy.stats import truncnorm
from openfold.utils.checkpointing import get_checkpoint_fn
from openfold.utils.tensor_utils import (
permute_final_dims,
flatten_final_dims,
......@@ -164,6 +167,135 @@ class Linear(nn.Linear):
raise ValueError("Invalid init string.")
class LayerNorm(nn.Module):
def __init__(self, c_in, eps=1e-5):
super(LayerNorm, self).__init__()
self.c_in = (c_in,)
self.eps = eps
self.weight = nn.Parameter(torch.ones(c_in))
self.bias = nn.Parameter(torch.zeros(c_in))
def forward(self, x):
d = x.dtype
if(d is torch.bfloat16 and not deepspeed.utils.is_initialized()):
with torch.cuda.amp.autocast(enabled=False):
out = nn.functional.layer_norm(
x,
self.c_in,
self.weight.to(dtype=d),
self.bias.to(dtype=d),
self.eps
)
else:
out = nn.functional.layer_norm(
x,
self.c_in,
self.weight,
self.bias,
self.eps,
)
return out
@torch.jit.ignore
def softmax(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
"""
Softmax, but without automatic casting to fp32 when the input is of
type bfloat16
"""
d = t.dtype
if(d is torch.bfloat16 and not deepspeed.utils.is_initialized()):
with torch.cuda.amp.autocast(enabled=False):
s = torch.nn.functional.softmax(t, dim=dim)
else:
s = torch.nn.functional.softmax(t, dim=dim)
return s
#@torch.jit.script
def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, biases: List[torch.Tensor]) -> torch.Tensor:
# [*, H, Q, C_hidden]
query = permute_final_dims(query, (1, 0, 2))
# [*, H, C_hidden, K]
key = permute_final_dims(key, (1, 2, 0))
# [*, H, V, C_hidden]
value = permute_final_dims(value, (1, 0, 2))
# [*, H, Q, K]
a = torch.matmul(query, key)
for b in biases:
a += b
a = softmax(a, -1)
# [*, H, Q, C_hidden]
a = torch.matmul(a, value)
# [*, Q, H, C_hidden]
a = a.transpose(-2, -3)
return a
@torch.jit.ignore
def _attention_chunked_trainable(
query, key, value, biases, chunk_size, chunk_dim, checkpoint,
):
if(checkpoint and len(biases) > 2):
raise ValueError(
"Checkpointed version permits only permits two bias terms"
)
def _checkpointable_attention(q, k, v, b1, b2):
bs = [b for b in [b1, b2] if b is not None]
return _attention(q, k, v, bs)
o_chunks = []
checkpoint_fn = get_checkpoint_fn()
count = query.shape[chunk_dim]
for start in range(0, count, chunk_size):
end = start + chunk_size
idx = [slice(None)] * len(query.shape)
idx[chunk_dim] = slice(start, end)
idx_tup = tuple(idx)
q_chunk = query[idx_tup]
k_chunk = key[idx_tup]
v_chunk = value[idx_tup]
def _slice_bias(b):
idx[chunk_dim] = (
slice(start, end) if b.shape[chunk_dim] != 1 else slice(None)
)
return b[tuple(idx)]
if(checkpoint):
bias_1_chunk, bias_2_chunk = [
_slice_bias(b) if b is not None else None
for b in (biases + [None, None])[:2]
]
o_chunk = checkpoint_fn(_checkpointable_attention,
q_chunk, k_chunk, v_chunk, bias_1_chunk, bias_2_chunk
)
else:
bias_chunks = [
_slice_bias(b) for b in biases
]
o_chunk = _attention(q_chunk, k_chunk, v_chunk, bias_chunks)
o_chunks.append(o_chunk)
o = torch.cat(o_chunks, dim=chunk_dim)
return o
class Attention(nn.Module):
"""
Standard multi-head attention using AlphaFold's default layer
......@@ -225,66 +357,34 @@ class Attention(nn.Module):
)
self.sigmoid = nn.Sigmoid()
self.softmax = nn.Softmax(dim=-1)
def forward(
self,
def _prep_qkv(self,
q_x: torch.Tensor,
k_x: torch.Tensor,
v_x: torch.Tensor,
biases: Optional[List[torch.Tensor]] = None,
) -> torch.Tensor:
"""
Args:
q_x:
[*, Q, C_q] query data
k_x:
[*, K, C_k] key data
v_x:
[*, V, C_v] value data
Returns
[*, Q, C_q] attention update
"""
kv_x: torch.Tensor
) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor
]:
# [*, Q/K/V, H * C_hidden]
q = self.linear_q(q_x)
k = self.linear_k(k_x)
v = self.linear_v(v_x)
k = self.linear_k(kv_x)
v = self.linear_v(kv_x)
# [*, Q/K, H, C_hidden]
q = q.view(q.shape[:-1] + (self.no_heads, -1))
k = k.view(k.shape[:-1] + (self.no_heads, -1))
v = v.view(v.shape[:-1] + (self.no_heads, -1))
# [*, H, Q, C_hidden]
q = permute_final_dims(q, (1, 0, 2))
q /= math.sqrt(self.c_hidden)
# [*, H, C_hidden, K]
k = permute_final_dims(k, (1, 2, 0))
# [*, H, Q, K]
a = torch.matmul(q, k)
del q, k
norm = 1 / math.sqrt(self.c_hidden) # [1]
a *= norm
if biases is not None:
for b in biases:
a += b
a = self.softmax(a)
# [*, H, V, C_hidden]
v = permute_final_dims(v, (1, 0, 2))
# [*, H, Q, C_hidden]
o = torch.matmul(a, v)
return q, k, v
# [*, Q, H, C_hidden]
o = o.transpose(-2, -3)
def _wrap_up(self,
o: torch.Tensor,
q_x: torch.Tensor
) -> torch.Tensor:
if(self.linear_g is not None):
g = self.sigmoid(self.linear_g(q_x))
# [*, Q, H, C_hidden]
g = g.view(g.shape[:-1] + (self.no_heads, -1))
o = o * g
......@@ -297,6 +397,56 @@ class Attention(nn.Module):
return o
def forward(
self,
q_x: torch.Tensor,
kv_x: torch.Tensor,
biases: Optional[List[torch.Tensor]] = None,
use_lma: bool = False,
q_chunk_size: Optional[int] = None,
kv_chunk_size: Optional[int] = None,
) -> torch.Tensor:
"""
Args:
q_x:
[*, Q, C_q] query data
kv_x:
[*, K, C_k] key data
biases:
List of biases that broadcast to [*, H, Q, K]
use_lma:
Whether to use low-memory attention
q_chunk_size:
Query chunk size (for LMA)
kv_chunk_size:
Key/Value chunk size (for LMA)
Returns
[*, Q, C_q] attention update
"""
if(biases is None):
biases = []
if(use_lma and (q_chunk_size is None or kv_chunk_size is None)):
raise ValueError(
"If use_lma is specified, q_chunk_size and kv_chunk_size must "
"be provided"
)
q, k, v = self._prep_qkv(q_x, kv_x)
if(use_lma):
biases = [
b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],))
for b in biases
]
o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size)
else:
o = _attention(q, k, v, biases)
o = self._wrap_up(o, q_x)
return o
class GlobalAttention(nn.Module):
def __init__(self, c_in, c_hidden, no_heads, inf, eps):
......@@ -322,7 +472,6 @@ class GlobalAttention(nn.Module):
self.linear_o = Linear(c_hidden * no_heads, c_in, init="final")
self.sigmoid = nn.Sigmoid()
self.softmax = nn.Softmax(dim=-1)
def forward(self, m: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
# [*, N_res, C_in]
......@@ -348,7 +497,7 @@ class GlobalAttention(nn.Module):
)
bias = (self.inf * (mask - 1))[..., :, None, :]
a += bias
a = self.softmax(a)
a = softmax(a)
# [*, N_res, H, C_hidden]
o = torch.matmul(
......@@ -374,14 +523,13 @@ class GlobalAttention(nn.Module):
return m
@torch.jit.script
def _lma(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
biases: List[torch.Tensor],
q_chunk_size: int,
kv_chunk_size: int
kv_chunk_size: int,
):
no_q, no_kv = q.shape[-3], k.shape[-3]
......@@ -389,7 +537,7 @@ def _lma(
o = q.new_zeros(q.shape)
for q_s in range(0, no_q, q_chunk_size):
q_chunk = q[..., q_s: q_s + q_chunk_size, :, :]
big_bias_chunks = [
large_bias_chunks = [
b[..., q_s: q_s + q_chunk_size, :] for b in biases
]
......@@ -400,11 +548,11 @@ def _lma(
k_chunk = k[..., kv_s: kv_s + kv_chunk_size, :, :]
v_chunk = v[..., kv_s: kv_s + kv_chunk_size, :, :]
small_bias_chunks = [
b[..., kv_s: kv_s + kv_chunk_size] for b in big_bias_chunks
b[..., kv_s: kv_s + kv_chunk_size] for b in large_bias_chunks
]
a = torch.einsum(
"...qhd,...khd->...hqk", q_chunk, k_chunk
"...qhd,...khd->...hqk", q_chunk, k_chunk,
)
for b in small_bias_chunks:
......@@ -412,11 +560,11 @@ def _lma(
a = a.transpose(-2, -3)
max_a = torch.max(a, dim=-1, keepdim=True)[0].detach()
max_a = torch.max(a, dim=-1, keepdim=True)[0]
exp_a = torch.exp(a - max_a)
exp_v = torch.einsum("...vhf,...qhv->...qhf", v_chunk, exp_a)
maxes.append(max_a.squeeze(-1))
maxes.append(max_a.detach().squeeze(-1))
weights.append(torch.sum(exp_a, dim=-1))
values.append(exp_v)
......@@ -437,111 +585,3 @@ def _lma(
o[..., q_s: q_s + q_chunk_size, :, :] = q_chunk_out
return o
class LowMemoryAttention(nn.Module):
"""
Standard multi-head attention using AlphaFold's default layer
initialization. Allows multiple bias vectors. Implements Rabe and Staats'
low-memory self-attention algorithm.
"""
def __init__(
self,
c_q: int,
c_k: int,
c_v: int,
c_hidden: int,
no_heads: int,
gating: bool = True,
):
"""
Args:
c_q:
Input dimension of query data
c_k:
Input dimension of key data
c_v:
Input dimension of value data
c_hidden:
Per-head hidden dimension
no_heads:
Number of attention heads
gating:
Whether the output should be gated using query data
chunk_size:
Trades memory for better parallelization. A low value
corresponds to lower memory usage.
"""
super().__init__()
self.c_q = c_q
self.c_k = c_k
self.c_v = c_v
self.c_hidden = c_hidden
self.no_heads = no_heads
self.gating = gating
self.linear_q = Linear(
self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot"
)
self.linear_k = Linear(
self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot"
)
self.linear_v = Linear(
self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot"
)
self.linear_o = Linear(
self.c_hidden * self.no_heads, self.c_q, init="final"
)
if self.gating:
self.linear_g = Linear(
self.c_q, self.c_hidden * self.no_heads, init="gating"
)
self.sigmoid = nn.Sigmoid()
self.softmax = nn.Softmax(dim=-1)
def forward(self,
q_x: torch.Tensor,
k_x: torch.Tensor,
v_x: torch.Tensor,
q_chunk_size: int,
kv_chunk_size: int,
biases: Optional[List[torch.Tensor]] = None,
):
if(biases is None):
biases = []
else:
biases = [
b.expand(b.shape[:-2] + (q_x.shape[-2],) + (k_x.shape[-2],))
for b in biases
]
# [*, Q/K/V, H * C_hidden]
q = self.linear_q(q_x)
k = self.linear_k(k_x)
v = self.linear_v(v_x)
# [*, Q/K, H, C_hidden]
q = q.view(q.shape[:-1] + (self.no_heads, -1))
k = k.view(k.shape[:-1] + (self.no_heads, -1))
v = v.view(v.shape[:-1] + (self.no_heads, -1))
q = q / math.sqrt(q.shape[-1])
o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size)
if self.gating:
g = self.sigmoid(self.linear_g(q_x))
# [*, Q, H, C_hidden]
g = g.view(g.shape[:-1] + (self.no_heads, -1))
o = o * g
# [*, Q, H * C_hidden]
o = flatten_final_dims(o, 2)
# [*, Q, C_q]
o = self.linear_o(o)
return o
......@@ -18,7 +18,7 @@ import torch
import torch.nn as nn
from typing import Optional, Tuple
from openfold.model.primitives import Linear, ipa_point_weights_init_
from openfold.model.primitives import Linear, LayerNorm, ipa_point_weights_init_
from openfold.np.residue_constants import (
restype_rigid_group_default_frame,
restype_atom14_to_rigid_group,
......@@ -298,8 +298,8 @@ class InvariantPointAttention(nn.Module):
permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
)
a = a * math.sqrt(1.0 / (3 * self.c_hidden))
a = a + (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)))
a *= math.sqrt(1.0 / (3 * self.c_hidden))
a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)))
# [*, N_res, N_res, H, P_q, 3]
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
......@@ -331,7 +331,9 @@ class InvariantPointAttention(nn.Module):
# Compute output
################
# [*, N_res, H, C_hidden]
o = torch.matmul(a, v.transpose(-2, -3)).transpose(-2, -3)
o = torch.matmul(
a, v.transpose(-2, -3).to(dtype=a.dtype)
).transpose(-2, -3)
# [*, N_res, H * C_hidden]
o = flatten_final_dims(o, 2)
......@@ -360,7 +362,7 @@ class InvariantPointAttention(nn.Module):
o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
# [*, N_res, H, C_z]
o_pair = torch.matmul(a.transpose(-2, -3), z)
o_pair = torch.matmul(a.transpose(-2, -3), z.to(dtype=a.dtype))
# [*, N_res, H * C_z]
o_pair = flatten_final_dims(o_pair, 2)
......@@ -369,7 +371,7 @@ class InvariantPointAttention(nn.Module):
s = self.linear_out(
torch.cat(
(o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1
)
).to(dtype=z.dtype)
)
return s
......@@ -444,7 +446,7 @@ class StructureModuleTransition(nn.Module):
self.layers.append(l)
self.dropout = nn.Dropout(self.dropout_rate)
self.layer_norm = nn.LayerNorm(self.c)
self.layer_norm = LayerNorm(self.c)
def forward(self, s):
for l in self.layers:
......@@ -534,8 +536,8 @@ class StructureModule(nn.Module):
self.atom_mask = None
self.lit_positions = None
self.layer_norm_s = nn.LayerNorm(self.c_s)
self.layer_norm_z = nn.LayerNorm(self.c_z)
self.layer_norm_s = LayerNorm(self.c_s)
self.layer_norm_z = LayerNorm(self.c_z)
self.linear_in = Linear(self.c_s, self.c_s)
......@@ -551,7 +553,7 @@ class StructureModule(nn.Module):
)
self.ipa_dropout = nn.Dropout(self.dropout_rate)
self.layer_norm_ipa = nn.LayerNorm(self.c_s)
self.layer_norm_ipa = LayerNorm(self.c_s)
self.transition = StructureModuleTransition(
self.c_s,
......
......@@ -19,7 +19,7 @@ from typing import Optional, List
import torch
import torch.nn as nn
from openfold.model.primitives import Linear, Attention
from openfold.model.primitives import Linear, LayerNorm, Attention
from openfold.model.dropout import (
DropoutRowwise,
DropoutColumnwise,
......@@ -80,8 +80,7 @@ class TemplatePointwiseAttention(nn.Module):
) -> torch.Tensor:
mha_inputs = {
"q_x": z,
"k_x": t,
"v_x": t,
"kv_x": t,
"biases": biases,
}
return chunk_layer(
......@@ -125,7 +124,7 @@ class TemplatePointwiseAttention(nn.Module):
if chunk_size is not None:
z = self._chunk(z, t, biases, chunk_size)
else:
z = self.mha(q_x=z, k_x=t, v_x=t, biases=biases)
z = self.mha(q_x=z, kv_x=t, biases=biases)
# [*, N_res, N_res, C_z]
z = z.squeeze(-2)
......@@ -292,7 +291,7 @@ class TemplatePairStack(nn.Module):
)
self.blocks.append(block)
self.layer_norm = nn.LayerNorm(c_t)
self.layer_norm = LayerNorm(c_t)
def forward(
self,
......
......@@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partialmethod
from functools import partialmethod, partial
import math
from typing import Optional, List
import torch
import torch.nn as nn
from openfold.model.primitives import Linear, Attention
from openfold.model.primitives import Linear, LayerNorm, Attention
from openfold.utils.tensor_utils import (
chunk_layer,
permute_final_dims,
......@@ -49,7 +49,7 @@ class TriangleAttention(nn.Module):
self.starting = starting
self.inf = inf
self.layer_norm = nn.LayerNorm(self.c_in)
self.layer_norm = LayerNorm(self.c_in)
self.linear = Linear(c_in, self.no_heads, bias=False, init="normal")
......@@ -65,12 +65,11 @@ class TriangleAttention(nn.Module):
) -> torch.Tensor:
mha_inputs = {
"q_x": x,
"k_x": x,
"v_x": x,
"kv_x": x,
"biases": biases,
}
return chunk_layer(
self.mha,
partial(self.mha),
mha_inputs,
chunk_size=chunk_size,
no_batch_dims=len(x.shape[:-2]),
......@@ -116,7 +115,7 @@ class TriangleAttention(nn.Module):
if chunk_size is not None:
x = self._chunk(x, biases, chunk_size)
else:
x = self.mha(q_x=x, k_x=x, v_x=x, biases=biases)
x = self.mha(q_x=x, kv_x=x, biases=biases)
if not self.starting:
x = x.transpose(-2, -3)
......
......@@ -19,7 +19,7 @@ from typing import Optional
import torch
import torch.nn as nn
from openfold.model.primitives import Linear
from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.tensor_utils import permute_final_dims
......@@ -47,8 +47,8 @@ class TriangleMultiplicativeUpdate(nn.Module):
self.linear_g = Linear(self.c_z, self.c_z, init="gating")
self.linear_z = Linear(self.c_hidden, self.c_z, init="final")
self.layer_norm_in = nn.LayerNorm(self.c_z)
self.layer_norm_out = nn.LayerNorm(self.c_hidden)
self.layer_norm_in = LayerNorm(self.c_z)
self.layer_norm_out = LayerNorm(self.c_hidden)
self.sigmoid = nn.Sigmoid()
......
......@@ -15,17 +15,27 @@
import deepspeed
import torch
import torch.utils.checkpoint
from typing import Any, Tuple, List, Callable
from typing import Any, Tuple, List, Callable, Optional
BLOCK_ARG = Any
BLOCK_ARGS = List[BLOCK_ARG]
def get_checkpoint_fn():
if(deepspeed.checkpointing.is_configured()):
checkpoint = deepspeed.checkpointing.checkpoint
else:
checkpoint = torch.utils.checkpoint.checkpoint
return checkpoint
@torch.jit.ignore
def checkpoint_blocks(
blocks: List[Callable],
args: BLOCK_ARGS,
blocks_per_ckpt: int,
blocks_per_ckpt: Optional[int],
) -> BLOCK_ARGS:
"""
Chunk a list of blocks and run each chunk with activation
......@@ -68,10 +78,7 @@ def checkpoint_blocks(
elif blocks_per_ckpt < 1 or blocks_per_ckpt > len(blocks):
raise ValueError("blocks_per_ckpt must be between 1 and len(blocks)")
if(deepspeed.checkpointing.is_configured()):
checkpoint = deepspeed.checkpointing.checkpoint
else:
checkpoint = torch.utils.checkpoint.checkpoint
checkpoint = get_checkpoint_fn()
for s in range(0, len(blocks), blocks_per_ckpt):
e = s + blocks_per_ckpt
......
......@@ -282,13 +282,19 @@ def import_jax_weights_(model, npz_path, version="model_1"):
b.msa_att_row
),
col_att_name: msa_col_att_params,
"msa_transition": MSATransitionParams(b.msa_transition),
"outer_product_mean": OuterProductMeanParams(b.outer_product_mean),
"triangle_multiplication_outgoing": TriMulOutParams(b.tri_mul_out),
"triangle_multiplication_incoming": TriMulInParams(b.tri_mul_in),
"triangle_attention_starting_node": TriAttParams(b.tri_att_start),
"triangle_attention_ending_node": TriAttParams(b.tri_att_end),
"pair_transition": PairTransitionParams(b.pair_transition),
"msa_transition": MSATransitionParams(b.core.msa_transition),
"outer_product_mean":
OuterProductMeanParams(b.core.outer_product_mean),
"triangle_multiplication_outgoing":
TriMulOutParams(b.core.tri_mul_out),
"triangle_multiplication_incoming":
TriMulInParams(b.core.tri_mul_in),
"triangle_attention_starting_node":
TriAttParams(b.core.tri_att_start),
"triangle_attention_ending_node":
TriAttParams(b.core.tri_att_end),
"pair_transition":
PairTransitionParams(b.core.pair_transition),
}
return d
......@@ -323,7 +329,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
[TemplatePairBlockParams(b) for b in tps_blocks]
)
ems_blocks = model.extra_msa_stack.stack.blocks
ems_blocks = model.extra_msa_stack.blocks
ems_blocks_params = stacked([ExtraMSABlockParams(b) for b in ems_blocks])
evo_blocks = model.evoformer.blocks
......
......@@ -43,8 +43,8 @@ def softmax_cross_entropy(logits, labels):
def sigmoid_cross_entropy(logits, labels):
log_p = torch.nn.functional.logsigmoid(logits)
log_not_p = torch.nn.functional.logsigmoid(-logits)
log_p = torch.log(torch.sigmoid(logits))
log_not_p = torch.log(torch.sigmoid(-logits))
loss = -labels * log_p - (1 - labels) * log_not_p
return loss
......@@ -329,26 +329,15 @@ def compute_plddt(logits: torch.Tensor) -> torch.Tensor:
return pred_lddt_ca * 100
def lddt_loss(
logits: torch.Tensor,
def lddt(
all_atom_pred_pos: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
resolution: torch.Tensor,
cutoff: float = 15.0,
no_bins: int = 50,
min_resolution: float = 0.1,
max_resolution: float = 3.0,
eps: float = 1e-10,
**kwargs,
per_residue: bool = True,
) -> torch.Tensor:
n = all_atom_mask.shape[-2]
ca_pos = residue_constants.atom_order["CA"]
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim
dmat_true = torch.sqrt(
eps
+ torch.sum(
......@@ -389,8 +378,63 @@ def lddt_loss(
)
score = score * 0.25
norm = 1.0 / (eps + torch.sum(dists_to_score, dim=-1))
score = norm * (eps + torch.sum(dists_to_score * score, dim=-1))
dims = (-1,) if per_residue else (-2, -1)
norm = 1.0 / (eps + torch.sum(dists_to_score, dim=dims))
score = norm * (eps + torch.sum(dists_to_score * score, dim=dims))
return score
def lddt_ca(
all_atom_pred_pos: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
cutoff: float = 15.0,
eps: float = 1e-10,
per_residue: bool = True,
) -> torch.Tensor:
ca_pos = residue_constants.atom_order["CA"]
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim
return lddt(
all_atom_pred_pos,
all_atom_positions,
all_atom_mask,
cutoff=cutoff,
eps=eps,
per_residue=per_residue,
)
def lddt_loss(
logits: torch.Tensor,
all_atom_pred_pos: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
resolution: torch.Tensor,
cutoff: float = 15.0,
no_bins: int = 50,
min_resolution: float = 0.1,
max_resolution: float = 3.0,
eps: float = 1e-10,
**kwargs,
) -> torch.Tensor:
n = all_atom_mask.shape[-2]
ca_pos = residue_constants.atom_order["CA"]
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim
score = lddt(
all_atom_pred_pos,
all_atom_positions,
all_atom_mask,
cutoff=cutoff,
eps=eps
)
score = score.detach()
......@@ -1462,7 +1506,7 @@ class AlphaFoldLoss(nn.Module):
self.config = config
def forward(self, out, batch):
if "violation" not in out.keys() and self.config.violation.weight:
if "violation" not in out.keys():
out["violation"] = find_structural_violations(
batch,
out["sm"]["positions"][-1],
......@@ -1509,23 +1553,27 @@ class AlphaFoldLoss(nn.Module):
out["violation"],
**batch,
),
"tm": lambda: tm_loss(
}
if(self.config.tm.enabled):
loss_fns["tm"] = lambda: tm_loss(
logits=out["tm_logits"],
**{**batch, **out, **self.config.tm},
),
}
)
cum_loss = 0.
for loss_name, loss_fn in loss_fns.items():
weight = self.config[loss_name].weight
if weight:
loss = loss_fn()
if(torch.isnan(loss) or torch.isinf(loss)):
logging.warning(f"{loss_name} loss is NaN. Skipping...")
loss = loss.new_tensor(0., requires_grad=True)
cum_loss = cum_loss + weight * loss
seq_len = torch.mean(batch["seq_length"].float())
crop_len = batch["aatype"].shape[-1]
cum_loss = cum_loss * torch.sqrt(min(seq_len, crop_len))
# Scale the loss by the square root of the minimum of the crop size and
# the (average) sequence length. See subsection 1.9.
seq_len = torch.mean(batch["seq_length"].float())
......
......@@ -26,7 +26,7 @@ def rot_matmul(
) -> torch.Tensor:
"""
Performs matrix multiplication of two rotation matrix tensors. Written
out by hand to avoid transfer to low-precision tensor cores.
out by hand to avoid AMP downcasting.
Args:
a: [*, 3, 3] left multiplicand
......@@ -86,7 +86,7 @@ def rot_vec_mul(
) -> torch.Tensor:
"""
Applies a rotation to a vector. Written out by hand to avoid transfer
to low-precision tensor cores.
to avoid AMP downcasting.
Args:
r: [*, 3, 3] rotation matrices
......@@ -323,6 +323,12 @@ class Rotation:
"Incorrectly shaped rotation matrix or quaternion"
)
# Force full-precision
if(quats is not None):
quats = quats.to(dtype=torch.float32)
if(rot_mats is not None):
rot_mats = rot_mats.to(dtype=torch.float32)
if(quats is not None and normalize_quats):
quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True)
......@@ -857,6 +863,9 @@ class Rigid:
(rots.device != trans.device)):
raise ValueError("Rots and trans incompatible")
# Force full precision. Happens to the rotations automatically.
trans = trans.to(dtype=torch.float32)
self._rots = rots
self._trans = trans
......
import argparse
from functools import partial
import logging
from multiprocessing import Pool
import os
import sys
import json
sys.path.append(".") # an innocent hack to get this to run from the top level
from tqdm import tqdm
from openfold.data.mmcif_parsing import parse
def parse_file(f, args):
with open(os.path.join(args.mmcif_dir, f), "r") as fp:
mmcif_string = fp.read()
file_id = os.path.splitext(f)[0]
mmcif = parse(file_id=file_id, mmcif_string=mmcif_string)
if mmcif.mmcif_object is None:
logging.info(f"Could not parse {f}. Skipping...")
return {}
else:
mmcif = mmcif.mmcif_object
local_data = {}
local_data["release_date"] = mmcif.header["release_date"]
local_data["no_chains"] = len(list(mmcif.structure.get_chains()))
return {file_id: local_data}
def main(args):
files = [f for f in os.listdir(args.mmcif_dir) if ".cif" in f]
fn = partial(parse_file, args=args)
data = {}
with Pool(processes=args.no_workers) as p:
with tqdm(total=len(files)) as pbar:
for d in p.imap_unordered(fn, files, chunksize=args.chunksize):
data.update(d)
pbar.update()
with open(args.output_path, "w") as fp:
fp.write(json.dumps(data, indent=4))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"mmcif_dir", type=str, help="Directory containing mmCIF files"
)
parser.add_argument(
"output_path", type=str, help="Path for .json output"
)
parser.add_argument(
"--no_workers", type=int, default=4,
help="Number of workers to use for parsing"
)
parser.add_argument(
"--chunksize", type=int, default=10,
help="How many files should be distributed to each worker at a time"
)
args = parser.parse_args()
main(args)
import argparse
from functools import partial
import json
import logging
import os
import threading
from multiprocessing import cpu_count
from shutil import copyfile
import tempfile
import openfold.data.mmcif_parsing as mmcif_parsing
......@@ -10,30 +15,58 @@ from openfold.np import protein, residue_constants
from utils import add_data_args
#python3 scripts/precompute_alignments.py mmcif_dir/ alignment_dir/ data/uniref90/uniref90.fasta data/mgnify/mgy_clusters_2018_12.fa data/pdb70/pdb70 data/pdb_mmcif/mmcif_files/ data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 --bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt --cpus 16 --jackhmmer_binary_path /home/u00u98too4mkqFBu8M357/openfold/lib/conda/envs/openfold_venv/bin/jackhmmer --hhblits_binary_path /home/u00u98too4mkqFBu8M357/openfold/lib/conda/envs/openfold_venv/bin/hhblits --hhsearch_binary_path /home/u00u98too4mkqFBu8M357/openfold/lib/conda/envs/openfold_venv/bin/hhsearch --kalign_binary_path /home/u00u98too4mkqFBu8M357/openfold/lib/conda/envs/openfold_venv/bin/kalign
logging.basicConfig(level=logging.DEBUG)
logging.basicConfig(level=logging.WARNING)
def main(args):
# Build the alignment tool runner
alignment_runner = AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
pdb70_database_path=args.pdb70_database_path,
use_small_bfd=args.bfd_database_path is None,
no_cpus=args.cpus,
def run_seq_group_alignments(seq_groups, alignment_runner, args):
dirs = set(os.listdir(args.output_dir))
for seq, names in seq_groups:
first_name = names[0]
alignment_dir = os.path.join(args.output_dir, first_name)
os.makedirs(alignment_dir, exist_ok=True)
# try:
# os.makedirs(alignment_dir)
# except Exception as e:
# logging.warning(f"Failed to create directory for {first_name} with exception {e}...")
# continue
fd, fasta_path = tempfile.mkstemp(suffix=".fasta")
with os.fdopen(fd, 'w') as fp:
fp.write(f'>query\n{seq}')
try:
alignment_runner.run(
fasta_path, alignment_dir
)
except:
logging.warning(f"Failed to run alignments for {first_name}. Skipping...")
os.remove(fasta_path)
os.rmdir(alignment_dir)
continue
os.remove(fasta_path)
for f in os.listdir(args.input_dir):
for name in names[1:]:
#if(name in dirs):
# logging.warning(
# f'{name} has already been processed. Skipping...'
# )
# continue
cp_dir = os.path.join(args.output_dir, name)
os.makedirs(cp_dir, exist_ok=True)
for f in os.listdir(alignment_dir):
copyfile(os.path.join(alignment_dir, f), os.path.join(cp_dir, f))
def parse_and_align(files, alignment_runner, args):
for f in files:
path = os.path.join(args.input_dir, f)
file_id = os.path.splitext(f)[0]
seqs = {}
seq_group_dict = {}
if(f.endswith('.cif')):
with open(path, 'r') as fp:
mmcif_str = fp.read()
......@@ -47,9 +80,10 @@ def main(args):
else:
continue
mmcif = mmcif.mmcif_object
for k,v in mmcif.chain_to_seqres.items():
chain_id = '_'.join([file_id, k])
seqs[chain_id] = v
for chain_letter, seq in mmcif.chain_to_seqres.items():
chain_id = '_'.join([file_id, chain_letter])
l = seq_group_dict.setdefault(seq, [])
l.append(chain_id)
elif(f.endswith('.fasta') or f.endswith('.fa')):
with open(path, 'r') as fp:
fasta_str = fp.read()
......@@ -61,7 +95,7 @@ def main(args):
else:
logging.warning(msg)
input_sequence = input_seqs[0]
seqs[file_id] = input_sequence
seq_group_dict[input_sequence] = [file_id]
elif(f.endswith('.core')):
with open(path, 'r') as fp:
core_str = fp.read()
......@@ -71,27 +105,114 @@ def main(args):
residue_constants.restypes_with_x[aatype[i]]
for i in range(len(aatype))
])
seqs[file_id] = seq
seq_group_dict[seq] = [file_id]
else:
continue
for name, seq in seqs.items():
alignment_dir = os.path.join(args.output_dir, name)
if(os.path.isdir(alignment_dir)):
logging.info(f'{f} has already been processed. Skipping...')
continue
seq_group_tuples = [(k,v) for k,v in seq_group_dict.items()]
run_seq_group_alignments(seq_group_tuples, alignment_runner, args)
os.makedirs(alignment_dir)
fd, fasta_path = tempfile.mkstemp(suffix=".fasta")
with os.fdopen(fd, 'w') as fp:
fp.write(f'>query\n{seq}')
def main(args):
# Build the alignment tool runner
alignment_runner = AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
pdb70_database_path=args.pdb70_database_path,
use_small_bfd=args.bfd_database_path is None,
no_cpus=args.cpus_per_task,
)
alignment_runner.run(
fasta_path, alignment_dir
files = list(os.listdir(args.input_dir))
# Do some filtering
if(args.mmcif_cache is not None):
with open(args.mmcif_cache, "r") as fp:
cache = json.load(fp)
else:
cache = None
dirs = []
if(cache is not None and args.filter):
dirs = set(os.listdir(args.output_dir))
def prot_is_done(f):
prot_id = os.path.splitext(f)[0]
if(prot_id in cache):
chain_ids = cache[prot_id]["chain_ids"]
for c in chain_ids:
full_name = prot_id + "_" + c
if(not full_name in dirs):
return False
else:
return False
return True
files = [f for f in files if not prot_is_done(f)]
def split_up_arglist(arglist):
# Split up the survivors
if(os.environ.get("SLURM_JOB_NUM_NODES", 0)):
num_nodes = int(os.environ["SLURM_JOB_NUM_NODES"])
if(num_nodes > 1):
node_id = int(os.environ["SLURM_NODEID"])
logging.warning(f"Num nodes: {num_nodes}")
logging.warning(f"Node ID: {node_id}")
arglist = arglist[node_id::num_nodes]
t_arglist = []
for i in range(args.no_tasks):
t_arglist.append(arglist[i::args.no_tasks])
return t_arglist
if(cache is not None and "seqs" in next(iter(cache.values()))):
seq_group_dict = {}
for f in files:
prot_id = os.path.splitext(f)[0]
if(prot_id in cache):
prot_cache = cache[prot_id]
chains_seqs = zip(
prot_cache["chain_ids"], prot_cache["seqs"]
)
for chain, seq in chains_seqs:
chain_name = prot_id + "_" + chain
if(chain_name not in dirs):
l = seq_group_dict.setdefault(seq, [])
l.append(chain_name)
func = partial(run_seq_group_alignments,
alignment_runner=alignment_runner,
args=args
)
os.remove(fasta_path)
seq_groups = [(k,v) for k,v in seq_group_dict.items()]
# Sort them by group length so the tasks are approximately balanced
seq_groups = sorted(seq_groups, key=lambda x: len(x[1]))
task_arglist = [[a] for a in split_up_arglist(seq_groups)]
else:
func = partial(parse_and_align,
alignment_runner=alignment_runner,
args=args,
)
task_arglist = [[a] for a in split_up_arglist(files)]
threads = []
for i, task_args in enumerate(task_arglist):
print(f"Started thread {i}...")
t = threading.Thread(target=func, args=task_args)
threads.append(t)
t.start()
for t in threads:
t.join()
if __name__ == "__main__":
......@@ -111,9 +232,19 @@ if __name__ == "__main__":
help="Whether to crash on parsing errors"
)
parser.add_argument(
"--cpus", type=int, default=4,
"--cpus_per_task", type=int, default=cpu_count(),
help="Number of CPUs to use"
)
parser.add_argument(
"--mmcif_cache", type=str, default=None,
help="Path to mmCIF cache. Used to filter files to be parsed"
)
parser.add_argument(
"--no_tasks", type=int, default=1,
)
parser.add_argument(
"--filter", type=bool, default=True,
)
args = parser.parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment