Unverified Commit 7d227395 authored by Jannik Gut's avatar Jannik Gut Committed by GitHub
Browse files

Merge branch 'main' into main

parents b38b6078 f37d0d96
......@@ -20,6 +20,7 @@ import os
import pickle
import random
import time
import json
logging.basicConfig()
logger = logging.getLogger(__file__)
......@@ -131,7 +132,16 @@ def generate_feature_dict(
args,
):
tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
if len(seqs) == 1:
if "multimer" in args.config_preset:
with open(tmp_fasta_path, "w") as fp:
fp.write(
'\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)])
)
feature_dict = data_processor.process_fasta(
fasta_path=tmp_fasta_path, alignment_dir=alignment_dir,
)
elif len(seqs) == 1:
tag = tags[0]
seq = seqs[0]
with open(tmp_fasta_path, "w") as fp:
......@@ -143,14 +153,6 @@ def generate_feature_dict(
alignment_dir=local_alignment_dir,
seqemb_mode=args.use_single_seq_mode,
)
elif "multimer" in args.config_preset:
with open(tmp_fasta_path, "w") as fp:
fp.write(
'\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)])
)
feature_dict = data_processor.process_fasta(
fasta_path=tmp_fasta_path, alignment_dir=alignment_dir,
)
else:
with open(tmp_fasta_path, "w") as fp:
fp.write(
......@@ -177,7 +179,21 @@ def main(args):
if args.config_preset.startswith("seq"):
args.use_single_seq_mode = True
config = model_config(args.config_preset, long_sequence_inference=args.long_sequence_inference)
config = model_config(
args.config_preset,
long_sequence_inference=args.long_sequence_inference,
use_deepspeed_evoformer_attention=args.use_deepspeed_evoformer_attention,
)
if args.experiment_config_json:
with open(args.experiment_config_json, 'r') as f:
custom_config_dict = json.load(f)
config.update_from_flattened_dict(custom_config_dict)
if args.experiment_config_json:
with open(args.experiment_config_json, 'r') as f:
custom_config_dict = json.load(f)
config.update_from_flattened_dict(custom_config_dict)
if args.trace_model:
if not config.data.predict.fixed_size:
......@@ -261,6 +277,11 @@ def main(args):
seq_sort_fn = lambda target: sum([len(s) for s in target[1]])
sorted_targets = sorted(zip(tag_list, seq_list), key=seq_sort_fn)
feature_dicts = {}
if is_multimer and args.openfold_checkpoint_path:
raise ValueError(
'`openfold_checkpoint_path` was specified, but no OpenFold checkpoints are available for multimer mode')
model_generator = load_models_from_command_line(
config,
args.model_device,
......@@ -459,6 +480,13 @@ if __name__ == "__main__":
"--cif_output", action="store_true", default=False,
help="Output predicted models in ModelCIF format instead of PDB format (default)"
)
parser.add_argument(
"--experiment_config_json", default="", help="Path to a json file with custom config values to overwrite config setting",
)
parser.add_argument(
"--use_deepspeed_evoformer_attention", action="store_true", default=False,
help="Whether to use the DeepSpeed evoformer attention layer. Must have deepspeed installed in the environment.",
)
add_data_args(parser)
args = parser.parse_args()
......
"""
This script generates a FASTA file for all chains in an alignment directory or
alignment DB.
"""
import json
from argparse import ArgumentParser
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Optional
from tqdm import tqdm
def chain_dir_to_fasta(dir: Path) -> str:
"""
Generates a FASTA string from a chain directory.
"""
# take some alignment file
for alignment_file_type in [
"mgnify_hits.a3m",
"uniref90_hits.a3m",
"bfd_uniclust_hits.a3m",
]:
alignment_file = dir / alignment_file_type
if alignment_file.exists():
break
with open(alignment_file, "r") as f:
next(f) # skip the first line
seq = next(f).strip()
try:
next_line = next(f)
except StopIteration:
pass
else:
assert next_line.startswith(">") # ensure that sequence ended
chain_id = dir.name
return f">{chain_id}\n{seq}\n"
def index_entry_to_fasta(index_entry: dict, db_dir: Path, chain_id: str) -> str:
"""
Generates a FASTA string from an alignment-db index entry.
"""
db_file = db_dir / index_entry["db"]
# look for an alignment file
for alignment_file_type in [
"mgnify_hits.a3m",
"uniref90_hits.a3m",
"bfd_uniclust_hits.a3m",
]:
for file_info in index_entry["files"]:
if file_info[0] == alignment_file_type:
start, size = file_info[1], file_info[2]
break
with open(db_file, "rb") as f:
f.seek(start)
msa_lines = f.read(size).decode("utf-8").splitlines()
seq = msa_lines[1]
try:
next_line = msa_lines[2]
except IndexError:
pass
else:
assert next_line.startswith(">") # ensure that sequence ended
return f">{chain_id}\n{seq}\n"
def main(
output_path: Path, alignment_db_index: Optional[Path], alignment_dir: Optional[Path]
) -> None:
"""
Generate a FASTA file from either an alignment-db index or a chain directory using multi-threading.
"""
fasta = []
if alignment_dir and alignment_db_index:
raise ValueError(
"Only one of alignment_db_index and alignment_dir can be provided."
)
if alignment_dir:
print("Creating FASTA from alignment directory...")
chain_dirs = list(alignment_dir.iterdir())
with ThreadPoolExecutor() as executor:
futures = [
executor.submit(chain_dir_to_fasta, chain_dir)
for chain_dir in chain_dirs
]
for future in tqdm(as_completed(futures), total=len(chain_dirs)):
fasta.append(future.result())
elif alignment_db_index:
print("Creating FASTA from alignment dbs...")
with open(alignment_db_index, "r") as f:
index = json.load(f)
db_dir = alignment_db_index.parent
with ThreadPoolExecutor() as executor:
futures = [
executor.submit(index_entry_to_fasta, index_entry, db_dir, chain_id)
for chain_id, index_entry in index.items()
]
for future in tqdm(as_completed(futures), total=len(index)):
fasta.append(future.result())
else:
raise ValueError("Either alignment_db_index or alignment_dir must be provided.")
with open(output_path, "w") as f:
f.write("".join(fasta))
print(f"FASTA file written to {output_path}.")
if __name__ == "__main__":
parser = ArgumentParser(description=__doc__)
parser.add_argument(
"output_path",
type=Path,
help="Path to output FASTA file.",
)
parser.add_argument(
"--alignment_db_index",
type=Path,
help="Path to alignment-db index file.",
)
parser.add_argument(
"--alignment_dir",
type=Path,
help="Path to alignment directory.",
)
args = parser.parse_args()
main(args.output_path, args.alignment_db_index, args.alignment_dir)
from argparse import ArgumentParser
from pathlib import Path
import json
def main(args):
# get the super index
with open(args.alignment_db_super_index_path, "r") as fp:
super_index = json.load(fp)
# get all chains and sequences
chains_to_seqs = {}
with open(args.all_chains_fasta, "r") as fp:
lines = fp.readlines()
# iterate through chain-sequence pairs
for chain_idx in range(0, len(lines), 2):
chain = lines[chain_idx][1:].strip()
seq = lines[chain_idx + 1].strip()
chains_to_seqs[chain] = seq
chains_w_alignments = set(super_index.keys())
chains_wo_alignments = set(chains_to_seqs.keys()) - chains_w_alignments
seq_to_chain_w_alignment = {
chains_to_seqs[chain]: chain for chain in chains_w_alignments
}
print("Unique sequences with alignments:", len(seq_to_chain_w_alignment))
# map chain without alignment to alignment entry of another chain with the
# same sequence
remaining_unaligned_chains = []
for chain in chains_wo_alignments:
seq = chains_to_seqs[chain]
try:
corresponding_alignment = super_index[seq_to_chain_w_alignment[seq]]
# no corresponding chain with alignment found
except KeyError:
remaining_unaligned_chains.append(chain)
continue
super_index[chain] = corresponding_alignment
with open(args.output_path, "w") as fp:
json.dump(super_index, fp)
print(
f"No corresponding alignment found for the following {len(remaining_unaligned_chains)} chains:",
remaining_unaligned_chains,
)
if __name__ == "__main__":
parser = ArgumentParser(
description="""
If the alignment-db index was created on unique-chain alignments only,
this will add the missing chain entries to the super-index file based on
a .fasta file that contains sequences for all chains.
Note that this only modifies the index and not the database itself, as
the duplicate sequences will just point to the same alignments.
"""
)
parser.add_argument(
"alignment_db_super_index_path",
type=Path,
help="Path to alignment-db super index file.",
)
parser.add_argument(
"output_path", type=Path, help="Write the output super index to this path."
)
parser.add_argument(
"all_chains_fasta",
type=Path,
help="Path to the fasta file containing sequences for all chains.",
)
args = parser.parse_args()
main(args)
This diff is collapsed.
This diff is collapsed.
#!/bin/bash
#
# Copyright 2021 DeepMind Technologies Limited
# Copyright 2024 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -14,9 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Downloads OpenFold parameters.
# Downloads OpenFold SoloSeq (single sequence model) parameters.
#
# Usage: bash download_openfold_params_huggingface.sh /path/to/download/directory
# Usage: bash download_openfold_soloseq_params.sh /path/to/download/directory
set -e
if [[ $# -eq 0 ]]; then
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment