rapid_paraformer.py 4.91 KB
Newer Older
SWHL's avatar
SWHL committed
1
2
3
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
4
import traceback
SWHL's avatar
SWHL committed
5
from pathlib import Path
SWHL's avatar
SWHL committed
6
from typing import List, Union, Tuple
SWHL's avatar
SWHL committed
7
8
9
10

import librosa
import numpy as np

SWHL's avatar
SWHL committed
11
12
13
from .utils import (CharTokenizer, Hypothesis, ONNXRuntimeError,
                    OrtInferSession, TokenIDConverter, WavFrontend, get_logger,
                    read_yaml)
SWHL's avatar
SWHL committed
14

15
logging = get_logger()
SWHL's avatar
SWHL committed
16
17
18


class RapidParaformer():
SWHL's avatar
SWHL committed
19
20
21
22
23
    def __init__(self, config_path: Union[str, Path]) -> None:
        if not Path(config_path).exists():
            raise FileNotFoundError(f'{config_path} does not exist.')

        config = read_yaml(config_path)
SWHL's avatar
SWHL committed
24
25
26

        self.converter = TokenIDConverter(**config['TokenIDConverter'])
        self.tokenizer = CharTokenizer(**config['CharTokenizer'])
SWHL's avatar
SWHL committed
27
        self.frontend = WavFrontend(
SWHL's avatar
SWHL committed
28
29
30
31
            cmvn_file=config['WavFrontend']['cmvn_file'],
            **config['WavFrontend']['frontend_conf']
        )
        self.ort_infer = OrtInferSession(config['Model'])
SWHL's avatar
SWHL committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
        self.batch_size = config['Model']['batch_size']

    def __call__(self, wav_content: Union[str, np.ndarray, List[str]]) -> List:
        waveform_list = self.load_data(wav_content)
        waveform_nums = len(waveform_list)

        asr_res = []
        for beg_idx in range(0, waveform_nums, self.batch_size):
            end_idx = min(waveform_nums, beg_idx + self.batch_size)

            feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])

            try:
                am_scores, valid_token_lens = self.infer(feats, feats_len)
            except ONNXRuntimeError:
SWHL's avatar
SWHL committed
47
                logging.warning("input wav is silence or noise")
SWHL's avatar
SWHL committed
48
49
50
51
52
53
54
55
56
57
                preds = []
            else:
                preds = self.decode(am_scores, valid_token_lens)

            asr_res.extend(preds)
        return asr_res

    def load_data(self,
                  wav_content: Union[str, np.ndarray, List[str]]) -> List:
        def load_wav(path: str) -> np.ndarray:
SWHL's avatar
SWHL committed
58
            waveform, _ = librosa.load(path, sr=None)
SWHL's avatar
SWHL committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
            return waveform[None, ...]

        if isinstance(wav_content, np.ndarray):
            return [wav_content]

        if isinstance(wav_content, str):
            return [load_wav(wav_content)]

        if isinstance(wav_content, list):
            return [load_wav(path) for path in wav_content]

        raise TypeError(
            f'The type of {wav_content} is not in [str, np.ndarray, list]')

    def extract_feat(self,
                     waveform_list: List[np.ndarray]
                     ) -> Tuple[np.ndarray, np.ndarray]:
        feats, feats_len = [], []
        for waveform in waveform_list:
            speech, _ = self.frontend.fbank(waveform)
            feat, feat_len = self.frontend.lfr_cmvn(speech)
            feats.append(feat)
            feats_len.append(feat_len)

        feats = self.pad_feats(feats, np.max(feats_len))
        feats_len = np.array(feats_len).astype(np.int32)
        return feats, feats_len

    @staticmethod
    def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
        def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
            pad_width = ((0, max_feat_len - cur_len), (0, 0))
            return np.pad(feat, pad_width, 'constant', constant_values=0)

        feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
        feats = np.array(feat_res).astype(np.float32)
        return feats

    def infer(self, feats: np.ndarray,
              feats_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        am_scores, token_nums = self.ort_infer([feats, feats_len])
        return am_scores, token_nums

    def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
        return [self.decode_one(am_score, token_num)
                for am_score, token_num in zip(am_scores, token_nums)]

    def decode_one(self,
                   am_score: np.ndarray,
                   valid_token_num: int) -> List[str]:
SWHL's avatar
SWHL committed
109
110
111
112
113
114
115
        yseq = am_score.argmax(axis=-1)
        score = am_score.max(axis=-1)
        score = np.sum(score, axis=-1)

        # pad with mask tokens to ensure compatibility with sos/eos tokens
        # asr_model.sos:1  asr_model.eos:2
        yseq = np.array([1] + yseq.tolist() + [2])
SWHL's avatar
SWHL committed
116
        hyp = Hypothesis(yseq=yseq, score=score)
SWHL's avatar
SWHL committed
117

SWHL's avatar
SWHL committed
118
119
120
        # remove sos/eos and get results
        last_pos = -1
        token_int = hyp.yseq[1:last_pos].tolist()
SWHL's avatar
SWHL committed
121

SWHL's avatar
SWHL committed
122
123
        # remove blank symbol id, which is assumed to be 0
        token_int = list(filter(lambda x: x not in (0, 2), token_int))
SWHL's avatar
SWHL committed
124

SWHL's avatar
SWHL committed
125
126
127
128
        # Change integer-ids to tokens
        token = self.converter.ids2tokens(token_int)
        text = self.tokenizer.tokens2text(token)
        return text[:valid_token_num-1]
SWHL's avatar
SWHL committed
129
130
131


if __name__ == '__main__':
SWHL's avatar
SWHL committed
132
133
134
    project_dir = Path(__file__).resolve().parent.parent
    cfg_path = project_dir / 'resources' / 'config.yaml'
    paraformer = RapidParaformer(cfg_path)
SWHL's avatar
SWHL committed
135
136
137
138
139

    wav_file = '0478_00017.wav'
    for i in range(1000):
        result = paraformer(wav_file)
        print(result)