Unverified Commit af652ca6 authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

Improve RNNT Loss docstrings (#1620)

parent d74d0604
...@@ -24,12 +24,13 @@ def rnnt_loss( ...@@ -24,12 +24,13 @@ def rnnt_loss(
dependencies. dependencies.
Args: Args:
logits (Tensor): Tensor of dimension (batch, time, target, class) containing output from joiner logits (Tensor): Tensor of dimension (batch, max seq length, max target length + 1, class)
containing output from joiner
targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded
logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder
target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
blank (int, opt): blank label (Default: ``-1``) blank (int, optional): blank label (Default: ``-1``)
clamp (float): clamp for gradients (Default: ``-1``) clamp (float, optional): clamp for gradients (Default: ``-1``)
reduction (string, optional): Specifies the reduction to apply to the output: reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``) ``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
...@@ -69,8 +70,8 @@ class RNNTLoss(torch.nn.Module): ...@@ -69,8 +70,8 @@ class RNNTLoss(torch.nn.Module):
dependencies. dependencies.
Args: Args:
blank (int, opt): blank label (Default: ``-1``) blank (int, optional): blank label (Default: ``-1``)
clamp (float): clamp for gradients (Default: ``-1``) clamp (float, optional): clamp for gradients (Default: ``-1``)
reduction (string, optional): Specifies the reduction to apply to the output: reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``) ``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
""" """
...@@ -95,7 +96,8 @@ class RNNTLoss(torch.nn.Module): ...@@ -95,7 +96,8 @@ class RNNTLoss(torch.nn.Module):
): ):
""" """
Args: Args:
logits (Tensor): Tensor of dimension (batch, time, target, class) containing output from joiner logits (Tensor): Tensor of dimension (batch, max seq length, max target length + 1, class)
containing output from joiner
targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded
logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder
target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment