Unverified Commit 5d0b0642 authored by moto's avatar moto Committed by GitHub
Browse files

Add speech recognition C++ example (#1538)

parent fad19fab
...@@ -2,3 +2,6 @@ ...@@ -2,3 +2,6 @@
path = third_party/kaldi/submodule path = third_party/kaldi/submodule
url = https://github.com/kaldi-asr/kaldi url = https://github.com/kaldi-asr/kaldi
ignore = dirty ignore = dirty
[submodule "examples/libtorchaudio/simplectc"]
path = examples/libtorchaudio/simplectc
url = https://github.com/mthrok/ctcdecode
build build
data/output.wav data/output.wav
data/pipeline.zip *.zip
output
...@@ -14,4 +14,6 @@ message("libtorchaudio CMakeLists: ${TORCH_CXX_FLAGS}") ...@@ -14,4 +14,6 @@ message("libtorchaudio CMakeLists: ${TORCH_CXX_FLAGS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
add_subdirectory(../.. libtorchaudio) add_subdirectory(../.. libtorchaudio)
add_subdirectory(simplectc)
add_subdirectory(augmentation) add_subdirectory(augmentation)
add_subdirectory(speech_recognition)
# Libtorchaudio Examples # Libtorchaudio Examples
* [Augmentation](./augmentation) * [Augmentation](./augmentation)
* [Speech Recognition with wav2vec2.0](./speech_recognition)
## Build ## Build
...@@ -14,6 +15,7 @@ It is currently not distributed, and it will be built alongside with the applica ...@@ -14,6 +15,7 @@ It is currently not distributed, and it will be built alongside with the applica
The following commands will build `libtorchaudio` and applications. The following commands will build `libtorchaudio` and applications.
```bash ```bash
git submodule update
mkdir build mkdir build
cd build cd build
cmake -GNinja \ cmake -GNinja \
......
...@@ -8,6 +8,7 @@ build_dir="${this_dir}/build" ...@@ -8,6 +8,7 @@ build_dir="${this_dir}/build"
mkdir -p "${build_dir}" mkdir -p "${build_dir}"
cd "${build_dir}" cd "${build_dir}"
git submodule update
cmake -GNinja \ cmake -GNinja \
-DCMAKE_PREFIX_PATH="$(python -c 'import torch;print(torch.utils.cmake_prefix_path)')" \ -DCMAKE_PREFIX_PATH="$(python -c 'import torch;print(torch.utils.cmake_prefix_path)')" \
-DBUILD_SOX=ON \ -DBUILD_SOX=ON \
......
Subproject commit b1a30d7a65342012e0d2524d9bae1c5412b24a23
add_executable(transcribe transcribe.cpp)
add_executable(transcribe_list transcribe_list.cpp)
target_link_libraries(transcribe "${TORCH_LIBRARIES}" "${TORCHAUDIO_LIBRARY}" "${CTCDECODE_LIBRARY}")
target_link_libraries(transcribe_list "${TORCH_LIBRARIES}" "${TORCHAUDIO_LIBRARY}" "${CTCDECODE_LIBRARY}")
set_property(TARGET transcribe PROPERTY CXX_STANDARD 14)
set_property(TARGET transcribe_list PROPERTY CXX_STANDARD 14)
# Speech Recognition with wav2vec2.0
This example demonstarates how you can use torchaudio's I/O features and models to run speech recognition in C++ application.
**NOTE**
This example uses `"sox_io"` backend for loading audio, which does not work on Windows. To make it work on
Windows, you need to replace the part of loading audio and converting it to Tensor object.
## 1. Create a transcription pipeline TorchScript file
We will create a TorchScript that performs the following processes;
1. Load audio from a file.
1. Pass audio to encoder which produces the sequence of probability distribution on labels.
1. Pass the encoder output to decoder which generates transcripts.
For building decoder, we borrow the pre-trained weights published by `fairseq` and/or Hugging Face Transformers, then convert it `torchaudio`'s format, which supports TorchScript.
### 1.1. From `fairseq`
For `fairseq` models, you can download pre-trained weights
You can download a model from [`fairseq` repository](https://github.com/pytorch/fairseq/tree/master/examples/wav2vec). Here, we will use `Base / 960h` model. You also need to download [the letter dictionary file](https://github.com/pytorch/fairseq/tree/master/examples/wav2vec#evaluating-a-ctc-model).
For the decoder part, we use [simple_ctc](https://github.com/mthrok/ctcdecode), which also supports TorchScript.
```bash
mkdir -p pipeline-fairseq
python build_pipeline_from_fairseq.py \
--model-file "wav2vec_small_960.pt" \
--dict-dir <DIRECTORY_WHERE_dict.ltr.txt_IS_FOUND> \
--output-path "./pipeline-fairseq/"
```
The above command should create the following TorchScript object files in the output directory.
```
decoder.zip encoder.zip loader.zip
```
* `loader.zip` loads audio file and generate waveform Tensor.
* `encoder.zip` receives waveform Tensor and generates the sequence of probability distribution over the label.
* `decoder.zip` receives the probability distribution over the label and generates a transcript.
### 1.2. From Hugging Face Transformers
[Hugging Face Transformers](https://huggingface.co/transformers/index.html) and [Hugging Face Model Hub](https://huggingface.co/models) provides `wav2vec2.0` models fine-tuned on variety of datasets and languages.
We can also import the model published on Hugging Face Hub and run it in our C++ application.
In the following example, we will try the Geremeny model, ([facebook/wav2vec2-large-xlsr-53-german](https://huggingface.co/facebook/wav2vec2-large-xlsr-53-german/tree/main)) on [VoxForge Germany dataset](http://www.voxforge.org/de/downloads).
```bash
mkdir -p pipeline-hf
python build_pipeline_from_huggingface_transformers.py \
--model facebook/wav2vec2-large-xlsr-53-german \
--output-path ./pipeline-hf/
```
The resulting TorchScript object files should be same as the `fairseq` example.
## 2. Build the application
Please refer to [the top level README.md](../README.md)
## 3. Run the application
Now we run the C++ application [`transcribe`](./transcribe.cpp), with the TorchScript object we created in Step.1.1. and an input audio file.
```bash
../build/speech_recognition/transcribe ./pipeline-fairseq ../data/input.wav
```
This will output something like the following.
```
Loading module from: ./pipeline/loader.zip
Loading module from: ./pipeline/encoder.zip
Loading module from: ./pipeline/decoder.zip
Loading the audio
Running inference
Generating the transcription
I HAD THAT CURIOSITY BESIDE ME AT THIS MOMENT
Done.
```
## 4. Evaluate the pipeline on Librispeech dataset
Let's evaluate this word error rate (WER) of this application using [Librispeech dataset](https://www.openslr.org/12).
### 4.1. Create a list of audio paths
For the sake of simplifying our C++ code, we will first parse the Librispeech dataset to get the list of audio path
```bash
python parse_librispeech.py <PATH_TO_YOUR_DATASET>/LibriSpeech/test-clean ./flist.txt
```
The list should look like the following;
```bash
head flist.txt
1089-134691-0000 /LibriSpeech/test-clean/1089/134691/1089-134691-0000.flac HE COULD WAIT NO LONGER
```
### 4.2. Run the transcription
[`transcribe_list`](./transcribe_list.cpp) processes the input flist list and feed the audio path one by one to the pipeline, then generate reference file and hypothesis file.
```bash
../build/speech_recognition/transcribe_list ./pipeline-fairseq ./flist.txt <OUTPUT_DIR>
```
### 4.3. Score WER
You need `sclite` for this step. You can download the code from [SCTK repository](https://github.com/usnistgov/SCTK).
```bash
# in the output directory
sclite -r ref.trn -h hyp.trn -i wsj -o pralign -o sum
```
WER can be found in the resulting `hyp.trn.sys`. Check out the column that starts with `Sum/Avg` the first column of the third block is `100 - WER`.
In our test, we got the following results.
| model | Fine Tune | test-clean | test-other |
|:-----------------------------------------:|----------:|:----------:|:----------:|
| Base<br/>`wav2vec_small_960` | 960h | 3.1 | 7.7 |
| Large<br/>`wav2vec_big_960` | 960h | 2.6 | 5.9 |
| Large (LV-60)<br/>`wav2vec2_vox_960h_new` | 960h | 2.9 | 6.2 |
| Large (LV-60) + Self Training<br/>`wav2vec_vox_960h_pl` | 960h | 1.9 | 4.5 |
You can also check `hyp.trn.pra` file to see what errors were made.
```
id: (3528-168669-0005)
Scores: (#C #S #D #I) 7 1 0 0
REF: there is a stone to be RAISED heavy
HYP: there is a stone to be RACED heavy
Eval: S
```
## 5. Evaluate the pipeline on VoxForge dataset
Now we use the pipeline we created in step 1.2. This time with German language dataset from VoxForge.
### 5.1. Create a list of audio paths
Download an archive from http://www.repository.voxforge1.org/downloads/de/Trunk/Audio/Main/16kHz_16bit/, and extract it to your local file system, then run the following to generate the file list.
```bash
python parse_voxforge.py <PATH_TO_YOUR_DATASET> > ./flist-de.txt
```
The list should look like
```bash
head flist-de.txt
de5-001 /datasets/voxforge/de/guenter-20140214-afn/wav/de5-001.wav ES SOLL ETWA FÜNFZIGTAUSEND VERSCHIEDENE SORTEN GEBEN
```
### 5.2. Run the application and score WER
This process is same as the Librispeech example. We just use the pipeline with the Germany model and file list of Germany dataset. Refer to the corresponding ssection in Librispeech evaluation..
```bash
../build/speech_recognition/transcribe_list ./pipeline-hf ./flist-de.txt <OUTPUT_DIR>
```
Then
```bash
# in the output directory
sclite -r ref.trn -h hyp.trn -i wsj -o pralign -o sum
```
You can find the detail of evalauation result in PRA.
```
id: (guenter-20140214-afn/mfc/de5-012)
Scores: (#C #S #D #I) 4 1 1 0
REF: die ausgaben kÖnnen gigantisch STEIGE N
HYP: die ausgaben kÖnnen gigantisch ****** STEIGEN
Eval: D S
```
#!/usr/bin/evn python3
"""Build Speech Recognition pipeline based on fairseq's wav2vec2.0 and dump it to TorchScript file.
To use this script, you need `fairseq`.
"""
import os
import argparse
import logging
import torch
from torch.utils.mobile_optimizer import optimize_for_mobile
import torchaudio
from torchaudio.models.wav2vec2.utils.import_fairseq import import_fairseq_model
import fairseq
import simple_ctc
_LG = logging.getLogger(__name__)
def _parse_args():
parser = argparse.ArgumentParser(
description=__doc__,
)
parser.add_argument(
'--model-file',
required=True,
help='Path to the input pretrained weight file.'
)
parser.add_argument(
'--dict-dir',
help=(
'Path to the directory in which `dict.ltr.txt` file is found. '
'Required only when the model is finetuned.'
)
)
parser.add_argument(
'--output-path',
help='Path to the directory, where the TorchScript-ed pipelines are saved.',
)
parser.add_argument(
'--test-file',
help='Path to a test audio file.',
)
parser.add_argument(
'--debug',
action='store_true',
help=(
'When enabled, individual components are separately tested '
'for the numerical compatibility and TorchScript compatibility.'
)
)
parser.add_argument(
'--quantize',
action='store_true',
help='Apply quantization to model.'
)
parser.add_argument(
'--optimize-for-mobile',
action='store_true',
help='Apply optmization for mobile.'
)
return parser.parse_args()
class Loader(torch.nn.Module):
def forward(self, audio_path: str) -> torch.Tensor:
waveform, sample_rate = torchaudio.load(audio_path)
if sample_rate != 16000:
waveform = torchaudio.functional.resample(waveform, float(sample_rate), 16000.)
return waveform
class Encoder(torch.nn.Module):
def __init__(self, encoder: torch.nn.Module):
super().__init__()
self.encoder = encoder
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
result, _ = self.encoder(waveform)
return result
class Decoder(torch.nn.Module):
def __init__(self, decoder: torch.nn.Module):
super().__init__()
self.decoder = decoder
def forward(self, emission: torch.Tensor) -> str:
result = self.decoder.decode(emission)
return ''.join(result.label_sequences[0][0]).replace('|', ' ')
def _get_decoder():
labels = [
"<s>",
"<pad>",
"</s>",
"<unk>",
"|",
"E",
"T",
"A",
"O",
"N",
"I",
"H",
"S",
"R",
"D",
"L",
"U",
"M",
"W",
"C",
"F",
"G",
"Y",
"P",
"B",
"V",
"K",
"'",
"X",
"J",
"Q",
"Z",
]
return Decoder(
simple_ctc.BeamSearchDecoder(
labels,
cutoff_top_n=40,
cutoff_prob=0.8,
beam_size=100,
num_processes=1,
blank_id=0,
is_nll=True,
)
)
def _load_fairseq_model(input_file, data_dir=None):
overrides = {}
if data_dir:
overrides['data'] = data_dir
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[input_file], arg_overrides=overrides
)
model = model[0]
return model
def _get_model(model_file, dict_dir):
original = _load_fairseq_model(model_file, dict_dir)
model = import_fairseq_model(original.w2v_encoder)
return model
def _main():
args = _parse_args()
_init_logging(args.debug)
loader = Loader()
model = _get_model(args.model_file, args.dict_dir).eval()
encoder = Encoder(model)
decoder = _get_decoder()
_LG.info(encoder)
if args.quantize:
_LG.info('Quantizing the model')
model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
encoder = torch.quantization.quantize_dynamic(
encoder, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
_LG.info(encoder)
# test
if args.test_file:
_LG.info('Testing with %s', args.test_file)
waveform = loader(args.test_file)
emission = encoder(waveform)
transcript = decoder(emission)
_LG.info(transcript)
torch.jit.script(loader).save(os.path.join(args.output_path, 'loader.zip'))
torch.jit.script(decoder).save(os.path.join(args.output_path, 'decoder.zip'))
scripted = torch.jit.script(encoder)
if args.optimize_for_mobile:
scripted = optimize_for_mobile(scripted)
scripted.save(os.path.join(args.output_path, 'encoder.zip'))
def _init_logging(debug=False):
level = logging.DEBUG if debug else logging.INFO
format_ = (
'%(message)s' if not debug else
'%(asctime)s: %(levelname)7s: %(funcName)10s: %(message)s'
)
logging.basicConfig(level=level, format=format_)
if __name__ == '__main__':
_main()
#!/usr/bin/env python3
import argparse
import logging
import os
import torch
import torchaudio
from torchaudio.models.wav2vec2.utils.import_huggingface import import_huggingface_model
import simple_ctc
_LG = logging.getLogger(__name__)
def _parse_args():
parser = argparse.ArgumentParser(
description=__doc__,
)
parser.add_argument(
'--model',
required=True,
help='Path to the input pretrained weight file.'
)
parser.add_argument(
'--output-path',
help='Path to the directory, where the Torchscript-ed pipelines are saved.',
)
parser.add_argument(
'--test-file',
help='Path to a test audio file.',
)
parser.add_argument(
'--quantize',
action='store_true',
help='Quantize the model.',
)
parser.add_argument(
'--debug',
action='store_true',
help=(
'When enabled, individual components are separately tested '
'for the numerical compatibility and TorchScript compatibility.'
)
)
return parser.parse_args()
class Loader(torch.nn.Module):
def forward(self, audio_path: str) -> torch.Tensor:
waveform, sample_rate = torchaudio.load(audio_path)
if sample_rate != 16000:
waveform = torchaudio.functional.resample(waveform, float(sample_rate), 16000.)
return waveform
class Encoder(torch.nn.Module):
def __init__(self, encoder: torch.nn.Module):
super().__init__()
self.encoder = encoder
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
length = torch.tensor([waveform.shape[1]])
result, length = self.encoder(waveform, length)
return result
class Decoder(torch.nn.Module):
def __init__(self, decoder: torch.nn.Module):
super().__init__()
self.decoder = decoder
def forward(self, emission: torch.Tensor) -> str:
result = self.decoder.decode(emission)
return ''.join(result.label_sequences[0][0]).replace('|', ' ')
def _get_model(model_id):
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
tokenizer = Wav2Vec2Processor.from_pretrained(model_id).tokenizer
labels = [k for k, v in sorted(tokenizer.get_vocab().items(), key=lambda kv: kv[1])]
original = Wav2Vec2ForCTC.from_pretrained(model_id)
model = import_huggingface_model(original)
return model.eval(), labels
def _get_decoder(labels):
return Decoder(
simple_ctc.BeamSearchDecoder(
labels,
cutoff_top_n=40,
cutoff_prob=0.8,
beam_size=100,
num_processes=1,
blank_id=0,
is_nll=True,
)
)
def _main():
args = _parse_args()
_init_logging(args.debug)
_LG.info('Loading model: %s', args.model)
model, labels = _get_model(args.model)
_LG.info('Labels: %s', labels)
_LG.info('Building pipeline')
loader = Loader()
encoder = Encoder(model)
decoder = _get_decoder(labels)
_LG.info(encoder)
if args.quantize:
_LG.info('Quantizing the model')
model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
encoder = torch.quantization.quantize_dynamic(
encoder, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
_LG.info(encoder)
# test
if args.test_file:
_LG.info('Testing with %s', args.test_file)
waveform = loader(args.test_file)
emission = encoder(waveform)
transcript = decoder(emission)
_LG.info(transcript)
torch.jit.script(loader).save(os.path.join(args.output_path, 'loader.zip'))
torch.jit.script(encoder).save(os.path.join(args.output_path, 'encoder.zip'))
torch.jit.script(decoder).save(os.path.join(args.output_path, 'decoder.zip'))
def _init_logging(debug=False):
level = logging.DEBUG if debug else logging.INFO
format_ = (
'%(message)s' if not debug else
'%(asctime)s: %(levelname)7s: %(funcName)10s: %(message)s'
)
logging.basicConfig(level=level, format=format_)
if __name__ == '__main__':
_main()
#!/usr/bin/env python3
"""Parse a directory contains Librispeech dataset.
Recursively search for "*.trans.txt" file in the given directory and print out
`<ID>\\t<AUDIO_PATH>\\t<TRANSCRIPTION>`
example: python parse_librispeech.py LibriSpeech/test-clean
1089-134691-0000\t/LibriSpeech/test-clean/1089/134691/1089-134691-0000.flac\tHE COULD WAIT NO LONGER
...
Dataset can be obtained from https://www.openslr.org/12
"""
import argparse
from pathlib import Path
def _parse_args():
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
'input_dir',
type=Path,
help='Directory where `*.trans.txt` files are searched.'
)
return parser.parse_args()
def _parse_transcript(path):
with open(path) as trans_fileobj:
for line in trans_fileobj:
line = line.strip()
if line:
yield line.split(' ', maxsplit=1)
def _parse_directory(root_dir: Path):
for trans_file in root_dir.glob('**/*.trans.txt'):
trans_dir = trans_file.parent
for id_, transcription in _parse_transcript(trans_file):
audio_path = trans_dir / f'{id_}.flac'
yield id_, audio_path, transcription
def _main():
args = _parse_args()
for id_, path, transcription in _parse_directory(args.input_dir):
print(f'{id_}\t{path}\t{transcription}')
if __name__ == '__main__':
_main()
#!/usr/bin/env python
"""Parse a directory contains VoxForge dataset.
Recursively search for "PROMPTS" file in the given directory and print out
`<ID>\\t<AUDIO_PATH>\\t<TRANSCRIPTION>`
example: python parse_voxforge.py voxforge/de/Helge-20150608-aku
de5-001\t/datasets/voxforge/de/guenter-20140214-afn/wav/de5-001.wav\tES SOLL ETWA FÜNFZIGTAUSEND VERSCHIEDENE SORTEN GEBEN
...
Dataset can be obtained from http://www.repository.voxforge1.org/downloads/de/Trunk/Audio/Main/16kHz_16bit/
"""
import os
import argparse
from pathlib import Path
def _parse_args():
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
'input_dir',
type=Path,
help='Directory where `*.trans.txt` files are searched.'
)
return parser.parse_args()
def _parse_prompts(path):
base_dir = path.parent.parent
with open(path) as trans_fileobj:
for line in trans_fileobj:
line = line.strip()
if not line:
continue
id_, transcript = line.split(' ', maxsplit=1)
if not transcript:
continue
transcript = transcript.upper()
filename = id_.split('/')[-1]
audio_path = base_dir / 'wav' / f'{filename}.wav'
if os.path.exists(audio_path):
yield id_, audio_path, transcript
def _parse_directory(root_dir: Path):
for prompt_file in root_dir.glob('**/PROMPTS'):
try:
yield from _parse_prompts(prompt_file)
except UnicodeDecodeError:
pass
def _main():
args = _parse_args()
for id_, path, transcription in _parse_directory(args.input_dir):
print(f'{id_}\t{path}\t{transcription}')
if __name__ == '__main__':
_main()
#include <torch/script.h>
int main(int argc, char* argv[]) {
if (argc != 3) {
std::cerr << "Usage: " << argv[0] << " <JIT_OBJECT_DIR> <INPUT_AUDIO_FILE>" << std::endl;
return -1;
}
torch::jit::script::Module loader, encoder, decoder;
std::cout << "Loading module from: " << argv[1] << std::endl;
try {
loader = torch::jit::load(std::string(argv[1]) + "/loader.zip");
} catch (const c10::Error &error) {
std::cerr << "Failed to load the module:" << error.what() << std::endl;
return -1;
}
try {
encoder = torch::jit::load(std::string(argv[1]) + "/encoder.zip");
} catch (const c10::Error &error) {
std::cerr << "Failed to load the module:" << error.what() << std::endl;
return -1;
}
try {
decoder = torch::jit::load(std::string(argv[1]) + "/decoder.zip");
} catch (const c10::Error &error) {
std::cerr << "Failed to load the module:" << error.what() << std::endl;
return -1;
}
std::cout << "Loading the audio" << std::endl;
auto waveform = loader.forward({c10::IValue(argv[2])});
std::cout << "Running inference" << std::endl;
auto emission = encoder.forward({waveform});
std::cout << "Generating the transcription" << std::endl;
auto result = decoder.forward({emission});
std::cout << result.toString()->string() << std::endl;
std::cout << "Done." << std::endl;
}
#include <chrono>
#include <torch/script.h>
int main(int argc, char* argv[]) {
if (argc != 4) {
std::cerr << "Usage: " << argv[0] << "<JIT_OBJECT_DIR> <FILE_LIST> <OUTPUT_DIR>\n" << std::endl;
std::cerr << "<FILE_LIST> is `<ID>\t<PATH>\t<TRANSCRIPTION>`" << std::endl;
return -1;
}
torch::jit::script::Module loader, encoder, decoder;
std::cout << "Loading module from: " << argv[1] << std::endl;
try {
loader = torch::jit::load(std::string(argv[1]) + "/loader.zip");
} catch (const c10::Error &error) {
std::cerr << "Failed to load the module:" << error.what() << std::endl;
return -1;
}
try {
encoder = torch::jit::load(std::string(argv[1]) + "/encoder.zip");
} catch (const c10::Error &error) {
std::cerr << "Failed to load the module:" << error.what() << std::endl;
return -1;
}
try {
decoder = torch::jit::load(std::string(argv[1]) + "/decoder.zip");
} catch (const c10::Error &error) {
std::cerr << "Failed to load the module:" << error.what() << std::endl;
return -1;
}
std::ifstream input_file(argv[2]);
std::string output_dir(argv[3]);
std::ofstream output_ref(output_dir + "/ref.trn");
std::ofstream output_hyp(output_dir + "/hyp.trn");
std::string line;
std::chrono::milliseconds t_encode(0);
std::chrono::milliseconds t_decode(0);
while(std::getline(input_file, line)) {
std::istringstream iline(line);
std::string id;
std::string path;
std::string reference;
std::getline(iline, id, '\t');
std::getline(iline, path, '\t');
std::getline(iline, reference, '\t');
auto waveform = loader.forward({c10::IValue(path)});
std::chrono::steady_clock::time_point t0 = std::chrono::steady_clock::now();
auto emission = encoder.forward({waveform});
std::chrono::steady_clock::time_point t1 = std::chrono::steady_clock::now();
auto result = decoder.forward({emission});
std::chrono::steady_clock::time_point t2 = std::chrono::steady_clock::now();
t_encode += std::chrono::duration_cast<std::chrono::milliseconds>(t1 - t0);
t_decode += std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1);
auto hypothesis = result.toString()->string();
output_hyp << hypothesis << " (" << id << ")" << std::endl;
output_ref << reference << " (" << id << ")" << std::endl;
std::cout << id << '\t' << hypothesis << std::endl;
}
std::cout << "Time (encode): " << t_encode.count() << " [ms]" << std::endl;
std::cout << "Time (decode): " << t_decode.count() << " [ms]" << std::endl;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment