Unverified Commit f434a278 authored by Jennifer Wei's avatar Jennifer Wei Committed by GitHub
Browse files

Merge pull request #439 from aqlaboratory/setup-improvements

Adds Documentation and minor quality of life fixes
parents 3eef7caa d8117ce3
name: openfold-venv name: openfold-env
channels: channels:
- conda-forge - conda-forge
- bioconda - bioconda
...@@ -10,26 +10,26 @@ dependencies: ...@@ -10,26 +10,26 @@ dependencies:
- pip - pip
- openmm=7.7 - openmm=7.7
- pdbfixer - pdbfixer
- cudatoolkit==11.3.* - pytorch-lightning
- pytorch-lightning==1.5.10 - biopython
- biopython==1.79 - numpy
- numpy==1.21 - pandas
- pandas==2.0
- PyYAML==5.4.1 - PyYAML==5.4.1
- requests - requests
- scipy==1.7 - scipy==1.7
- tqdm==4.62.2 - tqdm==4.62.2
- typing-extensions==3.10 - typing-extensions==4.0
- wandb==0.12.21 - wandb
- modelcif==0.7 - modelcif==0.7
- awscli - awscli
- ml-collections - ml-collections
- aria2 - aria2
- mkl==2024.0 - mkl=2024.0
- git - git
- bioconda::hmmer==3.3.2 - bioconda::hmmer==3.3.2
- bioconda::hhsuite==3.3.0 - bioconda::hhsuite==3.3.0
- bioconda::kalign2==2.04 - bioconda::kalign2==2.04
- bioconda::mmseqs2
- pytorch::pytorch=1.12.* - pytorch::pytorch=1.12.*
- pip: - pip:
- deepspeed==0.12.4 - deepspeed==0.12.4
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
>6KWC_1
GSTIQPGTGYNNGYFYSYWNDGHGGVTYTNGPGGQFSVNWSNSGEFVGGKGWQPGTKNKVINFSGSYNPNGNSYLSVYGWSRNPLIEYYIVENFGTYNPSTGATKLGEVTSDGSVYDIYRTQRVNQPSIIGTATFYQYWSVRRNHRSSGSVNTANHFNAWAQQGLTLGTMDYQIVAVQGYFSSGSASITVS
#!/bin/bash
export LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH
export LIBRARY_PATH=$CONDA_PREFIX/lib:$LIBRARY_PATH
export FASTA_DIR=./fasta_dir
export OUTPUT_DIR=./
export PRECOMPUTED_ALIGNMENT_DIR=./alignments
export MMCIF_DIR=/mmcifs # UPDATE with path to your mmcifs directory
python3 run_pretrained_openfold.py $FASTA_DIR \
$MMCIF_DIR \
--output_dir $OUTPUT_DIR \
--config_preset model_1_ptm \
--model_device "cuda:0" \
--data_random_seed 42 \
--use_precomputed_alignments $PRECOMPUTED_ALIGNMENT_DIR
{ {
"cells": [ "cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github"
},
"source": [
"<a href=\"https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
...@@ -120,7 +111,7 @@ ...@@ -120,7 +111,7 @@
"os.system(\"wget -qnc https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-Linux-x86_64.sh\")\n", "os.system(\"wget -qnc https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-Linux-x86_64.sh\")\n",
"os.system(\"bash Mambaforge-Linux-x86_64.sh -bfp /usr/local\")\n", "os.system(\"bash Mambaforge-Linux-x86_64.sh -bfp /usr/local\")\n",
"os.system(\"mamba config --set auto_update_conda false\")\n", "os.system(\"mamba config --set auto_update_conda false\")\n",
"os.system(f\"mamba install -y -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 openmm=7.7.0 python={python_version} pdbfixer biopython=1.79\")\n", "os.system(f\"mamba install -y -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 openmm=7.7.0 python={python_version} pdbfixer biopython=1.83\")\n",
"os.system(\"pip install -q torch ml_collections py3Dmol modelcif\")\n", "os.system(\"pip install -q torch ml_collections py3Dmol modelcif\")\n",
"\n", "\n",
"try:\n", "try:\n",
...@@ -259,7 +250,7 @@ ...@@ -259,7 +250,7 @@
"from openfold.np import protein\n", "from openfold.np import protein\n",
"from openfold.np.relax import relax\n", "from openfold.np.relax import relax\n",
"from openfold.np.relax.utils import overwrite_b_factors\n", "from openfold.np.relax.utils import overwrite_b_factors\n",
"from openfold.utils.import_weights import import_jax_weights_\n", "from openfold.utils.import_weights import import_jax_weights_, import_openfold_weights_\n",
"from openfold.utils.tensor_utils import tensor_tree_map\n", "from openfold.utils.tensor_utils import tensor_tree_map\n",
"\n", "\n",
"from IPython import display\n", "from IPython import display\n",
...@@ -582,7 +573,7 @@ ...@@ -582,7 +573,7 @@
" model_name,\n", " model_name,\n",
" )\n", " )\n",
" d = torch.load(params_name)\n", " d = torch.load(params_name)\n",
" openfold_model.load_state_dict(d)\n", " import_openfold_weights_(model=openfold_model, state_dict=d)\n",
" else:\n", " else:\n",
" raise ValueError(f\"Invalid weight set: {weight_set}\")\n", " raise ValueError(f\"Invalid weight set: {weight_set}\")\n",
"\n", "\n",
......
...@@ -62,7 +62,8 @@ def model_config( ...@@ -62,7 +62,8 @@ def model_config(
name, name,
train=False, train=False,
low_prec=False, low_prec=False,
long_sequence_inference=False long_sequence_inference=False,
use_deepspeed_evoformer_attention=False,
): ):
c = copy.deepcopy(config) c = copy.deepcopy(config)
# TRAINING PRESETS # TRAINING PRESETS
...@@ -237,6 +238,9 @@ def model_config( ...@@ -237,6 +238,9 @@ def model_config(
c.model.extra_msa.extra_msa_stack.tune_chunk_size = False c.model.extra_msa.extra_msa_stack.tune_chunk_size = False
c.model.evoformer_stack.tune_chunk_size = False c.model.evoformer_stack.tune_chunk_size = False
if use_deepspeed_evoformer_attention:
c.globals.use_deepspeed_evo_attention = True
if train: if train:
c.globals.blocks_per_ckpt = 1 c.globals.blocks_per_ckpt = 1
c.globals.chunk_size = None c.globals.chunk_size = None
......
...@@ -937,7 +937,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -937,7 +937,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
with open(distillation_alignment_index_path, "r") as fp: with open(distillation_alignment_index_path, "r") as fp:
self.distillation_alignment_index = json.load(fp) self.distillation_alignment_index = json.load(fp)
def setup(self): def setup(self, stage=None):
# Most of the arguments are the same for the three datasets # Most of the arguments are the same for the three datasets
dataset_gen = partial(OpenFoldSingleDataset, dataset_gen = partial(OpenFoldSingleDataset,
template_mmcif_dir=self.template_mmcif_dir, template_mmcif_dir=self.template_mmcif_dir,
...@@ -1016,7 +1016,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -1016,7 +1016,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
mode="predict", mode="predict",
) )
def _gen_dataloader(self, stage): def _gen_dataloader(self, stage=None):
generator = None generator = None
if self.batch_seed is not None: if self.batch_seed is not None:
generator = torch.Generator() generator = torch.Generator()
...@@ -1053,7 +1053,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -1053,7 +1053,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
def val_dataloader(self): def val_dataloader(self):
if self.eval_dataset is not None: if self.eval_dataset is not None:
return self._gen_dataloader("eval") return self._gen_dataloader("eval")
return None return []
def predict_dataloader(self): def predict_dataloader(self):
return self._gen_dataloader("predict") return self._gen_dataloader("predict")
...@@ -1085,7 +1085,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule): ...@@ -1085,7 +1085,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
self.training_mode = self.train_data_dir is not None self.training_mode = self.train_data_dir is not None
self.val_mmcif_data_cache_path = val_mmcif_data_cache_path self.val_mmcif_data_cache_path = val_mmcif_data_cache_path
def setup(self): def setup(self, setup=None):
# Most of the arguments are the same for the three datasets # Most of the arguments are the same for the three datasets
dataset_gen = partial(OpenFoldSingleMultimerDataset, dataset_gen = partial(OpenFoldSingleMultimerDataset,
template_mmcif_dir=self.template_mmcif_dir, template_mmcif_dir=self.template_mmcif_dir,
......
...@@ -244,7 +244,7 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict: ...@@ -244,7 +244,7 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
features["num_alignments"] = np.array( features["num_alignments"] = np.array(
[num_alignments] * num_res, dtype=np.int32 [num_alignments] * num_res, dtype=np.int32
) )
features["msa_species_identifiers"] = np.array(species_ids, dtype=np.object_) features["msa_species_identifiers"] = np.array(species_ids, dtype=object)
return features return features
...@@ -590,7 +590,7 @@ def convert_monomer_features( ...@@ -590,7 +590,7 @@ def convert_monomer_features(
) -> FeatureDict: ) -> FeatureDict:
"""Reshapes and modifies monomer features for multimer models.""" """Reshapes and modifies monomer features for multimer models."""
converted = {} converted = {}
converted['auth_chain_id'] = np.asarray(chain_id, dtype=np.object_) converted['auth_chain_id'] = np.asarray(chain_id, dtype=object)
unnecessary_leading_dim_feats = { unnecessary_leading_dim_feats = {
'sequence', 'domain_name', 'num_alignments', 'seq_length' 'sequence', 'domain_name', 'num_alignments', 'seq_length'
} }
...@@ -1296,7 +1296,7 @@ class DataPipelineMultimer: ...@@ -1296,7 +1296,7 @@ class DataPipelineMultimer:
) )
mmcif_feats["release_date"] = np.array( mmcif_feats["release_date"] = np.array(
[mmcif_object.header["release_date"].encode("utf-8")], dtype=np.object_ [mmcif_object.header["release_date"].encode("utf-8")], dtype=object
) )
mmcif_feats["is_distillation"] = np.array(0., dtype=np.float32) mmcif_feats["is_distillation"] = np.array(0., dtype=np.float32)
......
...@@ -24,7 +24,7 @@ import os ...@@ -24,7 +24,7 @@ import os
from typing import Any, Mapping, Optional, Sequence, Tuple from typing import Any, Mapping, Optional, Sequence, Tuple
from Bio import PDB from Bio import PDB
from Bio.Data import SCOPData from Bio.Data import PDBData
import numpy as np import numpy as np
from openfold.data.errors import MultipleChainsError from openfold.data.errors import MultipleChainsError
...@@ -283,7 +283,7 @@ def parse( ...@@ -283,7 +283,7 @@ def parse(
author_chain = mmcif_to_author_chain_id[chain_id] author_chain = mmcif_to_author_chain_id[chain_id]
seq = [] seq = []
for monomer in seq_info: for monomer in seq_info:
code = SCOPData.protein_letters_3to1.get(monomer.id, "X") code = PDBData.protein_letters_3to1.get(monomer.id, "X")
seq.append(code if len(code) == 1 else "X") seq.append(code if len(code) == 1 else "X")
seq = "".join(seq) seq = "".join(seq)
author_chain_to_sequence[author_chain] = seq author_chain_to_sequence[author_chain] = seq
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import re import re
import logging
from enum import Enum from enum import Enum
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
...@@ -681,15 +682,18 @@ def convert_deprecated_v1_keys(state_dict): ...@@ -681,15 +682,18 @@ def convert_deprecated_v1_keys(state_dict):
} }
convert_key_re = re.compile("(%s)" % "|".join(map(re.escape, replacements.keys()))) convert_key_re = re.compile("(%s)" % "|".join(map(re.escape, replacements.keys())))
template_emb_re = re.compile(r"^((module\.)?(model\.)?)(template(?!_embedder).*)")
converted_state_dict = {} converted_state_dict = {}
for key, value in state_dict.items(): for key, value in state_dict.items():
# For each match, look-up replacement value in the dictionary # For each match, look-up replacement value in the dictionary
new_key = convert_key_re.sub(lambda m: replacements[m.group()], key) new_key = convert_key_re.sub(lambda m: replacements[m.group(1)], key)
# Add prefix for template modules # Add prefix for template layers
if new_key.startswith('template'): template_match = re.match(template_emb_re, new_key)
new_key = f'template_embedder.{new_key}' if template_match:
prefix = template_match.group(1)
new_key = f'{prefix if prefix else ""}template_embedder.{template_match.group(4)}'
converted_state_dict[new_key] = value converted_state_dict[new_key] = value
......
...@@ -35,8 +35,8 @@ def _superimpose_np(reference, coords): ...@@ -35,8 +35,8 @@ def _superimpose_np(reference, coords):
def _superimpose_single(reference, coords): def _superimpose_single(reference, coords):
reference_np = reference.detach().cpu().numpy() reference_np = reference.detach().to(torch.float).cpu().numpy()
coords_np = coords.detach().cpu().numpy() coords_np = coords.detach().to(torch.float).cpu().numpy()
superimposed, rmsd = _superimpose_np(reference_np, coords_np) superimposed, rmsd = _superimpose_np(reference_np, coords_np)
return coords.new_tensor(superimposed), coords.new_tensor(rmsd) return coords.new_tensor(superimposed), coords.new_tensor(rmsd)
......
...@@ -114,8 +114,7 @@ def tree_map(fn, tree, leaf_type): ...@@ -114,8 +114,7 @@ def tree_map(fn, tree, leaf_type):
elif isinstance(tree, leaf_type): elif isinstance(tree, leaf_type):
return fn(tree) return fn(tree)
else: else:
print(type(tree)) raise ValueError(f"Tree of type {type(tree)} not supported")
raise ValueError("Not supported")
tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor) tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)
...@@ -20,6 +20,7 @@ import os ...@@ -20,6 +20,7 @@ import os
import pickle import pickle
import random import random
import time import time
import json
logging.basicConfig() logging.basicConfig()
logger = logging.getLogger(__file__) logger = logging.getLogger(__file__)
...@@ -178,7 +179,21 @@ def main(args): ...@@ -178,7 +179,21 @@ def main(args):
if args.config_preset.startswith("seq"): if args.config_preset.startswith("seq"):
args.use_single_seq_mode = True args.use_single_seq_mode = True
config = model_config(args.config_preset, long_sequence_inference=args.long_sequence_inference) config = model_config(
args.config_preset,
long_sequence_inference=args.long_sequence_inference,
use_deepspeed_evoformer_attention=args.use_deepspeed_evoformer_attention,
)
if args.experiment_config_json:
with open(args.experiment_config_json, 'r') as f:
custom_config_dict = json.load(f)
config.update_from_flattened_dict(custom_config_dict)
if args.experiment_config_json:
with open(args.experiment_config_json, 'r') as f:
custom_config_dict = json.load(f)
config.update_from_flattened_dict(custom_config_dict)
if args.trace_model: if args.trace_model:
if not config.data.predict.fixed_size: if not config.data.predict.fixed_size:
...@@ -258,6 +273,11 @@ def main(args): ...@@ -258,6 +273,11 @@ def main(args):
seq_sort_fn = lambda target: sum([len(s) for s in target[1]]) seq_sort_fn = lambda target: sum([len(s) for s in target[1]])
sorted_targets = sorted(zip(tag_list, seq_list), key=seq_sort_fn) sorted_targets = sorted(zip(tag_list, seq_list), key=seq_sort_fn)
feature_dicts = {} feature_dicts = {}
if is_multimer and args.openfold_checkpoint_path:
raise ValueError(
'`openfold_checkpoint_path` was specified, but no OpenFold checkpoints are available for multimer mode')
model_generator = load_models_from_command_line( model_generator = load_models_from_command_line(
config, config,
args.model_device, args.model_device,
...@@ -453,6 +473,13 @@ if __name__ == "__main__": ...@@ -453,6 +473,13 @@ if __name__ == "__main__":
"--cif_output", action="store_true", default=False, "--cif_output", action="store_true", default=False,
help="Output predicted models in ModelCIF format instead of PDB format (default)" help="Output predicted models in ModelCIF format instead of PDB format (default)"
) )
parser.add_argument(
"--experiment_config_json", default="", help="Path to a json file with custom config values to overwrite config setting",
)
parser.add_argument(
"--use_deepspeed_evoformer_attention", action="store_true", default=False,
help="Whether to use the DeepSpeed evoformer attention layer. Must have deepspeed installed in the environment.",
)
add_data_args(parser) add_data_args(parser)
args = parser.parse_args() args = parser.parse_args()
......
"""
This script generates a FASTA file for all chains in an alignment directory or
alignment DB.
"""
import json
from argparse import ArgumentParser
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Optional
from tqdm import tqdm
def chain_dir_to_fasta(dir: Path) -> str:
"""
Generates a FASTA string from a chain directory.
"""
# take some alignment file
for alignment_file_type in [
"mgnify_hits.a3m",
"uniref90_hits.a3m",
"bfd_uniclust_hits.a3m",
]:
alignment_file = dir / alignment_file_type
if alignment_file.exists():
break
with open(alignment_file, "r") as f:
next(f) # skip the first line
seq = next(f).strip()
try:
next_line = next(f)
except StopIteration:
pass
else:
assert next_line.startswith(">") # ensure that sequence ended
chain_id = dir.name
return f">{chain_id}\n{seq}\n"
def index_entry_to_fasta(index_entry: dict, db_dir: Path, chain_id: str) -> str:
"""
Generates a FASTA string from an alignment-db index entry.
"""
db_file = db_dir / index_entry["db"]
# look for an alignment file
for alignment_file_type in [
"mgnify_hits.a3m",
"uniref90_hits.a3m",
"bfd_uniclust_hits.a3m",
]:
for file_info in index_entry["files"]:
if file_info[0] == alignment_file_type:
start, size = file_info[1], file_info[2]
break
with open(db_file, "rb") as f:
f.seek(start)
msa_lines = f.read(size).decode("utf-8").splitlines()
seq = msa_lines[1]
try:
next_line = msa_lines[2]
except IndexError:
pass
else:
assert next_line.startswith(">") # ensure that sequence ended
return f">{chain_id}\n{seq}\n"
def main(
output_path: Path, alignment_db_index: Optional[Path], alignment_dir: Optional[Path]
) -> None:
"""
Generate a FASTA file from either an alignment-db index or a chain directory using multi-threading.
"""
fasta = []
if alignment_dir and alignment_db_index:
raise ValueError(
"Only one of alignment_db_index and alignment_dir can be provided."
)
if alignment_dir:
print("Creating FASTA from alignment directory...")
chain_dirs = list(alignment_dir.iterdir())
with ThreadPoolExecutor() as executor:
futures = [
executor.submit(chain_dir_to_fasta, chain_dir)
for chain_dir in chain_dirs
]
for future in tqdm(as_completed(futures), total=len(chain_dirs)):
fasta.append(future.result())
elif alignment_db_index:
print("Creating FASTA from alignment dbs...")
with open(alignment_db_index, "r") as f:
index = json.load(f)
db_dir = alignment_db_index.parent
with ThreadPoolExecutor() as executor:
futures = [
executor.submit(index_entry_to_fasta, index_entry, db_dir, chain_id)
for chain_id, index_entry in index.items()
]
for future in tqdm(as_completed(futures), total=len(index)):
fasta.append(future.result())
else:
raise ValueError("Either alignment_db_index or alignment_dir must be provided.")
with open(output_path, "w") as f:
f.write("".join(fasta))
print(f"FASTA file written to {output_path}.")
if __name__ == "__main__":
parser = ArgumentParser(description=__doc__)
parser.add_argument(
"output_path",
type=Path,
help="Path to output FASTA file.",
)
parser.add_argument(
"--alignment_db_index",
type=Path,
help="Path to alignment-db index file.",
)
parser.add_argument(
"--alignment_dir",
type=Path,
help="Path to alignment directory.",
)
args = parser.parse_args()
main(args.output_path, args.alignment_db_index, args.alignment_dir)
from argparse import ArgumentParser
from pathlib import Path
import json
def main(args):
# get the super index
with open(args.alignment_db_super_index_path, "r") as fp:
super_index = json.load(fp)
# get all chains and sequences
chains_to_seqs = {}
with open(args.all_chains_fasta, "r") as fp:
lines = fp.readlines()
# iterate through chain-sequence pairs
for chain_idx in range(0, len(lines), 2):
chain = lines[chain_idx][1:].strip()
seq = lines[chain_idx + 1].strip()
chains_to_seqs[chain] = seq
chains_w_alignments = set(super_index.keys())
chains_wo_alignments = set(chains_to_seqs.keys()) - chains_w_alignments
seq_to_chain_w_alignment = {
chains_to_seqs[chain]: chain for chain in chains_w_alignments
}
print("Unique sequences with alignments:", len(seq_to_chain_w_alignment))
# map chain without alignment to alignment entry of another chain with the
# same sequence
remaining_unaligned_chains = []
for chain in chains_wo_alignments:
seq = chains_to_seqs[chain]
try:
corresponding_alignment = super_index[seq_to_chain_w_alignment[seq]]
# no corresponding chain with alignment found
except KeyError:
remaining_unaligned_chains.append(chain)
continue
super_index[chain] = corresponding_alignment
with open(args.output_path, "w") as fp:
json.dump(super_index, fp)
print(
f"No corresponding alignment found for the following {len(remaining_unaligned_chains)} chains:",
remaining_unaligned_chains,
)
if __name__ == "__main__":
parser = ArgumentParser(
description="""
If the alignment-db index was created on unique-chain alignments only,
this will add the missing chain entries to the super-index file based on
a .fasta file that contains sequences for all chains.
Note that this only modifies the index and not the database itself, as
the duplicate sequences will just point to the same alignments.
"""
)
parser.add_argument(
"alignment_db_super_index_path",
type=Path,
help="Path to alignment-db super index file.",
)
parser.add_argument(
"output_path", type=Path, help="Write the output super index to this path."
)
parser.add_argument(
"all_chains_fasta",
type=Path,
help="Path to the fasta file containing sequences for all chains.",
)
args = parser.parse_args()
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment