"pcdet/git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "0d524bc156d77f27eeda02ca604cd6d2aff17e12"
Unverified Commit cfd0fc6e authored by Gustaf Ahdritz's avatar Gustaf Ahdritz Committed by GitHub
Browse files

Merge pull request #76 from aqlaboratory/chunking_experiment_rebased

parents c9e0f894 2726892a
...@@ -64,6 +64,7 @@ c_s = mlc.FieldReference(384, field_type=int) ...@@ -64,6 +64,7 @@ c_s = mlc.FieldReference(384, field_type=int)
blocks_per_ckpt = mlc.FieldReference(None, field_type=int) blocks_per_ckpt = mlc.FieldReference(None, field_type=int)
chunk_size = mlc.FieldReference(4, field_type=int) chunk_size = mlc.FieldReference(4, field_type=int)
aux_distogram_bins = mlc.FieldReference(64, field_type=int) aux_distogram_bins = mlc.FieldReference(64, field_type=int)
tm_enabled = mlc.FieldReference(False, field_type=bool)
eps = mlc.FieldReference(1e-8, field_type=float) eps = mlc.FieldReference(1e-8, field_type=float)
templates_enabled = mlc.FieldReference(True, field_type=bool) templates_enabled = mlc.FieldReference(True, field_type=bool)
embed_template_torsion_angles = mlc.FieldReference(True, field_type=bool) embed_template_torsion_angles = mlc.FieldReference(True, field_type=bool)
...@@ -228,7 +229,7 @@ config = mlc.ConfigDict( ...@@ -228,7 +229,7 @@ config = mlc.ConfigDict(
"use_small_bfd": False, "use_small_bfd": False,
"data_loaders": { "data_loaders": {
"batch_size": 1, "batch_size": 1,
"num_workers": 8, "num_workers": 16,
}, },
}, },
}, },
...@@ -320,10 +321,10 @@ config = mlc.ConfigDict( ...@@ -320,10 +321,10 @@ config = mlc.ConfigDict(
"transition_n": 4, "transition_n": 4,
"msa_dropout": 0.15, "msa_dropout": 0.15,
"pair_dropout": 0.25, "pair_dropout": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"clear_cache_between_blocks": True, "clear_cache_between_blocks": True,
"inf": 1e9, "inf": 1e9,
"eps": eps, # 1e-10, "eps": eps, # 1e-10,
"ckpt": blocks_per_ckpt is not None,
}, },
"enabled": True, "enabled": True,
}, },
...@@ -376,7 +377,7 @@ config = mlc.ConfigDict( ...@@ -376,7 +377,7 @@ config = mlc.ConfigDict(
"tm": { "tm": {
"c_z": c_z, "c_z": c_z,
"no_bins": aux_distogram_bins, "no_bins": aux_distogram_bins,
"enabled": False, "enabled": tm_enabled,
}, },
"masked_msa": { "masked_msa": {
"c_m": c_m, "c_m": c_m,
...@@ -454,6 +455,7 @@ config = mlc.ConfigDict( ...@@ -454,6 +455,7 @@ config = mlc.ConfigDict(
"max_resolution": 3.0, "max_resolution": 3.0,
"eps": eps, # 1e-8, "eps": eps, # 1e-8,
"weight": 0.0, "weight": 0.0,
"enabled": tm_enabled,
}, },
"eps": eps, "eps": eps,
}, },
......
...@@ -4,7 +4,7 @@ import json ...@@ -4,7 +4,7 @@ import json
import logging import logging
import os import os
import pickle import pickle
from typing import Optional, Sequence from typing import Optional, Sequence, List, Any
import ml_collections as mlc import ml_collections as mlc
import numpy as np import numpy as np
...@@ -29,14 +29,15 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -29,14 +29,15 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
max_template_date: str, max_template_date: str,
config: mlc.ConfigDict, config: mlc.ConfigDict,
kalign_binary_path: str = '/usr/bin/kalign', kalign_binary_path: str = '/usr/bin/kalign',
mapping_path: Optional[str] = None,
max_template_hits: int = 4, max_template_hits: int = 4,
obsolete_pdbs_file_path: Optional[str] = None, obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None, template_release_dates_cache_path: Optional[str] = None,
shuffle_top_k_prefiltered: Optional[int] = None, shuffle_top_k_prefiltered: Optional[int] = None,
treat_pdb_as_distillation: bool = True, treat_pdb_as_distillation: bool = True,
mapping_path: Optional[str] = None,
mode: str = "train", mode: str = "train",
_output_raw: bool = False, _output_raw: bool = False,
_alignment_index: Optional[Any] = None
): ):
""" """
Args: Args:
...@@ -56,12 +57,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -56,12 +57,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
A dataset config object. See openfold.config A dataset config object. See openfold.config
kalign_binary_path: kalign_binary_path:
Path to kalign binary. Path to kalign binary.
mapping_path:
A json file containing a mapping from consecutive numerical
ids to sample names (matching the directories in data_dir).
Samples not in this mapping are ignored. Can be used to
implement the various training-time filters described in
the AlphaFold supplement.
max_template_hits: max_template_hits:
An upper bound on how many templates are considered. During An upper bound on how many templates are considered. During
training, the templates ultimately used are subsampled training, the templates ultimately used are subsampled
...@@ -89,26 +84,30 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -89,26 +84,30 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
self.treat_pdb_as_distillation = treat_pdb_as_distillation self.treat_pdb_as_distillation = treat_pdb_as_distillation
self.mode = mode self.mode = mode
self._output_raw = _output_raw self._output_raw = _output_raw
self._alignment_index = _alignment_index
valid_modes = ["train", "eval", "predict"] valid_modes = ["train", "eval", "predict"]
if(mode not in valid_modes): if(mode not in valid_modes):
raise ValueError(f'mode must be one of {valid_modes}') raise ValueError(f'mode must be one of {valid_modes}')
if(mapping_path is None):
self.mapping = {
str(i):os.path.splitext(name)[0]
for i, name in enumerate(os.listdir(alignment_dir))
}
else:
with open(mapping_path, 'r') as fp:
self.mapping = json.load(fp)
if(template_release_dates_cache_path is None): if(template_release_dates_cache_path is None):
logging.warning( logging.warning(
"Template release dates cache does not exist. Remember to run " "Template release dates cache does not exist. Remember to run "
"scripts/generate_mmcif_cache.py before running OpenFold" "scripts/generate_mmcif_cache.py before running OpenFold"
) )
if(_alignment_index is not None):
self._chain_ids = list(_alignment_index.keys())
elif(mapping_path is None):
self._chain_ids = list(os.listdir(alignment_dir))
else:
with open(mapping_path, "r") as f:
self._chain_ids = [l.strip() for l in f.readlines()]
self._chain_id_to_idx_dict = {
chain: i for i, chain in enumerate(self._chain_ids)
}
template_featurizer = templates.TemplateHitFeaturizer( template_featurizer = templates.TemplateHitFeaturizer(
mmcif_dir=template_mmcif_dir, mmcif_dir=template_mmcif_dir,
max_template_date=max_template_date, max_template_date=max_template_date,
...@@ -126,7 +125,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -126,7 +125,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
if(not self._output_raw): if(not self._output_raw):
self.feature_pipeline = feature_pipeline.FeaturePipeline(config) self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def _parse_mmcif(self, path, file_id, chain_id, alignment_dir): def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, _alignment_index):
with open(path, 'r') as f: with open(path, 'r') as f:
mmcif_string = f.read() mmcif_string = f.read()
...@@ -145,14 +144,26 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -145,14 +144,26 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
mmcif=mmcif_object, mmcif=mmcif_object,
alignment_dir=alignment_dir, alignment_dir=alignment_dir,
chain_id=chain_id, chain_id=chain_id,
_alignment_index=_alignment_index
) )
return data return data
def chain_id_to_idx(self, chain_id):
return self._chain_id_to_idx_dict[chain_id]
def idx_to_chain_id(self, idx):
return self._chain_ids[idx]
def __getitem__(self, idx): def __getitem__(self, idx):
name = self.mapping[str(idx)] name = self.idx_to_chain_id(idx)
alignment_dir = os.path.join(self.alignment_dir, name) alignment_dir = os.path.join(self.alignment_dir, name)
_alignment_index = None
if(self._alignment_index is not None):
alignment_dir = self.alignment_dir
_alignment_index = self._alignment_index[name]
if(self.mode == 'train' or self.mode == 'eval'): if(self.mode == 'train' or self.mode == 'eval'):
spl = name.rsplit('_', 1) spl = name.rsplit('_', 1)
if(len(spl) == 2): if(len(spl) == 2):
...@@ -164,11 +175,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -164,11 +175,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
path = os.path.join(self.data_dir, file_id) path = os.path.join(self.data_dir, file_id)
if(os.path.exists(path + ".cif")): if(os.path.exists(path + ".cif")):
data = self._parse_mmcif( data = self._parse_mmcif(
path + ".cif", file_id, chain_id, alignment_dir path + ".cif", file_id, chain_id, alignment_dir, _alignment_index,
) )
elif(os.path.exists(path + ".core")): elif(os.path.exists(path + ".core")):
data = self.data_pipeline.process_core( data = self.data_pipeline.process_core(
path + ".core", alignment_dir path + ".core", alignment_dir, _alignment_index,
) )
elif(os.path.exists(path + ".pdb")): elif(os.path.exists(path + ".pdb")):
data = self.data_pipeline.process_pdb( data = self.data_pipeline.process_pdb(
...@@ -176,6 +187,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -176,6 +187,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
alignment_dir=alignment_dir, alignment_dir=alignment_dir,
is_distillation=self.treat_pdb_as_distillation, is_distillation=self.treat_pdb_as_distillation,
chain_id=chain_id, chain_id=chain_id,
_alignment_index=_alignment_index,
) )
else: else:
raise ValueError("Invalid file type") raise ValueError("Invalid file type")
...@@ -184,6 +196,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -184,6 +196,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
data = self.data_pipeline.process_fasta( data = self.data_pipeline.process_fasta(
fasta_path=path, fasta_path=path,
alignment_dir=alignment_dir, alignment_dir=alignment_dir,
_alignment_index=_alignment_index,
) )
if(self._output_raw): if(self._output_raw):
...@@ -196,53 +209,150 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -196,53 +209,150 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
return feats return feats
def __len__(self): def __len__(self):
return len(self.mapping.keys()) return len(self._chain_ids)
def deterministic_train_filter(
prot_data_cache_entry: Any,
max_resolution: float = 9.,
max_single_aa_prop: float = 0.8,
) -> bool:
# Hard filters
resolution = prot_data_cache_entry.get("resolution", None)
if(resolution is not None and resolution > max_resolution):
return False
seq = prot_data_cache_entry["seq"]
counts = {}
for aa in seq:
counts.setdefault(aa, 0)
counts[aa] += 1
largest_aa_count = max(counts.values())
largest_single_aa_prop = largest_aa_count / len(seq)
if(largest_single_aa_prop > max_single_aa_prop):
return False
return True
def get_stochastic_train_filter_prob(
prot_data_cache_entry: Any,
) -> List[float]:
# Stochastic filters
probabilities = []
cluster_size = prot_data_cache_entry.get("cluster_size", None)
if(cluster_size is not None and cluster_size > 0):
probabilities.append(1 / cluster_size)
chain_length = len(prot_data_cache_entry["seq"])
probabilities.append((1 / 512) * (max(min(chain_length, 512), 256)))
# Risk of underflow here?
out = 1
for p in probabilities:
out *= p
def looped_sequence(sequence): return out
while True:
for x in sequence:
yield x
class OpenFoldDataset(torch.utils.data.IterableDataset): class OpenFoldDataset(torch.utils.data.Dataset):
""" """
The Dataset is written to accommodate the requirement that proteins are Implements the stochastic filters applied during AlphaFold's training.
sampled from the distillation set with some probability p Because samples are selected from constituent datasets randomly, the
and from the PDB set with probability (1 - p). Proteins are sampled length of an OpenFoldFilteredDataset is arbitrary. Samples are selected
from both sets without replacement, and as soon as either set is and filtered once at initialization.
emptied, it is refilled. The Dataset therefore has an arbitrary length.
Nevertheless, for compatibility with various PyTorch Lightning
functionalities, it is possible to specify an epoch length. This length
has no effect on the output of the Dataset.
""" """
def __init__(self, def __init__(self,
datasets: Sequence[OpenFoldSingleDataset], datasets: Sequence[OpenFoldSingleDataset],
probabilities: Sequence[int], probabilities: Sequence[int],
epoch_len: int, epoch_len: int,
prot_data_cache_paths: List[str],
generator: torch.Generator = None,
_roll_at_init: bool = True,
): ):
self.datasets = datasets self.datasets = datasets
self.samplers = [ self.probabilities = probabilities
looped_sequence(RandomSampler(d)) for d in datasets
]
self.epoch_len = epoch_len self.epoch_len = epoch_len
self.generator = generator
self.prot_data_caches = []
for path in prot_data_cache_paths:
with open(path, "r") as fp:
self.prot_data_caches.append(json.load(fp))
def looped_shuffled_dataset_idx(dataset_len):
while True:
# Uniformly shuffle each dataset's indices
weights = [1. for _ in range(dataset_len)]
shuf = torch.multinomial(
torch.tensor(weights),
num_samples=dataset_len,
replacement=False,
generator=self.generator,
)
for idx in shuf:
yield idx
def looped_samples(dataset_idx):
max_cache_len = int(epoch_len * probabilities[dataset_idx])
dataset = self.datasets[dataset_idx]
idx_iter = looped_shuffled_dataset_idx(len(dataset))
prot_data_cache = self.prot_data_caches[dataset_idx]
while True:
weights = []
idx = []
for _ in range(max_cache_len):
candidate_idx = next(idx_iter)
chain_id = dataset.idx_to_chain_id(candidate_idx)
prot_data_cache_entry = prot_data_cache[chain_id]
if(not deterministic_train_filter(prot_data_cache_entry)):
continue
p = get_stochastic_train_filter_prob(
prot_data_cache_entry,
)
weights.append([1. - p, p])
idx.append(candidate_idx)
samples = torch.multinomial(
torch.tensor(weights),
num_samples=1,
generator=self.generator,
)
samples = samples.squeeze()
self.distr = torch.distributions.categorical.Categorical( cache = [i for i, s in zip(idx, samples) if s]
probs=torch.tensor(probabilities),
)
def __iter__(self): for datapoint_idx in cache:
return self yield datapoint_idx
self._samples = [looped_samples(i) for i in range(len(self.datasets))]
def __next__(self): if(_roll_at_init):
dataset_idx = self.distr.sample() self.reroll()
sampler = self.samplers[dataset_idx]
element_idx = next(sampler) def __getitem__(self, idx):
return self.datasets[dataset_idx][element_idx] dataset_idx, datapoint_idx = self.datapoints[idx]
return self.datasets[dataset_idx][datapoint_idx]
def __len__(self): def __len__(self):
return self.epoch_len return self.epoch_len
def reroll(self):
dataset_choices = torch.multinomial(
torch.tensor(self.probabilities),
num_samples=self.epoch_len,
replacement=True,
generator=self.generator,
)
self.datapoints = []
for dataset_idx in dataset_choices:
samples = self._samples[dataset_idx]
datapoint_idx = next(samples)
self.datapoints.append((dataset_idx, datapoint_idx))
class OpenFoldBatchCollator: class OpenFoldBatchCollator:
def __init__(self, config, stage="train"): def __init__(self, config, stage="train"):
...@@ -283,7 +393,7 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader): ...@@ -283,7 +393,7 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
keyed_probs.append( keyed_probs.append(
("use_clamped_fape", [1 - clamp_prob, clamp_prob]) ("use_clamped_fape", [1 - clamp_prob, clamp_prob])
) )
if(stage_cfg.uniform_recycling): if(stage_cfg.uniform_recycling):
recycling_probs = [ recycling_probs = [
1. / (max_iters + 1) for _ in range(max_iters + 1) 1. / (max_iters + 1) for _ in range(max_iters + 1)
...@@ -293,7 +403,7 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader): ...@@ -293,7 +403,7 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
0. for _ in range(max_iters + 1) 0. for _ in range(max_iters + 1)
] ]
recycling_probs[-1] = 1. recycling_probs[-1] = 1.
keyed_probs.append( keyed_probs.append(
("no_recycling_iters", recycling_probs) ("no_recycling_iters", recycling_probs)
) )
...@@ -361,8 +471,10 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -361,8 +471,10 @@ class OpenFoldDataModule(pl.LightningDataModule):
max_template_date: str, max_template_date: str,
train_data_dir: Optional[str] = None, train_data_dir: Optional[str] = None,
train_alignment_dir: Optional[str] = None, train_alignment_dir: Optional[str] = None,
train_prot_data_cache_path: Optional[str] = None,
distillation_data_dir: Optional[str] = None, distillation_data_dir: Optional[str] = None,
distillation_alignment_dir: Optional[str] = None, distillation_alignment_dir: Optional[str] = None,
distillation_prot_data_cache_path: Optional[str] = None,
val_data_dir: Optional[str] = None, val_data_dir: Optional[str] = None,
val_alignment_dir: Optional[str] = None, val_alignment_dir: Optional[str] = None,
predict_data_dir: Optional[str] = None, predict_data_dir: Optional[str] = None,
...@@ -373,6 +485,8 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -373,6 +485,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
obsolete_pdbs_file_path: Optional[str] = None, obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None, template_release_dates_cache_path: Optional[str] = None,
batch_seed: Optional[int] = None, batch_seed: Optional[int] = None,
train_epoch_len: int = 50000,
_alignment_index_path: Optional[str] = None,
**kwargs **kwargs
): ):
super(OpenFoldDataModule, self).__init__() super(OpenFoldDataModule, self).__init__()
...@@ -382,8 +496,12 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -382,8 +496,12 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.max_template_date = max_template_date self.max_template_date = max_template_date
self.train_data_dir = train_data_dir self.train_data_dir = train_data_dir
self.train_alignment_dir = train_alignment_dir self.train_alignment_dir = train_alignment_dir
self.train_prot_data_cache_path = train_prot_data_cache_path
self.distillation_data_dir = distillation_data_dir self.distillation_data_dir = distillation_data_dir
self.distillation_alignment_dir = distillation_alignment_dir self.distillation_alignment_dir = distillation_alignment_dir
self.distillation_prot_data_cache_path = (
distillation_prot_data_cache_path
)
self.val_data_dir = val_data_dir self.val_data_dir = val_data_dir
self.val_alignment_dir = val_alignment_dir self.val_alignment_dir = val_alignment_dir
self.predict_data_dir = predict_data_dir self.predict_data_dir = predict_data_dir
...@@ -396,6 +514,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -396,6 +514,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
) )
self.obsolete_pdbs_file_path = obsolete_pdbs_file_path self.obsolete_pdbs_file_path = obsolete_pdbs_file_path
self.batch_seed = batch_seed self.batch_seed = batch_seed
self.train_epoch_len = train_epoch_len
if(self.train_data_dir is None and self.predict_data_dir is None): if(self.train_data_dir is None and self.predict_data_dir is None):
raise ValueError( raise ValueError(
...@@ -405,11 +524,11 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -405,11 +524,11 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.training_mode = self.train_data_dir is not None self.training_mode = self.train_data_dir is not None
if(self.training_mode and self.train_alignment_dir is None): if(self.training_mode and train_alignment_dir is None):
raise ValueError( raise ValueError(
'In training mode, train_alignment_dir must be specified' 'In training mode, train_alignment_dir must be specified'
) )
elif(not self.training_mode and self.predict_alingment_dir is None): elif(not self.training_mode and predict_alignment_dir is None):
raise ValueError( raise ValueError(
'In inference mode, predict_alignment_dir must be specified' 'In inference mode, predict_alignment_dir must be specified'
) )
...@@ -419,10 +538,13 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -419,10 +538,13 @@ class OpenFoldDataModule(pl.LightningDataModule):
'be specified as well' 'be specified as well'
) )
def setup(self, stage: Optional[str] = None): # An ad-hoc measure for our particular filesystem restrictions
if(stage is None): self._alignment_index = None
stage = "train" if(_alignment_index_path is not None):
with open(_alignment_index_path, "r") as fp:
self._alignment_index = json.load(fp)
def setup(self):
# Most of the arguments are the same for the three datasets # Most of the arguments are the same for the three datasets
dataset_gen = partial(OpenFoldSingleDataset, dataset_gen = partial(OpenFoldSingleDataset,
template_mmcif_dir=self.template_mmcif_dir, template_mmcif_dir=self.template_mmcif_dir,
...@@ -435,8 +557,8 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -435,8 +557,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.obsolete_pdbs_file_path, self.obsolete_pdbs_file_path,
) )
if(self.training_mode): if(self.training_mode):
self.train_dataset = dataset_gen( train_dataset = dataset_gen(
data_dir=self.train_data_dir, data_dir=self.train_data_dir,
alignment_dir=self.train_alignment_dir, alignment_dir=self.train_alignment_dir,
mapping_path=self.train_mapping_path, mapping_path=self.train_mapping_path,
...@@ -446,8 +568,10 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -446,8 +568,10 @@ class OpenFoldDataModule(pl.LightningDataModule):
treat_pdb_as_distillation=False, treat_pdb_as_distillation=False,
mode="train", mode="train",
_output_raw=True, _output_raw=True,
_alignment_index=self._alignment_index,
) )
distillation_dataset = None
if(self.distillation_data_dir is not None): if(self.distillation_data_dir is not None):
distillation_dataset = dataset_gen( distillation_dataset = dataset_gen(
data_dir=self.distillation_data_dir, data_dir=self.distillation_data_dir,
...@@ -460,13 +584,29 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -460,13 +584,29 @@ class OpenFoldDataModule(pl.LightningDataModule):
) )
d_prob = self.config.train.distillation_prob d_prob = self.config.train.distillation_prob
self.train_dataset = OpenFoldDataset(
datasets=[self.train_dataset, distillation_dataset], if(distillation_dataset is not None):
probabilities=[1 - d_prob, d_prob], datasets = [train_dataset, distillation_dataset]
epoch_len=( d_prob = self.config.train.distillation_prob
self.train_dataset.len() + distillation_dataset.len() probabilities = [1 - d_prob, d_prob]
), prot_data_cache_paths = [
) self.train_prot_data_cache_path,
self.distillation_prot_data_cache_path,
]
else:
datasets = [train_dataset]
probabilities = [1.]
prot_data_cache_paths = [
self.train_prot_data_cache_path,
]
self.train_dataset = OpenFoldDataset(
datasets=datasets,
probabilities=probabilities,
epoch_len=self.train_epoch_len,
prot_data_cache_paths=prot_data_cache_paths,
_roll_at_init=False,
)
if(self.val_data_dir is not None): if(self.val_data_dir is not None):
self.eval_dataset = dataset_gen( self.eval_dataset = dataset_gen(
...@@ -496,6 +636,9 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -496,6 +636,9 @@ class OpenFoldDataModule(pl.LightningDataModule):
dataset = None dataset = None
if(stage == "train"): if(stage == "train"):
dataset = self.train_dataset dataset = self.train_dataset
# Filter the dataset, if necessary
dataset.reroll()
elif(stage == "eval"): elif(stage == "eval"):
dataset = self.eval_dataset dataset = self.eval_dataset
elif(stage == "predict"): elif(stage == "predict"):
......
...@@ -422,42 +422,89 @@ class DataPipeline: ...@@ -422,42 +422,89 @@ class DataPipeline:
def _parse_msa_data( def _parse_msa_data(
self, self,
alignment_dir: str, alignment_dir: str,
_alignment_index: Optional[Any] = None,
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
msa_data = {} msa_data = {}
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f) if(_alignment_index is not None):
ext = os.path.splitext(f)[-1] fp = open(os.path.join(alignment_dir, _alignment_index["db"]), "rb")
if(ext == ".a3m"): def read_msa(start, size):
with open(path, "r") as fp: fp.seek(start)
msa, deletion_matrix = parsers.parse_a3m(fp.read()) msa = fp.read(size).decode("utf-8")
data = {"msa": msa, "deletion_matrix": deletion_matrix} return msa
elif(ext == ".sto"):
with open(path, "r") as fp: for (name, start, size) in _alignment_index["files"]:
ext = os.path.splitext(name)[-1]
if(ext == ".a3m"):
msa, deletion_matrix = parsers.parse_a3m(
read_msa(start, size)
)
data = {"msa": msa, "deletion_matrix": deletion_matrix}
elif(ext == ".sto"):
msa, deletion_matrix, _ = parsers.parse_stockholm( msa, deletion_matrix, _ = parsers.parse_stockholm(
fp.read() read_msa(start, size)
) )
data = {"msa": msa, "deletion_matrix": deletion_matrix} data = {"msa": msa, "deletion_matrix": deletion_matrix}
else: else:
continue continue
msa_data[name] = data
msa_data[f] = data fp.close()
else:
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]
if(ext == ".a3m"):
with open(path, "r") as fp:
msa, deletion_matrix = parsers.parse_a3m(fp.read())
data = {"msa": msa, "deletion_matrix": deletion_matrix}
elif(ext == ".sto"):
with open(path, "r") as fp:
msa, deletion_matrix, _ = parsers.parse_stockholm(
fp.read()
)
data = {"msa": msa, "deletion_matrix": deletion_matrix}
else:
continue
msa_data[f] = data
return msa_data return msa_data
def _parse_template_hits( def _parse_template_hits(
self, self,
alignment_dir: str, alignment_dir: str,
_alignment_index: Optional[Any] = None
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
all_hits = {} all_hits = {}
for f in os.listdir(alignment_dir): if(_alignment_index is not None):
path = os.path.join(alignment_dir, f) fp = open(os.path.join(alignment_dir, _alignment_index["db"]), 'rb')
ext = os.path.splitext(f)[-1]
if(ext == ".hhr"): def read_template(start, size):
with open(path, "r") as fp: fp.seek(start)
hits = parsers.parse_hhr(fp.read()) return fp.read(size).decode("utf-8")
all_hits[f] = hits
for (name, start, size) in _alignment_index["files"]:
ext = os.path.splitext(name)[-1]
if(ext == ".hhr"):
hits = parsers.parse_hhr(read_template(start, size))
all_hits[name] = hits
fp.close()
else:
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]
if(ext == ".hhr"):
with open(path, "r") as fp:
hits = parsers.parse_hhr(fp.read())
all_hits[f] = hits
return all_hits return all_hits
...@@ -465,8 +512,9 @@ class DataPipeline: ...@@ -465,8 +512,9 @@ class DataPipeline:
self, self,
alignment_dir: str, alignment_dir: str,
input_sequence: Optional[str] = None, input_sequence: Optional[str] = None,
_alignment_index: Optional[str] = None
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
msa_data = self._parse_msa_data(alignment_dir) msa_data = self._parse_msa_data(alignment_dir, _alignment_index)
if(len(msa_data) == 0): if(len(msa_data) == 0):
if(input_sequence is None): if(input_sequence is None):
...@@ -496,6 +544,7 @@ class DataPipeline: ...@@ -496,6 +544,7 @@ class DataPipeline:
self, self,
fasta_path: str, fasta_path: str,
alignment_dir: str, alignment_dir: str,
_alignment_index: Optional[str] = None,
) -> FeatureDict: ) -> FeatureDict:
"""Assembles features for a single sequence in a FASTA file""" """Assembles features for a single sequence in a FASTA file"""
with open(fasta_path) as f: with open(fasta_path) as f:
...@@ -509,7 +558,7 @@ class DataPipeline: ...@@ -509,7 +558,7 @@ class DataPipeline:
input_description = input_descs[0] input_description = input_descs[0]
num_res = len(input_sequence) num_res = len(input_sequence)
hits = self._parse_template_hits(alignment_dir) hits = self._parse_template_hits(alignment_dir, _alignment_index)
template_features = make_template_features( template_features = make_template_features(
input_sequence, input_sequence,
hits, hits,
...@@ -522,7 +571,7 @@ class DataPipeline: ...@@ -522,7 +571,7 @@ class DataPipeline:
num_res=num_res, num_res=num_res,
) )
msa_features = self._process_msa_feats(alignment_dir, input_sequence) msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index)
return { return {
**sequence_features, **sequence_features,
...@@ -535,6 +584,7 @@ class DataPipeline: ...@@ -535,6 +584,7 @@ class DataPipeline:
mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path
alignment_dir: str, alignment_dir: str,
chain_id: Optional[str] = None, chain_id: Optional[str] = None,
_alignment_index: Optional[str] = None,
) -> FeatureDict: ) -> FeatureDict:
""" """
Assembles features for a specific chain in an mmCIF object. Assembles features for a specific chain in an mmCIF object.
...@@ -552,7 +602,7 @@ class DataPipeline: ...@@ -552,7 +602,7 @@ class DataPipeline:
mmcif_feats = make_mmcif_features(mmcif, chain_id) mmcif_feats = make_mmcif_features(mmcif, chain_id)
input_sequence = mmcif.chain_to_seqres[chain_id] input_sequence = mmcif.chain_to_seqres[chain_id]
hits = self._parse_template_hits(alignment_dir) hits = self._parse_template_hits(alignment_dir, _alignment_index)
template_features = make_template_features( template_features = make_template_features(
input_sequence, input_sequence,
hits, hits,
...@@ -560,7 +610,7 @@ class DataPipeline: ...@@ -560,7 +610,7 @@ class DataPipeline:
query_release_date=to_date(mmcif.header["release_date"]) query_release_date=to_date(mmcif.header["release_date"])
) )
msa_features = self._process_msa_feats(alignment_dir, input_sequence) msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index)
return {**mmcif_feats, **template_features, **msa_features} return {**mmcif_feats, **template_features, **msa_features}
...@@ -570,6 +620,7 @@ class DataPipeline: ...@@ -570,6 +620,7 @@ class DataPipeline:
alignment_dir: str, alignment_dir: str,
is_distillation: bool = True, is_distillation: bool = True,
chain_id: Optional[str] = None, chain_id: Optional[str] = None,
_alignment_index: Optional[str] = None,
) -> FeatureDict: ) -> FeatureDict:
""" """
Assembles features for a protein in a PDB file. Assembles features for a protein in a PDB file.
...@@ -586,14 +637,14 @@ class DataPipeline: ...@@ -586,14 +637,14 @@ class DataPipeline:
is_distillation is_distillation
) )
hits = self._parse_template_hits(alignment_dir) hits = self._parse_template_hits(alignment_dir, _alignment_index)
template_features = make_template_features( template_features = make_template_features(
input_sequence, input_sequence,
hits, hits,
self.template_featurizer, self.template_featurizer,
) )
msa_features = self._process_msa_feats(alignment_dir, input_sequence) msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index)
return {**pdb_feats, **template_features, **msa_features} return {**pdb_feats, **template_features, **msa_features}
...@@ -601,6 +652,7 @@ class DataPipeline: ...@@ -601,6 +652,7 @@ class DataPipeline:
self, self,
core_path: str, core_path: str,
alignment_dir: str, alignment_dir: str,
_alignment_index: Optional[str] = None,
) -> FeatureDict: ) -> FeatureDict:
""" """
Assembles features for a protein in a ProteinNet .core file. Assembles features for a protein in a ProteinNet .core file.
...@@ -613,7 +665,7 @@ class DataPipeline: ...@@ -613,7 +665,7 @@ class DataPipeline:
description = os.path.splitext(os.path.basename(core_path))[0].upper() description = os.path.splitext(os.path.basename(core_path))[0].upper()
core_feats = make_protein_features(protein_object, description) core_feats = make_protein_features(protein_object, description)
hits = self._parse_template_hits(alignment_dir) hits = self._parse_template_hits(alignment_dir, _alignment_index)
template_features = make_template_features( template_features = make_template_features(
input_sequence, input_sequence,
hits, hits,
......
...@@ -17,7 +17,7 @@ import torch ...@@ -17,7 +17,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from typing import Tuple from typing import Tuple
from openfold.model.primitives import Linear from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.tensor_utils import one_hot from openfold.utils.tensor_utils import one_hot
...@@ -168,8 +168,8 @@ class RecyclingEmbedder(nn.Module): ...@@ -168,8 +168,8 @@ class RecyclingEmbedder(nn.Module):
self.bins = None self.bins = None
self.linear = Linear(self.no_bins, self.c_z) self.linear = Linear(self.no_bins, self.c_z)
self.layer_norm_m = nn.LayerNorm(self.c_m) self.layer_norm_m = LayerNorm(self.c_m)
self.layer_norm_z = nn.LayerNorm(self.c_z) self.layer_norm_z = LayerNorm(self.c_z)
def forward( def forward(
self, self,
......
...@@ -13,12 +13,13 @@ ...@@ -13,12 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Tuple, Optional from typing import Tuple, Optional
from functools import partial from functools import partial
from openfold.model.primitives import Linear from openfold.model.primitives import Linear, LayerNorm
from openfold.model.dropout import DropoutRowwise, DropoutColumnwise from openfold.model.dropout import DropoutRowwise, DropoutColumnwise
from openfold.model.msa import ( from openfold.model.msa import (
MSARowAttentionWithPairBias, MSARowAttentionWithPairBias,
...@@ -35,7 +36,7 @@ from openfold.model.triangular_multiplicative_update import ( ...@@ -35,7 +36,7 @@ from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationOutgoing, TriangleMultiplicationOutgoing,
TriangleMultiplicationIncoming, TriangleMultiplicationIncoming,
) )
from openfold.utils.checkpointing import checkpoint_blocks from openfold.utils.checkpointing import checkpoint_blocks, get_checkpoint_fn
from openfold.utils.tensor_utils import chunk_layer from openfold.utils.tensor_utils import chunk_layer
...@@ -60,7 +61,7 @@ class MSATransition(nn.Module): ...@@ -60,7 +61,7 @@ class MSATransition(nn.Module):
self.c_m = c_m self.c_m = c_m
self.n = n self.n = n
self.layer_norm = nn.LayerNorm(self.c_m) self.layer_norm = LayerNorm(self.c_m)
self.linear_1 = Linear(self.c_m, self.n * self.c_m, init="relu") self.linear_1 = Linear(self.c_m, self.n * self.c_m, init="relu")
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.linear_2 = Linear(self.n * self.c_m, self.c_m, init="final") self.linear_2 = Linear(self.n * self.c_m, self.c_m, init="final")
...@@ -117,51 +118,23 @@ class MSATransition(nn.Module): ...@@ -117,51 +118,23 @@ class MSATransition(nn.Module):
return m return m
class EvoformerBlock(nn.Module): class EvoformerBlockCore(nn.Module):
def __init__( def __init__(
self, self,
c_m: int, c_m: int,
c_z: int, c_z: int,
c_hidden_msa_att: int,
c_hidden_opm: int, c_hidden_opm: int,
c_hidden_mul: int, c_hidden_mul: int,
c_hidden_pair_att: int, c_hidden_pair_att: int,
no_heads_msa: int, no_heads_msa: int,
no_heads_pair: int, no_heads_pair: int,
transition_n: int, transition_n: int,
msa_dropout: float,
pair_dropout: float, pair_dropout: float,
inf: float, inf: float,
eps: float, eps: float,
_is_extra_msa_stack: bool = False, _is_extra_msa_stack: bool = False,
): ):
super(EvoformerBlock, self).__init__() super(EvoformerBlockCore, self).__init__()
self._is_extra_msa_stack = _is_extra_msa_stack
self.msa_att_row = MSARowAttentionWithPairBias(
c_m=c_m,
c_z=c_z,
c_hidden=c_hidden_msa_att,
no_heads=no_heads_msa,
inf=inf,
)
if _is_extra_msa_stack:
self.msa_att_col = MSAColumnGlobalAttention(
c_in=c_m,
c_hidden=c_hidden_msa_att,
no_heads=no_heads_msa,
inf=inf,
eps=eps,
)
else:
self.msa_att_col = MSAColumnAttention(
c_m,
c_hidden_msa_att,
no_heads_msa,
inf=inf,
)
self.msa_transition = MSATransition( self.msa_transition = MSATransition(
c_m=c_m, c_m=c_m,
...@@ -201,7 +174,6 @@ class EvoformerBlock(nn.Module): ...@@ -201,7 +174,6 @@ class EvoformerBlock(nn.Module):
transition_n, transition_n,
) )
self.msa_dropout_layer = DropoutRowwise(msa_dropout)
self.ps_dropout_row_layer = DropoutRowwise(pair_dropout) self.ps_dropout_row_layer = DropoutRowwise(pair_dropout)
self.ps_dropout_col_layer = DropoutColumnwise(pair_dropout) self.ps_dropout_col_layer = DropoutColumnwise(pair_dropout)
...@@ -213,17 +185,13 @@ class EvoformerBlock(nn.Module): ...@@ -213,17 +185,13 @@ class EvoformerBlock(nn.Module):
pair_mask: torch.Tensor, pair_mask: torch.Tensor,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
_mask_trans: bool = True, _mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# DeepMind doesn't mask these transitions in the source, so _mask_trans # DeepMind doesn't mask these transitions in the source, so _mask_trans
# should be disabled to better approximate the exact activations of # should be disabled to better approximate the exact activations of
# the original. # the original.
msa_trans_mask = msa_mask if _mask_trans else None msa_trans_mask = msa_mask if _mask_trans else None
pair_trans_mask = pair_mask if _mask_trans else None pair_trans_mask = pair_mask if _mask_trans else None
m = m + self.msa_dropout_layer(
self.msa_att_row(m, z=z, mask=msa_mask, chunk_size=chunk_size)
)
m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size)
m = m + self.msa_transition( m = m + self.msa_transition(
m, mask=msa_trans_mask, chunk_size=chunk_size m, mask=msa_trans_mask, chunk_size=chunk_size
) )
...@@ -245,6 +213,175 @@ class EvoformerBlock(nn.Module): ...@@ -245,6 +213,175 @@ class EvoformerBlock(nn.Module):
return m, z return m, z
class EvoformerBlock(nn.Module):
def __init__(self,
c_m: int,
c_z: int,
c_hidden_msa_att: int,
c_hidden_opm: int,
c_hidden_mul: int,
c_hidden_pair_att: int,
no_heads_msa: int,
no_heads_pair: int,
transition_n: int,
msa_dropout: float,
pair_dropout: float,
inf: float,
eps: float,
):
super(EvoformerBlock, self).__init__()
self.msa_att_row = MSARowAttentionWithPairBias(
c_m=c_m,
c_z=c_z,
c_hidden=c_hidden_msa_att,
no_heads=no_heads_msa,
inf=inf,
)
self.msa_att_col = MSAColumnAttention(
c_m,
c_hidden_msa_att,
no_heads_msa,
inf=inf,
)
self.msa_dropout_layer = DropoutRowwise(msa_dropout)
self.core = EvoformerBlockCore(
c_m=c_m,
c_z=c_z,
c_hidden_opm=c_hidden_opm,
c_hidden_mul=c_hidden_mul,
c_hidden_pair_att=c_hidden_pair_att,
no_heads_msa=no_heads_msa,
no_heads_pair=no_heads_pair,
transition_n=transition_n,
pair_dropout=pair_dropout,
inf=inf,
eps=eps,
)
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
m = m + self.msa_dropout_layer(
self.msa_att_row(m, z=z, mask=msa_mask, chunk_size=chunk_size)
)
m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size)
m, z = self.core(
m,
z,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
return m, z
class ExtraMSABlock(nn.Module):
"""
Almost identical to the standard EvoformerBlock, except in that the
ExtraMSABlock uses GlobalAttention for MSA column attention and
requires more fine-grained control over checkpointing. Separated from
its twin to preserve the TorchScript-ability of the latter.
"""
def __init__(self,
c_m: int,
c_z: int,
c_hidden_msa_att: int,
c_hidden_opm: int,
c_hidden_mul: int,
c_hidden_pair_att: int,
no_heads_msa: int,
no_heads_pair: int,
transition_n: int,
msa_dropout: float,
pair_dropout: float,
inf: float,
eps: float,
ckpt: bool,
):
super(ExtraMSABlock, self).__init__()
self.ckpt = ckpt
self.msa_att_row = MSARowAttentionWithPairBias(
c_m=c_m,
c_z=c_z,
c_hidden=c_hidden_msa_att,
no_heads=no_heads_msa,
inf=inf,
)
self.msa_att_col = MSAColumnGlobalAttention(
c_in=c_m,
c_hidden=c_hidden_msa_att,
no_heads=no_heads_msa,
inf=inf,
eps=eps,
)
self.msa_dropout_layer = DropoutRowwise(msa_dropout)
self.core = EvoformerBlockCore(
c_m=c_m,
c_z=c_z,
c_hidden_opm=c_hidden_opm,
c_hidden_mul=c_hidden_mul,
c_hidden_pair_att=c_hidden_pair_att,
no_heads_msa=no_heads_msa,
no_heads_pair=no_heads_pair,
transition_n=transition_n,
pair_dropout=pair_dropout,
inf=inf,
eps=eps,
)
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
_chunk_logits: Optional[int] = 1024,
) -> Tuple[torch.Tensor, torch.Tensor]:
m = m + self.msa_dropout_layer(
self.msa_att_row(
m.clone(),
z=z.clone(),
mask=msa_mask,
chunk_size=chunk_size,
_chunk_logits=_chunk_logits,
_checkpoint_chunks=
self.ckpt if torch.is_grad_enabled() else False,
)
)
def fn(m, z):
m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size)
m, z = self.core(
m, z, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size
)
return m, z
if(torch.is_grad_enabled() and self.ckpt):
checkpoint_fn = get_checkpoint_fn()
m, z = checkpoint_fn(fn, m, z)
else:
m, z = fn(m, z)
return m, z
class EvoformerStack(nn.Module): class EvoformerStack(nn.Module):
""" """
Main Evoformer trunk. Main Evoformer trunk.
...@@ -271,7 +408,6 @@ class EvoformerStack(nn.Module): ...@@ -271,7 +408,6 @@ class EvoformerStack(nn.Module):
inf: float, inf: float,
eps: float, eps: float,
clear_cache_between_blocks: bool = False, clear_cache_between_blocks: bool = False,
_is_extra_msa_stack: bool = False,
**kwargs, **kwargs,
): ):
""" """
...@@ -313,7 +449,6 @@ class EvoformerStack(nn.Module): ...@@ -313,7 +449,6 @@ class EvoformerStack(nn.Module):
self.blocks_per_ckpt = blocks_per_ckpt self.blocks_per_ckpt = blocks_per_ckpt
self.clear_cache_between_blocks = clear_cache_between_blocks self.clear_cache_between_blocks = clear_cache_between_blocks
self._is_extra_msa_stack = _is_extra_msa_stack
self.blocks = nn.ModuleList() self.blocks = nn.ModuleList()
...@@ -332,15 +467,12 @@ class EvoformerStack(nn.Module): ...@@ -332,15 +467,12 @@ class EvoformerStack(nn.Module):
pair_dropout=pair_dropout, pair_dropout=pair_dropout,
inf=inf, inf=inf,
eps=eps, eps=eps,
_is_extra_msa_stack=_is_extra_msa_stack,
) )
self.blocks.append(block) self.blocks.append(block)
if not self._is_extra_msa_stack: self.linear = Linear(c_m, c_s)
self.linear = Linear(c_m, c_s)
def forward( def forward(self,
self,
m: torch.Tensor, m: torch.Tensor,
z: torch.Tensor, z: torch.Tensor,
msa_mask: torch.Tensor, msa_mask: torch.Tensor,
...@@ -390,13 +522,8 @@ class EvoformerStack(nn.Module): ...@@ -390,13 +522,8 @@ class EvoformerStack(nn.Module):
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None, blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
) )
s = None s = self.linear(m[..., 0, :, :])
if not self._is_extra_msa_stack:
seq_dim = -3
index = torch.tensor([0], device=m.device)
s = self.linear(torch.index_select(m, dim=seq_dim, index=index))
s = s.squeeze(seq_dim)
return m, z, s return m, z, s
...@@ -405,8 +532,7 @@ class ExtraMSAStack(nn.Module): ...@@ -405,8 +532,7 @@ class ExtraMSAStack(nn.Module):
Implements Algorithm 18. Implements Algorithm 18.
""" """
def __init__( def __init__(self,
self,
c_m: int, c_m: int,
c_z: int, c_z: int,
c_hidden_msa_att: int, c_hidden_msa_att: int,
...@@ -419,38 +545,38 @@ class ExtraMSAStack(nn.Module): ...@@ -419,38 +545,38 @@ class ExtraMSAStack(nn.Module):
transition_n: int, transition_n: int,
msa_dropout: float, msa_dropout: float,
pair_dropout: float, pair_dropout: float,
blocks_per_ckpt: int,
inf: float, inf: float,
eps: float, eps: float,
ckpt: bool,
clear_cache_between_blocks: bool = False, clear_cache_between_blocks: bool = False,
**kwargs, **kwargs,
): ):
super(ExtraMSAStack, self).__init__() super(ExtraMSAStack, self).__init__()
c_s = None self.clear_cache_between_blocks = clear_cache_between_blocks
self.stack = EvoformerStack(
c_m=c_m,
c_z=c_z,
c_hidden_msa_att=c_hidden_msa_att,
c_hidden_opm=c_hidden_opm,
c_hidden_mul=c_hidden_mul,
c_hidden_pair_att=c_hidden_pair_att,
c_s=c_s,
no_heads_msa=no_heads_msa,
no_heads_pair=no_heads_pair,
no_blocks=no_blocks,
transition_n=transition_n,
msa_dropout=msa_dropout,
pair_dropout=pair_dropout,
blocks_per_ckpt=blocks_per_ckpt,
inf=inf,
eps=eps,
clear_cache_between_blocks=clear_cache_between_blocks,
_is_extra_msa_stack=True,
)
def forward( self.blocks = nn.ModuleList()
self,
for _ in range(no_blocks):
block = ExtraMSABlock(
c_m=c_m,
c_z=c_z,
c_hidden_msa_att=c_hidden_msa_att,
c_hidden_opm=c_hidden_opm,
c_hidden_mul=c_hidden_mul,
c_hidden_pair_att=c_hidden_pair_att,
no_heads_msa=no_heads_msa,
no_heads_pair=no_heads_pair,
transition_n=transition_n,
msa_dropout=msa_dropout,
pair_dropout=pair_dropout,
inf=inf,
eps=eps,
ckpt=False,
)
self.blocks.append(block)
def forward(self,
m: torch.Tensor, m: torch.Tensor,
z: torch.Tensor, z: torch.Tensor,
chunk_size: int, chunk_size: int,
...@@ -470,13 +596,28 @@ class ExtraMSAStack(nn.Module): ...@@ -470,13 +596,28 @@ class ExtraMSAStack(nn.Module):
Optional [*, N_res, N_res] pair mask Optional [*, N_res, N_res] pair mask
Returns: Returns:
[*, N_res, N_res, C_z] pair update [*, N_res, N_res, C_z] pair update
""" """
_, z, _ = self.stack( #checkpoint_fn = get_checkpoint_fn()
m, #blocks = [
z, # partial(b, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size, _chunk_logits=None) for b in self.blocks
msa_mask=msa_mask, #]
pair_mask=pair_mask,
chunk_size=chunk_size, #def dodo(b, *args):
_mask_trans=_mask_trans, # torch.cuda.empty_cache()
) # return b(*args)
#blocks = [partial(dodo, b) for b in blocks]
#for b in blocks:
# if(torch.is_grad_enabled()):
# m, z = checkpoint_fn(b, *(m, z))
# else:
# m, z = b(m, z)
for b in self.blocks:
m, z = b(m, z, msa_mask, pair_mask, chunk_size=chunk_size)
if(self.clear_cache_between_blocks):
torch.cuda.empty_cache()
return z return z
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from openfold.model.primitives import Linear from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.loss import ( from openfold.utils.loss import (
compute_plddt, compute_plddt,
compute_tm, compute_tm,
...@@ -96,7 +96,7 @@ class PerResidueLDDTCaPredictor(nn.Module): ...@@ -96,7 +96,7 @@ class PerResidueLDDTCaPredictor(nn.Module):
self.c_in = c_in self.c_in = c_in
self.c_hidden = c_hidden self.c_hidden = c_hidden
self.layer_norm = nn.LayerNorm(self.c_in) self.layer_norm = LayerNorm(self.c_in)
self.linear_1 = Linear(self.c_in, self.c_hidden, init="relu") self.linear_1 = Linear(self.c_in, self.c_hidden, init="relu")
self.linear_2 = Linear(self.c_hidden, self.c_hidden, init="relu") self.linear_2 = Linear(self.c_hidden, self.c_hidden, init="relu")
......
...@@ -134,7 +134,7 @@ class AlphaFold(nn.Module): ...@@ -134,7 +134,7 @@ class AlphaFold(nn.Module):
inf=self.config.template.inf, inf=self.config.template.inf,
eps=self.config.template.eps, eps=self.config.template.eps,
**self.config.template.distogram, **self.config.template.distogram,
) ).to(z.dtype)
t = self.template_pair_embedder(t) t = self.template_pair_embedder(t)
single_template_embeds.update({"pair": t}) single_template_embeds.update({"pair": t})
...@@ -149,7 +149,7 @@ class AlphaFold(nn.Module): ...@@ -149,7 +149,7 @@ class AlphaFold(nn.Module):
# [*, S_t, N, N, C_z] # [*, S_t, N, N, C_z]
t = self.template_pair_stack( t = self.template_pair_stack(
template_embeds["pair"], template_embeds["pair"],
pair_mask.unsqueeze(-3), pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
_mask_trans=self.config._mask_trans, _mask_trans=self.config._mask_trans,
) )
...@@ -158,7 +158,7 @@ class AlphaFold(nn.Module): ...@@ -158,7 +158,7 @@ class AlphaFold(nn.Module):
t = self.template_pointwise_att( t = self.template_pointwise_att(
t, t,
z, z,
template_mask=batch["template_mask"], template_mask=batch["template_mask"].to(dtype=z.dtype),
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
) )
t = t * (torch.sum(batch["template_mask"]) > 0) t = t * (torch.sum(batch["template_mask"]) > 0)
...@@ -175,6 +175,12 @@ class AlphaFold(nn.Module): ...@@ -175,6 +175,12 @@ class AlphaFold(nn.Module):
# Primary output dictionary # Primary output dictionary
outputs = {} outputs = {}
# This needs to be done manually for DeepSpeed's sake
dtype = next(self.parameters()).dtype
for k in feats:
if(feats[k].dtype == torch.float32):
feats[k] = feats[k].to(dtype=dtype)
# Grab some data about the input # Grab some data about the input
batch_dims = feats["target_feat"].shape[:-2] batch_dims = feats["target_feat"].shape[:-2]
no_batch_dims = len(batch_dims) no_batch_dims = len(batch_dims)
...@@ -217,7 +223,9 @@ class AlphaFold(nn.Module): ...@@ -217,7 +223,9 @@ class AlphaFold(nn.Module):
requires_grad=False, requires_grad=False,
) )
x_prev = pseudo_beta_fn(feats["aatype"], x_prev, None) x_prev = pseudo_beta_fn(
feats["aatype"], x_prev, None
).to(dtype=z.dtype)
# m_1_prev_emb: [*, N, C_m] # m_1_prev_emb: [*, N, C_m]
# z_prev_emb: [*, N, N, C_z] # z_prev_emb: [*, N, N, C_z]
...@@ -246,34 +254,32 @@ class AlphaFold(nn.Module): ...@@ -246,34 +254,32 @@ class AlphaFold(nn.Module):
# Embed the templates + merge with MSA/pair embeddings # Embed the templates + merge with MSA/pair embeddings
if self.config.template.enabled: if self.config.template.enabled:
template_mask = feats["template_mask"] template_feats = {
if(torch.any(template_mask)): k: v for k, v in feats.items() if k.startswith("template_")
template_feats = { }
k: v for k, v in feats.items() if k.startswith("template_") template_embeds = self.embed_templates(
} template_feats,
template_embeds = self.embed_templates( z,
template_feats, pair_mask.to(dtype=z.dtype),
z, no_batch_dims,
pair_mask, )
no_batch_dims,
)
# [*, N, N, C_z] # [*, N, N, C_z]
z = z + template_embeds["template_pair_embedding"] z = z + template_embeds["template_pair_embedding"]
if self.config.template.embed_angles: if self.config.template.embed_angles:
# [*, S = S_c + S_t, N, C_m] # [*, S = S_c + S_t, N, C_m]
m = torch.cat( m = torch.cat(
[m, template_embeds["template_angle_embedding"]], [m, template_embeds["template_angle_embedding"]],
dim=-3 dim=-3
) )
# [*, S, N] # [*, S, N]
torsion_angles_mask = feats["template_torsion_angles_mask"] torsion_angles_mask = feats["template_torsion_angles_mask"]
msa_mask = torch.cat( msa_mask = torch.cat(
[feats["msa_mask"], torsion_angles_mask[..., 2]], [feats["msa_mask"], torsion_angles_mask[..., 2]],
dim=-2 dim=-2
) )
# Embed extra MSA features + merge with pairwise embeddings # Embed extra MSA features + merge with pairwise embeddings
if self.config.extra_msa.enabled: if self.config.extra_msa.enabled:
...@@ -284,9 +290,9 @@ class AlphaFold(nn.Module): ...@@ -284,9 +290,9 @@ class AlphaFold(nn.Module):
z = self.extra_msa_stack( z = self.extra_msa_stack(
a, a,
z, z,
msa_mask=feats["extra_msa_mask"], msa_mask=feats["extra_msa_mask"].to(dtype=a.dtype),
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
pair_mask=pair_mask, pair_mask=pair_mask.to(dtype=z.dtype),
_mask_trans=self.config._mask_trans, _mask_trans=self.config._mask_trans,
) )
...@@ -297,8 +303,8 @@ class AlphaFold(nn.Module): ...@@ -297,8 +303,8 @@ class AlphaFold(nn.Module):
m, z, s = self.evoformer( m, z, s = self.evoformer(
m, m,
z, z,
msa_mask=msa_mask, msa_mask=msa_mask.to(dtype=m.dtype),
pair_mask=pair_mask, pair_mask=pair_mask.to(dtype=z.dtype),
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
_mask_trans=self.config._mask_trans, _mask_trans=self.config._mask_trans,
) )
...@@ -312,7 +318,7 @@ class AlphaFold(nn.Module): ...@@ -312,7 +318,7 @@ class AlphaFold(nn.Module):
s, s,
z, z,
feats["aatype"], feats["aatype"],
mask=feats["seq_mask"], mask=feats["seq_mask"].to(dtype=s.dtype),
) )
outputs["final_atom_positions"] = atom14_to_atom37( outputs["final_atom_positions"] = atom14_to_atom37(
outputs["sm"]["positions"][-1], feats outputs["sm"]["positions"][-1], feats
...@@ -336,7 +342,9 @@ class AlphaFold(nn.Module): ...@@ -336,7 +342,9 @@ class AlphaFold(nn.Module):
def _disable_activation_checkpointing(self): def _disable_activation_checkpointing(self):
self.template_pair_stack.blocks_per_ckpt = None self.template_pair_stack.blocks_per_ckpt = None
self.evoformer.blocks_per_ckpt = None self.evoformer.blocks_per_ckpt = None
self.extra_msa_stack.stack.blocks_per_ckpt = None
for b in self.extra_msa_stack.blocks:
b.ckpt = False
def _enable_activation_checkpointing(self): def _enable_activation_checkpointing(self):
self.template_pair_stack.blocks_per_ckpt = ( self.template_pair_stack.blocks_per_ckpt = (
...@@ -345,9 +353,9 @@ class AlphaFold(nn.Module): ...@@ -345,9 +353,9 @@ class AlphaFold(nn.Module):
self.evoformer.blocks_per_ckpt = ( self.evoformer.blocks_per_ckpt = (
self.config.evoformer_stack.blocks_per_ckpt self.config.evoformer_stack.blocks_per_ckpt
) )
self.extra_msa_stack.stack.blocks_per_ckpt = (
self.config.extra_msa.extra_msa_stack.blocks_per_ckpt for b in self.extra_msa_stack.blocks:
) b.ckpt = self.config.extra_msa.extra_msa_stack.ckpt
def forward(self, batch): def forward(self, batch):
""" """
......
...@@ -16,9 +16,16 @@ ...@@ -16,9 +16,16 @@
import math import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Optional, List from typing import Optional, List, Tuple
from openfold.model.primitives import Linear, Attention, GlobalAttention from openfold.model.primitives import (
Linear,
LayerNorm,
Attention,
GlobalAttention,
_attention_chunked_trainable,
)
from openfold.utils.checkpointing import get_checkpoint_fn
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
chunk_layer, chunk_layer,
permute_final_dims, permute_final_dims,
...@@ -61,16 +68,16 @@ class MSAAttention(nn.Module): ...@@ -61,16 +68,16 @@ class MSAAttention(nn.Module):
self.c_z = c_z self.c_z = c_z
self.inf = inf self.inf = inf
self.layer_norm_m = nn.LayerNorm(self.c_in) self.layer_norm_m = LayerNorm(self.c_in)
self.layer_norm_z = None self.layer_norm_z = None
self.linear_z = None self.linear_z = None
if self.pair_bias: if self.pair_bias:
self.layer_norm_z = nn.LayerNorm(self.c_z) self.layer_norm_z = LayerNorm(self.c_z)
self.linear_z = Linear( self.linear_z = Linear(
self.c_z, self.no_heads, bias=False, init="normal" self.c_z, self.no_heads, bias=False, init="normal"
) )
self.mha = Attention( self.mha = Attention(
self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads
) )
...@@ -83,32 +90,16 @@ class MSAAttention(nn.Module): ...@@ -83,32 +90,16 @@ class MSAAttention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
return chunk_layer( return chunk_layer(
self.mha, self.mha,
{"q_x": m, "k_x": m, "v_x": m, "biases": biases}, {"q_x": m, "kv_x": m, "biases": biases},
chunk_size=chunk_size, chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2]), no_batch_dims=len(m.shape[:-2]),
) )
def forward(self, def _prep_inputs(self,
m: torch.Tensor, m: torch.Tensor,
z: Optional[torch.Tensor] = None, z: Optional[torch.Tensor],
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor]
chunk_size: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding. Required only if
pair_bias is True
mask:
[*, N_seq, N_res] MSA mask
chunk_size:
Size of chunks into which the inputs are split along their
batch dimensions. A low value decreases memory overhead at the
cost of slower execution. Chunking is not performed by default.
"""
# [*, N_seq, N_res, C_m] # [*, N_seq, N_res, C_m]
m = self.layer_norm_m(m) m = self.layer_norm_m(m)
...@@ -120,16 +111,14 @@ class MSAAttention(nn.Module): ...@@ -120,16 +111,14 @@ class MSAAttention(nn.Module):
) )
# [*, N_seq, 1, 1, N_res] # [*, N_seq, 1, 1, N_res]
bias = (self.inf * (mask - 1))[..., :, None, None, :] mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
# This step simply returns a larger view of the bias, and does not # This step simply returns a larger view of the bias, and does not
# consume additional memory. # consume additional memory.
# [*, N_seq, no_heads, N_res, N_res] # [*, N_seq, no_heads, N_res, N_res]
bias = bias.expand( #bias = bias.expand(
((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1) # ((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1)
) #)
biases = [bias]
if (self.pair_bias and if (self.pair_bias and
z is not None and # For the z is not None and # For the
...@@ -138,19 +127,98 @@ class MSAAttention(nn.Module): ...@@ -138,19 +127,98 @@ class MSAAttention(nn.Module):
): ):
# [*, N_res, N_res, C_z] # [*, N_res, N_res, C_z]
z = self.layer_norm_z(z) z = self.layer_norm_z(z)
# [*, N_res, N_res, no_heads] # [*, N_res, N_res, no_heads]
z = self.linear_z(z) z = self.linear_z(z)
# [*, 1, no_heads, N_res, N_res] # [*, 1, no_heads, N_res, N_res]
z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4) z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4)
return m, mask_bias, z
@torch.jit.ignore
def _chunked_msa_attn(self,
m: torch.Tensor,
z: Optional[torch.Tensor],
mask: Optional[torch.Tensor],
chunk_logits: int,
checkpoint: bool,
) -> torch.Tensor:
MSA_DIM = -4
def _get_qkv(m, z):
m, mask_bias, z = self._prep_inputs(m, z, mask)
q, k, v = self.mha._prep_qkv(m, m)
return m, q, k, v, mask_bias, z
checkpoint_fn = get_checkpoint_fn()
if(torch.is_grad_enabled() and checkpoint):
m, q, k, v, mask_bias, z = checkpoint_fn(_get_qkv, m, z)
else:
m, q, k, v, mask_bias, z = _get_qkv(m, z)
o = _attention_chunked_trainable(
query=q,
key=k,
value=v,
biases=[mask_bias, z],
chunk_size=chunk_logits,
chunk_dim=MSA_DIM,
checkpoint=checkpoint,
)
if(torch.is_grad_enabled() and checkpoint):
# Storing an additional m here is far from ideal
m = checkpoint_fn(self.mha._wrap_up, o, m)
else:
m = self.mha._wrap_up(o, m)
return m
def forward(self,
m: torch.Tensor,
z: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
_chunk_logits: Optional[int] = None,
_checkpoint_chunks: Optional[bool] = None,
) -> torch.Tensor:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding. Required only if
pair_bias is True
mask:
[*, N_seq, N_res] MSA mask
chunk_size:
Size of chunks into which the inputs are split along their
batch dimensions. A low value decreases memory overhead at the
cost of slower execution. Chunking is not performed by default.
"""
if(_chunk_logits is not None):
return self._chunked_msa_attn(
m=m, z=z, mask=mask,
chunk_logits=_chunk_logits, checkpoint=_checkpoint_chunks
)
m, mask_bias, z = self._prep_inputs(m, z, mask)
biases = [mask_bias]
if(z is not None):
biases.append(z) biases.append(z)
if chunk_size is not None: if chunk_size is not None:
m = self._chunk(m, biases, chunk_size) m = self._chunk(m, biases, chunk_size)
else: else:
m = self.mha(q_x=m, k_x=m, v_x=m, biases=biases) m = self.mha(
q_x=m,
kv_x=m,
biases=biases
)
return m return m
......
...@@ -17,7 +17,7 @@ from typing import Optional ...@@ -17,7 +17,7 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from openfold.model.primitives import Linear from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.tensor_utils import chunk_layer from openfold.utils.tensor_utils import chunk_layer
...@@ -40,7 +40,7 @@ class PairTransition(nn.Module): ...@@ -40,7 +40,7 @@ class PairTransition(nn.Module):
self.c_z = c_z self.c_z = c_z
self.n = n self.n = n
self.layer_norm = nn.LayerNorm(self.c_z) self.layer_norm = LayerNorm(self.c_z)
self.linear_1 = Linear(self.c_z, self.n * self.c_z, init="relu") self.linear_1 = Linear(self.c_z, self.n * self.c_z, init="relu")
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.linear_2 = Linear(self.n * self.c_z, c_z, init="final") self.linear_2 = Linear(self.n * self.c_z, c_z, init="final")
......
...@@ -13,14 +13,17 @@ ...@@ -13,14 +13,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial
import math import math
from typing import Optional, Callable, List, Tuple, Sequence from typing import Optional, Callable, List, Tuple, Sequence
import numpy as np import numpy as np
import deepspeed
import torch import torch
import torch.nn as nn import torch.nn as nn
from scipy.stats import truncnorm from scipy.stats import truncnorm
from openfold.utils.checkpointing import get_checkpoint_fn
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
permute_final_dims, permute_final_dims,
flatten_final_dims, flatten_final_dims,
...@@ -164,6 +167,135 @@ class Linear(nn.Linear): ...@@ -164,6 +167,135 @@ class Linear(nn.Linear):
raise ValueError("Invalid init string.") raise ValueError("Invalid init string.")
class LayerNorm(nn.Module):
def __init__(self, c_in, eps=1e-5):
super(LayerNorm, self).__init__()
self.c_in = (c_in,)
self.eps = eps
self.weight = nn.Parameter(torch.ones(c_in))
self.bias = nn.Parameter(torch.zeros(c_in))
def forward(self, x):
d = x.dtype
if(d is torch.bfloat16 and not deepspeed.utils.is_initialized()):
with torch.cuda.amp.autocast(enabled=False):
out = nn.functional.layer_norm(
x,
self.c_in,
self.weight.to(dtype=d),
self.bias.to(dtype=d),
self.eps
)
else:
out = nn.functional.layer_norm(
x,
self.c_in,
self.weight,
self.bias,
self.eps,
)
return out
@torch.jit.ignore
def softmax(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
"""
Softmax, but without automatic casting to fp32 when the input is of
type bfloat16
"""
d = t.dtype
if(d is torch.bfloat16 and not deepspeed.utils.is_initialized()):
with torch.cuda.amp.autocast(enabled=False):
s = torch.nn.functional.softmax(t, dim=dim)
else:
s = torch.nn.functional.softmax(t, dim=dim)
return s
#@torch.jit.script
def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, biases: List[torch.Tensor]) -> torch.Tensor:
# [*, H, Q, C_hidden]
query = permute_final_dims(query, (1, 0, 2))
# [*, H, C_hidden, K]
key = permute_final_dims(key, (1, 2, 0))
# [*, H, V, C_hidden]
value = permute_final_dims(value, (1, 0, 2))
# [*, H, Q, K]
a = torch.matmul(query, key)
for b in biases:
a += b
a = softmax(a, -1)
# [*, H, Q, C_hidden]
a = torch.matmul(a, value)
# [*, Q, H, C_hidden]
a = a.transpose(-2, -3)
return a
@torch.jit.ignore
def _attention_chunked_trainable(
query, key, value, biases, chunk_size, chunk_dim, checkpoint,
):
if(checkpoint and len(biases) > 2):
raise ValueError(
"Checkpointed version permits only permits two bias terms"
)
def _checkpointable_attention(q, k, v, b1, b2):
bs = [b for b in [b1, b2] if b is not None]
return _attention(q, k, v, bs)
o_chunks = []
checkpoint_fn = get_checkpoint_fn()
count = query.shape[chunk_dim]
for start in range(0, count, chunk_size):
end = start + chunk_size
idx = [slice(None)] * len(query.shape)
idx[chunk_dim] = slice(start, end)
idx_tup = tuple(idx)
q_chunk = query[idx_tup]
k_chunk = key[idx_tup]
v_chunk = value[idx_tup]
def _slice_bias(b):
idx[chunk_dim] = (
slice(start, end) if b.shape[chunk_dim] != 1 else slice(None)
)
return b[tuple(idx)]
if(checkpoint):
bias_1_chunk, bias_2_chunk = [
_slice_bias(b) if b is not None else None
for b in (biases + [None, None])[:2]
]
o_chunk = checkpoint_fn(_checkpointable_attention,
q_chunk, k_chunk, v_chunk, bias_1_chunk, bias_2_chunk
)
else:
bias_chunks = [
_slice_bias(b) for b in biases
]
o_chunk = _attention(q_chunk, k_chunk, v_chunk, bias_chunks)
o_chunks.append(o_chunk)
o = torch.cat(o_chunks, dim=chunk_dim)
return o
class Attention(nn.Module): class Attention(nn.Module):
""" """
Standard multi-head attention using AlphaFold's default layer Standard multi-head attention using AlphaFold's default layer
...@@ -225,66 +357,34 @@ class Attention(nn.Module): ...@@ -225,66 +357,34 @@ class Attention(nn.Module):
) )
self.sigmoid = nn.Sigmoid() self.sigmoid = nn.Sigmoid()
self.softmax = nn.Softmax(dim=-1)
def forward( def _prep_qkv(self,
self, q_x: torch.Tensor,
q_x: torch.Tensor, kv_x: torch.Tensor
k_x: torch.Tensor, ) -> Tuple[
v_x: torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
biases: Optional[List[torch.Tensor]] = None, ]:
) -> torch.Tensor:
"""
Args:
q_x:
[*, Q, C_q] query data
k_x:
[*, K, C_k] key data
v_x:
[*, V, C_v] value data
Returns
[*, Q, C_q] attention update
"""
# [*, Q/K/V, H * C_hidden] # [*, Q/K/V, H * C_hidden]
q = self.linear_q(q_x) q = self.linear_q(q_x)
k = self.linear_k(k_x) k = self.linear_k(kv_x)
v = self.linear_v(v_x) v = self.linear_v(kv_x)
# [*, Q/K, H, C_hidden] # [*, Q/K, H, C_hidden]
q = q.view(q.shape[:-1] + (self.no_heads, -1)) q = q.view(q.shape[:-1] + (self.no_heads, -1))
k = k.view(k.shape[:-1] + (self.no_heads, -1)) k = k.view(k.shape[:-1] + (self.no_heads, -1))
v = v.view(v.shape[:-1] + (self.no_heads, -1)) v = v.view(v.shape[:-1] + (self.no_heads, -1))
# [*, H, Q, C_hidden] q /= math.sqrt(self.c_hidden)
q = permute_final_dims(q, (1, 0, 2))
# [*, H, C_hidden, K]
k = permute_final_dims(k, (1, 2, 0))
# [*, H, Q, K]
a = torch.matmul(q, k)
del q, k
norm = 1 / math.sqrt(self.c_hidden) # [1]
a *= norm
if biases is not None:
for b in biases:
a += b
a = self.softmax(a) return q, k, v
# [*, H, V, C_hidden] def _wrap_up(self,
v = permute_final_dims(v, (1, 0, 2)) o: torch.Tensor,
q_x: torch.Tensor
# [*, H, Q, C_hidden] ) -> torch.Tensor:
o = torch.matmul(a, v)
# [*, Q, H, C_hidden]
o = o.transpose(-2, -3)
if(self.linear_g is not None): if(self.linear_g is not None):
g = self.sigmoid(self.linear_g(q_x)) g = self.sigmoid(self.linear_g(q_x))
# [*, Q, H, C_hidden] # [*, Q, H, C_hidden]
g = g.view(g.shape[:-1] + (self.no_heads, -1)) g = g.view(g.shape[:-1] + (self.no_heads, -1))
o = o * g o = o * g
...@@ -297,6 +397,56 @@ class Attention(nn.Module): ...@@ -297,6 +397,56 @@ class Attention(nn.Module):
return o return o
def forward(
self,
q_x: torch.Tensor,
kv_x: torch.Tensor,
biases: Optional[List[torch.Tensor]] = None,
use_lma: bool = False,
q_chunk_size: Optional[int] = None,
kv_chunk_size: Optional[int] = None,
) -> torch.Tensor:
"""
Args:
q_x:
[*, Q, C_q] query data
kv_x:
[*, K, C_k] key data
biases:
List of biases that broadcast to [*, H, Q, K]
use_lma:
Whether to use low-memory attention
q_chunk_size:
Query chunk size (for LMA)
kv_chunk_size:
Key/Value chunk size (for LMA)
Returns
[*, Q, C_q] attention update
"""
if(biases is None):
biases = []
if(use_lma and (q_chunk_size is None or kv_chunk_size is None)):
raise ValueError(
"If use_lma is specified, q_chunk_size and kv_chunk_size must "
"be provided"
)
q, k, v = self._prep_qkv(q_x, kv_x)
if(use_lma):
biases = [
b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],))
for b in biases
]
o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size)
else:
o = _attention(q, k, v, biases)
o = self._wrap_up(o, q_x)
return o
class GlobalAttention(nn.Module): class GlobalAttention(nn.Module):
def __init__(self, c_in, c_hidden, no_heads, inf, eps): def __init__(self, c_in, c_hidden, no_heads, inf, eps):
...@@ -322,7 +472,6 @@ class GlobalAttention(nn.Module): ...@@ -322,7 +472,6 @@ class GlobalAttention(nn.Module):
self.linear_o = Linear(c_hidden * no_heads, c_in, init="final") self.linear_o = Linear(c_hidden * no_heads, c_in, init="final")
self.sigmoid = nn.Sigmoid() self.sigmoid = nn.Sigmoid()
self.softmax = nn.Softmax(dim=-1)
def forward(self, m: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: def forward(self, m: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
# [*, N_res, C_in] # [*, N_res, C_in]
...@@ -348,7 +497,7 @@ class GlobalAttention(nn.Module): ...@@ -348,7 +497,7 @@ class GlobalAttention(nn.Module):
) )
bias = (self.inf * (mask - 1))[..., :, None, :] bias = (self.inf * (mask - 1))[..., :, None, :]
a += bias a += bias
a = self.softmax(a) a = softmax(a)
# [*, N_res, H, C_hidden] # [*, N_res, H, C_hidden]
o = torch.matmul( o = torch.matmul(
...@@ -374,14 +523,13 @@ class GlobalAttention(nn.Module): ...@@ -374,14 +523,13 @@ class GlobalAttention(nn.Module):
return m return m
@torch.jit.script
def _lma( def _lma(
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
biases: List[torch.Tensor], biases: List[torch.Tensor],
q_chunk_size: int, q_chunk_size: int,
kv_chunk_size: int kv_chunk_size: int,
): ):
no_q, no_kv = q.shape[-3], k.shape[-3] no_q, no_kv = q.shape[-3], k.shape[-3]
...@@ -389,34 +537,34 @@ def _lma( ...@@ -389,34 +537,34 @@ def _lma(
o = q.new_zeros(q.shape) o = q.new_zeros(q.shape)
for q_s in range(0, no_q, q_chunk_size): for q_s in range(0, no_q, q_chunk_size):
q_chunk = q[..., q_s: q_s + q_chunk_size, :, :] q_chunk = q[..., q_s: q_s + q_chunk_size, :, :]
big_bias_chunks = [ large_bias_chunks = [
b[..., q_s: q_s + q_chunk_size, :] for b in biases b[..., q_s: q_s + q_chunk_size, :] for b in biases
] ]
maxes = [] maxes = []
weights = [] weights = []
values = [] values = []
for kv_s in range(0, no_kv, kv_chunk_size): for kv_s in range(0, no_kv, kv_chunk_size):
k_chunk = k[..., kv_s: kv_s + kv_chunk_size, :, :] k_chunk = k[..., kv_s: kv_s + kv_chunk_size, :, :]
v_chunk = v[..., kv_s: kv_s + kv_chunk_size, :, :] v_chunk = v[..., kv_s: kv_s + kv_chunk_size, :, :]
small_bias_chunks = [ small_bias_chunks = [
b[..., kv_s: kv_s + kv_chunk_size] for b in big_bias_chunks b[..., kv_s: kv_s + kv_chunk_size] for b in large_bias_chunks
] ]
a = torch.einsum( a = torch.einsum(
"...qhd,...khd->...hqk", q_chunk, k_chunk "...qhd,...khd->...hqk", q_chunk, k_chunk,
) )
for b in small_bias_chunks: for b in small_bias_chunks:
a += b a += b
a = a.transpose(-2, -3) a = a.transpose(-2, -3)
max_a = torch.max(a, dim=-1, keepdim=True)[0].detach() max_a = torch.max(a, dim=-1, keepdim=True)[0]
exp_a = torch.exp(a - max_a) exp_a = torch.exp(a - max_a)
exp_v = torch.einsum("...vhf,...qhv->...qhf", v_chunk, exp_a) exp_v = torch.einsum("...vhf,...qhv->...qhf", v_chunk, exp_a)
maxes.append(max_a.squeeze(-1)) maxes.append(max_a.detach().squeeze(-1))
weights.append(torch.sum(exp_a, dim=-1)) weights.append(torch.sum(exp_a, dim=-1))
values.append(exp_v) values.append(exp_v)
...@@ -437,111 +585,3 @@ def _lma( ...@@ -437,111 +585,3 @@ def _lma(
o[..., q_s: q_s + q_chunk_size, :, :] = q_chunk_out o[..., q_s: q_s + q_chunk_size, :, :] = q_chunk_out
return o return o
class LowMemoryAttention(nn.Module):
"""
Standard multi-head attention using AlphaFold's default layer
initialization. Allows multiple bias vectors. Implements Rabe and Staats'
low-memory self-attention algorithm.
"""
def __init__(
self,
c_q: int,
c_k: int,
c_v: int,
c_hidden: int,
no_heads: int,
gating: bool = True,
):
"""
Args:
c_q:
Input dimension of query data
c_k:
Input dimension of key data
c_v:
Input dimension of value data
c_hidden:
Per-head hidden dimension
no_heads:
Number of attention heads
gating:
Whether the output should be gated using query data
chunk_size:
Trades memory for better parallelization. A low value
corresponds to lower memory usage.
"""
super().__init__()
self.c_q = c_q
self.c_k = c_k
self.c_v = c_v
self.c_hidden = c_hidden
self.no_heads = no_heads
self.gating = gating
self.linear_q = Linear(
self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot"
)
self.linear_k = Linear(
self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot"
)
self.linear_v = Linear(
self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot"
)
self.linear_o = Linear(
self.c_hidden * self.no_heads, self.c_q, init="final"
)
if self.gating:
self.linear_g = Linear(
self.c_q, self.c_hidden * self.no_heads, init="gating"
)
self.sigmoid = nn.Sigmoid()
self.softmax = nn.Softmax(dim=-1)
def forward(self,
q_x: torch.Tensor,
k_x: torch.Tensor,
v_x: torch.Tensor,
q_chunk_size: int,
kv_chunk_size: int,
biases: Optional[List[torch.Tensor]] = None,
):
if(biases is None):
biases = []
else:
biases = [
b.expand(b.shape[:-2] + (q_x.shape[-2],) + (k_x.shape[-2],))
for b in biases
]
# [*, Q/K/V, H * C_hidden]
q = self.linear_q(q_x)
k = self.linear_k(k_x)
v = self.linear_v(v_x)
# [*, Q/K, H, C_hidden]
q = q.view(q.shape[:-1] + (self.no_heads, -1))
k = k.view(k.shape[:-1] + (self.no_heads, -1))
v = v.view(v.shape[:-1] + (self.no_heads, -1))
q = q / math.sqrt(q.shape[-1])
o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size)
if self.gating:
g = self.sigmoid(self.linear_g(q_x))
# [*, Q, H, C_hidden]
g = g.view(g.shape[:-1] + (self.no_heads, -1))
o = o * g
# [*, Q, H * C_hidden]
o = flatten_final_dims(o, 2)
# [*, Q, C_q]
o = self.linear_o(o)
return o
...@@ -18,7 +18,7 @@ import torch ...@@ -18,7 +18,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from typing import Optional, Tuple from typing import Optional, Tuple
from openfold.model.primitives import Linear, ipa_point_weights_init_ from openfold.model.primitives import Linear, LayerNorm, ipa_point_weights_init_
from openfold.np.residue_constants import ( from openfold.np.residue_constants import (
restype_rigid_group_default_frame, restype_rigid_group_default_frame,
restype_atom14_to_rigid_group, restype_atom14_to_rigid_group,
...@@ -298,8 +298,8 @@ class InvariantPointAttention(nn.Module): ...@@ -298,8 +298,8 @@ class InvariantPointAttention(nn.Module):
permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden] permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res] permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
) )
a = a * math.sqrt(1.0 / (3 * self.c_hidden)) a *= math.sqrt(1.0 / (3 * self.c_hidden))
a = a + (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1))) a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)))
# [*, N_res, N_res, H, P_q, 3] # [*, N_res, N_res, H, P_q, 3]
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5) pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
...@@ -323,7 +323,7 @@ class InvariantPointAttention(nn.Module): ...@@ -323,7 +323,7 @@ class InvariantPointAttention(nn.Module):
# [*, H, N_res, N_res] # [*, H, N_res, N_res]
pt_att = permute_final_dims(pt_att, (2, 0, 1)) pt_att = permute_final_dims(pt_att, (2, 0, 1))
a = a + pt_att a = a + pt_att
a = a + square_mask.unsqueeze(-3) a = a + square_mask.unsqueeze(-3)
a = self.softmax(a) a = self.softmax(a)
...@@ -331,7 +331,9 @@ class InvariantPointAttention(nn.Module): ...@@ -331,7 +331,9 @@ class InvariantPointAttention(nn.Module):
# Compute output # Compute output
################ ################
# [*, N_res, H, C_hidden] # [*, N_res, H, C_hidden]
o = torch.matmul(a, v.transpose(-2, -3)).transpose(-2, -3) o = torch.matmul(
a, v.transpose(-2, -3).to(dtype=a.dtype)
).transpose(-2, -3)
# [*, N_res, H * C_hidden] # [*, N_res, H * C_hidden]
o = flatten_final_dims(o, 2) o = flatten_final_dims(o, 2)
...@@ -360,7 +362,7 @@ class InvariantPointAttention(nn.Module): ...@@ -360,7 +362,7 @@ class InvariantPointAttention(nn.Module):
o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3) o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
# [*, N_res, H, C_z] # [*, N_res, H, C_z]
o_pair = torch.matmul(a.transpose(-2, -3), z) o_pair = torch.matmul(a.transpose(-2, -3), z.to(dtype=a.dtype))
# [*, N_res, H * C_z] # [*, N_res, H * C_z]
o_pair = flatten_final_dims(o_pair, 2) o_pair = flatten_final_dims(o_pair, 2)
...@@ -369,7 +371,7 @@ class InvariantPointAttention(nn.Module): ...@@ -369,7 +371,7 @@ class InvariantPointAttention(nn.Module):
s = self.linear_out( s = self.linear_out(
torch.cat( torch.cat(
(o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1 (o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1
) ).to(dtype=z.dtype)
) )
return s return s
...@@ -444,7 +446,7 @@ class StructureModuleTransition(nn.Module): ...@@ -444,7 +446,7 @@ class StructureModuleTransition(nn.Module):
self.layers.append(l) self.layers.append(l)
self.dropout = nn.Dropout(self.dropout_rate) self.dropout = nn.Dropout(self.dropout_rate)
self.layer_norm = nn.LayerNorm(self.c) self.layer_norm = LayerNorm(self.c)
def forward(self, s): def forward(self, s):
for l in self.layers: for l in self.layers:
...@@ -534,8 +536,8 @@ class StructureModule(nn.Module): ...@@ -534,8 +536,8 @@ class StructureModule(nn.Module):
self.atom_mask = None self.atom_mask = None
self.lit_positions = None self.lit_positions = None
self.layer_norm_s = nn.LayerNorm(self.c_s) self.layer_norm_s = LayerNorm(self.c_s)
self.layer_norm_z = nn.LayerNorm(self.c_z) self.layer_norm_z = LayerNorm(self.c_z)
self.linear_in = Linear(self.c_s, self.c_s) self.linear_in = Linear(self.c_s, self.c_s)
...@@ -551,7 +553,7 @@ class StructureModule(nn.Module): ...@@ -551,7 +553,7 @@ class StructureModule(nn.Module):
) )
self.ipa_dropout = nn.Dropout(self.dropout_rate) self.ipa_dropout = nn.Dropout(self.dropout_rate)
self.layer_norm_ipa = nn.LayerNorm(self.c_s) self.layer_norm_ipa = LayerNorm(self.c_s)
self.transition = StructureModuleTransition( self.transition = StructureModuleTransition(
self.c_s, self.c_s,
......
...@@ -19,7 +19,7 @@ from typing import Optional, List ...@@ -19,7 +19,7 @@ from typing import Optional, List
import torch import torch
import torch.nn as nn import torch.nn as nn
from openfold.model.primitives import Linear, Attention from openfold.model.primitives import Linear, LayerNorm, Attention
from openfold.model.dropout import ( from openfold.model.dropout import (
DropoutRowwise, DropoutRowwise,
DropoutColumnwise, DropoutColumnwise,
...@@ -80,8 +80,7 @@ class TemplatePointwiseAttention(nn.Module): ...@@ -80,8 +80,7 @@ class TemplatePointwiseAttention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
mha_inputs = { mha_inputs = {
"q_x": z, "q_x": z,
"k_x": t, "kv_x": t,
"v_x": t,
"biases": biases, "biases": biases,
} }
return chunk_layer( return chunk_layer(
...@@ -125,7 +124,7 @@ class TemplatePointwiseAttention(nn.Module): ...@@ -125,7 +124,7 @@ class TemplatePointwiseAttention(nn.Module):
if chunk_size is not None: if chunk_size is not None:
z = self._chunk(z, t, biases, chunk_size) z = self._chunk(z, t, biases, chunk_size)
else: else:
z = self.mha(q_x=z, k_x=t, v_x=t, biases=biases) z = self.mha(q_x=z, kv_x=t, biases=biases)
# [*, N_res, N_res, C_z] # [*, N_res, N_res, C_z]
z = z.squeeze(-2) z = z.squeeze(-2)
...@@ -292,7 +291,7 @@ class TemplatePairStack(nn.Module): ...@@ -292,7 +291,7 @@ class TemplatePairStack(nn.Module):
) )
self.blocks.append(block) self.blocks.append(block)
self.layer_norm = nn.LayerNorm(c_t) self.layer_norm = LayerNorm(c_t)
def forward( def forward(
self, self,
......
...@@ -13,14 +13,14 @@ ...@@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partialmethod from functools import partialmethod, partial
import math import math
from typing import Optional, List from typing import Optional, List
import torch import torch
import torch.nn as nn import torch.nn as nn
from openfold.model.primitives import Linear, Attention from openfold.model.primitives import Linear, LayerNorm, Attention
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
chunk_layer, chunk_layer,
permute_final_dims, permute_final_dims,
...@@ -49,7 +49,7 @@ class TriangleAttention(nn.Module): ...@@ -49,7 +49,7 @@ class TriangleAttention(nn.Module):
self.starting = starting self.starting = starting
self.inf = inf self.inf = inf
self.layer_norm = nn.LayerNorm(self.c_in) self.layer_norm = LayerNorm(self.c_in)
self.linear = Linear(c_in, self.no_heads, bias=False, init="normal") self.linear = Linear(c_in, self.no_heads, bias=False, init="normal")
...@@ -65,12 +65,11 @@ class TriangleAttention(nn.Module): ...@@ -65,12 +65,11 @@ class TriangleAttention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
mha_inputs = { mha_inputs = {
"q_x": x, "q_x": x,
"k_x": x, "kv_x": x,
"v_x": x,
"biases": biases, "biases": biases,
} }
return chunk_layer( return chunk_layer(
self.mha, partial(self.mha),
mha_inputs, mha_inputs,
chunk_size=chunk_size, chunk_size=chunk_size,
no_batch_dims=len(x.shape[:-2]), no_batch_dims=len(x.shape[:-2]),
...@@ -116,7 +115,7 @@ class TriangleAttention(nn.Module): ...@@ -116,7 +115,7 @@ class TriangleAttention(nn.Module):
if chunk_size is not None: if chunk_size is not None:
x = self._chunk(x, biases, chunk_size) x = self._chunk(x, biases, chunk_size)
else: else:
x = self.mha(q_x=x, k_x=x, v_x=x, biases=biases) x = self.mha(q_x=x, kv_x=x, biases=biases)
if not self.starting: if not self.starting:
x = x.transpose(-2, -3) x = x.transpose(-2, -3)
......
...@@ -19,7 +19,7 @@ from typing import Optional ...@@ -19,7 +19,7 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from openfold.model.primitives import Linear from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.tensor_utils import permute_final_dims from openfold.utils.tensor_utils import permute_final_dims
...@@ -47,8 +47,8 @@ class TriangleMultiplicativeUpdate(nn.Module): ...@@ -47,8 +47,8 @@ class TriangleMultiplicativeUpdate(nn.Module):
self.linear_g = Linear(self.c_z, self.c_z, init="gating") self.linear_g = Linear(self.c_z, self.c_z, init="gating")
self.linear_z = Linear(self.c_hidden, self.c_z, init="final") self.linear_z = Linear(self.c_hidden, self.c_z, init="final")
self.layer_norm_in = nn.LayerNorm(self.c_z) self.layer_norm_in = LayerNorm(self.c_z)
self.layer_norm_out = nn.LayerNorm(self.c_hidden) self.layer_norm_out = LayerNorm(self.c_hidden)
self.sigmoid = nn.Sigmoid() self.sigmoid = nn.Sigmoid()
......
...@@ -15,17 +15,27 @@ ...@@ -15,17 +15,27 @@
import deepspeed import deepspeed
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from typing import Any, Tuple, List, Callable from typing import Any, Tuple, List, Callable, Optional
BLOCK_ARG = Any BLOCK_ARG = Any
BLOCK_ARGS = List[BLOCK_ARG] BLOCK_ARGS = List[BLOCK_ARG]
def get_checkpoint_fn():
if(deepspeed.checkpointing.is_configured()):
checkpoint = deepspeed.checkpointing.checkpoint
else:
checkpoint = torch.utils.checkpoint.checkpoint
return checkpoint
@torch.jit.ignore @torch.jit.ignore
def checkpoint_blocks( def checkpoint_blocks(
blocks: List[Callable], blocks: List[Callable],
args: BLOCK_ARGS, args: BLOCK_ARGS,
blocks_per_ckpt: int, blocks_per_ckpt: Optional[int],
) -> BLOCK_ARGS: ) -> BLOCK_ARGS:
""" """
Chunk a list of blocks and run each chunk with activation Chunk a list of blocks and run each chunk with activation
...@@ -68,10 +78,7 @@ def checkpoint_blocks( ...@@ -68,10 +78,7 @@ def checkpoint_blocks(
elif blocks_per_ckpt < 1 or blocks_per_ckpt > len(blocks): elif blocks_per_ckpt < 1 or blocks_per_ckpt > len(blocks):
raise ValueError("blocks_per_ckpt must be between 1 and len(blocks)") raise ValueError("blocks_per_ckpt must be between 1 and len(blocks)")
if(deepspeed.checkpointing.is_configured()): checkpoint = get_checkpoint_fn()
checkpoint = deepspeed.checkpointing.checkpoint
else:
checkpoint = torch.utils.checkpoint.checkpoint
for s in range(0, len(blocks), blocks_per_ckpt): for s in range(0, len(blocks), blocks_per_ckpt):
e = s + blocks_per_ckpt e = s + blocks_per_ckpt
......
...@@ -282,13 +282,19 @@ def import_jax_weights_(model, npz_path, version="model_1"): ...@@ -282,13 +282,19 @@ def import_jax_weights_(model, npz_path, version="model_1"):
b.msa_att_row b.msa_att_row
), ),
col_att_name: msa_col_att_params, col_att_name: msa_col_att_params,
"msa_transition": MSATransitionParams(b.msa_transition), "msa_transition": MSATransitionParams(b.core.msa_transition),
"outer_product_mean": OuterProductMeanParams(b.outer_product_mean), "outer_product_mean":
"triangle_multiplication_outgoing": TriMulOutParams(b.tri_mul_out), OuterProductMeanParams(b.core.outer_product_mean),
"triangle_multiplication_incoming": TriMulInParams(b.tri_mul_in), "triangle_multiplication_outgoing":
"triangle_attention_starting_node": TriAttParams(b.tri_att_start), TriMulOutParams(b.core.tri_mul_out),
"triangle_attention_ending_node": TriAttParams(b.tri_att_end), "triangle_multiplication_incoming":
"pair_transition": PairTransitionParams(b.pair_transition), TriMulInParams(b.core.tri_mul_in),
"triangle_attention_starting_node":
TriAttParams(b.core.tri_att_start),
"triangle_attention_ending_node":
TriAttParams(b.core.tri_att_end),
"pair_transition":
PairTransitionParams(b.core.pair_transition),
} }
return d return d
...@@ -323,7 +329,7 @@ def import_jax_weights_(model, npz_path, version="model_1"): ...@@ -323,7 +329,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
[TemplatePairBlockParams(b) for b in tps_blocks] [TemplatePairBlockParams(b) for b in tps_blocks]
) )
ems_blocks = model.extra_msa_stack.stack.blocks ems_blocks = model.extra_msa_stack.blocks
ems_blocks_params = stacked([ExtraMSABlockParams(b) for b in ems_blocks]) ems_blocks_params = stacked([ExtraMSABlockParams(b) for b in ems_blocks])
evo_blocks = model.evoformer.blocks evo_blocks = model.evoformer.blocks
......
...@@ -43,8 +43,8 @@ def softmax_cross_entropy(logits, labels): ...@@ -43,8 +43,8 @@ def softmax_cross_entropy(logits, labels):
def sigmoid_cross_entropy(logits, labels): def sigmoid_cross_entropy(logits, labels):
log_p = torch.nn.functional.logsigmoid(logits) log_p = torch.log(torch.sigmoid(logits))
log_not_p = torch.nn.functional.logsigmoid(-logits) log_not_p = torch.log(torch.sigmoid(-logits))
loss = -labels * log_p - (1 - labels) * log_not_p loss = -labels * log_p - (1 - labels) * log_not_p
return loss return loss
...@@ -329,26 +329,15 @@ def compute_plddt(logits: torch.Tensor) -> torch.Tensor: ...@@ -329,26 +329,15 @@ def compute_plddt(logits: torch.Tensor) -> torch.Tensor:
return pred_lddt_ca * 100 return pred_lddt_ca * 100
def lddt_loss( def lddt(
logits: torch.Tensor,
all_atom_pred_pos: torch.Tensor, all_atom_pred_pos: torch.Tensor,
all_atom_positions: torch.Tensor, all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor, all_atom_mask: torch.Tensor,
resolution: torch.Tensor,
cutoff: float = 15.0, cutoff: float = 15.0,
no_bins: int = 50,
min_resolution: float = 0.1,
max_resolution: float = 3.0,
eps: float = 1e-10, eps: float = 1e-10,
**kwargs, per_residue: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
n = all_atom_mask.shape[-2] n = all_atom_mask.shape[-2]
ca_pos = residue_constants.atom_order["CA"]
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim
dmat_true = torch.sqrt( dmat_true = torch.sqrt(
eps eps
+ torch.sum( + torch.sum(
...@@ -389,8 +378,63 @@ def lddt_loss( ...@@ -389,8 +378,63 @@ def lddt_loss(
) )
score = score * 0.25 score = score * 0.25
norm = 1.0 / (eps + torch.sum(dists_to_score, dim=-1)) dims = (-1,) if per_residue else (-2, -1)
score = norm * (eps + torch.sum(dists_to_score * score, dim=-1)) norm = 1.0 / (eps + torch.sum(dists_to_score, dim=dims))
score = norm * (eps + torch.sum(dists_to_score * score, dim=dims))
return score
def lddt_ca(
all_atom_pred_pos: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
cutoff: float = 15.0,
eps: float = 1e-10,
per_residue: bool = True,
) -> torch.Tensor:
ca_pos = residue_constants.atom_order["CA"]
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim
return lddt(
all_atom_pred_pos,
all_atom_positions,
all_atom_mask,
cutoff=cutoff,
eps=eps,
per_residue=per_residue,
)
def lddt_loss(
logits: torch.Tensor,
all_atom_pred_pos: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
resolution: torch.Tensor,
cutoff: float = 15.0,
no_bins: int = 50,
min_resolution: float = 0.1,
max_resolution: float = 3.0,
eps: float = 1e-10,
**kwargs,
) -> torch.Tensor:
n = all_atom_mask.shape[-2]
ca_pos = residue_constants.atom_order["CA"]
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim
score = lddt(
all_atom_pred_pos,
all_atom_positions,
all_atom_mask,
cutoff=cutoff,
eps=eps
)
score = score.detach() score = score.detach()
...@@ -1462,7 +1506,7 @@ class AlphaFoldLoss(nn.Module): ...@@ -1462,7 +1506,7 @@ class AlphaFoldLoss(nn.Module):
self.config = config self.config = config
def forward(self, out, batch): def forward(self, out, batch):
if "violation" not in out.keys() and self.config.violation.weight: if "violation" not in out.keys():
out["violation"] = find_structural_violations( out["violation"] = find_structural_violations(
batch, batch,
out["sm"]["positions"][-1], out["sm"]["positions"][-1],
...@@ -1509,22 +1553,26 @@ class AlphaFoldLoss(nn.Module): ...@@ -1509,22 +1553,26 @@ class AlphaFoldLoss(nn.Module):
out["violation"], out["violation"],
**batch, **batch,
), ),
"tm": lambda: tm_loss( }
if(self.config.tm.enabled):
loss_fns["tm"] = lambda: tm_loss(
logits=out["tm_logits"], logits=out["tm_logits"],
**{**batch, **out, **self.config.tm}, **{**batch, **out, **self.config.tm},
), )
}
cum_loss = 0. cum_loss = 0.
for loss_name, loss_fn in loss_fns.items(): for loss_name, loss_fn in loss_fns.items():
weight = self.config[loss_name].weight weight = self.config[loss_name].weight
if weight: loss = loss_fn()
loss = loss_fn() if(torch.isnan(loss) or torch.isinf(loss)):
logging.warning(f"{loss_name} loss is NaN. Skipping...")
if(torch.isnan(loss) or torch.isinf(loss)): loss = loss.new_tensor(0., requires_grad=True)
logging.warning(f"{loss_name} loss is NaN. Skipping...") cum_loss = cum_loss + weight * loss
loss = loss.new_tensor(0., requires_grad=True)
cum_loss = cum_loss + weight * loss seq_len = torch.mean(batch["seq_length"].float())
crop_len = batch["aatype"].shape[-1]
cum_loss = cum_loss * torch.sqrt(min(seq_len, crop_len))
# Scale the loss by the square root of the minimum of the crop size and # Scale the loss by the square root of the minimum of the crop size and
# the (average) sequence length. See subsection 1.9. # the (average) sequence length. See subsection 1.9.
......
...@@ -26,7 +26,7 @@ def rot_matmul( ...@@ -26,7 +26,7 @@ def rot_matmul(
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Performs matrix multiplication of two rotation matrix tensors. Written Performs matrix multiplication of two rotation matrix tensors. Written
out by hand to avoid transfer to low-precision tensor cores. out by hand to avoid AMP downcasting.
Args: Args:
a: [*, 3, 3] left multiplicand a: [*, 3, 3] left multiplicand
...@@ -86,7 +86,7 @@ def rot_vec_mul( ...@@ -86,7 +86,7 @@ def rot_vec_mul(
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Applies a rotation to a vector. Written out by hand to avoid transfer Applies a rotation to a vector. Written out by hand to avoid transfer
to low-precision tensor cores. to avoid AMP downcasting.
Args: Args:
r: [*, 3, 3] rotation matrices r: [*, 3, 3] rotation matrices
...@@ -323,6 +323,12 @@ class Rotation: ...@@ -323,6 +323,12 @@ class Rotation:
"Incorrectly shaped rotation matrix or quaternion" "Incorrectly shaped rotation matrix or quaternion"
) )
# Force full-precision
if(quats is not None):
quats = quats.to(dtype=torch.float32)
if(rot_mats is not None):
rot_mats = rot_mats.to(dtype=torch.float32)
if(quats is not None and normalize_quats): if(quats is not None and normalize_quats):
quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True) quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True)
...@@ -857,6 +863,9 @@ class Rigid: ...@@ -857,6 +863,9 @@ class Rigid:
(rots.device != trans.device)): (rots.device != trans.device)):
raise ValueError("Rots and trans incompatible") raise ValueError("Rots and trans incompatible")
# Force full precision. Happens to the rotations automatically.
trans = trans.to(dtype=torch.float32)
self._rots = rots self._rots = rots
self._trans = trans self._trans = trans
......
import argparse
from functools import partial
import logging
from multiprocessing import Pool
import os
import sys
import json
sys.path.append(".") # an innocent hack to get this to run from the top level
from tqdm import tqdm
from openfold.data.mmcif_parsing import parse
def parse_file(f, args):
with open(os.path.join(args.mmcif_dir, f), "r") as fp:
mmcif_string = fp.read()
file_id = os.path.splitext(f)[0]
mmcif = parse(file_id=file_id, mmcif_string=mmcif_string)
if mmcif.mmcif_object is None:
logging.info(f"Could not parse {f}. Skipping...")
return {}
else:
mmcif = mmcif.mmcif_object
local_data = {}
local_data["release_date"] = mmcif.header["release_date"]
local_data["no_chains"] = len(list(mmcif.structure.get_chains()))
return {file_id: local_data}
def main(args):
files = [f for f in os.listdir(args.mmcif_dir) if ".cif" in f]
fn = partial(parse_file, args=args)
data = {}
with Pool(processes=args.no_workers) as p:
with tqdm(total=len(files)) as pbar:
for d in p.imap_unordered(fn, files, chunksize=args.chunksize):
data.update(d)
pbar.update()
with open(args.output_path, "w") as fp:
fp.write(json.dumps(data, indent=4))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"mmcif_dir", type=str, help="Directory containing mmCIF files"
)
parser.add_argument(
"output_path", type=str, help="Path for .json output"
)
parser.add_argument(
"--no_workers", type=int, default=4,
help="Number of workers to use for parsing"
)
parser.add_argument(
"--chunksize", type=int, default=10,
help="How many files should be distributed to each worker at a time"
)
args = parser.parse_args()
main(args)
import argparse import argparse
from functools import partial
import json
import logging import logging
import os import os
import threading
from multiprocessing import cpu_count
from shutil import copyfile
import tempfile import tempfile
import openfold.data.mmcif_parsing as mmcif_parsing import openfold.data.mmcif_parsing as mmcif_parsing
...@@ -10,30 +15,58 @@ from openfold.np import protein, residue_constants ...@@ -10,30 +15,58 @@ from openfold.np import protein, residue_constants
from utils import add_data_args from utils import add_data_args
#python3 scripts/precompute_alignments.py mmcif_dir/ alignment_dir/ data/uniref90/uniref90.fasta data/mgnify/mgy_clusters_2018_12.fa data/pdb70/pdb70 data/pdb_mmcif/mmcif_files/ data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 --bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt --cpus 16 --jackhmmer_binary_path /home/u00u98too4mkqFBu8M357/openfold/lib/conda/envs/openfold_venv/bin/jackhmmer --hhblits_binary_path /home/u00u98too4mkqFBu8M357/openfold/lib/conda/envs/openfold_venv/bin/hhblits --hhsearch_binary_path /home/u00u98too4mkqFBu8M357/openfold/lib/conda/envs/openfold_venv/bin/hhsearch --kalign_binary_path /home/u00u98too4mkqFBu8M357/openfold/lib/conda/envs/openfold_venv/bin/kalign
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.WARNING)
def main(args): def run_seq_group_alignments(seq_groups, alignment_runner, args):
# Build the alignment tool runner dirs = set(os.listdir(args.output_dir))
alignment_runner = AlignmentRunner( for seq, names in seq_groups:
jackhmmer_binary_path=args.jackhmmer_binary_path, first_name = names[0]
hhblits_binary_path=args.hhblits_binary_path, alignment_dir = os.path.join(args.output_dir, first_name)
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path, os.makedirs(alignment_dir, exist_ok=True)
mgnify_database_path=args.mgnify_database_path, # try:
bfd_database_path=args.bfd_database_path, # os.makedirs(alignment_dir)
uniclust30_database_path=args.uniclust30_database_path, # except Exception as e:
pdb70_database_path=args.pdb70_database_path, # logging.warning(f"Failed to create directory for {first_name} with exception {e}...")
use_small_bfd=args.bfd_database_path is None, # continue
no_cpus=args.cpus,
) fd, fasta_path = tempfile.mkstemp(suffix=".fasta")
with os.fdopen(fd, 'w') as fp:
fp.write(f'>query\n{seq}')
try:
alignment_runner.run(
fasta_path, alignment_dir
)
except:
logging.warning(f"Failed to run alignments for {first_name}. Skipping...")
os.remove(fasta_path)
os.rmdir(alignment_dir)
continue
os.remove(fasta_path)
for f in os.listdir(args.input_dir): for name in names[1:]:
#if(name in dirs):
# logging.warning(
# f'{name} has already been processed. Skipping...'
# )
# continue
cp_dir = os.path.join(args.output_dir, name)
os.makedirs(cp_dir, exist_ok=True)
for f in os.listdir(alignment_dir):
copyfile(os.path.join(alignment_dir, f), os.path.join(cp_dir, f))
def parse_and_align(files, alignment_runner, args):
for f in files:
path = os.path.join(args.input_dir, f) path = os.path.join(args.input_dir, f)
file_id = os.path.splitext(f)[0] file_id = os.path.splitext(f)[0]
seqs = {} seq_group_dict = {}
if(f.endswith('.cif')): if(f.endswith('.cif')):
with open(path, 'r') as fp: with open(path, 'r') as fp:
mmcif_str = fp.read() mmcif_str = fp.read()
...@@ -47,9 +80,10 @@ def main(args): ...@@ -47,9 +80,10 @@ def main(args):
else: else:
continue continue
mmcif = mmcif.mmcif_object mmcif = mmcif.mmcif_object
for k,v in mmcif.chain_to_seqres.items(): for chain_letter, seq in mmcif.chain_to_seqres.items():
chain_id = '_'.join([file_id, k]) chain_id = '_'.join([file_id, chain_letter])
seqs[chain_id] = v l = seq_group_dict.setdefault(seq, [])
l.append(chain_id)
elif(f.endswith('.fasta') or f.endswith('.fa')): elif(f.endswith('.fasta') or f.endswith('.fa')):
with open(path, 'r') as fp: with open(path, 'r') as fp:
fasta_str = fp.read() fasta_str = fp.read()
...@@ -61,7 +95,7 @@ def main(args): ...@@ -61,7 +95,7 @@ def main(args):
else: else:
logging.warning(msg) logging.warning(msg)
input_sequence = input_seqs[0] input_sequence = input_seqs[0]
seqs[file_id] = input_sequence seq_group_dict[input_sequence] = [file_id]
elif(f.endswith('.core')): elif(f.endswith('.core')):
with open(path, 'r') as fp: with open(path, 'r') as fp:
core_str = fp.read() core_str = fp.read()
...@@ -71,27 +105,114 @@ def main(args): ...@@ -71,27 +105,114 @@ def main(args):
residue_constants.restypes_with_x[aatype[i]] residue_constants.restypes_with_x[aatype[i]]
for i in range(len(aatype)) for i in range(len(aatype))
]) ])
seqs[file_id] = seq seq_group_dict[seq] = [file_id]
else: else:
continue continue
for name, seq in seqs.items(): seq_group_tuples = [(k,v) for k,v in seq_group_dict.items()]
alignment_dir = os.path.join(args.output_dir, name) run_seq_group_alignments(seq_group_tuples, alignment_runner, args)
if(os.path.isdir(alignment_dir)):
logging.info(f'{f} has already been processed. Skipping...')
continue
os.makedirs(alignment_dir)
fd, fasta_path = tempfile.mkstemp(suffix=".fasta")
with os.fdopen(fd, 'w') as fp:
fp.write(f'>query\n{seq}')
alignment_runner.run( def main(args):
fasta_path, alignment_dir # Build the alignment tool runner
) alignment_runner = AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
pdb70_database_path=args.pdb70_database_path,
use_small_bfd=args.bfd_database_path is None,
no_cpus=args.cpus_per_task,
)
os.remove(fasta_path) files = list(os.listdir(args.input_dir))
# Do some filtering
if(args.mmcif_cache is not None):
with open(args.mmcif_cache, "r") as fp:
cache = json.load(fp)
else:
cache = None
dirs = []
if(cache is not None and args.filter):
dirs = set(os.listdir(args.output_dir))
def prot_is_done(f):
prot_id = os.path.splitext(f)[0]
if(prot_id in cache):
chain_ids = cache[prot_id]["chain_ids"]
for c in chain_ids:
full_name = prot_id + "_" + c
if(not full_name in dirs):
return False
else:
return False
return True
files = [f for f in files if not prot_is_done(f)]
def split_up_arglist(arglist):
# Split up the survivors
if(os.environ.get("SLURM_JOB_NUM_NODES", 0)):
num_nodes = int(os.environ["SLURM_JOB_NUM_NODES"])
if(num_nodes > 1):
node_id = int(os.environ["SLURM_NODEID"])
logging.warning(f"Num nodes: {num_nodes}")
logging.warning(f"Node ID: {node_id}")
arglist = arglist[node_id::num_nodes]
t_arglist = []
for i in range(args.no_tasks):
t_arglist.append(arglist[i::args.no_tasks])
return t_arglist
if(cache is not None and "seqs" in next(iter(cache.values()))):
seq_group_dict = {}
for f in files:
prot_id = os.path.splitext(f)[0]
if(prot_id in cache):
prot_cache = cache[prot_id]
chains_seqs = zip(
prot_cache["chain_ids"], prot_cache["seqs"]
)
for chain, seq in chains_seqs:
chain_name = prot_id + "_" + chain
if(chain_name not in dirs):
l = seq_group_dict.setdefault(seq, [])
l.append(chain_name)
func = partial(run_seq_group_alignments,
alignment_runner=alignment_runner,
args=args
)
seq_groups = [(k,v) for k,v in seq_group_dict.items()]
# Sort them by group length so the tasks are approximately balanced
seq_groups = sorted(seq_groups, key=lambda x: len(x[1]))
task_arglist = [[a] for a in split_up_arglist(seq_groups)]
else:
func = partial(parse_and_align,
alignment_runner=alignment_runner,
args=args,
)
task_arglist = [[a] for a in split_up_arglist(files)]
threads = []
for i, task_args in enumerate(task_arglist):
print(f"Started thread {i}...")
t = threading.Thread(target=func, args=task_args)
threads.append(t)
t.start()
for t in threads:
t.join()
if __name__ == "__main__": if __name__ == "__main__":
...@@ -111,9 +232,19 @@ if __name__ == "__main__": ...@@ -111,9 +232,19 @@ if __name__ == "__main__":
help="Whether to crash on parsing errors" help="Whether to crash on parsing errors"
) )
parser.add_argument( parser.add_argument(
"--cpus", type=int, default=4, "--cpus_per_task", type=int, default=cpu_count(),
help="Number of CPUs to use" help="Number of CPUs to use"
) )
parser.add_argument(
"--mmcif_cache", type=str, default=None,
help="Path to mmCIF cache. Used to filter files to be parsed"
)
parser.add_argument(
"--no_tasks", type=int, default=1,
)
parser.add_argument(
"--filter", type=bool, default=True,
)
args = parser.parse_args() args = parser.parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment