Commit ed6256a2 authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Remove invalid token blanking logic from RNN-T decoder (#2180)

Summary:
This PR removes logic in `RNNTBeamSearch` that blanks out joiner output values corresponding to special tokens, e.g. \<unk\>, \<eos\>, for the following reasons:
- Provided that the model was configured and trained properly, it shouldn't be necessary, e.g. the model would naturally produce low probabilities for special tokens if they don't exist in the training set.
- For our pre-trained LibriSpeech training pipeline, the removal of the logic doesn't affect evaluation WER on any of the dev/test splits.
- The existing logic doesn't generalize to arbitrary token vocabularies.
- Internally, it seems to have been acknowledged that this logic was introduced to compensate for quirks in other parts of the modeling infra.

Pull Request resolved: https://github.com/pytorch/audio/pull/2180

Reviewed By: carolineechen, mthrok

Differential Revision: D33822683

Pulled By: hwangjeff

fbshipit-source-id: e7047e294f71c732c77ae0c20fec60412f26f05a
parent 4c3fa875
......@@ -135,7 +135,6 @@ class RNNTBeamSearch(torch.nn.Module):
torch.tensor([1] * len(hypos), device=device),
) # [beam_width, 1, 1, num_tokens]
joined_out = torch.nn.functional.log_softmax(joined_out / self.temperature, dim=3)
joined_out[:, :, :, :4].add_(-99999) # blank out invalid tokens
return joined_out[:, 0, 0]
def _gen_b_hypos(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment