Unverified Commit bac32ec1 authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

Add reduction parameter for RNNT loss (#1590)

parent 2376e9c9
...@@ -31,6 +31,7 @@ def compute_with_pytorch_transducer(data, reuse_logits_for_grads=False): ...@@ -31,6 +31,7 @@ def compute_with_pytorch_transducer(data, reuse_logits_for_grads=False):
blank=data["blank"], blank=data["blank"],
fused_log_softmax=data.get("fused_log_softmax", True), fused_log_softmax=data.get("fused_log_softmax", True),
reuse_logits_for_grads=reuse_logits_for_grads, reuse_logits_for_grads=reuse_logits_for_grads,
reduction="none",
)( )(
logits=data["logits"], logits=data["logits"],
logit_lengths=data["logit_lengths"], logit_lengths=data["logit_lengths"],
......
...@@ -16,6 +16,7 @@ def rnnt_loss( ...@@ -16,6 +16,7 @@ def rnnt_loss(
clamp: float = -1, clamp: float = -1,
fused_log_softmax: bool = True, fused_log_softmax: bool = True,
reuse_logits_for_grads: bool = True, reuse_logits_for_grads: bool = True,
reduction: str = "mean",
): ):
"""Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks* """Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks*
[:footcite:`graves2012sequence`]. [:footcite:`graves2012sequence`].
...@@ -31,14 +32,18 @@ def rnnt_loss( ...@@ -31,14 +32,18 @@ def rnnt_loss(
target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
blank (int, opt): blank label (Default: ``-1``) blank (int, opt): blank label (Default: ``-1``)
clamp (float): clamp for gradients (Default: ``-1``) clamp (float): clamp for gradients (Default: ``-1``)
runtime_check (bool): whether to do sanity check during runtime. (Default: ``False``)
fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``) fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``True``) reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``True``)
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
Returns: Returns:
Tensor: Loss with the reduction option applied. If ``reduction`` is ``'none'``, then size (batch), Tensor: Loss with the reduction option applied. If ``reduction`` is ``'none'``, then size (batch),
otherwise scalar. otherwise scalar.
""" """
if reduction not in ['none', 'mean', 'sum']:
raise ValueError("reduction should be one of 'none', 'mean', or 'sum'")
if not fused_log_softmax: if not fused_log_softmax:
logits = torch.nn.functional.log_softmax(logits, dim=-1) logits = torch.nn.functional.log_softmax(logits, dim=-1)
reuse_logits_for_grads = ( reuse_logits_for_grads = (
...@@ -58,6 +63,11 @@ def rnnt_loss( ...@@ -58,6 +63,11 @@ def rnnt_loss(
fused_log_softmax=fused_log_softmax, fused_log_softmax=fused_log_softmax,
reuse_logits_for_grads=reuse_logits_for_grads,) reuse_logits_for_grads=reuse_logits_for_grads,)
if reduction == 'mean':
return costs.mean()
elif reduction == 'sum':
return costs.sum()
return costs return costs
...@@ -74,6 +84,8 @@ class RNNTLoss(torch.nn.Module): ...@@ -74,6 +84,8 @@ class RNNTLoss(torch.nn.Module):
clamp (float): clamp for gradients (Default: ``-1``) clamp (float): clamp for gradients (Default: ``-1``)
fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``) fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``True``) reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``True``)
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
""" """
def __init__( def __init__(
...@@ -82,12 +94,14 @@ class RNNTLoss(torch.nn.Module): ...@@ -82,12 +94,14 @@ class RNNTLoss(torch.nn.Module):
clamp: float = -1., clamp: float = -1.,
fused_log_softmax: bool = True, fused_log_softmax: bool = True,
reuse_logits_for_grads: bool = True, reuse_logits_for_grads: bool = True,
reduction: str = "mean",
): ):
super().__init__() super().__init__()
self.blank = blank self.blank = blank
self.clamp = clamp self.clamp = clamp
self.fused_log_softmax = fused_log_softmax self.fused_log_softmax = fused_log_softmax
self.reuse_logits_for_grads = reuse_logits_for_grads self.reuse_logits_for_grads = reuse_logits_for_grads
self.reduction = reduction
def forward( def forward(
self, self,
...@@ -116,4 +130,5 @@ class RNNTLoss(torch.nn.Module): ...@@ -116,4 +130,5 @@ class RNNTLoss(torch.nn.Module):
self.clamp, self.clamp,
self.fused_log_softmax, self.fused_log_softmax,
self.reuse_logits_for_grads, self.reuse_logits_for_grads,
self.reduction
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment