test_interactive_asr.py 3.27 KB
Newer Older
flyingdown's avatar
flyingdown committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import argparse
import logging
import os
import unittest

from interactive_asr.utils import setup_asr, transcribe_file


class ASRTest(unittest.TestCase):
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)

    arguments_dict = {
        "path": "/scratch/jamarshon/downloads/model.pt",
        "input_file": "/scratch/jamarshon/audio/examples/interactive_asr/data/sample.wav",
        "data": "/scratch/jamarshon/downloads",
        "user_dir": "/scratch/jamarshon/fairseq-py/examples/speech_recognition",
        "no_progress_bar": False,
        "log_interval": 1000,
        "log_format": None,
        "tensorboard_logdir": "",
        "tbmf_wrapper": False,
        "seed": 1,
        "cpu": True,
        "fp16": False,
        "memory_efficient_fp16": False,
        "fp16_init_scale": 128,
        "fp16_scale_window": None,
        "fp16_scale_tolerance": 0.0,
        "min_loss_scale": 0.0001,
        "threshold_loss_scale": None,
        "criterion": "cross_entropy",
        "tokenizer": None,
        "bpe": None,
        "optimizer": "nag",
        "lr_scheduler": "fixed",
        "task": "speech_recognition",
        "num_workers": 0,
        "skip_invalid_size_inputs_valid_test": False,
        "max_tokens": 10000000,
        "max_sentences": None,
        "required_batch_size_multiple": 8,
        "dataset_impl": None,
        "gen_subset": "test",
        "num_shards": 1,
        "shard_id": 0,
        "remove_bpe": None,
        "quiet": False,
        "model_overrides": "{}",
        "results_path": None,
        "beam": 40,
        "nbest": 1,
        "max_len_a": 0,
        "max_len_b": 200,
        "min_len": 1,
        "match_source_len": False,
        "no_early_stop": False,
        "unnormalized": False,
        "no_beamable_mm": False,
        "lenpen": 1,
        "unkpen": 0,
        "replace_unk": None,
        "sacrebleu": False,
        "score_reference": False,
        "prefix_size": 0,
        "no_repeat_ngram_size": 0,
        "sampling": False,
        "sampling_topk": -1,
        "sampling_topp": -1.0,
        "temperature": 1.0,
        "diverse_beam_groups": -1,
        "diverse_beam_strength": 0.5,
        "print_alignment": False,
        "ctc": False,
        "rnnt": False,
        "kspmodel": None,
        "wfstlm": None,
        "rnnt_decoding_type": "greedy",
        "lm_weight": 0.2,
        "rnnt_len_penalty": -0.5,
        "momentum": 0.99,
        "weight_decay": 0.0,
        "force_anneal": None,
        "lr_shrink": 0.1,
        "warmup_updates": 0,
    }

    arguments_dict["path"] = os.environ.get("ASR_MODEL_PATH", None)
    arguments_dict["input_file"] = os.environ.get("ASR_INPUT_FILE", None)
    arguments_dict["data"] = os.environ.get("ASR_DATA_PATH", None)
    arguments_dict["user_dir"] = os.environ.get("ASR_USER_DIR", None)
    args = argparse.Namespace(**arguments_dict)

    def test_transcribe_file(self):
        task, generator, models, sp, tgt_dict = setup_asr(self.args, self.logger)
        _, transcription = transcribe_file(
            self.args, task, generator, models, sp, tgt_dict
        )

        expected_transcription = [["THE QUICK BROWN FOX JUMPS OVER THE LAZY DOG"]]
        self.assertEqual(transcription, expected_transcription, msg=str(transcription))


if __name__ == "__main__":
    unittest.main()