pipeline_demo.py 3.43 KB
Newer Older
1
#!/usr/bin/env python3
2
3
4
5
6
"""The demo script for testing the pre-trained Emformer RNNT pipelines.

Example:
python pipeline_demo.py --model-type librispeech --dataset-path ./datasets/librispeech
"""
7
8
import logging
import pathlib
9
10
11
12
from argparse import ArgumentParser, RawTextHelpFormatter
from dataclasses import dataclass
from functools import partial
from typing import Callable
13
14
15
16
17

import torch
import torchaudio
from common import MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3
from torchaudio.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH
18
from torchaudio.pipelines import RNNTBundle
19
20
21
from torchaudio.prototype.pipelines import EMFORMER_RNNT_BASE_TEDLIUM3


22
logger = logging.getLogger(__name__)
23
24


25
26
27
28
@dataclass
class Config:
    dataset: Callable
    bundle: RNNTBundle
29
30


31
32
33
34
35
36
37
38
39
40
_CONFIGS = {
    MODEL_TYPE_LIBRISPEECH: Config(
        partial(torchaudio.datasets.LIBRISPEECH, url="test-clean"),
        EMFORMER_RNNT_BASE_LIBRISPEECH,
    ),
    MODEL_TYPE_TEDLIUM3: Config(
        partial(torchaudio.datasets.TEDLIUM, release="release3", subset="test"),
        EMFORMER_RNNT_BASE_TEDLIUM3,
    ),
}
41
42
43


def run_eval_streaming(args):
44
45
    dataset = _CONFIGS[args.model_type].dataset(args.dataset_path)
    bundle = _CONFIGS[args.model_type].bundle
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    decoder = bundle.get_decoder()
    token_processor = bundle.get_token_processor()
    feature_extractor = bundle.get_feature_extractor()
    streaming_feature_extractor = bundle.get_streaming_feature_extractor()
    hop_length = bundle.hop_length
    num_samples_segment = bundle.segment_length * hop_length
    num_samples_segment_right_context = num_samples_segment + bundle.right_context_length * hop_length

    for idx in range(10):
        sample = dataset[idx]
        waveform = sample[0].squeeze()
        # Streaming decode.
        state, hypothesis = None, None
        for idx in range(0, len(waveform), num_samples_segment):
            segment = waveform[idx : idx + num_samples_segment_right_context]
            segment = torch.nn.functional.pad(segment, (0, num_samples_segment_right_context - len(segment)))
            with torch.no_grad():
                features, length = streaming_feature_extractor(segment)
                hypos, state = decoder.infer(features, length, 10, state=state, hypothesis=hypothesis)
            hypothesis = hypos[0]
            transcript = token_processor(hypothesis.tokens, lstrip=False)
            print(transcript, end="", flush=True)
        print()

        # Non-streaming decode.
        with torch.no_grad():
            features, length = feature_extractor(waveform)
            hypos = decoder(features, length, 10)
        print(token_processor(hypos[0].tokens))
        print()


def parse_args():
79
80
    parser = ArgumentParser(description=__doc__, formatter_class=RawTextHelpFormatter)
    parser.add_argument("--model-type", type=str, choices=_CONFIGS.keys(), required=True)
81
    parser.add_argument(
82
        "--dataset-path",
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
        type=pathlib.Path,
        help="Path to dataset.",
        required=True,
    )
    parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging")
    return parser.parse_args()


def init_logger(debug):
    fmt = "%(asctime)s %(message)s" if debug else "%(message)s"
    level = logging.DEBUG if debug else logging.INFO
    logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S")


def cli_main():
    args = parse_args()
    init_logger(args.debug)
    run_eval_streaming(args)


if __name__ == "__main__":
    cli_main()