classification_utils.py 1.25 KB
Newer Older
yangql's avatar
yangql committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import sys
from typing import List, Sequence

import numpy as np


def levenshtein_distance(seq1: Sequence, seq2: Sequence):
    if seq1 == seq2:
        return 0
    num_rows = len(seq1) + 1
    num_cols = len(seq2) + 1
    dp_matrix = np.empty((num_rows, num_cols))
    dp_matrix[0, :] = range(num_cols)
    dp_matrix[:, 0] = range(num_rows)

    for i in range(1, num_rows):
        for j in range(1, num_cols):
            if seq1[i - 1] == seq2[j - 1]:
                dp_matrix[i, j] = dp_matrix[i - 1, j - 1]
            else:
                dp_matrix[i, j] = (
                    min(
                        dp_matrix[i - 1, j - 1],
                        dp_matrix[i - 1, j],
                        dp_matrix[i, j - 1],
                    )
                    + 1
                )

    return dp_matrix[num_rows - 1, num_cols - 1]


def get_closest_label(pred: Sequence, classes: List[Sequence]) -> int:
    min_id = sys.maxsize
    min_edit_distance = sys.maxsize
    for i, class_label in enumerate(classes):
        edit_distance = levenshtein_distance(pred, class_label)
        if edit_distance < min_edit_distance:
            min_id = i
            min_edit_distance = edit_distance
    return min_id


__all__ = ["levenshtein_distance", "get_closest_label"]