Unverified Commit af652ca6 authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

Improve RNNT Loss docstrings (#1620)

parent d74d0604
......@@ -24,12 +24,13 @@ def rnnt_loss(
dependencies.
Args:
logits (Tensor): Tensor of dimension (batch, time, target, class) containing output from joiner
logits (Tensor): Tensor of dimension (batch, max seq length, max target length + 1, class)
containing output from joiner
targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded
logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder
target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
blank (int, opt): blank label (Default: ``-1``)
clamp (float): clamp for gradients (Default: ``-1``)
blank (int, optional): blank label (Default: ``-1``)
clamp (float, optional): clamp for gradients (Default: ``-1``)
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
......@@ -69,8 +70,8 @@ class RNNTLoss(torch.nn.Module):
dependencies.
Args:
blank (int, opt): blank label (Default: ``-1``)
clamp (float): clamp for gradients (Default: ``-1``)
blank (int, optional): blank label (Default: ``-1``)
clamp (float, optional): clamp for gradients (Default: ``-1``)
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
"""
......@@ -95,7 +96,8 @@ class RNNTLoss(torch.nn.Module):
):
"""
Args:
logits (Tensor): Tensor of dimension (batch, time, target, class) containing output from joiner
logits (Tensor): Tensor of dimension (batch, max seq length, max target length + 1, class)
containing output from joiner
targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded
logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder
target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment