generate_chain_data_cache.py 4.46 KB
Newer Older
1
2
3
4
5
6
import argparse
from functools import partial
import json
import logging
from multiprocessing import Pool
import os
7
import string
8
9
10
import sys
sys.path.append(".") # an innocent hack to get this to run from the top level

11
from collections import defaultdict
12
13
14
from tqdm import tqdm

from openfold.data.mmcif_parsing import parse 
15
from openfold.np import protein, residue_constants
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41


def parse_file(
    f, 
    args,
    chain_cluster_size_dict
):
    file_id, ext = os.path.splitext(f)
    if(ext == ".cif"):
        with open(os.path.join(args.data_dir, f), "r") as fp:
            mmcif_string = fp.read()
        mmcif = parse(file_id=file_id, mmcif_string=mmcif_string)
        if mmcif.mmcif_object is None:
            logging.info(f"Could not parse {f}. Skipping...")
            return {}
        else:
            mmcif = mmcif.mmcif_object

        out = {}
        for chain_id, seq in mmcif.chain_to_seqres.items():
            full_name = "_".join([file_id, chain_id])
            out[full_name] = {}
            local_data = out[full_name]
            local_data["release_date"] = mmcif.header["release_date"]
            local_data["seq"] = seq
            local_data["resolution"] = mmcif.header["resolution"]
42
43
44
45
46
           
            if(chain_cluster_size_dict is not None):
                cluster_size = chain_cluster_size_dict.get(
                    full_name.upper(), -1
                )
47
                local_data["cluster_size"] = cluster_size
48
49
50
51
    elif(ext == ".pdb"):
        with open(os.path.join(args.data_dir, f), "r") as fp:
            pdb_string = fp.read()
          
52
        protein_object = protein.from_pdb_string(pdb_string, None)
53
54
        aatype = protein_object.aatype
        chain_index = protein_object.chain_index
55

56
57
58
        chain_dict = defaultdict(list)
        for i in range(aatype.shape[0]):
            chain_dict[chain_index[i]].append(residue_constants.restypes_with_x[aatype[i]])
59

60
61
62
63
64
65
66
67
68
69
70
71
72
73
        out = {}
        chain_tags = string.ascii_uppercase
        for chain, seq in chain_dict.items():
            full_name = "_".join([file_id, chain_tags[chain]])
            out[full_name] = {}
            local_data = out[full_name]
            local_data["resolution"] = 0.
            local_data["seq"] = ''.join(seq)
        
            if(chain_cluster_size_dict is not None):
                cluster_size = chain_cluster_size_dict.get(
                    full_name.upper(), -1
                )
                local_data["cluster_size"] = cluster_size
74
75
76
77
78

    return out


def main(args):
79
    chain_cluster_size_dict = None
80
    if(args.cluster_file is not None):
81
        chain_cluster_size_dict = {}
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
        with open(args.cluster_file, "r") as fp:
            clusters = [l.strip() for l in fp.readlines()]

        for cluster in clusters:
            chain_ids = cluster.split()
            cluster_len = len(chain_ids)
            for chain_id in chain_ids:
                chain_id = chain_id.upper()
                chain_cluster_size_dict[chain_id] = cluster_len
   
    accepted_exts = [".cif", ".pdb"]
    files = list(os.listdir(args.data_dir))
    files = [f for f in files if os.path.splitext(f)[-1] in accepted_exts]
    fn = partial(
        parse_file, 
        args=args,
        chain_cluster_size_dict=chain_cluster_size_dict,
    )
    data = {}
    with Pool(processes=args.no_workers) as p:
        with tqdm(total=len(files)) as pbar:
            for d in p.imap_unordered(fn, files, chunksize=args.chunksize):
                data.update(d)
                pbar.update()

    with open(args.output_path, "w") as fp:
        fp.write(json.dumps(data, indent=4))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "data_dir", type=str, help="Directory containing mmCIF or PDB files"
    )
    parser.add_argument(
        "output_path", type=str, help="Path for .json output"
    )
    parser.add_argument(
        "--cluster_file", type=str, default=None,
121
122
123
124
125
126
        help=(
            "Path to a cluster file (e.g. PDB40), one cluster "
            "({PROT1_ID}_{CHAIN_ID} {PROT2_ID}_{CHAIN_ID} ...) per line. "
            "Chains not in this cluster file will NOT be filtered by cluster "
            "size."
        )
127
128
129
130
131
132
133
134
135
136
137
138
139
    )
    parser.add_argument(
        "--no_workers", type=int, default=4,
        help="Number of workers to use for parsing"
    )
    parser.add_argument(
        "--chunksize", type=int, default=10,
        help="How many files should be distributed to each worker at a time"
    )

    args = parser.parse_args()

    main(args)