create_alignment_db_sharded.py 5.99 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
"""
This is a modified version of the create_alignment_db.py script in OpenFold
which supports sharding into multiple files. The created index is already a
super index, meaning that "unify_alignment_db_indices.py" does not need to be
run on the output index. Additionally this script uses threading and
multiprocessing and is much faster than the old version.
"""
import argparse
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
import json
from pathlib import Path
from typing import List
from tqdm import tqdm
from math import ceil


def split_file_list(file_list, n_shards):
    """
    Split up the total file list into n_shards sublists.
    """
    split_list = []

    for i in range(n_shards):
        split_list.append(file_list[i::n_shards])

    assert len([f for sublist in split_list for f in sublist]) == len(file_list)

    return split_list


def chunked_iterator(lst, chunk_size):
    """Iterate over a list in chunks of size chunk_size."""
    for i in range(0, len(lst), chunk_size):
        yield lst[i : i + chunk_size]


def read_chain_dir(chain_dir) -> dict:
    """
    Read all alignment files in a single chain directory and return a dict
    mapping chain name to file names and bytes.
    """
    if not chain_dir.is_dir():
        raise ValueError(f"chain_dir must be a directory, but is {chain_dir}")
    
    # ensure that PDB IDs are all lowercase
    pdb_id, chain = chain_dir.name.split("_")
    pdb_id = pdb_id.lower()
    chain_name = f"{pdb_id}_{chain}"
    
    
    file_data = []

    for file_path in sorted(chain_dir.iterdir()):
        file_name = file_path.name

        with open(file_path, "rb") as file:
            file_bytes = file.read()

        file_data.append((file_name, file_bytes))

    return {chain_name: file_data}


def process_chunk(chain_files: List[Path]) -> dict:
    """
    Returns the file names and bytes for all chains in a chunk of files.
    """
    chunk_data = {}

    with ThreadPoolExecutor() as executor:
        for file_data in executor.map(read_chain_dir, chain_files):
            chunk_data.update(file_data)

    return chunk_data


def create_index_default_dict() -> dict:
    """
    Returns a default dict for the index entries).
    """
    return {"db": None, "files": []}


def create_shard(
    shard_files: List[Path], output_dir: Path, output_name: str, shard_num: int
) -> dict:
    """
    Creates a single shard of the alignment database, and returns the
    corresponding indices for the super index.
    """
    CHUNK_SIZE = 200
    shard_index = defaultdict(
        create_index_default_dict
    )  # {chain_name: {db: str, files: [(file_name, db_offset, file_length)]}, ...}
    chunk_iter = chunked_iterator(shard_files, CHUNK_SIZE)

    pbar_desc = f"Shard {shard_num}"
    output_path = output_dir / f"{output_name}_{shard_num}.db"

    db_offset = 0
    db_file = open(output_path, "wb")
    for files_chunk in tqdm(
        chunk_iter, total=ceil(len(shard_files) / CHUNK_SIZE), desc=pbar_desc, position=shard_num, leave=False
    ):
        # get processed files for one chunk
        chunk_data = process_chunk(files_chunk)

        # write to db and store info in index
        for chain_name, file_data in chunk_data.items():
            shard_index[chain_name]["db"] = output_path.name

            for file_name, file_bytes in file_data:
                file_length = len(file_bytes)
                shard_index[chain_name]["files"].append(
                    (file_name, db_offset, file_length)
                )
                db_file.write(file_bytes)
                db_offset += file_length
    db_file.close()

    return shard_index


def main(args):
    alignment_dir = args.alignment_dir
    output_dir = args.output_db_path
    output_db_name = args.output_db_name
    n_shards = args.n_shards

    # get all chain dirs in alignment_dir
    print("Getting chain directories...")
    all_chain_dirs = sorted([f for f in tqdm(alignment_dir.iterdir())])

    # split chain dirs into n_shards sublists
    chain_dir_shards = split_file_list(all_chain_dirs, n_shards)

    # total index for all shards
    super_index = {}

    # create a shard for each sublist
    print(f"Creating {n_shards} alignment-db files...")
    with ProcessPoolExecutor() as executor:
        futures = [
            executor.submit(
                create_shard, shard_files, output_dir, output_db_name, shard_index
            )
            for shard_index, shard_files in enumerate(chain_dir_shards)
        ]

        for future in as_completed(futures):
            shard_index = future.result()
            super_index.update(shard_index)
    print("\nCreated all shards.")

    # write super index to file
    print("\nWriting super index...")
    index_path = output_dir / f"{output_db_name}.index"
    with open(index_path, "w") as fp:
        json.dump(super_index, fp, indent=4)
        
    print("Done.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="""
        This script creates an alignment database format from a directory of
        precomputed alignments. For better file system health, the total
        database is split into n_shards files, where each shard contains a
        subset of the total alignments. The output is a directory containing the
        n_shards database files, and a single index file mapping chain names to
        the database file and byte offsets for each alignment file.
        
        Note: For optimal performance, your machine should have at least as many
        cores as shards you want to create.
        """
    )
    parser.add_argument(
        "alignment_dir",
        type=Path,
        help="""Path to precomputed alignment directory, with one subdirectory 
                per chain.""",
    )
    parser.add_argument("output_db_path", type=Path)
    parser.add_argument("output_db_name", type=str)
    parser.add_argument(
        "n_shards", type=int, help="Number of shards to split the database into"
    )

    args = parser.parse_args()

    main(args)