"worker/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "4698c0f4cc5de87437fcf268387eb99828833515"
Unverified Commit 49ab0539 authored by Jennifer Wei's avatar Jennifer Wei Committed by GitHub
Browse files

Merge pull request #407 from jnwei/pl_upgrades

Pytorch lightning upgrades
parents df4dfacb f0fc7d91
...@@ -5,7 +5,7 @@ jobs: ...@@ -5,7 +5,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-python@v4 - uses: actions/setup-python@v5
- run: pip install --upgrade pip - run: pip install --upgrade pip
- run: pip install flake8 - run: pip install flake8
- run: flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - run: flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
...@@ -9,4 +9,4 @@ dist ...@@ -9,4 +9,4 @@ dist
data data
openfold/resources/ openfold/resources/
tests/test_data/ tests/test_data/
cutlass cutlass/
...@@ -13,7 +13,7 @@ RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/ ...@@ -13,7 +13,7 @@ RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/
RUN apt-get update && apt-get install -y wget libxml2 cuda-minimal-build-11-3 libcusparse-dev-11-3 libcublas-dev-11-3 libcusolver-dev-11-3 git RUN apt-get update && apt-get install -y wget libxml2 cuda-minimal-build-11-3 libcusparse-dev-11-3 libcublas-dev-11-3 libcusolver-dev-11-3 git
RUN wget -P /tmp \ RUN wget -P /tmp \
"https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh" \ "https://github.com/conda-forge/miniforge/releases/download/23.3.1-1/Miniforge3-Linux-x86_64.sh" \
&& bash /tmp/Miniforge3-Linux-x86_64.sh -b -p /opt/conda \ && bash /tmp/Miniforge3-Linux-x86_64.sh -b -p /opt/conda \
&& rm /tmp/Miniforge3-Linux-x86_64.sh && rm /tmp/Miniforge3-Linux-x86_64.sh
ENV PATH /opt/conda/bin:$PATH ENV PATH /opt/conda/bin:$PATH
......
...@@ -351,7 +351,7 @@ python3 run_pretrained_openfold.py \ ...@@ -351,7 +351,7 @@ python3 run_pretrained_openfold.py \
--output_dir ./ \ --output_dir ./ \
--model_device "cuda:0" \ --model_device "cuda:0" \
--config_preset "seq_model_esm1b_ptm" \ --config_preset "seq_model_esm1b_ptm" \
--openfold_checkpoint_path openfold/resources/openfold_params/seq_model_esm1b_ptm.pt \ --openfold_checkpoint_path openfold/resources/openfold_soloseq_params/seq_model_esm1b_ptm.pt \
--uniref90_database_path data/uniref90/uniref90.fasta \ --uniref90_database_path data/uniref90/uniref90.fasta \
--pdb70_database_path data/pdb70/pdb70 \ --pdb70_database_path data/pdb70/pdb70 \
--jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \ --jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
...@@ -595,4 +595,4 @@ If you use OpenProteinSet, please also cite: ...@@ -595,4 +595,4 @@ If you use OpenProteinSet, please also cite:
primaryClass={q-bio.BM} primaryClass={q-bio.BM}
} }
``` ```
Any work that cites OpenFold should also cite AlphaFold. Any work that cites OpenFold should also cite [AlphaFold](https://www.nature.com/articles/s41586-021-03819-2) and [AlphaFold-Multimer](https://www.biorxiv.org/content/10.1101/2021.10.04.463034v1) if applicable.
This diff is collapsed.
...@@ -3,15 +3,15 @@ channels: ...@@ -3,15 +3,15 @@ channels:
- conda-forge - conda-forge
- bioconda - bioconda
dependencies: dependencies:
- conda-forge::openmm=7.5.1 - openmm=7.7
- conda-forge::pdbfixer - pdbfixer
- ml-collections
- PyYAML==5.4.1
- requests
- typing-extensions
- bioconda::hmmer==3.3.2 - bioconda::hmmer==3.3.2
- bioconda::hhsuite==3.3.0 - bioconda::hhsuite==3.3.0
- bioconda::kalign2==2.04 - bioconda::kalign2==2.04
- pip: - pip:
- biopython==1.79 - biopython==1.79
- dm-tree==0.1.6 - dm-tree==0.1.6
- ml-collections==0.1.0
- PyYAML==5.4.1
- requests==2.26.0
- typing-extensions==3.10.0.2
...@@ -21,14 +21,11 @@ import dataclasses ...@@ -21,14 +21,11 @@ import dataclasses
from multiprocessing import cpu_count from multiprocessing import cpu_count
import tempfile import tempfile
from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union
import subprocess
import numpy as np import numpy as np
import torch import torch
import pickle
from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer
from openfold.data.templates import get_custom_template_features, empty_template_feats from openfold.data.templates import get_custom_template_features, empty_template_feats
from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch
from openfold.data.tools.utils import to_date
from openfold.np import residue_constants, protein from openfold.np import residue_constants, protein
FeatureDict = MutableMapping[str, np.ndarray] FeatureDict = MutableMapping[str, np.ndarray]
...@@ -704,10 +701,10 @@ class DataPipeline: ...@@ -704,10 +701,10 @@ class DataPipeline:
def _parse_msa_data( def _parse_msa_data(
self, self,
alignment_dir: str, alignment_dir: str,
alignment_index: Optional[Any] = None, alignment_index: Optional[Any] = None
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
msa_data = {} msa_data = {}
if(alignment_index is not None): if alignment_index is not None:
fp = open(os.path.join(alignment_dir, alignment_index["db"]), "rb") fp = open(os.path.join(alignment_dir, alignment_index["db"]), "rb")
def read_msa(start, size): def read_msa(start, size):
...@@ -718,14 +715,14 @@ class DataPipeline: ...@@ -718,14 +715,14 @@ class DataPipeline:
for (name, start, size) in alignment_index["files"]: for (name, start, size) in alignment_index["files"]:
filename, ext = os.path.splitext(name) filename, ext = os.path.splitext(name)
if(ext == ".a3m"): if ext == ".a3m":
msa = parsers.parse_a3m( msa = parsers.parse_a3m(
read_msa(start, size) read_msa(start, size)
) )
# The "hmm_output" exception is a crude way to exclude # The "hmm_output" exception is a crude way to exclude
# multimer template hits. # multimer template hits.
# Multimer "uniprot_hits" processed separately. # Multimer "uniprot_hits" processed separately.
elif(ext == ".sto" and filename not in ["uniprot_hits", "hmm_output"]): elif ext == ".sto" and filename not in ["uniprot_hits", "hmm_output"]:
msa = parsers.parse_stockholm(read_msa(start, size)) msa = parsers.parse_stockholm(read_msa(start, size))
else: else:
continue continue
...@@ -734,13 +731,22 @@ class DataPipeline: ...@@ -734,13 +731,22 @@ class DataPipeline:
fp.close() fp.close()
else: else:
# Now will split the following steps into multiple processes for f in os.listdir(alignment_dir):
current_directory = os.path.dirname(os.path.abspath(__file__)) path = os.path.join(alignment_dir, f)
cmd = f"{current_directory}/tools/parse_msa_files.py" filename, ext = os.path.splitext(f)
msa_data_path = subprocess.run(['python',cmd, f"--alignment_dir={alignment_dir}"],capture_output=True, text=True)
msa_data_path = msa_data_path.stdout.lstrip().rstrip() if ext == ".a3m":
msa_data = pickle.load((open(msa_data_path,'rb'))) with open(path, "r") as fp:
os.remove(msa_data_path) msa = parsers.parse_a3m(fp.read())
elif ext == ".sto" and filename not in ["uniprot_hits", "hmm_output"]:
with open(path, "r") as fp:
msa = parsers.parse_stockholm(
fp.read()
)
else:
continue
msa_data[f] = msa
return msa_data return msa_data
......
...@@ -101,8 +101,8 @@ def empty_template_feats(n_res): ...@@ -101,8 +101,8 @@ def empty_template_feats(n_res):
"template_all_atom_positions": np.zeros( "template_all_atom_positions": np.zeros(
(0, n_res, residue_constants.atom_type_num, 3), np.float32 (0, n_res, residue_constants.atom_type_num, 3), np.float32
), ),
"template_domain_names": np.array([''.encode()], dtype=np.object), "template_domain_names": np.array([''.encode()], dtype=object),
"template_sequence": np.array([''.encode()], dtype=np.object), "template_sequence": np.array([''.encode()], dtype=object),
"template_sum_probs": np.zeros((0, 1), dtype=np.float32), "template_sum_probs": np.zeros((0, 1), dtype=np.float32),
} }
......
...@@ -90,15 +90,15 @@ def get_optimal_transform( ...@@ -90,15 +90,15 @@ def get_optimal_transform(
def get_least_asym_entity_or_longest_length(batch, input_asym_id): def get_least_asym_entity_or_longest_length(batch, input_asym_id):
""" """
First check how many subunit(s) one sequence has, if there is no tie, e.g. AABBB then select First check how many subunit(s) one sequence has. Select the subunit that is less
one of the A as anchor common, e.g. if the protein was AABBB then select one of the A as anchor
If there is a tie, e.g. AABB, first check which sequence is the longer/longest, If there is a tie, e.g. AABB, first check which sequence is the longer/longest,
then choose one of the corresponding subunits as anchor then choose one of the corresponding subunits as anchor
Args: Args:
batch: in this funtion batch is the full ground truth features batch: in this function batch is the full ground truth features
input_asym_id: A list of aym_ids that are in the cropped input features input_asym_id: A list of asym_ids that are in the cropped input features
Return: Return:
anchor_gt_asym_id: Tensor(int) selected ground truth asym_id anchor_gt_asym_id: Tensor(int) selected ground truth asym_id
...@@ -126,7 +126,7 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id): ...@@ -126,7 +126,7 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id):
min_asym_count = min(entity_asym_count.values()) min_asym_count = min(entity_asym_count.values())
least_asym_entities = [entity for entity, count in entity_asym_count.items() if count == min_asym_count] least_asym_entities = [entity for entity, count in entity_asym_count.items() if count == min_asym_count]
# If multiple entities have the least asym_id count, return those with the shortest length # If multiple entities have the least asym_id count, return those with the longest length
if len(least_asym_entities) > 1: if len(least_asym_entities) > 1:
max_length = max([entity_length[entity] for entity in least_asym_entities]) max_length = max([entity_length[entity] for entity in least_asym_entities])
least_asym_entities = [entity for entity in least_asym_entities if entity_length[entity] == max_length] least_asym_entities = [entity for entity in least_asym_entities if entity_length[entity] == max_length]
......
...@@ -123,7 +123,7 @@ def parse_fasta(data): ...@@ -123,7 +123,7 @@ def parse_fasta(data):
][1:] ][1:]
tags, seqs = lines[::2], lines[1::2] tags, seqs = lines[::2], lines[1::2]
tags = [t.split()[0] for t in tags] tags = [re.split('\W| \|', t)[0] for t in tags]
return tags, seqs return tags, seqs
......
...@@ -63,10 +63,7 @@ def precompute_alignments(tags, seqs, alignment_dir, args): ...@@ -63,10 +63,7 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
with open(tmp_fasta_path, "w") as fp: with open(tmp_fasta_path, "w") as fp:
fp.write(f">{tag}\n{seq}") fp.write(f">{tag}\n{seq}")
local_alignment_dir = os.path.join( local_alignment_dir = os.path.join(alignment_dir, tag)
alignment_dir,
os.path.join(alignment_dir, tag),
)
if args.use_precomputed_alignments is None: if args.use_precomputed_alignments is None:
logger.info(f"Generating alignments for {tag}...") logger.info(f"Generating alignments for {tag}...")
......
...@@ -113,10 +113,10 @@ else: ...@@ -113,10 +113,10 @@ else:
setup( setup(
name='openfold', name='openfold',
version='1.0.1', version='2.0.0',
description='A PyTorch reimplementation of DeepMind\'s AlphaFold 2', description='A PyTorch reimplementation of DeepMind\'s AlphaFold 2',
author='Gustaf Ahdritz & DeepMind', author='OpenFold Team',
author_email='gahdritz@gmail.com', author_email='jennifer.wei@omsf.io',
license='Apache License, Version 2.0', license='Apache License, Version 2.0',
url='https://github.com/aqlaboratory/openfold', url='https://github.com/aqlaboratory/openfold',
packages=find_packages(exclude=["tests", "scripts"]), packages=find_packages(exclude=["tests", "scripts"]),
......
import ml_collections as mlc import ml_collections as mlc
consts = mlc.ConfigDict(
monomer_consts = mlc.ConfigDict(
{
"model": "model_1_ptm", # monomer:model_1_ptm, multimer: model_1_multimer_v3
"is_multimer": False, # monomer: False, multimer: True
"chunk_size": 4,
"batch_size": 2,
"n_res": 22,
"n_seq": 13,
"n_templ": 3,
"n_extra": 17,
"n_heads_extra_msa": 8,
"eps": 5e-4,
# For compatibility with DeepMind's pretrained weights, it's easiest for
# everyone if these take their real values.
"c_m": 256,
"c_z": 128,
"c_s": 384,
"c_t": 64,
"c_e": 64,
"msa_logits": 23, # monomer: 23, multimer: 22
"template_mmcif_dir": None # Set for test_multimer_datamodule
}
)
multimer_consts = mlc.ConfigDict(
{ {
"model": "model_1_multimer_v3", # monomer:model_1_ptm, multimer: model_1_multimer_v3 "model": "model_1_multimer_v3", # monomer:model_1_ptm, multimer: model_1_multimer_v3
"is_multimer": True, # monomer: False, multimer: True "is_multimer": True, # monomer: False, multimer: True
...@@ -24,6 +49,8 @@ consts = mlc.ConfigDict( ...@@ -24,6 +49,8 @@ consts = mlc.ConfigDict(
} }
) )
consts = monomer_consts
config = mlc.ConfigDict( config = mlc.ConfigDict(
{ {
"data": { "data": {
......
...@@ -244,9 +244,6 @@ class TestDeepSpeedKernel(unittest.TestCase): ...@@ -244,9 +244,6 @@ class TestDeepSpeedKernel(unittest.TestCase):
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32) pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32) pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32)
inds = np.random.randint(0, 21, (n_res,))
batch["target_feat"] = np.eye(22)[inds]
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()} batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
template_feats = { template_feats = {
k: v for k, v in batch.items() if k.startswith("template_") k: v for k, v in batch.items() if k.startswith("template_")
...@@ -309,7 +306,7 @@ class TestDeepSpeedKernel(unittest.TestCase): ...@@ -309,7 +306,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
batch["residx_atom37_to_atom14"] = batch[ batch["residx_atom37_to_atom14"] = batch[
"residx_atom37_to_atom14" "residx_atom37_to_atom14"
].long() ].long()
batch["target_feat"] = torch.nn.functional.one_hot(batch["aatype"], 21).to(torch.float32) batch["target_feat"] = torch.nn.functional.one_hot(batch["aatype"], consts.msa_logits - 1).to(torch.float32)
batch["template_all_atom_mask"] = batch["template_all_atom_masks"] batch["template_all_atom_mask"] = batch["template_all_atom_masks"]
batch.update( batch.update(
data_transforms.atom37_to_torsion_angles("template_")(batch) data_transforms.atom37_to_torsion_angles("template_")(batch)
......
...@@ -21,7 +21,6 @@ from openfold.utils.multi_chain_permutation import (pad_features, get_least_asym ...@@ -21,7 +21,6 @@ from openfold.utils.multi_chain_permutation import (pad_features, get_least_asym
merge_labels) merge_labels)
@unittest.skip("Tests need to be fixed post-refactor")
class TestPermutation(unittest.TestCase): class TestPermutation(unittest.TestCase):
def setUp(self): def setUp(self):
""" """
...@@ -65,10 +64,16 @@ class TestPermutation(unittest.TestCase): ...@@ -65,10 +64,16 @@ class TestPermutation(unittest.TestCase):
'seq_length': torch.tensor([57]) 'seq_length': torch.tensor([57])
} }
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch, batch['asym_id']) anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch, batch['asym_id'])
self.assertIn(int(anchor_gt_asym), [1, 2]) anchor_gt_asym = int(anchor_gt_asym)
self.assertNotIn(int(anchor_gt_asym), [3, 4, 5]) anchor_pred_asym = {int(i) for i in anchor_pred_asym}
self.assertIn(int(anchor_pred_asym), [1, 2]) expected_anchors = {1, 2}
self.assertNotIn(int(anchor_pred_asym), [3, 4, 5]) expected_non_anchors = {3, 4, 5}
self.assertIn(anchor_gt_asym, expected_anchors)
self.assertNotIn(anchor_gt_asym, expected_non_anchors)
# Check that predicted anchors are within expected anchor set
self.assertEqual(anchor_pred_asym, expected_anchors & anchor_pred_asym)
self.assertEqual(set(), anchor_pred_asym & expected_non_anchors)
def test_2_permutation_pentamer(self): def test_2_permutation_pentamer(self):
batch = { batch = {
...@@ -114,6 +119,7 @@ class TestPermutation(unittest.TestCase): ...@@ -114,6 +119,7 @@ class TestPermutation(unittest.TestCase):
self.assertIn(aligns, possible_outcome) self.assertIn(aligns, possible_outcome)
self.assertNotIn(aligns, wrong_outcome) self.assertNotIn(aligns, wrong_outcome)
@unittest.skip("Test needs to be fixed post-refactor")
def test_3_merge_labels(self): def test_3_merge_labels(self):
nres_pad = 325 - 57 # suppose the cropping size is 325 nres_pad = 325 - 57 # suppose the cropping size is 325
batch = { batch = {
......
...@@ -235,6 +235,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -235,6 +235,7 @@ class OpenFoldWrapper(pl.LightningModule):
lr_scheduler = AlphaFoldLRScheduler( lr_scheduler = AlphaFoldLRScheduler(
optimizer, optimizer,
last_epoch=self.last_lr_step
) )
return { return {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment