Commit 896ade04 authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Allow token list as CTC decoder input (#2112)

Summary:
Additionally accept list of tokens as CTC decoder input. This makes it possible to directly pass in something like `bundles.get_labels()` into the decoder factory function instead of requiring a separate tokens file.

Pull Request resolved: https://github.com/pytorch/audio/pull/2112

Reviewed By: hwangjeff, nateanl, mthrok

Differential Revision: D33352909

Pulled By: carolineechen

fbshipit-source-id: 6d22072e34f6cd7c6f931ce4eaf294ae4cf0c5cc
parent bb528d7e
import torch
from parameterized import parameterized
from torchaudio_unittest.common_utils import (
TempDirMixin,
TorchaudioTestCase,
......@@ -9,21 +10,24 @@ from torchaudio_unittest.common_utils import (
@skipIfNoCtcDecoder
class CTCDecoderTest(TempDirMixin, TorchaudioTestCase):
def _get_decoder(self):
def _get_decoder(self, tokens=None):
from torchaudio.prototype.ctc_decoder import kenlm_lexicon_decoder
lexicon_file = get_asset_path("decoder/lexicon.txt")
tokens_file = get_asset_path("decoder/tokens.txt")
kenlm_file = get_asset_path("decoder/kenlm.arpa")
if tokens is None:
tokens = get_asset_path("decoder/tokens.txt")
return kenlm_lexicon_decoder(
lexicon=lexicon_file,
tokens=tokens_file,
tokens=tokens,
kenlm=kenlm_file,
)
def test_construct_decoder(self):
self._get_decoder()
@parameterized.expand([(get_asset_path("decoder/tokens.txt"),), (["-", "|", "f", "o", "b", "a", "r"],)])
def test_construct_decoder(self, tokens):
self._get_decoder(tokens)
def test_shape(self):
B, T, N = 4, 15, 10
......@@ -36,9 +40,10 @@ class CTCDecoderTest(TempDirMixin, TorchaudioTestCase):
self.assertEqual(len(results), B)
def test_index_to_tokens(self):
@parameterized.expand([(get_asset_path("decoder/tokens.txt"),), (["-", "|", "f", "o", "b", "a", "r"],)])
def test_index_to_tokens(self, tokens):
# decoder tokens: '-' '|' 'f' 'o' 'b' 'a' 'r'
decoder = self._get_decoder()
decoder = self._get_decoder(tokens)
idxs = torch.LongTensor((1, 2, 1, 3, 5))
tokens = decoder.idxs_to_tokens(idxs)
......
......@@ -225,6 +225,7 @@ PYBIND11_MODULE(_torchaudio_decoder, m) {
py::class_<Dictionary>(m, "_Dictionary")
.def(py::init<>())
.def(py::init<const std::vector<std::string>&>(), "tkns"_a)
.def(py::init<const std::string&>(), "filename"_a)
.def("entry_size", &Dictionary::entrySize)
.def("index_size", &Dictionary::indexSize)
......
......@@ -26,6 +26,15 @@ Dictionary::Dictionary(const std::string& filename) {
createFromStream(stream);
}
Dictionary::Dictionary(const std::vector<std::string>& tkns) {
for (const auto& tkn : tkns) {
addEntry(tkn);
}
if (!isContiguous()) {
throw std::runtime_error("Invalid dictionary format - not contiguous");
}
}
void Dictionary::createFromStream(std::istream& stream) {
if (!stream) {
throw std::runtime_error("Unable to open dictionary input stream.");
......
......@@ -26,6 +26,8 @@ class Dictionary {
explicit Dictionary(const std::string& filename);
explicit Dictionary(const std::vector<std::string>& tkns);
size_t entrySize() const;
size_t indexSize() const;
......
import itertools as it
from collections import namedtuple
from typing import Dict
from typing import List, Optional
from typing import List, Optional, Union
import torch
from torchaudio._torchaudio_decoder import (
......@@ -157,7 +157,7 @@ class KenLMLexiconDecoder:
def kenlm_lexicon_decoder(
lexicon: str,
tokens: str,
tokens: Union[str, List[str]],
kenlm: str,
nbest: int = 1,
beam_size: int = 50,
......@@ -177,7 +177,7 @@ def kenlm_lexicon_decoder(
Args:
lexicon (str): lexicon file containing the possible words
tokens (str): file containing valid tokens
tokens (str or List[str]): file or list containing valid tokens
kenlm (str): file containing languge model
nbest (int, optional): number of best decodings to return (Default: 1)
beam_size (int, optional): max number of hypos to hold after each decode step (Default: 50)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment