Unverified Commit bc075004 authored by Jennifer Wei's avatar Jennifer Wei Committed by GitHub
Browse files

Merge pull request #404 from jnwei/multimer-small-edits

Type fixes and README changes for multimer branch
parents 5e0616b6 a2adb147
......@@ -5,7 +5,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
- uses: actions/setup-python@v5
- run: pip install --upgrade pip
- run: pip install flake8
- run: flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
......@@ -9,4 +9,4 @@ dist
data
openfold/resources/
tests/test_data/
cutlass/
......@@ -13,7 +13,7 @@ RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/
RUN apt-get update && apt-get install -y wget libxml2 cuda-minimal-build-11-3 libcusparse-dev-11-3 libcublas-dev-11-3 libcusolver-dev-11-3 git
RUN wget -P /tmp \
"https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh" \
"https://github.com/conda-forge/miniforge/releases/download/23.3.1-1/Miniforge3-Linux-x86_64.sh" \
&& bash /tmp/Miniforge3-Linux-x86_64.sh -b -p /opt/conda \
&& rm /tmp/Miniforge3-Linux-x86_64.sh
ENV PATH /opt/conda/bin:$PATH
......
......@@ -595,4 +595,4 @@ If you use OpenProteinSet, please also cite:
primaryClass={q-bio.BM}
}
```
Any work that cites OpenFold should also cite AlphaFold.
Any work that cites OpenFold should also cite [AlphaFold](https://www.nature.com/articles/s41586-021-03819-2) and [AlphaFold-Multimer](https://www.biorxiv.org/content/10.1101/2021.10.04.463034v1) if applicable.
......@@ -110,12 +110,12 @@ def make_sequence_features(
)
features["between_segment_residues"] = np.zeros((num_res,), dtype=np.int32)
features["domain_name"] = np.array(
[description.encode("utf-8")], dtype=np.object_
[description.encode("utf-8")], dtype=object
)
features["residue_index"] = np.array(range(num_res), dtype=np.int32)
features["seq_length"] = np.array([num_res] * num_res, dtype=np.int32)
features["sequence"] = np.array(
[sequence.encode("utf-8")], dtype=np.object_
[sequence.encode("utf-8")], dtype=object
)
return features
......@@ -148,7 +148,7 @@ def make_mmcif_features(
)
mmcif_feats["release_date"] = np.array(
[mmcif_object.header["release_date"].encode("utf-8")], dtype=np.object_
[mmcif_object.header["release_date"].encode("utf-8")], dtype=object
)
mmcif_feats["is_distillation"] = np.array(0., dtype=np.float32)
......
......@@ -83,8 +83,8 @@ TEMPLATE_FEATURES = {
"template_aatype": np.int64,
"template_all_atom_mask": np.float32,
"template_all_atom_positions": np.float32,
"template_domain_names": np.object,
"template_sequence": np.object,
"template_domain_names": object,
"template_sequence": object,
"template_sum_probs": np.float32,
}
......@@ -101,8 +101,8 @@ def empty_template_feats(n_res):
"template_all_atom_positions": np.zeros(
(0, n_res, residue_constants.atom_type_num, 3), np.float32
),
"template_domain_names": np.array([''.encode()], dtype=np.object),
"template_sequence": np.array([''.encode()], dtype=np.object),
"template_domain_names": np.array([''.encode()], dtype=object),
"template_sequence": np.array([''.encode()], dtype=object),
"template_sum_probs": np.zeros((0, 1), dtype=np.float32),
}
......
......@@ -79,7 +79,7 @@ def assert_equal_nonterminal_atom_types(
"""Checks that pre- and post-minimized proteins have same atom set."""
# Ignore any terminal OXT atoms which may have been added by minimization.
oxt = residue_constants.atom_order["OXT"]
no_oxt_mask = np.ones(shape=atom_mask.shape, dtype=np.bool)
no_oxt_mask = np.ones(shape=atom_mask.shape, dtype=bool)
no_oxt_mask[..., oxt] = False
np.testing.assert_almost_equal(
ref_atom_mask[no_oxt_mask], atom_mask[no_oxt_mask]
......
......@@ -90,15 +90,15 @@ def get_optimal_transform(
def get_least_asym_entity_or_longest_length(batch, input_asym_id):
"""
First check how many subunit(s) one sequence has, if there is no tie, e.g. AABBB then select
one of the A as anchor
First check how many subunit(s) one sequence has. Select the subunit that is less
common, e.g. if the protein was AABBB then select one of the A as anchor
If there is a tie, e.g. AABB, first check which sequence is the longer/longest,
then choose one of the corresponding subunits as anchor
Args:
batch: in this funtion batch is the full ground truth features
input_asym_id: A list of aym_ids that are in the cropped input features
batch: in this function batch is the full ground truth features
input_asym_id: A list of asym_ids that are in the cropped input features
Return:
anchor_gt_asym_id: Tensor(int) selected ground truth asym_id
......@@ -126,7 +126,7 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id):
min_asym_count = min(entity_asym_count.values())
least_asym_entities = [entity for entity, count in entity_asym_count.items() if count == min_asym_count]
# If multiple entities have the least asym_id count, return those with the shortest length
# If multiple entities have the least asym_id count, return those with the longest length
if len(least_asym_entities) > 1:
max_length = max([entity_length[entity] for entity in least_asym_entities])
least_asym_entities = [entity for entity in least_asym_entities if entity_length[entity] == max_length]
......
......@@ -123,7 +123,7 @@ def parse_fasta(data):
][1:]
tags, seqs = lines[::2], lines[1::2]
tags = [t.split()[0] for t in tags]
tags = [re.split('\W| \|', t)[0] for t in tags]
return tags, seqs
......
......@@ -63,10 +63,7 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
with open(tmp_fasta_path, "w") as fp:
fp.write(f">{tag}\n{seq}")
local_alignment_dir = os.path.join(
alignment_dir,
os.path.join(alignment_dir, tag),
)
local_alignment_dir = os.path.join(alignment_dir, tag),
if args.use_precomputed_alignments is None:
logger.info(f"Generating alignments for {tag}...")
......
......@@ -113,10 +113,10 @@ else:
setup(
name='openfold',
version='1.0.1',
version='2.0.0',
description='A PyTorch reimplementation of DeepMind\'s AlphaFold 2',
author='Gustaf Ahdritz & DeepMind',
author_email='gahdritz@gmail.com',
author='OpenFold Team',
author_email='jennifer.wei@omsf.io',
license='Apache License, Version 2.0',
url='https://github.com/aqlaboratory/openfold',
packages=find_packages(exclude=["tests", "scripts"]),
......
......@@ -6,6 +6,7 @@ import sys
import unittest
import numpy as np
import torch
from openfold.config import model_config
from openfold.model.model import AlphaFold
......@@ -119,3 +120,20 @@ def fetch_alphafold_module_weights(weight_path):
"Make sure to call import_alphafold before running this function"
)
return params
def _assert_abs_diff_small_base(compare_func, expected, actual, eps):
# Helper function for comparing absolute differences of two torch tensors.
abs_diff = torch.abs(expected - actual)
err = compare_func(abs_diff)
zero_tensor = torch.tensor(0, dtype=err.dtype)
rtol = 1.6e-2 if err.dtype == torch.bfloat16 else 1.3e-6
torch.testing.assert_close(err, zero_tensor, atol=eps, rtol=rtol)
def assert_max_abs_diff_small(expected, actual, eps):
_assert_abs_diff_small_base(torch.max, expected, actual, eps)
def assert_mean_abs_diff_small(expected, actual, eps):
_assert_abs_diff_small_base(torch.mean, expected, actual, eps)
import ml_collections as mlc
consts = mlc.ConfigDict(
monomer_consts = mlc.ConfigDict(
{
"model": "model_1_ptm", # monomer:model_1_ptm, multimer: model_1_multimer_v3
"is_multimer": False, # monomer: False, multimer: True
"chunk_size": 4,
"batch_size": 2,
"n_res": 22,
"n_seq": 13,
"n_templ": 3,
"n_extra": 17,
"n_heads_extra_msa": 8,
"eps": 5e-4,
# For compatibility with DeepMind's pretrained weights, it's easiest for
# everyone if these take their real values.
"c_m": 256,
"c_z": 128,
"c_s": 384,
"c_t": 64,
"c_e": 64,
"msa_logits": 23, # monomer: 23, multimer: 22
"template_mmcif_dir": None # Set for test_multimer_datamodule
}
)
multimer_consts = mlc.ConfigDict(
{
"model": "model_1_multimer_v3", # monomer:model_1_ptm, multimer: model_1_multimer_v3
"is_multimer": True, # monomer: False, multimer: True
......@@ -24,6 +49,8 @@ consts = mlc.ConfigDict(
}
)
consts = monomer_consts
config = mlc.ConfigDict(
{
"data": {
......
......@@ -244,9 +244,6 @@ class TestDeepSpeedKernel(unittest.TestCase):
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32)
inds = np.random.randint(0, 21, (n_res,))
batch["target_feat"] = np.eye(22)[inds]
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
template_feats = {
k: v for k, v in batch.items() if k.startswith("template_")
......@@ -276,8 +273,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
)
out_repro_ds = out_repro_ds["template_pair_embedding"].cpu()
err = torch.max(torch.abs(out_repro - out_repro_ds))
self.assertTrue(err < eps, f'Error {err}')
compare_utils.assert_max_abs_diff_small(out_repro, out_repro_ds, eps)
def test_compare_model(self):
"""
......@@ -310,7 +306,8 @@ class TestDeepSpeedKernel(unittest.TestCase):
batch["residx_atom37_to_atom14"] = batch[
"residx_atom37_to_atom14"
].long()
batch["target_feat"] = torch.nn.functional.one_hot(batch["aatype"], 21).to(torch.float32)
# print(batch["target_feat"].shape)
batch["target_feat"] = torch.nn.functional.one_hot(batch["aatype"], consts.msa_logits - 1).to(torch.float32)
batch["template_all_atom_mask"] = batch["template_all_atom_masks"]
batch.update(
data_transforms.atom37_to_torsion_angles("template_")(batch)
......@@ -335,8 +332,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
out_repro = out_repro["sm"]["positions"][-1].squeeze(0)
out_repro_ds = out_repro_ds["sm"]["positions"][-1].squeeze(0)
err = torch.mean(torch.abs(out_repro - out_repro_ds))
self.assertTrue(err < eps, f'Error: {err}')
compare_utils.assert_mean_abs_diff_small(out_repro, out_repro_ds, eps)
if __name__ == "__main__":
......
......@@ -178,7 +178,7 @@ class TestEvoformerStack(unittest.TestCase):
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration"
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
params = tree_map(lambda n: n[0], params, jax.Array)
key = jax.random.PRNGKey(42)
out_gt = f.apply(params, key, activations, masks)
......@@ -200,8 +200,8 @@ class TestEvoformerStack(unittest.TestCase):
out_repro_msa = out_repro_msa.cpu()
out_repro_pair = out_repro_pair.cpu()
self.assertTrue(torch.mean(torch.abs(out_repro_msa - out_gt_msa)) < consts.eps)
self.assertTrue(torch.max(torch.abs(out_repro_pair - out_gt_pair)) < consts.eps)
compare_utils.assert_mean_abs_diff_small(out_gt_msa, out_repro_msa, consts.eps)
compare_utils.assert_max_abs_diff_small(out_gt_pair, out_repro_pair, consts.eps)
# Inplace version
out_repro_msa, out_repro_pair = model.evoformer.blocks[0](
......@@ -217,8 +217,8 @@ class TestEvoformerStack(unittest.TestCase):
out_repro_msa = out_repro_msa.cpu()
out_repro_pair = out_repro_pair.cpu()
self.assertTrue(torch.mean(torch.abs(out_repro_msa - out_gt_msa)) < consts.eps)
self.assertTrue(torch.max(torch.abs(out_repro_pair - out_gt_pair)) < consts.eps)
compare_utils.assert_mean_abs_diff_small(out_gt_msa, out_repro_msa, consts.eps)
compare_utils.assert_max_abs_diff_small(out_gt_pair, out_repro_pair, consts.eps)
class TestExtraMSAStack(unittest.TestCase):
......@@ -339,7 +339,7 @@ class TestMSATransition(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+ "msa_transition"
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
params = tree_map(lambda n: n[0], params, jax.Array)
out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt))
......@@ -354,8 +354,7 @@ class TestMSATransition(unittest.TestCase):
.cpu()
)
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps)
if __name__ == "__main__":
unittest.main()
......@@ -386,7 +386,7 @@ class TestFeats(unittest.TestCase):
torch.tensor(restype_atom14_rigid_group_positions).cuda(),
).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps)
if __name__ == "__main__":
......
......@@ -79,7 +79,7 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+ "msa_row_attention"
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
params = tree_map(lambda n: n[0], params, jax.Array)
out_gt = f.apply(
params, None, msa_act, msa_mask, pair_act
......@@ -96,7 +96,7 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
)
).cpu()
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps)
compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, consts.eps)
class TestMSAColumnAttention(unittest.TestCase):
......@@ -144,7 +144,7 @@ class TestMSAColumnAttention(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+ "msa_column_attention"
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
params = tree_map(lambda n: n[0], params, jax.Array)
out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt))
......@@ -158,7 +158,7 @@ class TestMSAColumnAttention(unittest.TestCase):
)
).cpu()
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps)
compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, consts.eps)
class TestMSAColumnGlobalAttention(unittest.TestCase):
......@@ -207,7 +207,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/extra_msa_stack/"
+ "msa_column_global_attention"
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
params = tree_map(lambda n: n[0], params, jax.Array)
out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt.block_until_ready()))
......@@ -222,7 +222,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
.cpu()
)
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps)
if __name__ == "__main__":
......
......@@ -74,7 +74,7 @@ class TestOuterProductMean(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/"
+ "evoformer_iteration/outer_product_mean"
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
params = tree_map(lambda n: n[0], params, jax.Array)
out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt))
......@@ -92,7 +92,7 @@ class TestOuterProductMean(unittest.TestCase):
# Even when correct, OPM has large, precision-related errors. It gets
# a special pass from consts.eps.
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < 5e-4)
compare_utils.assert_max_abs_diff_small(out_gt, out_repro, 5e-4)
if __name__ == "__main__":
......
......@@ -69,7 +69,7 @@ class TestPairTransition(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+ "pair_transition"
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
params = tree_map(lambda n: n[0], params, jax.Array)
out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt.block_until_ready()))
......
......@@ -21,7 +21,6 @@ from openfold.utils.multi_chain_permutation import (pad_features, get_least_asym
merge_labels)
@unittest.skip("Tests need to be fixed post-refactor")
class TestPermutation(unittest.TestCase):
def setUp(self):
"""
......@@ -65,10 +64,16 @@ class TestPermutation(unittest.TestCase):
'seq_length': torch.tensor([57])
}
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch, batch['asym_id'])
self.assertIn(int(anchor_gt_asym), [1, 2])
self.assertNotIn(int(anchor_gt_asym), [3, 4, 5])
self.assertIn(int(anchor_pred_asym), [1, 2])
self.assertNotIn(int(anchor_pred_asym), [3, 4, 5])
anchor_gt_asym = int(anchor_gt_asym)
anchor_pred_asym = {int(i) for i in anchor_pred_asym}
expected_anchors = {1, 2}
expected_non_anchors = {3, 4, 5}
self.assertIn(anchor_gt_asym, expected_anchors)
self.assertNotIn(anchor_gt_asym, expected_non_anchors)
# Check that predicted anchors are within expected anchor set
self.assertEqual(anchor_pred_asym, expected_anchors & anchor_pred_asym)
self.assertEqual(set(), anchor_pred_asym & expected_non_anchors)
def test_2_permutation_pentamer(self):
batch = {
......@@ -114,6 +119,7 @@ class TestPermutation(unittest.TestCase):
self.assertIn(aligns, possible_outcome)
self.assertNotIn(aligns, wrong_outcome)
@unittest.skip("Test needs to be fixed post-refactor")
def test_3_merge_labels(self):
nres_pad = 325 - 57 # suppose the cropping size is 325
batch = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment