Commit a6b34a5d authored by jiyuntu-eero's avatar jiyuntu-eero Committed by Facebook GitHub Bot
Browse files

Fix initialization of `get_trellis`. (#3172)

Summary:
Fix https://github.com/pytorch/audio/issues/3166. In `get_trellis` method, the index of blank symbol is regarded as 0 by default. It should be changed to `blank_id`.

Pull Request resolved: https://github.com/pytorch/audio/pull/3172

Reviewed By: mthrok

Differential Revision: D44090889

Pulled By: nateanl

fbshipit-source-id: d119f4ded895d31aeefd59f8d975224870100264
parent 014d7140
......@@ -154,7 +154,7 @@ def get_trellis(emission, tokens, blank_id=0):
# The extra dim for time axis is for simplification of the code.
trellis = torch.empty((num_frame + 1, num_tokens + 1))
trellis[0, 0] = 0
trellis[1:, 0] = torch.cumsum(emission[:, 0], 0)
trellis[1:, 0] = torch.cumsum(emission[:, blank_id], 0)
trellis[0, -num_tokens:] = -float("inf")
trellis[-num_tokens:, 0] = float("inf")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment