pipeline_demo.py 3.58 KB
Newer Older
1
#!/usr/bin/env python3
2
3
4
5
6
"""The demo script for testing the pre-trained Emformer RNNT pipelines.

Example:
python pipeline_demo.py --model-type librispeech --dataset-path ./datasets/librispeech
"""
7
8
import logging
import pathlib
9
10
11
12
from argparse import ArgumentParser, RawTextHelpFormatter
from dataclasses import dataclass
from functools import partial
from typing import Callable
13
14
15

import torch
import torchaudio
mayp777's avatar
UPDATE  
mayp777 committed
16
17
from common import MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_MUSTC, MODEL_TYPE_TEDLIUM3
from mustc.dataset import MUSTC
18
from torchaudio.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH, RNNTBundle
mayp777's avatar
UPDATE  
mayp777 committed
19
from torchaudio.prototype.pipelines import EMFORMER_RNNT_BASE_MUSTC, EMFORMER_RNNT_BASE_TEDLIUM3
20

21
logger = logging.getLogger(__name__)
22
23


24
25
26
27
@dataclass
class Config:
    dataset: Callable
    bundle: RNNTBundle
28
29


30
31
32
33
34
_CONFIGS = {
    MODEL_TYPE_LIBRISPEECH: Config(
        partial(torchaudio.datasets.LIBRISPEECH, url="test-clean"),
        EMFORMER_RNNT_BASE_LIBRISPEECH,
    ),
mayp777's avatar
UPDATE  
mayp777 committed
35
36
37
38
39
40
41
42
    MODEL_TYPE_MUSTC: Config(
        partial(MUSTC, subset="tst-COMMON"),
        EMFORMER_RNNT_BASE_MUSTC,
    ),
    MODEL_TYPE_TEDLIUM3: Config(
        partial(torchaudio.datasets.TEDLIUM, release="release3", subset="test"),
        EMFORMER_RNNT_BASE_TEDLIUM3,
    ),
43
}
44
45
46


def run_eval_streaming(args):
47
48
    dataset = _CONFIGS[args.model_type].dataset(args.dataset_path)
    bundle = _CONFIGS[args.model_type].bundle
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    decoder = bundle.get_decoder()
    token_processor = bundle.get_token_processor()
    feature_extractor = bundle.get_feature_extractor()
    streaming_feature_extractor = bundle.get_streaming_feature_extractor()
    hop_length = bundle.hop_length
    num_samples_segment = bundle.segment_length * hop_length
    num_samples_segment_right_context = num_samples_segment + bundle.right_context_length * hop_length

    for idx in range(10):
        sample = dataset[idx]
        waveform = sample[0].squeeze()
        # Streaming decode.
        state, hypothesis = None, None
        for idx in range(0, len(waveform), num_samples_segment):
            segment = waveform[idx : idx + num_samples_segment_right_context]
            segment = torch.nn.functional.pad(segment, (0, num_samples_segment_right_context - len(segment)))
            with torch.no_grad():
                features, length = streaming_feature_extractor(segment)
                hypos, state = decoder.infer(features, length, 10, state=state, hypothesis=hypothesis)
mayp777's avatar
UPDATE  
mayp777 committed
68
69
70
            hypothesis = hypos
            transcript = token_processor(hypos[0][0], lstrip=True)
            print(transcript, end="\r", flush=True)
71
72
73
74
75
76
        print()

        # Non-streaming decode.
        with torch.no_grad():
            features, length = feature_extractor(waveform)
            hypos = decoder(features, length, 10)
77
        print(token_processor(hypos[0][0]))
78
79
80
81
        print()


def parse_args():
82
83
    parser = ArgumentParser(description=__doc__, formatter_class=RawTextHelpFormatter)
    parser.add_argument("--model-type", type=str, choices=_CONFIGS.keys(), required=True)
84
    parser.add_argument(
85
        "--dataset-path",
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
        type=pathlib.Path,
        help="Path to dataset.",
        required=True,
    )
    parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging")
    return parser.parse_args()


def init_logger(debug):
    fmt = "%(asctime)s %(message)s" if debug else "%(message)s"
    level = logging.DEBUG if debug else logging.INFO
    logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S")


def cli_main():
    args = parse_args()
    init_logger(args.debug)
    run_eval_streaming(args)


if __name__ == "__main__":
    cli_main()