rapid_paraformer.py 2.44 KB
Newer Older
SWHL's avatar
SWHL committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
from pathlib import Path
from typing import List

import librosa
import numpy as np

from .utils import (CharTokenizer, Hypothesis, OrtInferSession,
                    TokenIDConverter, WavFrontend, read_yaml)

cur_dir = Path(__file__).resolve().parent


class RapidParaformer():
    def __init__(self, config_path: str = None) -> None:
        config = read_yaml(cur_dir / 'config.yaml')
        if config_path:
            config = read_yaml(config_path)

        self.converter = TokenIDConverter(**config['TokenIDConverter'])
        self.tokenizer = CharTokenizer(**config['CharTokenizer'])
        self.frontend_asr = WavFrontend(
            cmvn_file=config['WavFrontend']['cmvn_file'],
            **config['WavFrontend']['frontend_conf']
        )
        self.ort_infer = OrtInferSession(config['Model'])

    def __call__(self, wav_path: str) -> List:
        waveform = librosa.load(wav_path)[0][None, ...]

        speech, _ = self.frontend_asr.forward_fbank(waveform)
        feats, feats_len = self.frontend_asr.forward_lfr_cmvn(speech)
        am_scores = self.ort_infer(input_content=[feats, feats_len])

        results = []
        for am_score in am_scores:
            pred_res = self.infer_one_feat(am_score)
            results.append(pred_res)
        return results

    def infer_one_feat(self, am_score: np.ndarray) -> List[str]:
        yseq = am_score.argmax(axis=-1)
        score = am_score.max(axis=-1)
        score = np.sum(score, axis=-1)

        # pad with mask tokens to ensure compatibility with sos/eos tokens
        # asr_model.sos:1  asr_model.eos:2
        yseq = np.array([1] + yseq.tolist() + [2])
        nbest_hyps = [Hypothesis(yseq=yseq, score=score)]

        infer_res = []
        for hyp in nbest_hyps:
            # remove sos/eos and get results
            last_pos = -1
            token_int = hyp.yseq[1:last_pos].tolist()

            # remove blank symbol id, which is assumed to be 0
            token_int = list(filter(lambda x: x not in (0, 2), token_int))

            # Change integer-ids to tokens
            token = self.converter.ids2tokens(token_int)

            text = self.tokenizer.tokens2text(token)
            infer_res.append(text)
        return infer_res


if __name__ == '__main__':
    paraformer = RapidParaformer()

    wav_file = '0478_00017.wav'
    for i in range(1000):
        result = paraformer(wav_file)
        print(result)