aligners.py 1.93 KB
Newer Older
wanglch's avatar
wanglch committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from typing import Type

from sequence_align.pairwise import hirschberg, needleman_wunsch

from .registry import BaseRegistry


class AlignerRegistry(BaseRegistry[Type["BaseAligner"]]):
    """A registry for aligners."""


class BaseAligner:
    def __init__(self, *args, **kwargs):
        super().__init__()

    def align(self, gold: list[str], pred: list[str]) -> tuple[list[str], list[str]]:
        raise NotImplementedError()


@AlignerRegistry.add("hirschberg")
class HirschbergAligner(BaseAligner):
    def __init__(
        self,
        match_score: float = 1.0,
        mismatch_score: float = -1.0,
        indel_score: float = -1.0,
        gap_token: str = "▓",
    ):
        self.match_score = match_score
        self.mismatch_score = mismatch_score
        self.indel_score = indel_score
        self.gap_token = gap_token
        super().__init__()

    def align(self, gold: list[str], pred: list[str]) -> tuple[list[str], list[str]]:
        return hirschberg(
            gold,
            pred,
            match_score=self.match_score,
            mismatch_score=self.mismatch_score,
            indel_score=self.indel_score,
            gap=self.gap_token,
        )


@AlignerRegistry.add("needleman-wunsch")
class NeedlemanWunschAligner(BaseAligner):
    def __init__(
        self,
        match_score: float = 1.0,
        mismatch_score: float = -1.0,
        indel_score: float = -1.0,
        gap_token: str = "▓",
    ):
        self.match_score = match_score
        self.mismatch_score = mismatch_score
        self.indel_score = indel_score
        self.gap_token = gap_token
        super().__init__()

    def align(self, gold: list[str], pred: list[str]) -> tuple[list[str], list[str]]:
        return needleman_wunsch(
            gold,
            pred,
            match_score=self.match_score,
            mismatch_score=self.mismatch_score,
            indel_score=self.indel_score,
            gap=self.gap_token,
        )