import argparse import logging import os import unittest from interactive_asr.utils import setup_asr, transcribe_file class ASRTest(unittest.TestCase): logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) arguments_dict = { "path": "/scratch/jamarshon/downloads/model.pt", "input_file": "/scratch/jamarshon/audio/examples/interactive_asr/data/sample.wav", "data": "/scratch/jamarshon/downloads", "user_dir": "/scratch/jamarshon/fairseq-py/examples/speech_recognition", "no_progress_bar": False, "log_interval": 1000, "log_format": None, "tensorboard_logdir": "", "tbmf_wrapper": False, "seed": 1, "cpu": True, "fp16": False, "memory_efficient_fp16": False, "fp16_init_scale": 128, "fp16_scale_window": None, "fp16_scale_tolerance": 0.0, "min_loss_scale": 0.0001, "threshold_loss_scale": None, "criterion": "cross_entropy", "tokenizer": None, "bpe": None, "optimizer": "nag", "lr_scheduler": "fixed", "task": "speech_recognition", "num_workers": 0, "skip_invalid_size_inputs_valid_test": False, "max_tokens": 10000000, "max_sentences": None, "required_batch_size_multiple": 8, "dataset_impl": None, "gen_subset": "test", "num_shards": 1, "shard_id": 0, "remove_bpe": None, "quiet": False, "model_overrides": "{}", "results_path": None, "beam": 40, "nbest": 1, "max_len_a": 0, "max_len_b": 200, "min_len": 1, "match_source_len": False, "no_early_stop": False, "unnormalized": False, "no_beamable_mm": False, "lenpen": 1, "unkpen": 0, "replace_unk": None, "sacrebleu": False, "score_reference": False, "prefix_size": 0, "no_repeat_ngram_size": 0, "sampling": False, "sampling_topk": -1, "sampling_topp": -1.0, "temperature": 1.0, "diverse_beam_groups": -1, "diverse_beam_strength": 0.5, "print_alignment": False, "ctc": False, "rnnt": False, "kspmodel": None, "wfstlm": None, "rnnt_decoding_type": "greedy", "lm_weight": 0.2, "rnnt_len_penalty": -0.5, "momentum": 0.99, "weight_decay": 0.0, "force_anneal": None, "lr_shrink": 0.1, "warmup_updates": 0, } arguments_dict["path"] = os.environ.get("ASR_MODEL_PATH", None) arguments_dict["input_file"] = os.environ.get("ASR_INPUT_FILE", None) arguments_dict["data"] = os.environ.get("ASR_DATA_PATH", None) arguments_dict["user_dir"] = os.environ.get("ASR_USER_DIR", None) args = argparse.Namespace(**arguments_dict) def test_transcribe_file(self): task, generator, models, sp, tgt_dict = setup_asr(self.args, self.logger) _, transcription = transcribe_file(self.args, task, generator, models, sp, tgt_dict) expected_transcription = [["THE QUICK BROWN FOX JUMPS OVER THE LAZY DOG"]] self.assertEqual(transcription, expected_transcription, msg=str(transcription)) if __name__ == "__main__": unittest.main()