create_alignment_db_sharded.py 7.79 KB
Newer Older
1
2
3
4
5
6
7
"""
This is a modified version of the create_alignment_db.py script in OpenFold
which supports sharding into multiple files. The created index is already a
super index, meaning that "unify_alignment_db_indices.py" does not need to be
run on the output index. Additionally this script uses threading and
multiprocessing and is much faster than the old version.
"""
8

9
import argparse
10
import json
11
12
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
13
from math import ceil
Lukas Jarosch's avatar
Lukas Jarosch committed
14
from multiprocessing import cpu_count
15
from pathlib import Path
16

17
18
19
from tqdm import tqdm


20
def split_file_list(file_list: list[Path], n_shards: int):
21
22
23
24
25
26
27
28
29
30
31
32
33
    """
    Split up the total file list into n_shards sublists.
    """
    split_list = []

    for i in range(n_shards):
        split_list.append(file_list[i::n_shards])

    assert len([f for sublist in split_list for f in sublist]) == len(file_list)

    return split_list


34
def chunked_iterator(lst: list, chunk_size: int):
35
36
37
38
39
    """Iterate over a list in chunks of size chunk_size."""
    for i in range(0, len(lst), chunk_size):
        yield lst[i : i + chunk_size]


40
def read_chain_dir(chain_dir: Path) -> dict:
41
42
43
44
45
46
    """
    Read all alignment files in a single chain directory and return a dict
    mapping chain name to file names and bytes.
    """
    if not chain_dir.is_dir():
        raise ValueError(f"chain_dir must be a directory, but is {chain_dir}")
47

48
49
50
51
    # ensure that PDB IDs are all lowercase
    pdb_id, chain = chain_dir.name.split("_")
    pdb_id = pdb_id.lower()
    chain_name = f"{pdb_id}_{chain}"
52

53
54
55
56
57
58
59
60
61
62
63
64
65
    file_data = []

    for file_path in sorted(chain_dir.iterdir()):
        file_name = file_path.name

        with open(file_path, "rb") as file:
            file_bytes = file.read()

        file_data.append((file_name, file_bytes))

    return {chain_name: file_data}


66
def process_chunk(chain_files: list[Path]) -> dict:
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
    """
    Returns the file names and bytes for all chains in a chunk of files.
    """
    chunk_data = {}

    with ThreadPoolExecutor() as executor:
        for file_data in executor.map(read_chain_dir, chain_files):
            chunk_data.update(file_data)

    return chunk_data


def create_index_default_dict() -> dict:
    """
    Returns a default dict for the index entries).
    """
    return {"db": None, "files": []}


def create_shard(
87
    shard_files: list[Path], output_dir: Path, output_name: str, shard_num: int
88
89
90
91
92
93
94
95
) -> dict:
    """
    Creates a single shard of the alignment database, and returns the
    corresponding indices for the super index.
    """
    CHUNK_SIZE = 200
    shard_index = defaultdict(
        create_index_default_dict
Lukas Jarosch's avatar
Lukas Jarosch committed
96
    )  # e.g. {chain_name: {db: str, files: [(file_name, db_offset, file_length)]}, ...}
97
98
99
100
101
102
103
104
    chunk_iter = chunked_iterator(shard_files, CHUNK_SIZE)

    pbar_desc = f"Shard {shard_num}"
    output_path = output_dir / f"{output_name}_{shard_num}.db"

    db_offset = 0
    db_file = open(output_path, "wb")
    for files_chunk in tqdm(
105
106
107
108
109
        chunk_iter,
        total=ceil(len(shard_files) / CHUNK_SIZE),
        desc=pbar_desc,
        position=shard_num,
        leave=False,
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    ):
        # get processed files for one chunk
        chunk_data = process_chunk(files_chunk)

        # write to db and store info in index
        for chain_name, file_data in chunk_data.items():
            shard_index[chain_name]["db"] = output_path.name

            for file_name, file_bytes in file_data:
                file_length = len(file_bytes)
                shard_index[chain_name]["files"].append(
                    (file_name, db_offset, file_length)
                )
                db_file.write(file_bytes)
                db_offset += file_length
    db_file.close()

    return shard_index


def main(args):
    alignment_dir = args.alignment_dir
    output_dir = args.output_db_path
133
    output_dir.mkdir(exist_ok=True, parents=True)
134
135
136
    output_db_name = args.output_db_name
    n_shards = args.n_shards

Lukas Jarosch's avatar
Lukas Jarosch committed
137
138
139
140
141
142
143
    n_cpus = cpu_count()
    if n_shards > n_cpus:
        print(
            f"Warning: Your number of shards ({n_shards}) is greater than the number of cores on your machine ({n_cpus}). "
            "This may result in slower performance. Consider using a smaller number of shards."
        )

144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    # get all chain dirs in alignment_dir
    print("Getting chain directories...")
    all_chain_dirs = sorted([f for f in tqdm(alignment_dir.iterdir())])

    # split chain dirs into n_shards sublists
    chain_dir_shards = split_file_list(all_chain_dirs, n_shards)

    # total index for all shards
    super_index = {}

    # create a shard for each sublist
    print(f"Creating {n_shards} alignment-db files...")
    with ProcessPoolExecutor() as executor:
        futures = [
            executor.submit(
                create_shard, shard_files, output_dir, output_db_name, shard_index
            )
            for shard_index, shard_files in enumerate(chain_dir_shards)
        ]

        for future in as_completed(futures):
            shard_index = future.result()
            super_index.update(shard_index)
    print("\nCreated all shards.")

169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
    if args.duplicate_chains_file:
        print("Extending super index with duplicate chains...")
        duplicates_added = 0
        with open(args.duplicate_chains_file, "r") as fp:
            duplicate_chains = [line.strip().split() for line in fp]

        for chains in duplicate_chains:
            # find representative with alignment
            for chain in chains:
                if chain in super_index:
                    representative_chain = chain
                    break
            else:
                print(f"No representative chain found for {chains}, skipping...")
                continue

            # add duplicates to index
            for chain in chains:
                if chain != representative_chain:
                    super_index[chain] = super_index[representative_chain]
                    duplicates_added += 1

        print(f"Added {duplicates_added} duplicate chains to index.")

193
194
195
196
197
    # write super index to file
    print("\nWriting super index...")
    index_path = output_dir / f"{output_db_name}.index"
    with open(index_path, "w") as fp:
        json.dump(super_index, fp, indent=4)
198

199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
    print("Done.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="""
        This script creates an alignment database format from a directory of
        precomputed alignments. For better file system health, the total
        database is split into n_shards files, where each shard contains a
        subset of the total alignments. The output is a directory containing the
        n_shards database files, and a single index file mapping chain names to
        the database file and byte offsets for each alignment file.
        
        Note: For optimal performance, your machine should have at least as many
        cores as shards you want to create.
        """
    )
    parser.add_argument(
        "alignment_dir",
        type=Path,
219
220
        help="""Path to precomputed flattened alignment directory, with one
                subdirectory per chain.""",
221
222
223
224
    )
    parser.add_argument("output_db_path", type=Path)
    parser.add_argument("output_db_name", type=str)
    parser.add_argument(
Lukas Jarosch's avatar
Lukas Jarosch committed
225
226
227
228
        "--n_shards",
        type=int,
        help="Number of shards to split the database into",
        default=10,
229
    )
230
231
232
233
234
235
236
237
238
239
240
    parser.add_argument(
        "--duplicate_chains_file",
        type=Path,
        help="""
        Optional path to file containing duplicate chain information, where each
        line contains chains that are 100% sequence identical. If provided,
        duplicate chains will be added to the index and point to the same
        underlying database entry as their representatives in the alignment dir.
        """,
        default=None,
    )
241
242
243
244

    args = parser.parse_args()

    main(args)