Commit dc35168d authored by pkufool's avatar pkufool
Browse files

Add delay penalty

parent 23b841cf
...@@ -202,6 +202,7 @@ def rnnt_loss_simple( ...@@ -202,6 +202,7 @@ def rnnt_loss_simple(
termination_symbol: int, termination_symbol: int,
boundary: Optional[Tensor] = None, boundary: Optional[Tensor] = None,
modified: bool = False, modified: bool = False,
delay_penalty: float = 0.0,
reduction: Optional[str] = "mean", reduction: Optional[str] = "mean",
return_grad: bool = False, return_grad: bool = False,
) -> Union[Tensor, Tuple[Tensor, Tuple[Tensor, Tensor]]]: ) -> Union[Tensor, Tuple[Tensor, Tuple[Tensor, Tensor]]]:
...@@ -228,6 +229,10 @@ def rnnt_loss_simple( ...@@ -228,6 +229,10 @@ def rnnt_loss_simple(
Most likely you will want begin_symbol and begin_frame to be zero. Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame. also be consumed, so at most 1 symbol can appear per frame.
delay_penalty: A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time-masking
encouraging the network to delay symbols.
See https://github.com/k2-fsa/k2/issues/955 for more details.
reduction: reduction:
Specifies the reduction to apply to the output: `none`, `mean` or `sum`. Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
`none`: no reduction will be applied. `none`: no reduction will be applied.
...@@ -257,6 +262,22 @@ def rnnt_loss_simple( ...@@ -257,6 +262,22 @@ def rnnt_loss_simple(
boundary=boundary, boundary=boundary,
modified=modified, modified=modified,
) )
if delay_penalty > 0.0:
B, S, T0 = px.shape
T = T0 if modified else T0 - 1
if boundary is None:
offset = torch.tensor(
(T - 1) / 2, dtype=px.dtype, device=px.device,
).expand(B, 1, 1)
else:
offset = (boundary[:, 3] - 1) / 2
penalty = offset.reshape(B, 1, 1) - torch.arange(
T0, device=px.device
).reshape(1, 1, T0)
penalty = penalty * delay_penalty
px += penalty.to(px.dtype)
scores_and_grads = mutual_information_recursion( scores_and_grads = mutual_information_recursion(
px=px, py=py, boundary=boundary, return_grad=return_grad px=px, py=py, boundary=boundary, return_grad=return_grad
) )
...@@ -374,6 +395,7 @@ def rnnt_loss( ...@@ -374,6 +395,7 @@ def rnnt_loss(
termination_symbol: int, termination_symbol: int,
boundary: Optional[Tensor] = None, boundary: Optional[Tensor] = None,
modified: bool = False, modified: bool = False,
delay_penalty: float = 0.0,
reduction: Optional[str] = "mean", reduction: Optional[str] = "mean",
) -> Tensor: ) -> Tensor:
"""A normal RNN-T loss, which uses a 'joiner' network output as input, """A normal RNN-T loss, which uses a 'joiner' network output as input,
...@@ -395,6 +417,10 @@ def rnnt_loss( ...@@ -395,6 +417,10 @@ def rnnt_loss(
Most likely you will want begin_symbol and begin_frame to be zero. Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame. also be consumed, so at most 1 symbol can appear per frame.
delay_penalty: A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time-masking
encouraging the network to delay symbols.
See https://github.com/k2-fsa/k2/issues/955 for more details.
reduction: reduction:
Specifies the reduction to apply to the output: `none`, `mean` or `sum`. Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
`none`: no reduction will be applied. `none`: no reduction will be applied.
...@@ -414,6 +440,20 @@ def rnnt_loss( ...@@ -414,6 +440,20 @@ def rnnt_loss(
boundary=boundary, boundary=boundary,
modified=modified, modified=modified,
) )
if delay_penalty > 0.0:
B, S, T0 = px.shape
T = T0 if modified else T0 - 1
if boundary is None:
offset = torch.tensor(
(T - 1) / 2, dtype=px.dtype, device=px.device,
).expand(B, 1, 1)
else:
offset = (boundary[:, 3] - 1) / 2
penalty = offset.reshape(B, 1, 1) - torch.arange(
T0, device=px.device
).reshape(1, 1, T0)
penalty = penalty * delay_penalty
px += penalty.to(px.dtype)
negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary) negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary)
if reduction == "none": if reduction == "none":
return -negated_loss return -negated_loss
...@@ -480,6 +520,7 @@ def _adjust_pruning_lower_bound( ...@@ -480,6 +520,7 @@ def _adjust_pruning_lower_bound(
) )
return s_begin return s_begin
# To get more insight of how we calculate pruning bounds, please read # To get more insight of how we calculate pruning bounds, please read
# chapter 3.2 (Pruning bounds) of our Pruned RNN-T paper # chapter 3.2 (Pruning bounds) of our Pruned RNN-T paper
# (https://arxiv.org/pdf/2206.13236.pdf) # (https://arxiv.org/pdf/2206.13236.pdf)
...@@ -818,6 +859,7 @@ def rnnt_loss_pruned( ...@@ -818,6 +859,7 @@ def rnnt_loss_pruned(
termination_symbol: int, termination_symbol: int,
boundary: Tensor = None, boundary: Tensor = None,
modified: bool = False, modified: bool = False,
delay_penalty: float = 0.0,
reduction: Optional[str] = "mean", reduction: Optional[str] = "mean",
) -> Tensor: ) -> Tensor:
"""A RNN-T loss with pruning, which uses a pruned 'joiner' network output """A RNN-T loss with pruning, which uses a pruned 'joiner' network output
...@@ -842,6 +884,10 @@ def rnnt_loss_pruned( ...@@ -842,6 +884,10 @@ def rnnt_loss_pruned(
Most likely you will want begin_symbol and begin_frame to be zero. Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame. also be consumed, so at most 1 symbol can appear per frame.
delay_penalty: A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time-masking
encouraging the network to delay symbols.
See https://github.com/k2-fsa/k2/issues/955 for more details.
reduction: reduction:
Specifies the reduction to apply to the output: `none`, `mean` or `sum`. Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
`none`: no reduction will be applied. `none`: no reduction will be applied.
...@@ -861,6 +907,20 @@ def rnnt_loss_pruned( ...@@ -861,6 +907,20 @@ def rnnt_loss_pruned(
boundary=boundary, boundary=boundary,
modified=modified, modified=modified,
) )
if delay_penalty > 0.0:
B, S, T0 = px.shape
T = T0 if modified else T0 - 1
if boundary is None:
offset = torch.tensor(
(T - 1) / 2, dtype=px.dtype, device=px.device,
).expand(B, 1, 1)
else:
offset = (boundary[:, 3] - 1) / 2
penalty = offset.reshape(B, 1, 1) - torch.arange(
T0, device=px.device
).reshape(1, 1, T0)
penalty = penalty * delay_penalty
px += penalty.to(px.dtype)
negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary) negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary)
if reduction == "none": if reduction == "none":
return -negated_loss return -negated_loss
...@@ -1105,6 +1165,7 @@ def rnnt_loss_smoothed( ...@@ -1105,6 +1165,7 @@ def rnnt_loss_smoothed(
am_only_scale: float = 0.1, am_only_scale: float = 0.1,
boundary: Optional[Tensor] = None, boundary: Optional[Tensor] = None,
modified: bool = False, modified: bool = False,
delay_penalty: float = 0.0,
reduction: Optional[str] = "mean", reduction: Optional[str] = "mean",
return_grad: bool = False, return_grad: bool = False,
) -> Union[Tuple[Tensor, Tuple[Tensor, Tensor]], Tensor]: ) -> Union[Tuple[Tensor, Tuple[Tensor, Tensor]], Tensor]:
...@@ -1138,6 +1199,10 @@ def rnnt_loss_smoothed( ...@@ -1138,6 +1199,10 @@ def rnnt_loss_smoothed(
Most likely you will want begin_symbol and begin_frame to be zero. Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame. also be consumed, so at most 1 symbol can appear per frame.
delay_penalty: A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time-masking
encouraging the network to delay symbols.
See https://github.com/k2-fsa/k2/issues/955 for more details.
reduction: reduction:
Specifies the reduction to apply to the output: `none`, `mean` or `sum`. Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
`none`: no reduction will be applied. `none`: no reduction will be applied.
...@@ -1170,6 +1235,20 @@ def rnnt_loss_smoothed( ...@@ -1170,6 +1235,20 @@ def rnnt_loss_smoothed(
boundary=boundary, boundary=boundary,
modified=modified, modified=modified,
) )
if delay_penalty > 0.0:
B, S, T0 = px.shape
T = T0 if modified else T0 - 1
if boundary is None:
offset = torch.tensor(
(T - 1) / 2, dtype=px.dtype, device=px.device,
).expand(B, 1, 1)
else:
offset = (boundary[:, 3] - 1) / 2
penalty = offset.reshape(B, 1, 1) - torch.arange(
T0, device=px.device
).reshape(1, 1, T0)
penalty = penalty * delay_penalty
px += penalty.to(px.dtype)
scores_and_grads = mutual_information_recursion( scores_and_grads = mutual_information_recursion(
px=px, py=py, boundary=boundary, return_grad=return_grad px=px, py=py, boundary=boundary, return_grad=return_grad
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment