Commit 0c1e3253 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Fix virtual function issue with CTC decoder (#3230)

Summary:
Currently, creating CTCDecoder object by passing a language model to
`lm` argument without assigning it to a variable elsewhere causes
`RuntimeError: Tried to call pure virtual function "LM::start"`.

According to discussions on PyBind11, (
https://github.com/pybind/pybind11/discussions/4013 and
https://github.com/pybind/pybind11/pull/2839
) this is due to Python object garbage-collected by the time
it's used by code implemented in C++. It attempts to call
methods defined in Python, which overrides the base pure virtual
function, but the object which provides this override gets
deleted by garbage collrector, as the original object is not
reference counted.

This commit fixes this by simply assiging the given `lm` object
as an attribute of CTCDecoder class.

Address https://github.com/pytorch/audio/issues/3218

Pull Request resolved: https://github.com/pytorch/audio/pull/3230

Reviewed By: hwangjeff

Differential Revision: D44642989

Pulled By: mthrok

fbshipit-source-id: a90af828c7c576bc0eb505164327365ebaadc471
parent 6270e609
...@@ -169,3 +169,19 @@ class CTCDecoderTest(TempDirMixin, TorchaudioTestCase): ...@@ -169,3 +169,19 @@ class CTCDecoderTest(TempDirMixin, TorchaudioTestCase):
expected_tokens = ["|", "f", "|", "o", "a"] expected_tokens = ["|", "f", "|", "o", "a"]
self.assertEqual(tokens, expected_tokens) self.assertEqual(tokens, expected_tokens)
def test_lm_lifecycle(self):
"""Passing lm without assiging it to a vaiable won't cause runtime error
https://github.com/pytorch/audio/issues/3218
"""
from torchaudio.models.decoder import ctc_decoder
from .ctc_decoder_utils import CustomZeroLM
decoder = ctc_decoder(
lexicon=get_asset_path("decoder/lexicon.txt"),
tokens=get_asset_path("decoder/tokens.txt"),
lm=CustomZeroLM(),
)
decoder(torch.zeros((1, 3, NUM_TOKENS), dtype=torch.float32))
...@@ -269,6 +269,12 @@ class CTCDecoder: ...@@ -269,6 +269,12 @@ class CTCDecoder:
) )
else: else:
self.decoder = _LexiconFreeDecoder(decoder_options, lm, silence, self.blank, transitions) self.decoder = _LexiconFreeDecoder(decoder_options, lm, silence, self.blank, transitions)
# https://github.com/pytorch/audio/issues/3218
# If lm is passed like rvalue reference, the lm object gets garbage collected,
# and later call to the lm fails.
# This ensures that lm object is not deleted as long as the decoder is alive.
# https://github.com/pybind/pybind11/discussions/4013
self.lm = lm
def _get_tokens(self, idxs: torch.IntTensor) -> torch.LongTensor: def _get_tokens(self, idxs: torch.IntTensor) -> torch.LongTensor:
idxs = (g[0] for g in it.groupby(idxs)) idxs = (g[0] for g in it.groupby(idxs))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment