ctc_util.py 3.1 KB
Newer Older
Sugon_ldc's avatar
Sugon_ldc committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import torch

def insert_blank(label, blank_id=0):
    """Insert blank token between every two label token."""
    label = np.expand_dims(label, 1)
    blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id
    label = np.concatenate([blanks, label], axis=1)
    label = label.reshape(-1)
    label = np.append(label, label[0])
    return label

def forced_align(ctc_probs: torch.Tensor,
                 y: torch.Tensor,
                 blank_id=0) -> list:
    """ctc forced alignment.

    Args:
        torch.Tensor ctc_probs: hidden state sequence, 2d tensor (T, D)
        torch.Tensor y: id sequence tensor 1d tensor (L)
        int blank_id: blank symbol index
    Returns:
        torch.Tensor: alignment result
    """
    y_insert_blank = insert_blank(y, blank_id)

    log_alpha = torch.zeros((ctc_probs.size(0), len(y_insert_blank)))
    log_alpha = log_alpha - float('inf')  # log of zero
    state_path = (torch.zeros(
        (ctc_probs.size(0), len(y_insert_blank)), dtype=torch.int16) - 1
    )  # state path

    # init start state
    log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]]
    log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]]

    for t in range(1, ctc_probs.size(0)):
        for s in range(len(y_insert_blank)):
            if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[
                    s] == y_insert_blank[s - 2]:
                candidates = torch.tensor(
                    [log_alpha[t - 1, s], log_alpha[t - 1, s - 1]])
                prev_state = [s, s - 1]
            else:
                candidates = torch.tensor([
                    log_alpha[t - 1, s],
                    log_alpha[t - 1, s - 1],
                    log_alpha[t - 1, s - 2],
                ])
                prev_state = [s, s - 1, s - 2]
            log_alpha[t, s] = torch.max(candidates) + ctc_probs[t][y_insert_blank[s]]
            state_path[t, s] = prev_state[torch.argmax(candidates)]

    state_seq = -1 * torch.ones((ctc_probs.size(0), 1), dtype=torch.int16)

    candidates = torch.tensor([
        log_alpha[-1, len(y_insert_blank) - 1],
        log_alpha[-1, len(y_insert_blank) - 2]
    ])
    prev_state = [len(y_insert_blank) - 1, len(y_insert_blank) - 2]
    state_seq[-1] = prev_state[torch.argmax(candidates)]
    for t in range(ctc_probs.size(0) - 2, -1, -1):
        state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]]

    output_alignment = []
    for t in range(0, ctc_probs.size(0)):
        output_alignment.append(y_insert_blank[state_seq[t, 0]])

    return output_alignment