Unverified Commit 6a8ecd98 authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

Add return type in doc for RNNT loss (#1591)

parent 89807cf7
......@@ -34,6 +34,10 @@ def rnnt_loss(
runtime_check (bool): whether to do sanity check during runtime. (Default: ``False``)
fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``True``)
Returns:
Tensor: Loss with the reduction option applied. If ``reduction`` is ``'none'``, then size (batch),
otherwise scalar.
"""
if not fused_log_softmax:
logits = torch.nn.functional.log_softmax(logits, dim=-1)
......@@ -98,6 +102,10 @@ class RNNTLoss(torch.nn.Module):
targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded
logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder
target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
Returns:
Tensor: Loss with the reduction option applied. If ``reduction`` is ``'none'``, then size (batch),
otherwise scalar.
"""
return rnnt_loss(
logits,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment