kalign.py 3.71 KB
Newer Older
zhangqha's avatar
zhangqha committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#            http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A Python wrapper for Kalign."""
import os
import subprocess
from typing import Sequence

from absl import logging

from . import utils


def _to_a3m(sequences: Sequence[str]) -> str:
    """Converts sequences to an a3m file."""
    names = ["sequence %d" % i for i in range(1, len(sequences) + 1)]
    a3m = []
    for sequence, name in zip(sequences, names):
        a3m.append(">" + name + "\n")
        a3m.append(sequence + "\n")
    return "".join(a3m)


class Kalign:
    """Python wrapper of the Kalign binary."""

    def __init__(self, *, binary_path: str):
        """Initializes the Python Kalign wrapper.

        Args:
            binary_path: The path to the Kalign binary.

        Raises:
            RuntimeError: If Kalign binary not found within the path.
        """
        self.binary_path = binary_path

    def align(self, sequences: Sequence[str]) -> str:
        """Aligns the sequences and returns the alignment in A3M string.

        Args:
            sequences: A list of query sequence strings. The sequences have to be at
                least 6 residues long (Kalign requires this). Note that the order in
                which you give the sequences might alter the output slightly as
                different alignment tree might get constructed.

        Returns:
            A string with the alignment in a3m format.

        Raises:
            RuntimeError: If Kalign fails.
            ValueError: If any of the sequences is less than 6 residues long.
        """
        logging.info("Aligning %d sequences", len(sequences))

        for s in sequences:
            if len(s) < 6:
                raise ValueError(
                    "Kalign requires all sequences to be at least 6 "
                    "residues long. Got %s (%d residues)." % (s, len(s))
                )

        with utils.tmpdir_manager() as query_tmp_dir:
            input_fasta_path = os.path.join(query_tmp_dir, "input.fasta")
            output_a3m_path = os.path.join(query_tmp_dir, "output.a3m")

            with open(input_fasta_path, "w") as f:
                f.write(_to_a3m(sequences))

            cmd = [
                self.binary_path,
                "-i",
                input_fasta_path,
                "-o",
                output_a3m_path,
                "-format",
                "fasta",
            ]

            logging.info('Launching subprocess "%s"', " ".join(cmd))
            process = subprocess.Popen(
                cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
            )

            with utils.timing("Kalign query"):
                stdout, stderr = process.communicate()
                retcode = process.wait()
                logging.info(
                    "Kalign stdout:\n%s\n\nstderr:\n%s\n",
                    stdout.decode("utf-8"),
                    stderr.decode("utf-8"),
                )

            if retcode:
                raise RuntimeError(
                    "Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n"
                    % (stdout.decode("utf-8"), stderr.decode("utf-8"))
                )

            with open(output_a3m_path) as f:
                a3m = f.read()

            return a3m