Commit e3b40d1c authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Refactor eval and pipeline_demo scripts in emformer_rnnt (#2238)

Summary:
- Add docstring to `eval.py` and `pipeline_demo.py` under `emformer_rnnt` directory.
- Refactor logger and ArgumentParser

Pull Request resolved: https://github.com/pytorch/audio/pull/2238

Reviewed By: mthrok

Differential Revision: D34267059

Pulled By: nateanl

fbshipit-source-id: 4b8d3d183ee7bc0ad71ce305cab87bfa90208b2e
parent eeba91dc
#!/usr/bin/env python3 #!/usr/bin/env python3
"""Evaluate the lightning module by loading the checkpoint, the SentencePiece model, and the global_stats.json.
Example:
python eval.py --model-type tedlium3 --checkpoint-path ./experiments/checkpoints/epoch=119-step=254999.ckpt
--dataset-path ./datasets/tedlium --sp-model-path ./spm_bpe_500.model
"""
import logging import logging
import pathlib import pathlib
from argparse import ArgumentParser from argparse import ArgumentParser, RawTextHelpFormatter
import torch import torch
import torchaudio import torchaudio
...@@ -11,7 +17,7 @@ from mustc.lightning import MuSTCRNNTModule ...@@ -11,7 +17,7 @@ from mustc.lightning import MuSTCRNNTModule
from tedlium3.lightning import TEDLIUM3RNNTModule from tedlium3.lightning import TEDLIUM3RNNTModule
logger = logging.getLogger() logger = logging.getLogger(__name__)
def compute_word_level_distance(seq1, seq2): def compute_word_level_distance(seq1, seq2):
...@@ -79,7 +85,7 @@ def get_lightning_module(args): ...@@ -79,7 +85,7 @@ def get_lightning_module(args):
def parse_args(): def parse_args():
parser = ArgumentParser() parser = ArgumentParser(description=__doc__, formatter_class=RawTextHelpFormatter)
parser.add_argument( parser.add_argument(
"--model-type", type=str, choices=[MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3, MODEL_TYPE_MUSTC], required=True "--model-type", type=str, choices=[MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3, MODEL_TYPE_MUSTC], required=True
) )
......
#!/usr/bin/env python3 #!/usr/bin/env python3
"""The demo script for testing the pre-trained Emformer RNNT pipelines.
Example:
python pipeline_demo.py --model-type librispeech --dataset-path ./datasets/librispeech
"""
import logging import logging
import pathlib import pathlib
from argparse import ArgumentParser from argparse import ArgumentParser
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment