Commit ed175137 authored by jamarshon's avatar jamarshon Committed by cpuhrsch
Browse files

Increasing test coverage (ASR demo) (#248)

parent 42a705d5
...@@ -110,3 +110,8 @@ ENV/ ...@@ -110,3 +110,8 @@ ENV/
test/assets/sinewave.wav test/assets/sinewave.wav
torchaudio/version.py torchaudio/version.py
gen.yml gen.yml
# Examples
examples/interactive_asr/data/*.txt
examples/interactive_asr/data/*.model
examples/interactive_asr/data/*.pt
...@@ -8,15 +8,19 @@ cache: ...@@ -8,15 +8,19 @@ cache:
directories: directories:
- /home/travis/download - /home/travis/download
# This matrix tests that the code works on Python 3.5, 3.6, and passes lint. # This matrix tests that the code works on Python 2.7, 3.5, 3.6, 3.7, passes
# lint and example tests.
matrix: matrix:
fast_finish: true fast_finish: true
include: include:
- env: PYTHON_VERSION="3.7"
- env: PYTHON_VERSION="3.6"
# TODO add this back in when there is a pytorch 1.2 for python 3.5
- env: PYTHON_VERSION="3.5" RUN_FLAKE8="true" SKIP_TESTS="true"
- env: PYTHON_VERSION="2.7" - env: PYTHON_VERSION="2.7"
- env: PYTHON_VERSION="3.5"
- env: PYTHON_VERSION="3.6"
- env: PYTHON_VERSION="3.7"
- env: PYTHON_VERSION="3.5" RUN_FLAKE8="true" SKIP_INSTALL="true" SKIP_TESTS="true"
- env: PYTHON_VERSION="3.5" RUN_EXAMPLE_TESTS="true" SKIP_TESTS="true"
allow_failures:
- env: PYTHON_VERSION="3.5" RUN_EXAMPLE_TESTS="true" SKIP_TESTS="true"
addons: addons:
apt: apt:
...@@ -24,6 +28,7 @@ addons: ...@@ -24,6 +28,7 @@ addons:
sox sox
libsox-dev libsox-dev
libsox-fmt-all libsox-fmt-all
portaudio19-dev
notifications: notifications:
email: false email: false
......
...@@ -51,7 +51,30 @@ source activate testenv ...@@ -51,7 +51,30 @@ source activate testenv
pip install -r requirements.txt pip install -r requirements.txt
# Install the following only if running tests # Install the following only if running tests
if [[ "$SKIP_TESTS" != "true" ]]; then if [[ "$SKIP_INSTALL" != "true" ]]; then
# TorchAudio CPP Extensions # TorchAudio CPP Extensions
python setup.py install python setup.py install
fi fi
if [[ "$RUN_EXAMPLE_TESTS" == "true" ]]; then
# Install dependencies
pip install sentencepiece PyAudio
if [[ ! -d $HOME/download/fairseq ]]; then
# Install fairseq from source
git clone https://github.com/pytorch/fairseq $HOME/download/fairseq
fi
pushd $HOME/download/fairseq
pip install --editable .
popd
mkdir -p $HOME/download/data
# Install dictionary, sentence piece model, and model
# These are cached so they are not downloaded if they already exist
wget -nc -O $HOME/download/data/dict.txt https://download.pytorch.org/models/audio/dict.txt || true
wget -nc -O $HOME/download/data/spm.model https://download.pytorch.org/models/audio/spm.model || true
wget -nc -O $HOME/download/data/model.pt https://download.pytorch.org/models/audio/checkpoint_avg_60_80.pt || true
fi
echo "Finished installation"
...@@ -32,5 +32,17 @@ if [[ "$RUN_FLAKE8" == "true" ]]; then ...@@ -32,5 +32,17 @@ if [[ "$RUN_FLAKE8" == "true" ]]; then
fi fi
if [[ "$SKIP_TESTS" != "true" ]]; then if [[ "$SKIP_TESTS" != "true" ]]; then
echo "run_tests"
run_tests run_tests
fi fi
if [[ "$RUN_EXAMPLE_TESTS" == "true" ]]; then
echo "run_example_tests"
pushd examples
ASR_MODEL_PATH=$HOME/download/data/model.pt \
ASR_INPUT_FILE=interactive_asr/data/sample.wav \
ASR_DATA_PATH=$HOME/download/data \
ASR_USER_DIR=$HOME/download/fairseq/examples/speech_recognition \
python -m unittest test/test_interactive_asr.py
popd
fi
...@@ -16,6 +16,9 @@ and the following models ...@@ -16,6 +16,9 @@ and the following models
We recommend that you use [conda](https://docs.conda.io/en/latest/miniconda.html) to install the dependencies when available. We recommend that you use [conda](https://docs.conda.io/en/latest/miniconda.html) to install the dependencies when available.
```bash ```bash
# Assume that all commands are from the examples folder
cd examples
# Install dependencies # Install dependencies
conda install -c pytorch torchaudio conda install -c pytorch torchaudio
conda install -c conda-forge librosa conda install -c conda-forge librosa
...@@ -23,26 +26,38 @@ conda install pyaudio ...@@ -23,26 +26,38 @@ conda install pyaudio
pip install sentencepiece pip install sentencepiece
# Install fairseq from source # Install fairseq from source
git clone https://github.com/pytorch/fairseq git clone https://github.com/pytorch/fairseq interactive_asr/fairseq
cd fairseq pushd interactive_asr/fairseq
export CFLAGS='-stdlib=libc++' # For Mac only export CFLAGS='-stdlib=libc++' # For Mac only
pip install --editable . pip install --editable .
cd .. popd
# Install dictionary, sentence piece model, and model # Install dictionary, sentence piece model, and model
wget -O ./data/dict.txt https://download.pytorch.org/models/audio/dict.txt wget -O interactive_asr/data/dict.txt https://download.pytorch.org/models/audio/dict.txt
wget -O ./data/spm.model https://download.pytorch.org/models/audio/spm.model wget -O interactive_asr/data/spm.model https://download.pytorch.org/models/audio/spm.model
wget -O ./data/model.pt https://download.pytorch.org/models/audio/checkpoint_avg_60_80.pt wget -O interactive_asr/data/model.pt https://download.pytorch.org/models/audio/checkpoint_avg_60_80.pt
``` ```
## Run ## Run
On a file On a file
```bash ```bash
INPUT_FILE=./data/sample.wav INPUT_FILE=interactive_asr/data/sample.wav
python asr.py ./data --input_file $INPUT_FILE --max-tokens 10000000 --nbest 1 --path ./data/model.pt --beam 40 --task speech_recognition --user-dir ./fairseq/examples/speech_recognition python -m interactive_asr.asr interactive_asr/data --input_file $INPUT_FILE --max-tokens 10000000 --nbest 1 \
--path interactive_asr/data/model.pt --beam 40 --task speech_recognition \
--user-dir interactive_asr/fairseq/examples/speech_recognition
``` ```
As a microphone As a microphone
```bash ```bash
python asr.py ./data --max-tokens 10000000 --nbest 1 --path ./data/model.pt --beam 40 --task speech_recognition --user-dir ./fairseq/examples/speech_recognition python -m interactive_asr.asr interactive_asr/data --max-tokens 10000000 --nbest 1 \
--path interactive_asr/data/model.pt --beam 40 --task speech_recognition \
--user-dir interactive_asr/fairseq/examples/speech_recognition
```
To run the testcase associated with this example
```bash
ASR_MODEL_PATH=interactive_asr/data/model.pt \
ASR_INPUT_FILE=interactive_asr/data/sample.wav \
ASR_DATA_PATH=interactive_asr/data \
ASR_USER_DIR=interactive_asr/fairseq/examples/speech_recognition \
python -m unittest test/test_interactive_asr.py
``` ```
...@@ -11,187 +11,24 @@ Run inference for pre-processed data with a trained model. ...@@ -11,187 +11,24 @@ Run inference for pre-processed data with a trained model.
import datetime as dt import datetime as dt
import logging import logging
import os
import sys
import time
import torch from fairseq import options
import sentencepiece as spm from interactive_asr.utils import add_asr_eval_argument, setup_asr, get_microphone_transcription, transcribe_file
import torchaudio
from fairseq import options, tasks, utils
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.utils import import_user_module
from vad import get_microphone_chunks
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
def add_asr_eval_argument(parser):
parser.add_argument("--input_file", help="input file")
parser.add_argument("--ctc", action="store_true", help="decode a ctc model")
parser.add_argument("--rnnt", default=False, help="decode a rnnt model")
parser.add_argument("--kspmodel", default=None, help="sentence piece model")
parser.add_argument(
"--wfstlm", default=None, help="wfstlm on dictonary output units"
)
parser.add_argument(
"--rnnt_decoding_type",
default="greedy",
help="wfstlm on dictonary output units",
)
parser.add_argument(
"--lm_weight",
default=0.2,
help="weight for wfstlm while interpolating with neural score",
)
parser.add_argument(
"--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level"
)
return parser
def check_args(args):
assert args.path is not None, "--path required for generation!"
assert (
not args.sampling or args.nbest == args.beam
), "--sampling requires --nbest to be equal to --beam"
assert (
args.replace_unk is None or args.raw_text
), "--replace-unk requires a raw text dataset (--raw-text)"
def process_predictions(args, hypos, sp, tgt_dict):
res = []
for hypo in hypos[: min(len(hypos), args.nbest)]:
hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu())
hyp_words = sp.DecodePieces(hyp_pieces.split())
res.append(hyp_words)
return res
def optimize_models(args, use_cuda, models):
"""Optimize ensemble for generation
"""
for model in models:
model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
need_attn=args.print_alignment,
)
if args.fp16:
model.half()
if use_cuda:
model.cuda()
def calc_mean_invstddev(feature):
if len(feature.shape) != 2:
raise ValueError("We expect the input feature to be 2-D tensor")
mean = torch.mean(feature, dim=0)
var = torch.var(feature, dim=0)
# avoid division by ~zero
if (var < sys.float_info.epsilon).any():
return mean, 1.0 / (torch.sqrt(var) + sys.float_info.epsilon)
return mean, 1.0 / torch.sqrt(var)
def calcMN(features):
mean, invstddev = calc_mean_invstddev(features)
res = (features - mean) * invstddev
return res
def transcribe(waveform, args, task, generator, models, sp, tgt_dict):
num_features = 80
output = torchaudio.compliance.kaldi.fbank(waveform, num_mel_bins=num_features)
output_cmvn = calcMN(output.cpu().detach())
# size (m, n)
source = torch.tensor(output_cmvn)
frames_lengths = torch.LongTensor([source.size(0)])
# size (1, m, n). In general, if source is (x, m, n), then hypos is (x, ...)
source.unsqueeze_(0)
sample = {"net_input": {"src_tokens": source, "src_lengths": frames_lengths}}
hypos = task.inference_step(generator, models, sample)
assert len(hypos) == 1
transcription = []
for i in range(len(hypos)):
# Process top predictions
hyp_words = process_predictions(args, hypos[i], sp, tgt_dict)
transcription.append(hyp_words)
return transcription
def main(args): def main(args):
check_args(args) logger = logging.getLogger(__name__)
import_user_module(args) logger.setLevel(logging.INFO)
task, generator, models, sp, tgt_dict = setup_asr(args, logger)
if args.max_tokens is None and args.max_sentences is None:
args.max_tokens = 30000
logger.info(args)
use_cuda = torch.cuda.is_available() and not args.cpu
# Load dataset splits
task = tasks.setup_task(args)
# Set dictionary
tgt_dict = task.target_dictionary
if args.ctc or args.rnnt:
tgt_dict.add_symbol("<ctc_blank>")
if args.ctc:
logger.info("| decoding a ctc model")
if args.rnnt:
logger.info("| decoding a rnnt model")
# Load ensemble
logger.info("| loading model(s) from {}".format(args.path))
models, _model_args = utils.load_ensemble_for_inference(
args.path.split(":"),
task,
model_arg_overrides=eval(args.model_overrides), # noqa
)
optimize_models(args, use_cuda, models)
# Initialize generator
generator = task.build_generator(args)
sp = spm.SentencePieceProcessor()
sp.Load(os.path.join(args.data, "spm.model"))
print("READY!")
if args.input_file: if args.input_file:
path = args.input_file transcription_time, transcription = transcribe_file(args, task, generator, models, sp, tgt_dict)
if not os.path.exists(path):
raise FileNotFoundError("Audio file not found: {}".format(path))
waveform, sample_rate = torchaudio.load_wav(path)
waveform = waveform.mean(0, True)
waveform = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=16000
)(waveform)
print(sample_rate, waveform.shape)
start = time.time()
transcription = transcribe(
waveform, args, task, generator, models, sp, tgt_dict
)
end = time.time()
print("transcription:", transcription) print("transcription:", transcription)
print(end - start) print("transcription_time:", transcription_time)
else: else:
print("READY!") for transcription in get_microphone_transcription(args, task, generator, models, sp, tgt_dict):
for (waveform, sample_rate) in get_microphone_chunks():
waveform = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=16000
)(waveform.reshape(1, -1))
transcription = transcribe(
waveform, args, task, generator, models, sp, tgt_dict
)
print( print(
"{}: {}".format( "{}: {}".format(
dt.datetime.now().strftime("%H:%M:%S"), transcription[0][0] dt.datetime.now().strftime("%H:%M:%S"), transcription[0][0]
......
#!/usr/bin/env python3
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import os
import sys
import time
import torch
import torchaudio
import sentencepiece as spm
from fairseq import tasks
from fairseq.utils import load_ensemble_for_inference, import_user_module
from interactive_asr.vad import get_microphone_chunks
def add_asr_eval_argument(parser):
parser.add_argument("--input_file", help="input file")
parser.add_argument("--ctc", action="store_true", help="decode a ctc model")
parser.add_argument("--rnnt", default=False, help="decode a rnnt model")
parser.add_argument("--kspmodel", default=None, help="sentence piece model")
parser.add_argument(
"--wfstlm", default=None, help="wfstlm on dictonary output units"
)
parser.add_argument(
"--rnnt_decoding_type",
default="greedy",
help="wfstlm on dictonary output units",
)
parser.add_argument(
"--lm_weight",
default=0.2,
help="weight for wfstlm while interpolating with neural score",
)
parser.add_argument(
"--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level"
)
return parser
def check_args(args):
assert args.path is not None, "--path required for generation!"
assert (
not args.sampling or args.nbest == args.beam
), "--sampling requires --nbest to be equal to --beam"
assert (
args.replace_unk is None or args.raw_text
), "--replace-unk requires a raw text dataset (--raw-text)"
def process_predictions(args, hypos, sp, tgt_dict):
res = []
for hypo in hypos[: min(len(hypos), args.nbest)]:
hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu())
hyp_words = sp.DecodePieces(hyp_pieces.split())
res.append(hyp_words)
return res
def optimize_models(args, use_cuda, models):
"""Optimize ensemble for generation
"""
for model in models:
model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
need_attn=args.print_alignment,
)
if args.fp16:
model.half()
if use_cuda:
model.cuda()
def calc_mean_invstddev(feature):
if len(feature.shape) != 2:
raise ValueError("We expect the input feature to be 2-D tensor")
mean = torch.mean(feature, dim=0)
var = torch.var(feature, dim=0)
# avoid division by ~zero
if (var < sys.float_info.epsilon).any():
return mean, 1.0 / (torch.sqrt(var) + sys.float_info.epsilon)
return mean, 1.0 / torch.sqrt(var)
def calcMN(features):
mean, invstddev = calc_mean_invstddev(features)
res = (features - mean) * invstddev
return res
def transcribe(waveform, args, task, generator, models, sp, tgt_dict):
num_features = 80
output = torchaudio.compliance.kaldi.fbank(waveform, num_mel_bins=num_features)
output_cmvn = calcMN(output.cpu().detach())
# size (m, n)
source = output_cmvn
frames_lengths = torch.LongTensor([source.size(0)])
# size (1, m, n). In general, if source is (x, m, n), then hypos is (x, ...)
source.unsqueeze_(0)
sample = {"net_input": {"src_tokens": source, "src_lengths": frames_lengths}}
hypos = task.inference_step(generator, models, sample)
assert len(hypos) == 1
transcription = []
for i in range(len(hypos)):
# Process top predictions
hyp_words = process_predictions(args, hypos[i], sp, tgt_dict)
transcription.append(hyp_words)
return transcription
def setup_asr(args, logger):
check_args(args)
import_user_module(args)
if args.max_tokens is None and args.max_sentences is None:
args.max_tokens = 30000
logger.info(args)
use_cuda = torch.cuda.is_available() and not args.cpu
# Load dataset splits
task = tasks.setup_task(args)
# Set dictionary
tgt_dict = task.target_dictionary
if args.ctc or args.rnnt:
tgt_dict.add_symbol("<ctc_blank>")
if args.ctc:
logger.info("| decoding a ctc model")
if args.rnnt:
logger.info("| decoding a rnnt model")
# Load ensemble
logger.info("| loading model(s) from {}".format(args.path))
models, _model_args = load_ensemble_for_inference(
args.path.split(":"),
task,
model_arg_overrides=eval(args.model_overrides), # noqa
)
optimize_models(args, use_cuda, models)
# Initialize generator
generator = task.build_generator(args)
sp = spm.SentencePieceProcessor()
sp.Load(os.path.join(args.data, "spm.model"))
return task, generator, models, sp, tgt_dict
def transcribe_file(args, task, generator, models, sp, tgt_dict):
path = args.input_file
if not os.path.exists(path):
raise FileNotFoundError("Audio file not found: {}".format(path))
waveform, sample_rate = torchaudio.load_wav(path)
waveform = waveform.mean(0, True)
waveform = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=16000
)(waveform)
start = time.time()
transcription = transcribe(
waveform, args, task, generator, models, sp, tgt_dict
)
transcription_time = time.time() - start
return transcription_time, transcription
def get_microphone_transcription(args, task, generator, models, sp, tgt_dict):
for (waveform, sample_rate) in get_microphone_chunks():
waveform = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=16000
)(waveform.reshape(1, -1))
transcription = transcribe(
waveform, args, task, generator, models, sp, tgt_dict
)
yield transcription
import argparse
import logging
import os
import unittest
from interactive_asr.utils import setup_asr, transcribe_file
class ASRTest(unittest.TestCase):
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
arguments_dict = {
'path': '/scratch/jamarshon/downloads/model.pt',
'input_file': '/scratch/jamarshon/audio/examples/interactive_asr/data/sample.wav',
'data': '/scratch/jamarshon/downloads',
'user_dir': '/scratch/jamarshon/fairseq-py/examples/speech_recognition',
'no_progress_bar': False, 'log_interval': 1000, 'log_format': None,
'tensorboard_logdir': '', 'tbmf_wrapper': False, 'seed': 1, 'cpu': True,
'fp16': False, 'memory_efficient_fp16': False, 'fp16_init_scale': 128,
'fp16_scale_window': None, 'fp16_scale_tolerance': 0.0,
'min_loss_scale': 0.0001, 'threshold_loss_scale': None,
'criterion': 'cross_entropy', 'tokenizer': None, 'bpe': None, 'optimizer':
'nag', 'lr_scheduler': 'fixed', 'task': 'speech_recognition', 'num_workers': 0,
'skip_invalid_size_inputs_valid_test': False, 'max_tokens': 10000000,
'max_sentences': None, 'required_batch_size_multiple': 8, 'dataset_impl': None,
'gen_subset': 'test', 'num_shards': 1, 'shard_id': 0,
'remove_bpe': None, 'quiet': False, 'model_overrides': '{}',
'results_path': None, 'beam': 40, 'nbest': 1, 'max_len_a': 0,
'max_len_b': 200, 'min_len': 1, 'match_source_len': False,
'no_early_stop': False, 'unnormalized': False, 'no_beamable_mm': False,
'lenpen': 1, 'unkpen': 0, 'replace_unk': None, 'sacrebleu': False,
'score_reference': False, 'prefix_size': 0, 'no_repeat_ngram_size': 0,
'sampling': False, 'sampling_topk': -1, 'sampling_topp': -1.0,
'temperature': 1.0, 'diverse_beam_groups': -1, 'diverse_beam_strength': 0.5,
'print_alignment': False, 'ctc': False,
'rnnt': False, 'kspmodel': None, 'wfstlm': None, 'rnnt_decoding_type': 'greedy',
'lm_weight': 0.2, 'rnnt_len_penalty': -0.5, 'momentum': 0.99, 'weight_decay': 0.0,
'force_anneal': None, 'lr_shrink': 0.1, 'warmup_updates': 0}
arguments_dict['path'] = os.environ.get('ASR_MODEL_PATH', None)
arguments_dict['input_file'] = os.environ.get('ASR_INPUT_FILE', None)
arguments_dict['data'] = os.environ.get('ASR_DATA_PATH', None)
arguments_dict['user_dir'] = os.environ.get('ASR_USER_DIR', None)
args = argparse.Namespace(**arguments_dict)
def test_transcribe_file(self):
task, generator, models, sp, tgt_dict = setup_asr(self.args, self.logger)
_, transcription = transcribe_file(self.args, task, generator, models, sp, tgt_dict)
expected_transcription = [['THE QUICK BROWN FOX JUMPS OVER THE LAZY DOG']]
self.assertEqual(transcription, expected_transcription, msg=str(transcription))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment