import torch class Decoder(torch.nn.Module): def __init__(self, labels): super().__init__() self.labels = labels def forward(self, logits: torch.Tensor) -> str: """Given a sequence logits over labels, get the best path string Args: logits (Tensor): Logit tensors. Shape `[num_seq, num_label]`. Returns: str: The resulting transcript """ best_path = torch.argmax(logits, dim=-1) # [num_seq,] best_path = torch.unique_consecutive(best_path, dim=-1) hypothesis = "" for i in best_path: char = self.labels[i] if char in ["", ""]: continue if char == "|": char = " " hypothesis += char return hypothesis