utils.py 2.62 KB
Newer Older
zhangqha's avatar
zhangqha committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from absl import logging
import json
import os
from typing import Mapping, Sequence

from unifold.data import protein


def get_chain_id_map(
    sequences: Sequence[str],
    descriptions: Sequence[str],
):
    """
    Makes a mapping from PDB-format chain ID to sequence and description,
    and parses the order of multi-chains
    """
    unique_seqs = []
    for seq in sequences:
        if seq not in unique_seqs:
            unique_seqs.append(seq)

    chain_id_map = {
        chain_id: {"descriptions": [], "sequence": seq}
        for chain_id, seq in zip(protein.PDB_CHAIN_IDS, unique_seqs)
    }
    chain_order = []

    for seq, des in zip(sequences, descriptions):
        chain_id = protein.PDB_CHAIN_IDS[unique_seqs.index(seq)]
        chain_id_map[chain_id]["descriptions"].append(des)
        chain_order.append(chain_id)

    return chain_id_map, chain_order


def divide_multi_chains(
    fasta_name: str,
    output_dir_base: str,
    sequences: Sequence[str],
    descriptions: Sequence[str],
):
    """
    Divides the multi-chains fasta into several single fasta files and
    records multi-chains mapping information.
    """
    if len(sequences) != len(descriptions):
        raise ValueError(
            "sequences and descriptions must have equal length. "
            f"Got {len(sequences)} != {len(descriptions)}."
        )
    if len(sequences) > protein.PDB_MAX_CHAINS:
        raise ValueError(
            "Cannot process more chains than the PDB format supports. "
            f"Got {len(sequences)} chains."
        )

    chain_id_map, chain_order = get_chain_id_map(sequences, descriptions)

    output_dir = os.path.join(output_dir_base, fasta_name)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    chain_id_map_path = os.path.join(output_dir, "chain_id_map.json")
    with open(chain_id_map_path, "w") as f:
        json.dump(chain_id_map, f, indent=4, sort_keys=True)

    chain_order_path = os.path.join(output_dir, "chains.txt")
    with open(chain_order_path, "w") as f:
        f.write(" ".join(chain_order))

    logging.info(
        "Mapping multi-chains fasta with chain order: %s", " ".join(chain_order)
    )

    temp_names = []
    temp_paths = []
    for chain_id in chain_id_map.keys():
        temp_name = fasta_name + "_{}".format(chain_id)
        temp_path = os.path.join(output_dir, temp_name + ".fasta")
        des = "chain_{}".format(chain_id)
        seq = chain_id_map[chain_id]["sequence"]
        with open(temp_path, "w") as f:
            f.write(">" + des + "\n" + seq)
        temp_names.append(temp_name)
        temp_paths.append(temp_path)
    return temp_names, temp_paths