from torch import topk class GreedyDecoder: def __call__(self, outputs): """Greedy Decoder. Returns highest probability of class labels for each timestep Args: outputs (torch.Tensor): shape (input length, batch size, number of classes (including blank)) Returns: torch.Tensor: class labels per time step. """ _, indices = topk(outputs, k=1, dim=-1) return indices[..., 0]