Commit 896ade04 authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Allow token list as CTC decoder input (#2112)

Summary:
Additionally accept list of tokens as CTC decoder input. This makes it possible to directly pass in something like `bundles.get_labels()` into the decoder factory function instead of requiring a separate tokens file.

Pull Request resolved: https://github.com/pytorch/audio/pull/2112

Reviewed By: hwangjeff, nateanl, mthrok

Differential Revision: D33352909

Pulled By: carolineechen

fbshipit-source-id: 6d22072e34f6cd7c6f931ce4eaf294ae4cf0c5cc
parent bb528d7e
import torch import torch
from parameterized import parameterized
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TempDirMixin, TempDirMixin,
TorchaudioTestCase, TorchaudioTestCase,
...@@ -9,21 +10,24 @@ from torchaudio_unittest.common_utils import ( ...@@ -9,21 +10,24 @@ from torchaudio_unittest.common_utils import (
@skipIfNoCtcDecoder @skipIfNoCtcDecoder
class CTCDecoderTest(TempDirMixin, TorchaudioTestCase): class CTCDecoderTest(TempDirMixin, TorchaudioTestCase):
def _get_decoder(self): def _get_decoder(self, tokens=None):
from torchaudio.prototype.ctc_decoder import kenlm_lexicon_decoder from torchaudio.prototype.ctc_decoder import kenlm_lexicon_decoder
lexicon_file = get_asset_path("decoder/lexicon.txt") lexicon_file = get_asset_path("decoder/lexicon.txt")
tokens_file = get_asset_path("decoder/tokens.txt")
kenlm_file = get_asset_path("decoder/kenlm.arpa") kenlm_file = get_asset_path("decoder/kenlm.arpa")
if tokens is None:
tokens = get_asset_path("decoder/tokens.txt")
return kenlm_lexicon_decoder( return kenlm_lexicon_decoder(
lexicon=lexicon_file, lexicon=lexicon_file,
tokens=tokens_file, tokens=tokens,
kenlm=kenlm_file, kenlm=kenlm_file,
) )
def test_construct_decoder(self): @parameterized.expand([(get_asset_path("decoder/tokens.txt"),), (["-", "|", "f", "o", "b", "a", "r"],)])
self._get_decoder() def test_construct_decoder(self, tokens):
self._get_decoder(tokens)
def test_shape(self): def test_shape(self):
B, T, N = 4, 15, 10 B, T, N = 4, 15, 10
...@@ -36,9 +40,10 @@ class CTCDecoderTest(TempDirMixin, TorchaudioTestCase): ...@@ -36,9 +40,10 @@ class CTCDecoderTest(TempDirMixin, TorchaudioTestCase):
self.assertEqual(len(results), B) self.assertEqual(len(results), B)
def test_index_to_tokens(self): @parameterized.expand([(get_asset_path("decoder/tokens.txt"),), (["-", "|", "f", "o", "b", "a", "r"],)])
def test_index_to_tokens(self, tokens):
# decoder tokens: '-' '|' 'f' 'o' 'b' 'a' 'r' # decoder tokens: '-' '|' 'f' 'o' 'b' 'a' 'r'
decoder = self._get_decoder() decoder = self._get_decoder(tokens)
idxs = torch.LongTensor((1, 2, 1, 3, 5)) idxs = torch.LongTensor((1, 2, 1, 3, 5))
tokens = decoder.idxs_to_tokens(idxs) tokens = decoder.idxs_to_tokens(idxs)
......
...@@ -225,6 +225,7 @@ PYBIND11_MODULE(_torchaudio_decoder, m) { ...@@ -225,6 +225,7 @@ PYBIND11_MODULE(_torchaudio_decoder, m) {
py::class_<Dictionary>(m, "_Dictionary") py::class_<Dictionary>(m, "_Dictionary")
.def(py::init<>()) .def(py::init<>())
.def(py::init<const std::vector<std::string>&>(), "tkns"_a)
.def(py::init<const std::string&>(), "filename"_a) .def(py::init<const std::string&>(), "filename"_a)
.def("entry_size", &Dictionary::entrySize) .def("entry_size", &Dictionary::entrySize)
.def("index_size", &Dictionary::indexSize) .def("index_size", &Dictionary::indexSize)
......
...@@ -26,6 +26,15 @@ Dictionary::Dictionary(const std::string& filename) { ...@@ -26,6 +26,15 @@ Dictionary::Dictionary(const std::string& filename) {
createFromStream(stream); createFromStream(stream);
} }
Dictionary::Dictionary(const std::vector<std::string>& tkns) {
for (const auto& tkn : tkns) {
addEntry(tkn);
}
if (!isContiguous()) {
throw std::runtime_error("Invalid dictionary format - not contiguous");
}
}
void Dictionary::createFromStream(std::istream& stream) { void Dictionary::createFromStream(std::istream& stream) {
if (!stream) { if (!stream) {
throw std::runtime_error("Unable to open dictionary input stream."); throw std::runtime_error("Unable to open dictionary input stream.");
......
...@@ -26,6 +26,8 @@ class Dictionary { ...@@ -26,6 +26,8 @@ class Dictionary {
explicit Dictionary(const std::string& filename); explicit Dictionary(const std::string& filename);
explicit Dictionary(const std::vector<std::string>& tkns);
size_t entrySize() const; size_t entrySize() const;
size_t indexSize() const; size_t indexSize() const;
......
import itertools as it import itertools as it
from collections import namedtuple from collections import namedtuple
from typing import Dict from typing import Dict
from typing import List, Optional from typing import List, Optional, Union
import torch import torch
from torchaudio._torchaudio_decoder import ( from torchaudio._torchaudio_decoder import (
...@@ -157,7 +157,7 @@ class KenLMLexiconDecoder: ...@@ -157,7 +157,7 @@ class KenLMLexiconDecoder:
def kenlm_lexicon_decoder( def kenlm_lexicon_decoder(
lexicon: str, lexicon: str,
tokens: str, tokens: Union[str, List[str]],
kenlm: str, kenlm: str,
nbest: int = 1, nbest: int = 1,
beam_size: int = 50, beam_size: int = 50,
...@@ -177,7 +177,7 @@ def kenlm_lexicon_decoder( ...@@ -177,7 +177,7 @@ def kenlm_lexicon_decoder(
Args: Args:
lexicon (str): lexicon file containing the possible words lexicon (str): lexicon file containing the possible words
tokens (str): file containing valid tokens tokens (str or List[str]): file or list containing valid tokens
kenlm (str): file containing languge model kenlm (str): file containing languge model
nbest (int, optional): number of best decodings to return (Default: 1) nbest (int, optional): number of best decodings to return (Default: 1)
beam_size (int, optional): max number of hypos to hold after each decode step (Default: 50) beam_size (int, optional): max number of hypos to hold after each decode step (Default: 50)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment