Commit 77860bb7 authored by Lukas Jarosch's avatar Lukas Jarosch
Browse files

Improve type hints and formatting

parent 6ba0a594
......@@ -5,17 +5,18 @@ super index, meaning that "unify_alignment_db_indices.py" does not need to be
run on the output index. Additionally this script uses threading and
multiprocessing and is much faster than the old version.
"""
import argparse
import json
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
import json
from math import ceil
from pathlib import Path
from typing import List
from tqdm import tqdm
from math import ceil
def split_file_list(file_list, n_shards):
def split_file_list(file_list: list[Path], n_shards: int):
"""
Split up the total file list into n_shards sublists.
"""
......@@ -29,26 +30,25 @@ def split_file_list(file_list, n_shards):
return split_list
def chunked_iterator(lst, chunk_size):
def chunked_iterator(lst: list, chunk_size: int):
"""Iterate over a list in chunks of size chunk_size."""
for i in range(0, len(lst), chunk_size):
yield lst[i : i + chunk_size]
def read_chain_dir(chain_dir) -> dict:
def read_chain_dir(chain_dir: Path) -> dict:
"""
Read all alignment files in a single chain directory and return a dict
mapping chain name to file names and bytes.
"""
if not chain_dir.is_dir():
raise ValueError(f"chain_dir must be a directory, but is {chain_dir}")
# ensure that PDB IDs are all lowercase
pdb_id, chain = chain_dir.name.split("_")
pdb_id = pdb_id.lower()
chain_name = f"{pdb_id}_{chain}"
file_data = []
for file_path in sorted(chain_dir.iterdir()):
......@@ -62,7 +62,7 @@ def read_chain_dir(chain_dir) -> dict:
return {chain_name: file_data}
def process_chunk(chain_files: List[Path]) -> dict:
def process_chunk(chain_files: list[Path]) -> dict:
"""
Returns the file names and bytes for all chains in a chunk of files.
"""
......@@ -83,7 +83,7 @@ def create_index_default_dict() -> dict:
def create_shard(
shard_files: List[Path], output_dir: Path, output_name: str, shard_num: int
shard_files: list[Path], output_dir: Path, output_name: str, shard_num: int
) -> dict:
"""
Creates a single shard of the alignment database, and returns the
......@@ -101,7 +101,11 @@ def create_shard(
db_offset = 0
db_file = open(output_path, "wb")
for files_chunk in tqdm(
chunk_iter, total=ceil(len(shard_files) / CHUNK_SIZE), desc=pbar_desc, position=shard_num, leave=False
chunk_iter,
total=ceil(len(shard_files) / CHUNK_SIZE),
desc=pbar_desc,
position=shard_num,
leave=False,
):
# get processed files for one chunk
chunk_data = process_chunk(files_chunk)
......@@ -158,7 +162,7 @@ def main(args):
index_path = output_dir / f"{output_db_name}.index"
with open(index_path, "w") as fp:
json.dump(super_index, fp, indent=4)
print("Done.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment