generate_chain_data_cache.py 3.86 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
import argparse
from functools import partial
import json
import logging
from multiprocessing import Pool
import os

import sys
sys.path.append(".") # an innocent hack to get this to run from the top level

from tqdm import tqdm

from openfold.data.mmcif_parsing import parse 
14
from openfold.np import protein, residue_constants
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52


def parse_file(
    f, 
    args,
    chain_cluster_size_dict
):
    file_id, ext = os.path.splitext(f)
    if(ext == ".cif"):
        with open(os.path.join(args.data_dir, f), "r") as fp:
            mmcif_string = fp.read()
        mmcif = parse(file_id=file_id, mmcif_string=mmcif_string)
        if mmcif.mmcif_object is None:
            logging.info(f"Could not parse {f}. Skipping...")
            return {}
        else:
            mmcif = mmcif.mmcif_object

        out = {}
        for chain_id, seq in mmcif.chain_to_seqres.items():
            full_name = "_".join([file_id, chain_id])
            out[full_name] = {}
            local_data = out[full_name]
            local_data["release_date"] = mmcif.header["release_date"]
            local_data["seq"] = seq
            local_data["resolution"] = mmcif.header["resolution"]
            
            cluster_size = chain_cluster_size_dict.get(full_name.upper(), None)
            if(cluster_size is None):
                print(file_id)
                out.pop(full_name)
                continue
            else:
                local_data["cluster_size"] = cluster_size
    elif(ext == ".pdb"):
        with open(os.path.join(args.data_dir, f), "r") as fp:
            pdb_string = fp.read()
          
53
        protein_object = protein.from_pdb_string(pdb_string, None)
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128

        chain_dict = {} 
        chain_dict["seq"] = residue_constants.aatype_to_str_sequence(
            protein_object.aatype,
        )
        local_data["resolution"] = 0.

        cluster_size = chain_cluster_size_dict.get(file_id.upper(), None)
        if(cluster_size is None):
            print(file_id)
            return {}
        else:
            local_data["cluster_size"] = cluster_size

        out = {file_id: chain_dict}

    return out


def main(args):
    chain_cluster_size_dict = {}
    if(args.cluster_file is not None):
        with open(args.cluster_file, "r") as fp:
            clusters = [l.strip() for l in fp.readlines()]

        for cluster in clusters:
            chain_ids = cluster.split()
            cluster_len = len(chain_ids)
            for chain_id in chain_ids:
                chain_id = chain_id.upper()
                chain_cluster_size_dict[chain_id] = cluster_len
   
    accepted_exts = [".cif", ".pdb"]
    files = list(os.listdir(args.data_dir))
    files = [f for f in files if os.path.splitext(f)[-1] in accepted_exts]
    fn = partial(
        parse_file, 
        args=args,
        chain_cluster_size_dict=chain_cluster_size_dict,
    )
    data = {}
    with Pool(processes=args.no_workers) as p:
        with tqdm(total=len(files)) as pbar:
            for d in p.imap_unordered(fn, files, chunksize=args.chunksize):
                data.update(d)
                pbar.update()

    with open(args.output_path, "w") as fp:
        fp.write(json.dumps(data, indent=4))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "data_dir", type=str, help="Directory containing mmCIF or PDB files"
    )
    parser.add_argument(
        "output_path", type=str, help="Path for .json output"
    )
    parser.add_argument(
        "--cluster_file", type=str, default=None,
        help="Path to a cluster file (e.g. PDB40), one cluster per line"
    )
    parser.add_argument(
        "--no_workers", type=int, default=4,
        help="Number of workers to use for parsing"
    )
    parser.add_argument(
        "--chunksize", type=int, default=10,
        help="How many files should be distributed to each worker at a time"
    )

    args = parser.parse_args()

    main(args)