Commit 9ce96fb5 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Merge branch 'main' of ssh://github.com/aqlaboratory/openfold into main

parents 3d5e8740 e1c7c9e7
...@@ -86,7 +86,7 @@ MMseqs2 should be split according to the memory available on the system). ...@@ -86,7 +86,7 @@ MMseqs2 should be split according to the memory available on the system).
Alternatively, you can use raw MSAs from Alternatively, you can use raw MSAs from
[ProteinNet](https://github.com/aqlaboratory/proteinnet). After downloading [ProteinNet](https://github.com/aqlaboratory/proteinnet). After downloading
the database, use `scripts/prepare_proteinnet_msas.py` to convert the data into the database, use `scripts/prep_proteinnet_msas.py` to convert the data into
a format recognized by the OpenFold parser. The resulting directory becomes the a format recognized by the OpenFold parser. The resulting directory becomes the
`alignment_dir` used in subsequent steps. Use `scripts/unpack_proteinnet.py` to `alignment_dir` used in subsequent steps. Use `scripts/unpack_proteinnet.py` to
extract `.core` files from ProteinNet text files. extract `.core` files from ProteinNet text files.
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import itertools import itertools
from functools import reduce from functools import reduce, wraps
from operator import add from operator import add
import numpy as np import numpy as np
...@@ -71,7 +71,7 @@ def make_template_mask(protein): ...@@ -71,7 +71,7 @@ def make_template_mask(protein):
def curry1(f): def curry1(f):
"""Supply all arguments but the first.""" """Supply all arguments but the first."""
@wraps(f)
def fc(*args, **kwargs): def fc(*args, **kwargs):
return lambda x: f(x, *args, **kwargs) return lambda x: f(x, *args, **kwargs)
...@@ -145,7 +145,10 @@ def squeeze_features(protein): ...@@ -145,7 +145,10 @@ def squeeze_features(protein):
if k in protein: if k in protein:
final_dim = protein[k].shape[-1] final_dim = protein[k].shape[-1]
if isinstance(final_dim, int) and final_dim == 1: if isinstance(final_dim, int) and final_dim == 1:
protein[k] = torch.squeeze(protein[k], dim=-1) if torch.is_tensor(protein[k]):
protein[k] = torch.squeeze(protein[k], dim=-1)
else:
protein[k] = np.squeeze(protein[k], axis=-1)
for k in ["seq_length", "num_alignments"]: for k in ["seq_length", "num_alignments"]:
if k in protein: if k in protein:
...@@ -162,7 +165,9 @@ def randomly_replace_msa_with_unknown(protein, replace_proportion): ...@@ -162,7 +165,9 @@ def randomly_replace_msa_with_unknown(protein, replace_proportion):
gap_idx = 21 gap_idx = 21
msa_mask = torch.logical_and(msa_mask, protein["msa"] != gap_idx) msa_mask = torch.logical_and(msa_mask, protein["msa"] != gap_idx)
protein["msa"] = torch.where( protein["msa"] = torch.where(
msa_mask, torch.ones_like(protein["msa"]) * x_idx, protein["msa"] msa_mask,
torch.ones_like(protein["msa"]) * x_idx,
protein["msa"]
) )
aatype_mask = torch.rand(protein["aatype"].shape) < replace_proportion aatype_mask = torch.rand(protein["aatype"].shape) < replace_proportion
...@@ -199,6 +204,11 @@ def sample_msa(protein, max_seq, keep_extra, seed=None): ...@@ -199,6 +204,11 @@ def sample_msa(protein, max_seq, keep_extra, seed=None):
return protein return protein
@curry1
def add_distillation_flag(protein, distillation):
protein['is_distillation'] = distillation
return protein
@curry1 @curry1
def sample_msa_distillation(protein, max_seq): def sample_msa_distillation(protein, max_seq):
if(protein["is_distillation"] == 1): if(protein["is_distillation"] == 1):
...@@ -349,7 +359,7 @@ def make_msa_mask(protein): ...@@ -349,7 +359,7 @@ def make_msa_mask(protein):
"""Mask features are all ones, but will later be zero-padded.""" """Mask features are all ones, but will later be zero-padded."""
protein["msa_mask"] = torch.ones(protein["msa"].shape, dtype=torch.float32) protein["msa_mask"] = torch.ones(protein["msa"].shape, dtype=torch.float32)
protein["msa_row_mask"] = torch.ones( protein["msa_row_mask"] = torch.ones(
protein["msa"].shape[0], dtype=torch.float32 (protein["msa"].shape[0]), dtype=torch.float32
) )
return protein return protein
...@@ -602,17 +612,18 @@ def make_atom14_masks(protein): ...@@ -602,17 +612,18 @@ def make_atom14_masks(protein):
dtype=torch.float32, dtype=torch.float32,
device=protein["aatype"].device, device=protein["aatype"].device,
) )
protein_aatype = protein['aatype'].to(torch.long)
# create the mapping for (residx, atom14) --> atom37, i.e. an array # create the mapping for (residx, atom14) --> atom37, i.e. an array
# with shape (num_res, 14) containing the atom37 indices for this protein # with shape (num_res, 14) containing the atom37 indices for this protein
residx_atom14_to_atom37 = restype_atom14_to_atom37[protein["aatype"]] residx_atom14_to_atom37 = restype_atom14_to_atom37[protein_aatype]
residx_atom14_mask = restype_atom14_mask[protein["aatype"]] residx_atom14_mask = restype_atom14_mask[protein_aatype]
protein["atom14_atom_exists"] = residx_atom14_mask protein["atom14_atom_exists"] = residx_atom14_mask
protein["residx_atom14_to_atom37"] = residx_atom14_to_atom37.long() protein["residx_atom14_to_atom37"] = residx_atom14_to_atom37.long()
# create the gather indices for mapping back # create the gather indices for mapping back
residx_atom37_to_atom14 = restype_atom37_to_atom14[protein["aatype"]] residx_atom37_to_atom14 = restype_atom37_to_atom14[protein_aatype]
protein["residx_atom37_to_atom14"] = residx_atom37_to_atom14.long() protein["residx_atom37_to_atom14"] = residx_atom37_to_atom14.long()
# create the corresponding mask # create the corresponding mask
...@@ -626,7 +637,7 @@ def make_atom14_masks(protein): ...@@ -626,7 +637,7 @@ def make_atom14_masks(protein):
atom_type = rc.atom_order[atom_name] atom_type = rc.atom_order[atom_name]
restype_atom37_mask[restype, atom_type] = 1 restype_atom37_mask[restype, atom_type] = 1
residx_atom37_mask = restype_atom37_mask[protein["aatype"]] residx_atom37_mask = restype_atom37_mask[protein_aatype]
protein["atom37_atom_exists"] = residx_atom37_mask protein["atom37_atom_exists"] = residx_atom37_mask
return protein return protein
......
...@@ -17,7 +17,7 @@ from setuptools import setup ...@@ -17,7 +17,7 @@ from setuptools import setup
setup( setup(
name='openfold', name='openfold',
version='1.0.0', version='0.1.0',
description='A PyTorch reimplementation of DeepMind\'s AlphaFold 2', description='A PyTorch reimplementation of DeepMind\'s AlphaFold 2',
author='Gustaf Ahdritz & DeepMind', author='Gustaf Ahdritz & DeepMind',
author_email='gahdritz@gmail.com', author_email='gahdritz@gmail.com',
......
...@@ -17,3 +17,17 @@ consts = mlc.ConfigDict( ...@@ -17,3 +17,17 @@ consts = mlc.ConfigDict(
"c_e": 64, "c_e": 64,
} }
) )
config = mlc.ConfigDict(
{
"data": {
"common": {
"masked_msa": {
"profile_prob": 0.1,
"same_prob": 0.1,
"uniform_prob": 0.1,
},
}
}
}
)
...@@ -5,12 +5,15 @@ import os ...@@ -5,12 +5,15 @@ import os
import pickle import pickle
import numpy import numpy as np
import torch import torch
import unittest import unittest
from data.data_transforms import make_seq_mask from data.data_transforms import make_seq_mask, add_distillation_flag, make_all_atom_aatype, fix_templates_aatype, \
from openfold.config import model_config correct_msa_restypes, squeeze_features, randomly_replace_msa_with_unknown, MSA_FEATURE_NAMES, sample_msa, \
crop_extra_msa, delete_extra_msa, nearest_neighbor_clusters, make_msa_mask, make_hhblits_profile, make_masked_msa, \
make_msa_feat, crop_templates, make_atom14_masks
from tests.config import config
class TestDataTransforms(unittest.TestCase): class TestDataTransforms(unittest.TestCase):
...@@ -25,6 +28,221 @@ class TestDataTransforms(unittest.TestCase): ...@@ -25,6 +28,221 @@ class TestDataTransforms(unittest.TestCase):
assert 'seq_mask' in protein assert 'seq_mask' in protein
assert protein['seq_mask'].shape == torch.Size((seq.shape[0], 20)) assert protein['seq_mask'].shape == torch.Size((seq.shape[0], 20))
def test_add_distillation_flag(self):
protein = {}
protein = add_distillation_flag.__wrapped__(protein, True)
print(protein)
assert 'is_distillation' in protein
assert protein['is_distillation'] is True
def test_make_all_atom_aatype(self):
seq = torch.tensor([range(20)], dtype=torch.int64).transpose(0, 1)
seq_one_hot = torch.FloatTensor(seq.shape[0], 20).zero_()
seq_one_hot.scatter_(1, seq, 1)
protein_aatype = torch.tensor(seq_one_hot)
protein = {'aatype': protein_aatype}
protein = make_all_atom_aatype(protein)
print(protein)
assert 'all_atom_aatype' in protein
assert protein['all_atom_aatype'].shape == protein['aatype'].shape
def test_fix_templates_aatype(self):
template_seq = torch.tensor(list(range(20))*2, dtype=torch.int64)
template_seq = template_seq.unsqueeze(0).transpose(0, 1)
template_seq_one_hot = torch.FloatTensor(template_seq.shape[0], 20).zero_()
template_seq_one_hot.scatter_(1, template_seq, 1)
template_aatype = torch.tensor(template_seq_one_hot).unsqueeze(0)
protein = {'template_aatype': template_aatype}
protein = fix_templates_aatype(protein)
print(protein)
template_seq_ours = torch.tensor([[0, 4, 3, 6, 13, 7, 8, 9, 11, 10, 12, 2, 14, 5, 1, 15, 16, 19, 17, 18]*2])
assert torch.all(torch.eq(protein['template_aatype'], template_seq_ours))
def test_correct_msa_restypes(self):
with open("../test_data/features.pkl", 'rb') as file:
features = pickle.load(file)
protein = {'msa': torch.tensor(features['msa'], dtype=torch.int64)}
protein = correct_msa_restypes(protein)
print(protein)
assert torch.all(torch.eq(torch.tensor(features['msa'].shape), torch.tensor(protein['msa'].shape)))
def test_squeeze_features(self):
with open("../test_data/features.pkl", "rb") as file:
features = pickle.load(file)
print(os.path.realpath(file.name), 'Keys: ', features.keys())
features_list = [
'domain_name', 'msa', 'num_alignments', 'seq_length', 'sequence',
'superfamily', 'deletion_matrix', 'resolution',
'between_segment_residues', 'residue_index', 'template_all_atom_mask']
protein = {'aatype': torch.tensor(features['aatype'])}
for k in features_list:
if k in features:
print(k, features[k].dtype)
if k in ['domain_name', 'sequence']:
protein[k] = np.expand_dims(features[k], -1)
else:
protein[k] = torch.tensor(features[k]).unsqueeze(-1)
for k in ['seq_length', 'num_alignments']:
if k in protein:
protein[k] = torch.tensor(protein[k]).unsqueeze(0)
protein_squeezed = squeeze_features(protein)
print(protein)
for k in features_list:
if k in protein:
print(k, protein_squeezed[k].shape, features[k].shape)
assert protein_squeezed[k].shape == features[k].shape
def test_randomly_replace_msa_with_unknown(self):
with open('../test_data/features.pkl', 'rb') as file:
features = pickle.load(file)
protein = {'msa': torch.tensor(features['msa']),
'aatype': torch.argmax(torch.tensor(features['aatype']), dim=1)}
replace_proportion = 0.15
x_idx = 20
protein = randomly_replace_msa_with_unknown.__wrapped__(protein, replace_proportion)
unknown_proportion_in_msa = torch.bincount(protein['msa'].flatten()) / torch.numel(protein['msa'])
unknown_proportion_in_seq = torch.bincount(protein['aatype'].flatten()) / torch.numel(protein['aatype'])
print(protein)
print('Proportion of X in MSA: ', unknown_proportion_in_msa[x_idx])
print('Proportion of X in sequence: ', unknown_proportion_in_seq[x_idx])
def test_sample_msa(self):
with open('../test_data/features.pkl', 'rb') as file:
features = pickle.load(file)
max_seq = 1000
keep_extra = True
protein = {}
for k in MSA_FEATURE_NAMES:
if k in features:
protein[k] = torch.tensor(features[k])
protein_processed = sample_msa.__wrapped__(protein.copy(), max_seq, keep_extra)
print(protein)
for k in MSA_FEATURE_NAMES:
if k in protein and keep_extra:
assert protein_processed[k].shape[0] == min(protein[k].shape[0], max_seq)
assert 'extra_'+k in protein_processed
print('extra_'+str(k), protein_processed['extra_'+k].shape)
print('msa', protein[k].shape[0] - min(protein[k].shape[0], max_seq))
assert protein_processed['extra_'+k].shape[0] == protein[k].shape[0] - min(protein[k].shape[0], max_seq)
def test_crop_extra_msa(self):
with open('../test_data/features.pkl', 'rb') as file:
features = pickle.load(file)
max_extra_msa = 10
protein = {'extra_msa': torch.tensor(features['msa'])}
num_seq = protein["extra_msa"].shape[0]
protein = crop_extra_msa.__wrapped__(protein, max_extra_msa)
print(protein)
for k in MSA_FEATURE_NAMES:
if "extra_" + k in protein:
assert protein["extra_" + k].shape[0] == min(max_extra_msa, num_seq)
def test_delete_extra_msa(self):
protein = {'extra_msa': torch.rand((512, 100, 23))}
extra_msa_has_deletion_shape = list(protein['extra_msa'].shape)
extra_msa_has_deletion_shape[2] = 1
protein['extra_deletion_matrix'] = torch.rand(extra_msa_has_deletion_shape)
protein = delete_extra_msa(protein)
print(protein)
for k in MSA_FEATURE_NAMES:
assert 'extra_' + k not in protein
assert 'extra_msa' not in protein
def test_nearest_neighbor_clusters(self):
with gzip.open('../test_data/sample_feats.pickle.gz', 'rb') as f:
features = pickle.load(f)
protein = {'msa': torch.tensor(features['true_msa'][0], dtype=torch.int64),
'msa_mask': torch.tensor(features['msa_mask'][0], dtype=torch.int64),
'extra_msa': torch.tensor(features['extra_msa'][0], dtype=torch.int64),
'extra_msa_mask': torch.tensor(features['extra_msa_mask'][0], dtype=torch.int64)}
protein = nearest_neighbor_clusters.__wrapped__(protein, 0)
print(protein)
assert 'extra_cluster_assignment' in protein
def test_make_msa_mask(self):
with open('../test_data/features.pkl', 'rb') as file:
features = pickle.load(file)
msa_mat = torch.tensor(features['msa'])
protein = {'msa': msa_mat}
protein = make_msa_mask(protein)
print(protein)
assert 'msa_row_mask' in protein
assert protein['msa_row_mask'].shape[0] == msa_mat.shape[0]
def test_make_hhblits_profile(self):
with open('../test_data/features.pkl', 'rb') as file:
features = pickle.load(file)
protein = {'msa': torch.tensor(features['msa'], dtype=torch.int64)}
protein = make_hhblits_profile(protein)
assert 'hhblits_profile' in protein
assert protein['hhblits_profile'].shape == torch.Size((protein['msa'].shape[1], 22))
def test_make_masked_msa(self):
with open('../test_data/features.pkl', 'rb') as file:
features = pickle.load(file)
protein = {'msa': torch.tensor(features['msa'], dtype=torch.int64)}
protein = make_hhblits_profile(protein)
masked_msa_config = config.data.common.masked_msa
protein = make_masked_msa.__wrapped__(protein, masked_msa_config, replace_fraction=0.15)
print(protein)
assert 'bert_mask' in protein
assert 'true_msa' in protein
assert 'msa' in protein
assert protein['bert_mask'].sum() >= 0
assert torch.all(torch.eq(
protein['true_msa'] * (1-protein['bert_mask']), protein['msa'] * (1-protein['bert_mask'])))
def test_make_msa_feat(self):
with open('../test_data/features.pkl', 'rb') as file:
features = pickle.load(file)
protein = {'between_segment_residues': torch.tensor(features['between_segment_residues']),
'msa': torch.tensor(features['msa'], dtype=torch.int64),
'deletion_matrix': torch.tensor(features['deletion_matrix_int']),
'aatype': torch.argmax(torch.tensor(features['aatype']), dim=1)}
protein = make_msa_feat.__wrapped__(protein)
assert 'msa_feat' in protein
assert 'target_feat' in protein
assert protein['target_feat'].shape == torch.Size((protein['msa'].shape[1], 22))
assert protein['msa_feat'].shape == torch.Size((*protein['msa'].shape, 25))
def test_crop_templates(self):
with gzip.open('../test_data/sample_feats.pickle.gz', 'rb') as f:
features = pickle.load(f)
protein = {'template_aatype': torch.tensor(features['true_msa'][0]),
'template_all_atom_masks': torch.tensor(features['msa_mask'][0])}
max_templates = 2
protein = crop_templates.__wrapped__(protein, max_templates)
assert protein['template_aatype'].shape[0] == max_templates
assert protein['template_all_atom_masks'].shape[0] == max_templates
def test_make_atom14_masks(self):
with gzip.open('../test_data/sample_feats.pickle.gz', 'rb') as file:
features = pickle.load(file)
protein = {'aatype': torch.tensor(features['aatype'][0])}
protein = make_atom14_masks(protein)
print(protein)
assert 'atom14_atom_exists' in protein
assert 'residx_atom14_to_atom37' in protein
assert 'residx_atom37_to_atom14' in protein
assert 'atom37_atom_exists' in protein
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment