utils.py 6.45 KB
Newer Older
chenxj's avatar
chenxj committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import os
import torch
import warnings
import torchaudio
from typing import List
from itertools import groupby
import math


def read_batch(audio_paths: List[str]):
    return [read_audio(audio_path)
            for audio_path
            in audio_paths]


def split_into_batches(lst: List[str],
                       batch_size: int = 10):
    return [lst[i:i + batch_size]
            for i in
            range(0, len(lst), batch_size)]


def read_audio(path: str,
               target_sr: int = 16000):

    wav, sr = torchaudio.load(path)

    if wav.size(0) > 1:
        wav = wav.mean(dim=0, keepdim=True)

    if sr != target_sr:
        transform = torchaudio.transforms.Resample(orig_freq=sr,
                                                   new_freq=target_sr)
        wav = transform(wav)
        sr = target_sr

    
    if len(wav.squeeze(0)) > 8 * pow(2, 18):
        transform = torchaudio.transforms.Resample(orig_freq=16000, new_freq=int(16000 * 8 * pow(2, 18) / len(wav.squeeze(0))))
        wav = transform(wav)
    assert len(wav.squeeze(0)) <= 8 * pow(2, 18)
    return wav.squeeze(0)


def _check_wav_file(path):
    img_end = {'wav'}
    return any([path.lower().endswith(e) for e in img_end])


def get_wav_file_list(wav_file):
    wavs_lists = []
    if wav_file is None or not os.path.exists(wav_file):
        raise Exception("not found any wav file in {}".format(wav_file))

    if os.path.isfile(wav_file) and _check_wav_file(wav_file):
        wavs_lists.append(wav_file)
    elif os.path.isdir(wav_file):
        for single_file in os.listdir(wav_file):
            file_path = os.path.join(wav_file, single_file)
            if os.path.isfile(file_path) and _check_wav_file(file_path):
                wavs_lists.append(file_path)
    if len(wavs_lists) == 0:
        raise Exception("not found any wav file in {}".format(wav_file))
    wavs_lists = sorted(wavs_lists)
    return wavs_lists


def prepare_model_input(lst: List[str], batch_size: int = 24, device=torch.device('cpu')):
    batches = split_into_batches(lst, batch_size)
    inputs = []
    for i in range(0, len(batches)):
        tensors = read_batch(batches[i])
        max_seqlength = max(max([len(_) for _ in tensors]), 12800)
        max_seqlength = math.ceil(max_seqlength / pow(2, 18)) * pow(2, 18)
        max_batch = math.ceil(len(tensors) / 4) * 4
        input = torch.zeros(max_batch, max_seqlength)
        for i, wav in enumerate(tensors):
            input[i, :len(wav)].copy_(wav)
        input = input.to(device)
        inputs.append(input)
    return inputs


def prepare_model_input_warmup(path: str, device=torch.device('cpu')):
    audio_tensor = read_audio(path)
    inputs = []
    max_batnum = 24
    min_batnum = 4
    for l in range(8):
        max_seqlength = (l + 1) * pow(2, 18)
        single_input = torch.zeros(1, max_seqlength)
        for i in range(int(max_seqlength / len(audio_tensor))):
            single_input[0, i * len(audio_tensor):(i + 1) * len(audio_tensor)].copy_(audio_tensor)
        for bn in range(int(max_batnum / min_batnum)):
            input = torch.zeros(min_batnum * (bn + 1), max_seqlength)
            for h in range(min_batnum * (bn + 1)):
                input[h, :single_input.shape[1]].copy_(single_input.squeeze(0))
            input = input.to(device)
            inputs.append(input) 
    return inputs


class Decoder():
    def __init__(self,
                 labels: List[str]):
        self.labels = labels
        self.blank_idx = self.labels.index('_')
        self.space_idx = self.labels.index(' ')

    def process(self,
                probs, wav_len, word_align):
        assert len(self.labels) == probs.shape[1]
        for_string = []
        argm = torch.argmax(probs, axis=1)
        align_list = [[]]
        for j, i in enumerate(argm):
            if i == self.labels.index('2'):
                try:
                    prev = for_string[-1]
                    for_string.append('$')
                    for_string.append(prev)
                    align_list[-1].append(j)
                    continue
                except:
                    for_string.append(' ')
                    warnings.warn('Token "2" detected a the beginning of sentence, omitting')
                    align_list.append([])
                    continue
            if i != self.blank_idx:
                for_string.append(self.labels[i])
                if i == self.space_idx:
                    align_list.append([])
                else:
                    align_list[-1].append(j)

        string = ''.join([x[0] for x in groupby(for_string)]).replace('$', '').strip()

        align_list = list(filter(lambda x: x, align_list))

        if align_list and wav_len and word_align:
            align_dicts = []
            linear_align_coeff = wav_len / len(argm)
            to_move = min(align_list[0][0], 1.5)
            for i, align_word in enumerate(align_list):
                if len(align_word) == 1:
                    align_word.append(align_word[0])
                align_word[0] = align_word[0] - to_move
                if i == (len(align_list) - 1):
                    to_move = min(1.5, len(argm) - i)
                    align_word[-1] = align_word[-1] + to_move
                else:
                    to_move = min(1.5, (align_list[i+1][0] - align_word[-1]) / 2)
                    align_word[-1] = align_word[-1] + to_move

            for word, timing in zip(string.split(), align_list):
                align_dicts.append({'word': word,
                                    'start_ts': round(timing[0] * linear_align_coeff, 2),
                                    'end_ts': round(timing[-1] * linear_align_coeff, 2)})

            return string, align_dicts
        return string

    def __call__(self,
                 probs: torch.Tensor,
                 wav_len: float = 0,
                 word_align: bool = False):
        return self.process(probs, wav_len, word_align)


def init_jit_model(model_url: str,
                   device: torch.device = torch.device('cpu')):
    torch.set_grad_enabled(False)

    model_dir = os.path.join(os.path.dirname(__file__), "model")
    os.makedirs(model_dir, exist_ok=True)
    model_path = os.path.join(model_dir, os.path.basename(model_url))

    if not os.path.isfile(model_path):
        torch.hub.download_url_to_file(model_url,
                                       model_path,
                                       progress=True)

    model = torch.jit.load(model_path, map_location=device)
    model.eval()
    return model, Decoder(model.labels)