Unverified Commit 29b5823e authored by Jennifer Wei's avatar Jennifer Wei Committed by GitHub
Browse files

Merge pull request #419 from aqlaboratory/setup-improvements_additional-scripts

Duplicate expansion support
parents 6cba403b 04410d5e
"""
This script generates a FASTA file for all chains in an alignment directory or
alignment DB.
"""
import json
from argparse import ArgumentParser
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Optional
from tqdm import tqdm
def chain_dir_to_fasta(dir: Path) -> str:
"""
Generates a FASTA string from a chain directory.
"""
# take some alignment file
for alignment_file_type in [
"mgnify_hits.a3m",
"uniref90_hits.a3m",
"bfd_uniclust_hits.a3m",
]:
alignment_file = dir / alignment_file_type
if alignment_file.exists():
break
with open(alignment_file, "r") as f:
next(f) # skip the first line
seq = next(f).strip()
try:
next_line = next(f)
except StopIteration:
pass
else:
assert next_line.startswith(">") # ensure that sequence ended
chain_id = dir.name
return f">{chain_id}\n{seq}\n"
def index_entry_to_fasta(index_entry: dict, db_dir: Path, chain_id: str) -> str:
"""
Generates a FASTA string from an alignment-db index entry.
"""
db_file = db_dir / index_entry["db"]
# look for an alignment file
for alignment_file_type in [
"mgnify_hits.a3m",
"uniref90_hits.a3m",
"bfd_uniclust_hits.a3m",
]:
for file_info in index_entry["files"]:
if file_info[0] == alignment_file_type:
start, size = file_info[1], file_info[2]
break
with open(db_file, "rb") as f:
f.seek(start)
msa_lines = f.read(size).decode("utf-8").splitlines()
seq = msa_lines[1]
try:
next_line = msa_lines[2]
except IndexError:
pass
else:
assert next_line.startswith(">") # ensure that sequence ended
return f">{chain_id}\n{seq}\n"
def main(
output_path: Path, alignment_db_index: Optional[Path], alignment_dir: Optional[Path]
) -> None:
"""
Generate a FASTA file from either an alignment-db index or a chain directory using multi-threading.
"""
fasta = []
if alignment_dir and alignment_db_index:
raise ValueError(
"Only one of alignment_db_index and alignment_dir can be provided."
)
if alignment_dir:
print("Creating FASTA from alignment directory...")
chain_dirs = list(alignment_dir.iterdir())
with ThreadPoolExecutor() as executor:
futures = [
executor.submit(chain_dir_to_fasta, chain_dir)
for chain_dir in chain_dirs
]
for future in tqdm(as_completed(futures), total=len(chain_dirs)):
fasta.append(future.result())
elif alignment_db_index:
print("Creating FASTA from alignment dbs...")
with open(alignment_db_index, "r") as f:
index = json.load(f)
db_dir = alignment_db_index.parent
with ThreadPoolExecutor() as executor:
futures = [
executor.submit(index_entry_to_fasta, index_entry, db_dir, chain_id)
for chain_id, index_entry in index.items()
]
for future in tqdm(as_completed(futures), total=len(index)):
fasta.append(future.result())
else:
raise ValueError("Either alignment_db_index or alignment_dir must be provided.")
with open(output_path, "w") as f:
f.write("".join(fasta))
print(f"FASTA file written to {output_path}.")
if __name__ == "__main__":
parser = ArgumentParser(description=__doc__)
parser.add_argument(
"output_path",
type=Path,
help="Path to output FASTA file.",
)
parser.add_argument(
"--alignment_db_index",
type=Path,
help="Path to alignment-db index file.",
)
parser.add_argument(
"--alignment_dir",
type=Path,
help="Path to alignment directory.",
)
args = parser.parse_args()
main(args.output_path, args.alignment_db_index, args.alignment_dir)
...@@ -5,17 +5,19 @@ super index, meaning that "unify_alignment_db_indices.py" does not need to be ...@@ -5,17 +5,19 @@ super index, meaning that "unify_alignment_db_indices.py" does not need to be
run on the output index. Additionally this script uses threading and run on the output index. Additionally this script uses threading and
multiprocessing and is much faster than the old version. multiprocessing and is much faster than the old version.
""" """
import argparse import argparse
import json
from collections import defaultdict from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
import json from math import ceil
from multiprocessing import cpu_count
from pathlib import Path from pathlib import Path
from typing import List
from tqdm import tqdm from tqdm import tqdm
from math import ceil
def split_file_list(file_list, n_shards): def split_file_list(file_list: list[Path], n_shards: int):
""" """
Split up the total file list into n_shards sublists. Split up the total file list into n_shards sublists.
""" """
...@@ -29,13 +31,13 @@ def split_file_list(file_list, n_shards): ...@@ -29,13 +31,13 @@ def split_file_list(file_list, n_shards):
return split_list return split_list
def chunked_iterator(lst, chunk_size): def chunked_iterator(lst: list, chunk_size: int):
"""Iterate over a list in chunks of size chunk_size.""" """Iterate over a list in chunks of size chunk_size."""
for i in range(0, len(lst), chunk_size): for i in range(0, len(lst), chunk_size):
yield lst[i : i + chunk_size] yield lst[i : i + chunk_size]
def read_chain_dir(chain_dir) -> dict: def read_chain_dir(chain_dir: Path) -> dict:
""" """
Read all alignment files in a single chain directory and return a dict Read all alignment files in a single chain directory and return a dict
mapping chain name to file names and bytes. mapping chain name to file names and bytes.
...@@ -48,7 +50,6 @@ def read_chain_dir(chain_dir) -> dict: ...@@ -48,7 +50,6 @@ def read_chain_dir(chain_dir) -> dict:
pdb_id = pdb_id.lower() pdb_id = pdb_id.lower()
chain_name = f"{pdb_id}_{chain}" chain_name = f"{pdb_id}_{chain}"
file_data = [] file_data = []
for file_path in sorted(chain_dir.iterdir()): for file_path in sorted(chain_dir.iterdir()):
...@@ -62,7 +63,7 @@ def read_chain_dir(chain_dir) -> dict: ...@@ -62,7 +63,7 @@ def read_chain_dir(chain_dir) -> dict:
return {chain_name: file_data} return {chain_name: file_data}
def process_chunk(chain_files: List[Path]) -> dict: def process_chunk(chain_files: list[Path]) -> dict:
""" """
Returns the file names and bytes for all chains in a chunk of files. Returns the file names and bytes for all chains in a chunk of files.
""" """
...@@ -83,7 +84,7 @@ def create_index_default_dict() -> dict: ...@@ -83,7 +84,7 @@ def create_index_default_dict() -> dict:
def create_shard( def create_shard(
shard_files: List[Path], output_dir: Path, output_name: str, shard_num: int shard_files: list[Path], output_dir: Path, output_name: str, shard_num: int
) -> dict: ) -> dict:
""" """
Creates a single shard of the alignment database, and returns the Creates a single shard of the alignment database, and returns the
...@@ -92,7 +93,7 @@ def create_shard( ...@@ -92,7 +93,7 @@ def create_shard(
CHUNK_SIZE = 200 CHUNK_SIZE = 200
shard_index = defaultdict( shard_index = defaultdict(
create_index_default_dict create_index_default_dict
) # {chain_name: {db: str, files: [(file_name, db_offset, file_length)]}, ...} ) # e.g. {chain_name: {db: str, files: [(file_name, db_offset, file_length)]}, ...}
chunk_iter = chunked_iterator(shard_files, CHUNK_SIZE) chunk_iter = chunked_iterator(shard_files, CHUNK_SIZE)
pbar_desc = f"Shard {shard_num}" pbar_desc = f"Shard {shard_num}"
...@@ -101,7 +102,11 @@ def create_shard( ...@@ -101,7 +102,11 @@ def create_shard(
db_offset = 0 db_offset = 0
db_file = open(output_path, "wb") db_file = open(output_path, "wb")
for files_chunk in tqdm( for files_chunk in tqdm(
chunk_iter, total=ceil(len(shard_files) / CHUNK_SIZE), desc=pbar_desc, position=shard_num, leave=False chunk_iter,
total=ceil(len(shard_files) / CHUNK_SIZE),
desc=pbar_desc,
position=shard_num,
leave=False,
): ):
# get processed files for one chunk # get processed files for one chunk
chunk_data = process_chunk(files_chunk) chunk_data = process_chunk(files_chunk)
...@@ -125,9 +130,17 @@ def create_shard( ...@@ -125,9 +130,17 @@ def create_shard(
def main(args): def main(args):
alignment_dir = args.alignment_dir alignment_dir = args.alignment_dir
output_dir = args.output_db_path output_dir = args.output_db_path
output_dir.mkdir(exist_ok=True, parents=True)
output_db_name = args.output_db_name output_db_name = args.output_db_name
n_shards = args.n_shards n_shards = args.n_shards
n_cpus = cpu_count()
if n_shards > n_cpus:
print(
f"Warning: Your number of shards ({n_shards}) is greater than the number of cores on your machine ({n_cpus}). "
"This may result in slower performance. Consider using a smaller number of shards."
)
# get all chain dirs in alignment_dir # get all chain dirs in alignment_dir
print("Getting chain directories...") print("Getting chain directories...")
all_chain_dirs = sorted([f for f in tqdm(alignment_dir.iterdir())]) all_chain_dirs = sorted([f for f in tqdm(alignment_dir.iterdir())])
...@@ -153,6 +166,30 @@ def main(args): ...@@ -153,6 +166,30 @@ def main(args):
super_index.update(shard_index) super_index.update(shard_index)
print("\nCreated all shards.") print("\nCreated all shards.")
if args.duplicate_chains_file:
print("Extending super index with duplicate chains...")
duplicates_added = 0
with open(args.duplicate_chains_file, "r") as fp:
duplicate_chains = [line.strip().split() for line in fp]
for chains in duplicate_chains:
# find representative with alignment
for chain in chains:
if chain in super_index:
representative_chain = chain
break
else:
print(f"No representative chain found for {chains}, skipping...")
continue
# add duplicates to index
for chain in chains:
if chain != representative_chain:
super_index[chain] = super_index[representative_chain]
duplicates_added += 1
print(f"Added {duplicates_added} duplicate chains to index.")
# write super index to file # write super index to file
print("\nWriting super index...") print("\nWriting super index...")
index_path = output_dir / f"{output_db_name}.index" index_path = output_dir / f"{output_db_name}.index"
...@@ -179,13 +216,27 @@ if __name__ == "__main__": ...@@ -179,13 +216,27 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"alignment_dir", "alignment_dir",
type=Path, type=Path,
help="""Path to precomputed alignment directory, with one subdirectory help="""Path to precomputed flattened alignment directory, with one
per chain.""", subdirectory per chain.""",
) )
parser.add_argument("output_db_path", type=Path) parser.add_argument("output_db_path", type=Path)
parser.add_argument("output_db_name", type=str) parser.add_argument("output_db_name", type=str)
parser.add_argument( parser.add_argument(
"n_shards", type=int, help="Number of shards to split the database into" "--n_shards",
type=int,
help="Number of shards to split the database into",
default=10,
)
parser.add_argument(
"--duplicate_chains_file",
type=Path,
help="""
Optional path to file containing duplicate chain information, where each
line contains chains that are 100% sequence identical. If provided,
duplicate chains will be added to the index and point to the same
underlying database entry as their representatives in the alignment dir.
""",
default=None,
) )
args = parser.parse_args() args = parser.parse_args()
......
"""
The OpenProteinSet alignment database is non-redundant, meaning that it only
stores one explicit representative alignment directory for all PDB chains in a
100% sequence identity cluster. In order to add explicit alignments for all PDB
chains, this script will add the missing chain directories and symlink them to
their representative alignment directories. This is required in order to train
OpenFold on the full PDB, not just one representative chain per cluster.
"""
from argparse import ArgumentParser
from pathlib import Path
from tqdm import tqdm
def create_duplicate_dirs(duplicate_chains: list[list[str]], alignment_dir: Path):
"""
Create duplicate directory symlinks for all chains in the given duplicate lists.
Args:
duplicate_lists (list[list[str]]): A list of lists, where each inner list
contains chains that are 100% sequence identical.
alignment_dir (Path): Path to flattened alignment directory, with one
subdirectory per chain.
"""
print("Creating duplicate directory symlinks...")
dirs_created = 0
for chains in tqdm(duplicate_chains):
# find the chain that has an alignment
for chain in chains:
if (alignment_dir / chain).exists():
representative_chain = chain
break
else:
print(f"No representative chain found for {chains}, skipping...")
continue
# create symlinks for all other chains
for chain in chains:
if chain != representative_chain:
target_path = alignment_dir / chain
if target_path.exists():
print(f"Chain {chain} already exists, skipping...")
else:
(target_path).symlink_to(alignment_dir / representative_chain)
dirs_created += 1
print(f"Created directories for {dirs_created} duplicate chains.")
def main(alignment_dir: Path, duplicate_chains_file: Path):
# read duplicate chains file
with open(duplicate_chains_file, "r") as fp:
duplicate_chains = [list(line.strip().split()) for line in fp]
# convert to absolute path for symlink creation
alignment_dir = alignment_dir.resolve()
create_duplicate_dirs(duplicate_chains, alignment_dir)
if __name__ == "__main__":
parser = ArgumentParser(description=__doc__)
parser.add_argument(
"alignment_dir",
type=Path,
help="""Path to flattened alignment directory, with one subdirectory
per chain.""",
)
parser.add_argument(
"duplicate_chains_file",
type=Path,
help="""Path to file containing duplicate chains, where each line
contains a space-separated list of chains that are 100%%
sequence identical.
""",
)
args = parser.parse_args()
main(args.alignment_dir, args.duplicate_chains_file)
...@@ -85,7 +85,7 @@ def main(args): ...@@ -85,7 +85,7 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = ArgumentParser( parser = ArgumentParser(
description="Creates a sequence cluster file from a .fasta file using mmseqs2 with PDB settings." description=__doc__
) )
parser.add_argument( parser.add_argument(
"input_fasta", "input_fasta",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment