generate_chain_data_cache.py 4.09 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
import argparse
from functools import partial
import json
import logging
from multiprocessing import Pool
import os

import sys
sys.path.append(".") # an innocent hack to get this to run from the top level

from tqdm import tqdm

from openfold.data.mmcif_parsing import parse 
14
from openfold.np import protein, residue_constants
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40


def parse_file(
    f, 
    args,
    chain_cluster_size_dict
):
    file_id, ext = os.path.splitext(f)
    if(ext == ".cif"):
        with open(os.path.join(args.data_dir, f), "r") as fp:
            mmcif_string = fp.read()
        mmcif = parse(file_id=file_id, mmcif_string=mmcif_string)
        if mmcif.mmcif_object is None:
            logging.info(f"Could not parse {f}. Skipping...")
            return {}
        else:
            mmcif = mmcif.mmcif_object

        out = {}
        for chain_id, seq in mmcif.chain_to_seqres.items():
            full_name = "_".join([file_id, chain_id])
            out[full_name] = {}
            local_data = out[full_name]
            local_data["release_date"] = mmcif.header["release_date"]
            local_data["seq"] = seq
            local_data["resolution"] = mmcif.header["resolution"]
41
42
43
44
45
           
            if(chain_cluster_size_dict is not None):
                cluster_size = chain_cluster_size_dict.get(
                    full_name.upper(), -1
                )
46
                local_data["cluster_size"] = cluster_size
47
48
49
50
    elif(ext == ".pdb"):
        with open(os.path.join(args.data_dir, f), "r") as fp:
            pdb_string = fp.read()
          
51
        protein_object = protein.from_pdb_string(pdb_string, None)
52
53
54
55
56
57
58

        chain_dict = {} 
        chain_dict["seq"] = residue_constants.aatype_to_str_sequence(
            protein_object.aatype,
        )
        local_data["resolution"] = 0.

59
60
61
62
63
        cluster_size = chain_cluster_size_dict.get(file_id.upper(), -1)
        if(chain_cluster_size_dict is not None):
            cluster_size = chain_cluster_size_dict.get(
                full_name.upper(), -1
            )
64
            chain_dict["cluster_size"] = cluster_size
65
66
67
68
69
70
71

        out = {file_id: chain_dict}

    return out


def main(args):
72
    chain_cluster_size_dict = None
73
    if(args.cluster_file is not None):
74
        chain_cluster_size_dict = {}
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        with open(args.cluster_file, "r") as fp:
            clusters = [l.strip() for l in fp.readlines()]

        for cluster in clusters:
            chain_ids = cluster.split()
            cluster_len = len(chain_ids)
            for chain_id in chain_ids:
                chain_id = chain_id.upper()
                chain_cluster_size_dict[chain_id] = cluster_len
   
    accepted_exts = [".cif", ".pdb"]
    files = list(os.listdir(args.data_dir))
    files = [f for f in files if os.path.splitext(f)[-1] in accepted_exts]
    fn = partial(
        parse_file, 
        args=args,
        chain_cluster_size_dict=chain_cluster_size_dict,
    )
    data = {}
    with Pool(processes=args.no_workers) as p:
        with tqdm(total=len(files)) as pbar:
            for d in p.imap_unordered(fn, files, chunksize=args.chunksize):
                data.update(d)
                pbar.update()

    with open(args.output_path, "w") as fp:
        fp.write(json.dumps(data, indent=4))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "data_dir", type=str, help="Directory containing mmCIF or PDB files"
    )
    parser.add_argument(
        "output_path", type=str, help="Path for .json output"
    )
    parser.add_argument(
        "--cluster_file", type=str, default=None,
114
115
116
117
118
119
        help=(
            "Path to a cluster file (e.g. PDB40), one cluster "
            "({PROT1_ID}_{CHAIN_ID} {PROT2_ID}_{CHAIN_ID} ...) per line. "
            "Chains not in this cluster file will NOT be filtered by cluster "
            "size."
        )
120
121
122
123
124
125
126
127
128
129
130
131
132
    )
    parser.add_argument(
        "--no_workers", type=int, default=4,
        help="Number of workers to use for parsing"
    )
    parser.add_argument(
        "--chunksize", type=int, default=10,
        help="How many files should be distributed to each worker at a time"
    )

    args = parser.parse_args()

    main(args)