Commit b0ed23ef authored by pkufool's avatar pkufool
Browse files

Add constrained rnnt

parent dc35168d
...@@ -26,13 +26,13 @@ from .mutual_information import mutual_information_recursion ...@@ -26,13 +26,13 @@ from .mutual_information import mutual_information_recursion
def fix_for_boundary(px: Tensor, boundary: Optional[Tensor] = None) -> Tensor: def fix_for_boundary(px: Tensor, boundary: Optional[Tensor] = None) -> Tensor:
""" """
Insert -inf's into `px` in appropriate places if `boundary` is not Insert -inf's into `px` in appropriate places if `boundary` is not
None. If boundary == None and modified == False, px[:,:,-1] will None. If boundary == None and rnnt_type == "regular", px[:,:,-1] will
be -infinity, but if boundary is specified, we need px[b,:,boundary[b,3]] be -infinity, but if boundary is specified, we need px[b,:,boundary[b,3]]
to be -infinity. to be -infinity.
Args: Args:
px: a Tensor of of shape [B][S][T+1] (this function is only px: a Tensor of of shape [B][S][T+1] (this function is only
called if modified == False, see other docs for `modified`) called if rnnt_type == "regular", see other docs for `rnnt_type`)
px is modified in-place and returned. px is modified in-place and returned.
boundary: None, or a Tensor of shape [B][3] containing boundary: None, or a Tensor of shape [B][3] containing
[s_begin, t_begin, s_end, t_end]; we need only t_end. [s_begin, t_begin, s_end, t_end]; we need only t_end.
...@@ -49,8 +49,8 @@ def get_rnnt_logprobs( ...@@ -49,8 +49,8 @@ def get_rnnt_logprobs(
am: Tensor, am: Tensor,
symbols: Tensor, symbols: Tensor,
termination_symbol: int, termination_symbol: int,
rnnt_type: str = "regular",
boundary: Optional[Tensor] = None, boundary: Optional[Tensor] = None,
modified: bool = False,
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
""" """
Reduces RNN-T problem (the simple case, where joiner network is just Reduces RNN-T problem (the simple case, where joiner network is just
...@@ -97,20 +97,32 @@ def get_rnnt_logprobs( ...@@ -97,20 +97,32 @@ def get_rnnt_logprobs(
[0, 0, S, T] [0, 0, S, T]
if boundary is not supplied. if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero. Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will rnnt_type:
also be consumed, so at most 1 symbol can appear per frame. Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
Returns: Returns:
(px, py) (the names are quite arbitrary). (px, py) (the names are quite arbitrary).
px: logprobs, of shape [B][S][T+1] if !modified, [B][S][T] if modified. px: logprobs, of shape [B][S][T+1] if rnnt_type is regular,
[B][S][T] if rnnt_type is not regular.
py: logprobs, of shape [B][S+1][T] py: logprobs, of shape [B][S+1][T]
in the recursion:: in the recursion::
p[b,0,0] = 0.0 p[b,0,0] = 0.0
if !modified: if rnnt_type == "regular":
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1]) p[b,s,t-1] + py[b,s,t-1])
if modified: if rnnt_type != "regular":
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1], p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1]) p[b,s,t-1] + py[b,s,t-1])
.. where p[b][s][t] is the "joint score" of the pair of subsequences .. where p[b][s][t] is the "joint score" of the pair of subsequences
...@@ -121,21 +133,22 @@ def get_rnnt_logprobs( ...@@ -121,21 +133,22 @@ def get_rnnt_logprobs(
(s,t) by one in the t direction, (s,t) by one in the t direction,
i.e. of emitting the termination/next-frame symbol. i.e. of emitting the termination/next-frame symbol.
if !modified, px[:,:,T] equals -infinity, meaning on the if rnnt_type == "regular", px[:,:,T] equals -infinity, meaning on the
"one-past-the-last" frame we cannot emit any symbols. "one-past-the-last" frame we cannot emit any symbols.
This is simply a way of incorporating This is simply a way of incorporating
the probability of the termination symbol on the last frame. the probability of the termination symbol on the last frame.
""" """
assert lm.ndim == 3 assert lm.ndim == 3, lm.ndim
assert am.ndim == 3 assert am.ndim == 3, am.ndim
assert lm.shape[0] == am.shape[0] assert lm.shape[0] == am.shape[0], (lm.shape[0], am.shape[0])
assert lm.shape[2] == am.shape[2] assert lm.shape[2] == am.shape[2], (lm.shape[2], am.shape[2])
(B, T, C) = am.shape (B, T, C) = am.shape
S = lm.shape[1] - 1 S = lm.shape[1] - 1
assert symbols.shape == (B, S) assert symbols.shape == (B, S), symbols.shape
assert S >= 1 assert S >= 1, S
assert T >= S assert T >= S, (T, S)
assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type
# subtracting am_max and lm_max is to ensure the probs are in a good range # subtracting am_max and lm_max is to ensure the probs are in a good range
# to do exp() without causing underflow or overflow. # to do exp() without causing underflow or overflow.
...@@ -162,7 +175,7 @@ def get_rnnt_logprobs( ...@@ -162,7 +175,7 @@ def get_rnnt_logprobs(
-1 -1
) # [B][S][T] ) # [B][S][T]
if not modified: if rnnt_type == "regular":
px_am = torch.cat( px_am = torch.cat(
( (
px_am, px_am,
...@@ -189,8 +202,10 @@ def get_rnnt_logprobs( ...@@ -189,8 +202,10 @@ def get_rnnt_logprobs(
py_lm = lm[:, :, termination_symbol].unsqueeze(2) # [B][S+1][1] py_lm = lm[:, :, termination_symbol].unsqueeze(2) # [B][S+1][1]
py = py_am + py_lm - normalizers py = py_am + py_lm - normalizers
if not modified: if rnnt_type == "regular":
px = fix_for_boundary(px, boundary) px = fix_for_boundary(px, boundary)
elif rnnt_type == "constrained":
px += py[:, 1:, :]
return (px, py) return (px, py)
...@@ -201,7 +216,7 @@ def rnnt_loss_simple( ...@@ -201,7 +216,7 @@ def rnnt_loss_simple(
symbols: Tensor, symbols: Tensor,
termination_symbol: int, termination_symbol: int,
boundary: Optional[Tensor] = None, boundary: Optional[Tensor] = None,
modified: bool = False, rnnt_type: str = "regular",
delay_penalty: float = 0.0, delay_penalty: float = 0.0,
reduction: Optional[str] = "mean", reduction: Optional[str] = "mean",
return_grad: bool = False, return_grad: bool = False,
...@@ -227,8 +242,19 @@ def rnnt_loss_simple( ...@@ -227,8 +242,19 @@ def rnnt_loss_simple(
[0, 0, S, T] [0, 0, S, T]
if boundary is not supplied. if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero. Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will rnnt_type:
also be consumed, so at most 1 symbol can appear per frame. Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
delay_penalty: A constant value to penalize symbol delay, this may be delay_penalty: A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time-masking needed when training with time masking, to avoid the time-masking
encouraging the network to delay symbols. encouraging the network to delay symbols.
...@@ -260,12 +286,12 @@ def rnnt_loss_simple( ...@@ -260,12 +286,12 @@ def rnnt_loss_simple(
symbols=symbols, symbols=symbols,
termination_symbol=termination_symbol, termination_symbol=termination_symbol,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
) )
if delay_penalty > 0.0: if delay_penalty > 0.0:
B, S, T0 = px.shape B, S, T0 = px.shape
T = T0 if modified else T0 - 1 T = T0 if rnnt_type != "regular" else T0 - 1
if boundary is None: if boundary is None:
offset = torch.tensor( offset = torch.tensor(
(T - 1) / 2, dtype=px.dtype, device=px.device, (T - 1) / 2, dtype=px.dtype, device=px.device,
...@@ -289,9 +315,9 @@ def rnnt_loss_simple( ...@@ -289,9 +315,9 @@ def rnnt_loss_simple(
elif reduction == "sum": elif reduction == "sum":
loss = -torch.sum(negated_loss) loss = -torch.sum(negated_loss)
else: else:
assert ( raise ValueError(
False f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}"
), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}" )
return (loss, scores_and_grads[1]) if return_grad else loss return (loss, scores_and_grads[1]) if return_grad else loss
...@@ -300,7 +326,7 @@ def get_rnnt_logprobs_joint( ...@@ -300,7 +326,7 @@ def get_rnnt_logprobs_joint(
symbols: Tensor, symbols: Tensor,
termination_symbol: int, termination_symbol: int,
boundary: Optional[Tensor] = None, boundary: Optional[Tensor] = None,
modified: bool = False, rnnt_type: str = "regular",
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
"""Reduces RNN-T problem to a compact, standard form that can then be given """Reduces RNN-T problem to a compact, standard form that can then be given
(with boundaries) to mutual_information_recursion(). (with boundaries) to mutual_information_recursion().
...@@ -321,21 +347,33 @@ def get_rnnt_logprobs_joint( ...@@ -321,21 +347,33 @@ def get_rnnt_logprobs_joint(
[0, 0, S, T] [0, 0, S, T]
if boundary is not supplied. if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero. Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will rnnt_type:
also be consumed, so at most 1 symbol can appear per frame. Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
Returns: Returns:
(px, py) (the names are quite arbitrary):: (px, py) (the names are quite arbitrary)::
px: logprobs, of shape [B][S][T+1] px: logprobs, of shape [B][S][T+1] if rnnt_type is regular,
[B][S][T] if rnnt_type is not regular.
py: logprobs, of shape [B][S+1][T] py: logprobs, of shape [B][S+1][T]
in the recursion:: in the recursion::
p[b,0,0] = 0.0 p[b,0,0] = 0.0
if !modified: if rnnt_type == "regular":
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1]) p[b,s,t-1] + py[b,s,t-1])
if modified: if rnnt_type != "regular":
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1], p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1]) p[b,s,t-1] + py[b,s,t-1])
.. where p[b][s][t] is the "joint score" of the pair of subsequences of .. where p[b][s][t] is the "joint score" of the pair of subsequences of
...@@ -345,17 +383,18 @@ def get_rnnt_logprobs_joint( ...@@ -345,17 +383,18 @@ def get_rnnt_logprobs_joint(
of extending the subsequences of length (s,t) by one in the t direction, of extending the subsequences of length (s,t) by one in the t direction,
i.e. of emitting the termination/next-frame symbol. i.e. of emitting the termination/next-frame symbol.
if !modified, px[:,:,T] equals -infinity, meaning on the if rnnt_type == "regular", px[:,:,T] equals -infinity, meaning on the
"one-past-the-last" frame we cannot emit any symbols. "one-past-the-last" frame we cannot emit any symbols.
This is simply a way of incorporating This is simply a way of incorporating
the probability of the termination symbol on the last frame. the probability of the termination symbol on the last frame.
""" """
assert logits.ndim == 4 assert logits.ndim == 4, logits.ndim
(B, T, S1, C) = logits.shape (B, T, S1, C) = logits.shape
S = S1 - 1 S = S1 - 1
assert symbols.shape == (B, S) assert symbols.shape == (B, S), symbols.shape
assert S >= 1 assert S >= 1, S
assert T >= S assert T >= S, (T, S)
assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type
normalizers = torch.logsumexp(logits, dim=3) normalizers = torch.logsumexp(logits, dim=3)
normalizers = normalizers.permute((0, 2, 1)) normalizers = normalizers.permute((0, 2, 1))
...@@ -365,7 +404,7 @@ def get_rnnt_logprobs_joint( ...@@ -365,7 +404,7 @@ def get_rnnt_logprobs_joint(
).squeeze(-1) ).squeeze(-1)
px = px.permute((0, 2, 1)) px = px.permute((0, 2, 1))
if not modified: if rnnt_type == "regular":
px = torch.cat( px = torch.cat(
( (
px, px,
...@@ -383,8 +422,10 @@ def get_rnnt_logprobs_joint( ...@@ -383,8 +422,10 @@ def get_rnnt_logprobs_joint(
) # [B][S+1][T] ) # [B][S+1][T]
py -= normalizers py -= normalizers
if not modified: if rnnt_type == "regular":
px = fix_for_boundary(px, boundary) px = fix_for_boundary(px, boundary)
elif rnnt_type == "constrained":
px += py[:, 1:, :]
return (px, py) return (px, py)
...@@ -394,7 +435,7 @@ def rnnt_loss( ...@@ -394,7 +435,7 @@ def rnnt_loss(
symbols: Tensor, symbols: Tensor,
termination_symbol: int, termination_symbol: int,
boundary: Optional[Tensor] = None, boundary: Optional[Tensor] = None,
modified: bool = False, rnnt_type: str = "regular",
delay_penalty: float = 0.0, delay_penalty: float = 0.0,
reduction: Optional[str] = "mean", reduction: Optional[str] = "mean",
) -> Tensor: ) -> Tensor:
...@@ -415,8 +456,19 @@ def rnnt_loss( ...@@ -415,8 +456,19 @@ def rnnt_loss(
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as [begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T] if boundary is not supplied. [0, 0, S, T] if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero. Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will rnnt_type:
also be consumed, so at most 1 symbol can appear per frame. Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
delay_penalty: A constant value to penalize symbol delay, this may be delay_penalty: A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time-masking needed when training with time masking, to avoid the time-masking
encouraging the network to delay symbols. encouraging the network to delay symbols.
...@@ -438,11 +490,12 @@ def rnnt_loss( ...@@ -438,11 +490,12 @@ def rnnt_loss(
symbols=symbols, symbols=symbols,
termination_symbol=termination_symbol, termination_symbol=termination_symbol,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
) )
if delay_penalty > 0.0: if delay_penalty > 0.0:
B, S, T0 = px.shape B, S, T0 = px.shape
T = T0 if modified else T0 - 1 T = T0 if rnnt_type != "regular" else T0 - 1
if boundary is None: if boundary is None:
offset = torch.tensor( offset = torch.tensor(
(T - 1) / 2, dtype=px.dtype, device=px.device, (T - 1) / 2, dtype=px.dtype, device=px.device,
...@@ -454,6 +507,7 @@ def rnnt_loss( ...@@ -454,6 +507,7 @@ def rnnt_loss(
).reshape(1, 1, T0) ).reshape(1, 1, T0)
penalty = penalty * delay_penalty penalty = penalty * delay_penalty
px += penalty.to(px.dtype) px += penalty.to(px.dtype)
negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary) negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary)
if reduction == "none": if reduction == "none":
return -negated_loss return -negated_loss
...@@ -462,30 +516,30 @@ def rnnt_loss( ...@@ -462,30 +516,30 @@ def rnnt_loss(
elif reduction == "sum": elif reduction == "sum":
return -torch.sum(negated_loss) return -torch.sum(negated_loss)
else: else:
assert ( raise ValueError(
False f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}"
), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}" )
def _adjust_pruning_lower_bound( def _adjust_pruning_lower_bound(
s_begin: torch.Tensor, s_range: int s_begin: torch.Tensor, s_range: int
) -> torch.Tensor: ) -> torch.Tensor:
"""Adjust s_begin (pruning lower bound) to make it satisfied the following """Adjust s_begin (pruning lower bounds) to make it satisfy the following
constrains constraints
- monotonic increasing, i.e. s_begin[i] <= s_begin[i + 1] - monotonic increasing, i.e. s_begin[i] <= s_begin[i + 1]
- start with symbol 0 at first frame. - start with symbol 0 at first frame.
- s_begin[i + 1] - s_begin[i] < s_range, whicn means that we can't skip - s_begin[i + 1] - s_begin[i] < s_range, which means that we can't skip
any symbols. any symbols.
To make it monotonic increasing, we can use `monotonic_lower_bound` function To make it monotonic increasing, we can use `monotonic_lower_bound` function
in k2, which guarantee `s_begin[i] <= s_begin[i + 1]`. The main idea is: in k2, which guarantees `s_begin[i] <= s_begin[i + 1]`. The main idea is:
traverse the array in reverse order and update the elements by traverse the array in reverse order and update the elements by
`min_value = min(a_begin[i], min_value)`, the initial `min_value` set to `min_value = min(a_begin[i], min_value)`, the initial `min_value` is set to
`inf`. `inf`.
The method we used to realize `s_begin[i + 1] - s_begin[i] < s_range` The method we used to realize `s_begin[i + 1] - s_begin[i] < s_range`
constrain is a little tricky. We first transform `s_begin` with constraint is a little tricky. We first transform `s_begin` with
`s_begin = -(s_begin - (s_range - 1) * torch.arange(0,T))` `s_begin = -(s_begin - (s_range - 1) * torch.arange(0,T))`
then we make the transformed `s_begin` monotonic increasing, after that, then we make the transformed `s_begin` monotonic increasing, after that,
we transform back `s_begin` with the same formula as the previous we transform back `s_begin` with the same formula as the previous
...@@ -551,9 +605,9 @@ def get_rnnt_prune_ranges( ...@@ -551,9 +605,9 @@ def get_rnnt_prune_ranges(
Note: Note:
For the generated tensor ranges(assuming batch size is 1), ranges[:, 0] For the generated tensor ranges(assuming batch size is 1), ranges[:, 0]
is a monotonic increasing tensor from 0 to `len(symbols)` and it satisfies is a monotonic increasing tensor from 0 to `len(symbols) - s_range` and
`ranges[t+1, 0] - ranges[t, 0] < s_range` which means we won't skip any it satisfies `ranges[t+1, 0] - ranges[t, 0] < s_range` which means we
symbols. won't skip any symbols.
Args: Args:
px_grad: px_grad:
...@@ -568,21 +622,21 @@ def get_rnnt_prune_ranges( ...@@ -568,21 +622,21 @@ def get_rnnt_prune_ranges(
s_range: s_range:
How many symbols to keep for each frame. How many symbols to keep for each frame.
Returns: Returns:
A tensor contains the kept symbols indexes for each frame, with shape A tensor with the shape of (B, T, s_range) containing the indexes of the
(B, T, s_range). kept symbols for each frame.
""" """
(B, S, T1) = px_grad.shape (B, S, T1) = px_grad.shape
T = py_grad.shape[-1] T = py_grad.shape[-1]
assert T1 in [T, T + 1] assert T1 in [T, T + 1], T1
S1 = S + 1 S1 = S + 1
assert py_grad.shape == (B, S + 1, T) assert py_grad.shape == (B, S + 1, T), py_grad.shape
assert boundary.shape == (B, 4) assert boundary.shape == (B, 4), boundary.shape
assert S >= 1 assert S >= 1, S
assert T >= S assert T >= S, (T, S)
# s_range > S means we won't prune out any symbols. To make indexing with # s_range > S means we won't prune out any symbols. To make indexing with
# ranges runs normally, s_range should be equal to or less than ``S + 1``. # ranges run normally, s_range should be equal to or less than ``S + 1``.
if s_range > S: if s_range > S:
s_range = S + 1 s_range = S + 1
...@@ -630,16 +684,17 @@ def get_rnnt_prune_ranges( ...@@ -630,16 +684,17 @@ def get_rnnt_prune_ranges(
mask = mask < boundary[:, 3].reshape(B, 1) - 1 mask = mask < boundary[:, 3].reshape(B, 1) - 1
s_begin_padding = boundary[:, 2].reshape(B, 1) - s_range + 1 s_begin_padding = boundary[:, 2].reshape(B, 1) - s_range + 1
# handle the cases when `len(symbols) < s_range` # handle the cases where `len(symbols) < s_range`
s_begin_padding = torch.clamp(s_begin_padding, min=0) s_begin_padding = torch.clamp(s_begin_padding, min=0)
s_begin = torch.where(mask, s_begin, s_begin_padding) s_begin = torch.where(mask, s_begin, s_begin_padding)
# adjusting lower bound to make it satisfied some constrains, see docs in # adjusting lower bound to make it satisfy some constraints, see docs in
# `_adjust_pruning_lower_bound` for more details of these constrains. # `_adjust_pruning_lower_bound` for more details of these constraints.
# T1 == T here means we are using the modified version of transducer, # T1 == T here means we are using the non-regular(i.e. modified rnnt or
# the third constrain becomes `s_begin[i + 1] - s_begin[i] < 2`, because # constrained rnnt) version of transducer, the third constraint becomes
# it only emits one symbol per frame. # `s_begin[i + 1] - s_begin[i] < 2`, because it only emits one symbol per
# frame.
s_begin = _adjust_pruning_lower_bound(s_begin, 2 if T1 == T else s_range) s_begin = _adjust_pruning_lower_bound(s_begin, 2 if T1 == T else s_range)
ranges = s_begin.reshape((B, T, 1)).expand((B, T, s_range)) + torch.arange( ranges = s_begin.reshape((B, T, 1)).expand((B, T, s_range)) + torch.arange(
...@@ -652,8 +707,8 @@ def get_rnnt_prune_ranges( ...@@ -652,8 +707,8 @@ def get_rnnt_prune_ranges(
def do_rnnt_pruning( def do_rnnt_pruning(
am: torch.Tensor, lm: torch.Tensor, ranges: torch.Tensor am: torch.Tensor, lm: torch.Tensor, ranges: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Prune the output of encoder(am) output and prediction network(lm) """Prune the output of encoder(am) and prediction network(lm) with ranges
output of RNNT. generated by `get_rnnt_prune_ranges`.
Args: Args:
am: am:
...@@ -671,9 +726,9 @@ def do_rnnt_pruning( ...@@ -671,9 +726,9 @@ def do_rnnt_pruning(
# am (B, T, C) # am (B, T, C)
# lm (B, S + 1, C) # lm (B, S + 1, C)
# ranges (B, T, s_range) # ranges (B, T, s_range)
assert ranges.shape[0] == am.shape[0] assert ranges.shape[0] == am.shape[0], (ranges.shape[0], am.shape[0])
assert ranges.shape[0] == lm.shape[0] assert ranges.shape[0] == lm.shape[0], (ranges.shape[0], lm.shape[0])
assert am.shape[1] == ranges.shape[1] assert am.shape[1] == ranges.shape[1], (am.shape[1], ranges.shape[1])
(B, T, s_range) = ranges.shape (B, T, s_range) = ranges.shape
(B, S1, C) = lm.shape (B, S1, C) = lm.shape
S = S1 - 1 S = S1 - 1
...@@ -711,9 +766,9 @@ def _roll_by_shifts(src: torch.Tensor, shifts: torch.LongTensor): ...@@ -711,9 +766,9 @@ def _roll_by_shifts(src: torch.Tensor, shifts: torch.LongTensor):
[ 8, 9, 5, 6, 7], [ 8, 9, 5, 6, 7],
[12, 13, 14, 10, 11]]]) [12, 13, 14, 10, 11]]])
""" """
assert src.dim() == 3 assert src.dim() == 3, src.dim()
(B, T, S) = src.shape (B, T, S) = src.shape
assert shifts.shape == (B, T) assert shifts.shape == (B, T), shifts.shape
index = ( index = (
torch.arange(S, device=src.device) torch.arange(S, device=src.device)
...@@ -731,7 +786,7 @@ def get_rnnt_logprobs_pruned( ...@@ -731,7 +786,7 @@ def get_rnnt_logprobs_pruned(
ranges: Tensor, ranges: Tensor,
termination_symbol: int, termination_symbol: int,
boundary: Tensor, boundary: Tensor,
modified: bool = False, rnnt_type: str = "regular",
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
"""Construct px, py for mutual_information_recursion with pruned output. """Construct px, py for mutual_information_recursion with pruned output.
...@@ -751,21 +806,53 @@ def get_rnnt_logprobs_pruned( ...@@ -751,21 +806,53 @@ def get_rnnt_logprobs_pruned(
[0, 0, S, T] [0, 0, S, T]
if boundary is not supplied. if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero. Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will rnnt_type:
also be consumed, so at most 1 symbol can appear per frame. Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
Returns: Returns:
Return the px (B, S, T) if modified else (B, S, T + 1) and (px, py) (the names are quite arbitrary)::
py (B, S + 1, T) needed by mutual_information_recursion. px: logprobs, of shape [B][S][T+1] if rnnt_type is regular,
[B][S][T] if rnnt_type is not regular.
py: logprobs, of shape [B][S+1][T]
in the recursion::
p[b,0,0] = 0.0
if rnnt_type == "regular":
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
if rnnt_type != "regular":
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1])
.. where p[b][s][t] is the "joint score" of the pair of subsequences of
length s and t respectively. px[b][s][t] represents the probability of
extending the subsequences of length (s,t) by one in the s direction,
given the particular symbol, and py[b][s][t] represents the probability
of extending the subsequences of length (s,t) by one in the t direction,
i.e. of emitting the termination/next-frame symbol.
if `rnnt_type == "regular"`, px[:,:,T] equals -infinity, meaning on the
"one-past-the-last" frame we cannot emit any symbols.
This is simply a way of incorporating
the probability of the termination symbol on the last frame.
""" """
# logits (B, T, s_range, C) # logits (B, T, s_range, C)
# symbols (B, S) # symbols (B, S)
# ranges (B, T, s_range) # ranges (B, T, s_range)
assert logits.ndim == 4 assert logits.ndim == 4, logits.ndim
(B, T, s_range, C) = logits.shape (B, T, s_range, C) = logits.shape
assert ranges.shape == (B, T, s_range) assert ranges.shape == (B, T, s_range), ranges.shape
(B, S) = symbols.shape (B, S) = symbols.shape
assert S >= 1 assert S >= 1, S
assert T >= S assert T >= S, (T, S)
assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type
normalizers = torch.logsumexp(logits, dim=3) normalizers = torch.logsumexp(logits, dim=3)
...@@ -813,7 +900,7 @@ def get_rnnt_logprobs_pruned( ...@@ -813,7 +900,7 @@ def get_rnnt_logprobs_pruned(
px = px.permute((0, 2, 1)) px = px.permute((0, 2, 1))
if not modified: if rnnt_type == "regular":
px = torch.cat( px = torch.cat(
( (
px, px,
...@@ -846,8 +933,10 @@ def get_rnnt_logprobs_pruned( ...@@ -846,8 +933,10 @@ def get_rnnt_logprobs_pruned(
# (B, S + 1, T) # (B, S + 1, T)
py = py.permute((0, 2, 1)) py = py.permute((0, 2, 1))
if not modified: if rnnt_type == "regular":
px = fix_for_boundary(px, boundary) px = fix_for_boundary(px, boundary)
elif rnnt_type == "constrained":
px += py[:, 1:, :]
return (px, py) return (px, py)
...@@ -858,13 +947,13 @@ def rnnt_loss_pruned( ...@@ -858,13 +947,13 @@ def rnnt_loss_pruned(
ranges: Tensor, ranges: Tensor,
termination_symbol: int, termination_symbol: int,
boundary: Tensor = None, boundary: Tensor = None,
modified: bool = False, rnnt_type: str = "regular",
delay_penalty: float = 0.0, delay_penalty: float = 0.0,
reduction: Optional[str] = "mean", reduction: Optional[str] = "mean",
) -> Tensor: ) -> Tensor:
"""A RNN-T loss with pruning, which uses a pruned 'joiner' network output """A RNN-T loss with pruning, which uses the output of a pruned 'joiner'
as input, i.e. a 4 dimensions tensor with shape (B, T, s_range, C), network as input, i.e. a 4 dimensions tensor with shape (B, T, s_range, C),
s_range means the symbols number kept for each frame. s_range means the number of symbols kept for each frame.
Args: Args:
logits: logits:
...@@ -882,8 +971,19 @@ def rnnt_loss_pruned( ...@@ -882,8 +971,19 @@ def rnnt_loss_pruned(
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as [begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T] if boundary is not supplied. [0, 0, S, T] if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero. Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will rnnt_type:
also be consumed, so at most 1 symbol can appear per frame. Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
delay_penalty: A constant value to penalize symbol delay, this may be delay_penalty: A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time-masking needed when training with time masking, to avoid the time-masking
encouraging the network to delay symbols. encouraging the network to delay symbols.
...@@ -895,8 +995,8 @@ def rnnt_loss_pruned( ...@@ -895,8 +995,8 @@ def rnnt_loss_pruned(
`sum`: the output will be summed. `sum`: the output will be summed.
Default: `mean` Default: `mean`
Returns: Returns:
If recursion is `none`, returns a tensor of shape (B,), containing the If reduction is `none`, returns a tensor of shape (B,), containing the
total RNN-T loss values for each element of the batch, otherwise a scalar total RNN-T loss values for each sequence of the batch, otherwise a scalar
with the reduction applied. with the reduction applied.
""" """
px, py = get_rnnt_logprobs_pruned( px, py = get_rnnt_logprobs_pruned(
...@@ -905,11 +1005,12 @@ def rnnt_loss_pruned( ...@@ -905,11 +1005,12 @@ def rnnt_loss_pruned(
ranges=ranges, ranges=ranges,
termination_symbol=termination_symbol, termination_symbol=termination_symbol,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
) )
if delay_penalty > 0.0: if delay_penalty > 0.0:
B, S, T0 = px.shape B, S, T0 = px.shape
T = T0 if modified else T0 - 1 T = T0 if rnnt_type != "regular" else T0 - 1
if boundary is None: if boundary is None:
offset = torch.tensor( offset = torch.tensor(
(T - 1) / 2, dtype=px.dtype, device=px.device, (T - 1) / 2, dtype=px.dtype, device=px.device,
...@@ -921,6 +1022,7 @@ def rnnt_loss_pruned( ...@@ -921,6 +1022,7 @@ def rnnt_loss_pruned(
).reshape(1, 1, T0) ).reshape(1, 1, T0)
penalty = penalty * delay_penalty penalty = penalty * delay_penalty
px += penalty.to(px.dtype) px += penalty.to(px.dtype)
negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary) negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary)
if reduction == "none": if reduction == "none":
return -negated_loss return -negated_loss
...@@ -929,9 +1031,9 @@ def rnnt_loss_pruned( ...@@ -929,9 +1031,9 @@ def rnnt_loss_pruned(
elif reduction == "sum": elif reduction == "sum":
return -torch.sum(negated_loss) return -torch.sum(negated_loss)
else: else:
assert ( raise ValueError(
False f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}"
), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}" )
def get_rnnt_logprobs_smoothed( def get_rnnt_logprobs_smoothed(
...@@ -942,7 +1044,7 @@ def get_rnnt_logprobs_smoothed( ...@@ -942,7 +1044,7 @@ def get_rnnt_logprobs_smoothed(
lm_only_scale: float = 0.1, lm_only_scale: float = 0.1,
am_only_scale: float = 0.1, am_only_scale: float = 0.1,
boundary: Optional[Tensor] = None, boundary: Optional[Tensor] = None,
modified: bool = False, rnnt_type: str = "regular",
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
""" """
Reduces RNN-T problem (the simple case, where joiner network is just Reduces RNN-T problem (the simple case, where joiner network is just
...@@ -1005,18 +1107,32 @@ def get_rnnt_logprobs_smoothed( ...@@ -1005,18 +1107,32 @@ def get_rnnt_logprobs_smoothed(
Most likely you will want begin_symbol and begin_frame to be zero. Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame. also be consumed, so at most 1 symbol can appear per frame.
rnnt_type:
Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
Returns: Returns:
(px, py) (the names are quite arbitrary). (px, py) (the names are quite arbitrary).
px: logprobs, of shape [B][S][T+1] if !modified, [B][S][T] if modified. px: logprobs, of shape [B][S][T+1] if rnnt_type == "regular",
[B][S][T] if rnnt_type != "regular".
py: logprobs, of shape [B][S+1][T] py: logprobs, of shape [B][S+1][T]
in the recursion:: in the recursion::
p[b,0,0] = 0.0 p[b,0,0] = 0.0
if !modified: if rnnt_type == "regular":
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1]) p[b,s,t-1] + py[b,s,t-1])
if modified: if rnnt_type != "regular":
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1], p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1]) p[b,s,t-1] + py[b,s,t-1])
.. where p[b][s][t] is the "joint score" of the pair of subsequences .. where p[b][s][t] is the "joint score" of the pair of subsequences
...@@ -1031,15 +1147,16 @@ def get_rnnt_logprobs_smoothed( ...@@ -1031,15 +1147,16 @@ def get_rnnt_logprobs_smoothed(
we cannot emit any symbols. This is simply a way of incorporating we cannot emit any symbols. This is simply a way of incorporating
the probability of the termination symbol on the last frame. the probability of the termination symbol on the last frame.
""" """
assert lm.ndim == 3 assert lm.ndim == 3, lm.ndim
assert am.ndim == 3 assert am.ndim == 3, am.ndim
assert lm.shape[0] == am.shape[0] assert lm.shape[0] == am.shape[0], (lm.shape[0], am.shape[0])
assert lm.shape[2] == am.shape[2] assert lm.shape[2] == am.shape[2], (lm.shape[2], am.shape[2])
(B, T, C) = am.shape (B, T, C) = am.shape
S = lm.shape[1] - 1 S = lm.shape[1] - 1
assert symbols.shape == (B, S) assert symbols.shape == (B, S), symbols.shape
assert S >= 1 assert S >= 1, S
assert T >= S assert T >= S, (T, S)
assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type
# Caution: some parts of this code are a little less clear than they could # Caution: some parts of this code are a little less clear than they could
# be due to optimizations. In particular it may not be totally obvious that # be due to optimizations. In particular it may not be totally obvious that
...@@ -1091,7 +1208,7 @@ def get_rnnt_logprobs_smoothed( ...@@ -1091,7 +1208,7 @@ def get_rnnt_logprobs_smoothed(
-1 -1
) # [B][S][T] ) # [B][S][T]
if not modified: if rnnt_type == "regular":
px_am = torch.cat( px_am = torch.cat(
( (
px_am, px_am,
...@@ -1150,8 +1267,10 @@ def get_rnnt_logprobs_smoothed( ...@@ -1150,8 +1267,10 @@ def get_rnnt_logprobs_smoothed(
+ py_amonly * am_only_scale + py_amonly * am_only_scale
) )
if not modified: if rnnt_type == "regular":
px_interp = fix_for_boundary(px_interp, boundary) px_interp = fix_for_boundary(px_interp, boundary)
elif rnnt_type == "constrained":
px_interp += py_interp[:, 1:, :]
return (px_interp, py_interp) return (px_interp, py_interp)
...@@ -1164,7 +1283,7 @@ def rnnt_loss_smoothed( ...@@ -1164,7 +1283,7 @@ def rnnt_loss_smoothed(
lm_only_scale: float = 0.1, lm_only_scale: float = 0.1,
am_only_scale: float = 0.1, am_only_scale: float = 0.1,
boundary: Optional[Tensor] = None, boundary: Optional[Tensor] = None,
modified: bool = False, rnnt_type: str = "regular",
delay_penalty: float = 0.0, delay_penalty: float = 0.0,
reduction: Optional[str] = "mean", reduction: Optional[str] = "mean",
return_grad: bool = False, return_grad: bool = False,
...@@ -1197,8 +1316,19 @@ def rnnt_loss_smoothed( ...@@ -1197,8 +1316,19 @@ def rnnt_loss_smoothed(
[0, 0, S, T] [0, 0, S, T]
if boundary is not supplied. if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero. Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will rnnt_type:
also be consumed, so at most 1 symbol can appear per frame. Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
delay_penalty: A constant value to penalize symbol delay, this may be delay_penalty: A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time-masking needed when training with time masking, to avoid the time-masking
encouraging the network to delay symbols. encouraging the network to delay symbols.
...@@ -1233,11 +1363,12 @@ def rnnt_loss_smoothed( ...@@ -1233,11 +1363,12 @@ def rnnt_loss_smoothed(
lm_only_scale=lm_only_scale, lm_only_scale=lm_only_scale,
am_only_scale=am_only_scale, am_only_scale=am_only_scale,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
) )
if delay_penalty > 0.0: if delay_penalty > 0.0:
B, S, T0 = px.shape B, S, T0 = px.shape
T = T0 if modified else T0 - 1 T = T0 if rnnt_type != "regular" else T0 - 1
if boundary is None: if boundary is None:
offset = torch.tensor( offset = torch.tensor(
(T - 1) / 2, dtype=px.dtype, device=px.device, (T - 1) / 2, dtype=px.dtype, device=px.device,
...@@ -1249,6 +1380,7 @@ def rnnt_loss_smoothed( ...@@ -1249,6 +1380,7 @@ def rnnt_loss_smoothed(
).reshape(1, 1, T0) ).reshape(1, 1, T0)
penalty = penalty * delay_penalty penalty = penalty * delay_penalty
px += penalty.to(px.dtype) px += penalty.to(px.dtype)
scores_and_grads = mutual_information_recursion( scores_and_grads = mutual_information_recursion(
px=px, py=py, boundary=boundary, return_grad=return_grad px=px, py=py, boundary=boundary, return_grad=return_grad
) )
...@@ -1260,7 +1392,7 @@ def rnnt_loss_smoothed( ...@@ -1260,7 +1392,7 @@ def rnnt_loss_smoothed(
elif reduction == "sum": elif reduction == "sum":
loss = -torch.sum(negated_loss) loss = -torch.sum(negated_loss)
else: else:
assert ( raise ValueError(
False f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}"
), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}" )
return (loss, scores_and_grads[1]) if return_grad else loss return (loss, scores_and_grads[1]) if return_grad else loss
...@@ -90,7 +90,9 @@ class TestRnntLoss(unittest.TestCase): ...@@ -90,7 +90,9 @@ class TestRnntLoss(unittest.TestCase):
assert px.shape == (B, S, T + 1) assert px.shape == (B, S, T + 1)
assert py.shape == (B, S + 1, T) assert py.shape == (B, S + 1, T)
assert symbols.shape == (B, S) assert symbols.shape == (B, S)
m = fast_rnnt.mutual_information_recursion(px=px, py=py, boundary=None) m = fast_rnnt.mutual_information_recursion(
px=px, py=py, boundary=None
)
if device == torch.device("cpu"): if device == torch.device("cpu"):
expected = -m expected = -m
...@@ -205,7 +207,7 @@ class TestRnntLoss(unittest.TestCase): ...@@ -205,7 +207,7 @@ class TestRnntLoss(unittest.TestCase):
boundary_[:, 2] = seq_length boundary_[:, 2] = seq_length
boundary_[:, 3] = frames boundary_[:, 3] = frames
for modified in [True, False]: for rnnt_type in ["regular", "modified", "constrained"]:
for device in self.devices: for device in self.devices:
# lm: [B][S+1][C] # lm: [B][S+1][C]
lm = lm_.to(device) lm = lm_.to(device)
...@@ -220,9 +222,13 @@ class TestRnntLoss(unittest.TestCase): ...@@ -220,9 +222,13 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols, symbols=symbols,
termination_symbol=termination_symbol, termination_symbol=termination_symbol,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
)
assert (
px.shape == (B, S, T)
if rnnt_type != "regular"
else (B, S, T + 1)
) )
assert px.shape == (B, S, T) if modified else (B, S, T + 1)
assert py.shape == (B, S + 1, T) assert py.shape == (B, S + 1, T)
assert symbols.shape == (B, S) assert symbols.shape == (B, S)
m = fast_rnnt.mutual_information_recursion( m = fast_rnnt.mutual_information_recursion(
...@@ -239,7 +245,7 @@ class TestRnntLoss(unittest.TestCase): ...@@ -239,7 +245,7 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols, symbols=symbols,
termination_symbol=termination_symbol, termination_symbol=termination_symbol,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
) )
assert torch.allclose(m, expected.to(device)) assert torch.allclose(m, expected.to(device))
...@@ -251,7 +257,7 @@ class TestRnntLoss(unittest.TestCase): ...@@ -251,7 +257,7 @@ class TestRnntLoss(unittest.TestCase):
lm_only_scale=0.0, lm_only_scale=0.0,
am_only_scale=0.0, am_only_scale=0.0,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
) )
assert torch.allclose(m, expected.to(device)) assert torch.allclose(m, expected.to(device))
...@@ -261,12 +267,12 @@ class TestRnntLoss(unittest.TestCase): ...@@ -261,12 +267,12 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols, symbols=symbols,
termination_symbol=termination_symbol, termination_symbol=termination_symbol,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
) )
assert torch.allclose(m, expected.to(device)) assert torch.allclose(m, expected.to(device))
# compare with torchaudio rnnt_loss # compare with torchaudio rnnt_loss
if self.has_torch_rnnt_loss and not modified: if self.has_torch_rnnt_loss and rnnt_type == "regular":
import torchaudio.functional import torchaudio.functional
m = torchaudio.functional.rnnt_loss( m = torchaudio.functional.rnnt_loss(
...@@ -288,7 +294,7 @@ class TestRnntLoss(unittest.TestCase): ...@@ -288,7 +294,7 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols, symbols=symbols,
termination_symbol=termination_symbol, termination_symbol=termination_symbol,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
) )
assert torch.allclose(m, expected.to(device)) assert torch.allclose(m, expected.to(device))
...@@ -298,7 +304,7 @@ class TestRnntLoss(unittest.TestCase): ...@@ -298,7 +304,7 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols, symbols=symbols,
termination_symbol=termination_symbol, termination_symbol=termination_symbol,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
) )
assert torch.allclose(m, expected.to(device)) assert torch.allclose(m, expected.to(device))
...@@ -310,7 +316,7 @@ class TestRnntLoss(unittest.TestCase): ...@@ -310,7 +316,7 @@ class TestRnntLoss(unittest.TestCase):
lm_only_scale=0.0, lm_only_scale=0.0,
am_only_scale=0.0, am_only_scale=0.0,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
) )
assert torch.allclose(m, expected.to(device)) assert torch.allclose(m, expected.to(device))
...@@ -368,9 +374,13 @@ class TestRnntLoss(unittest.TestCase): ...@@ -368,9 +374,13 @@ class TestRnntLoss(unittest.TestCase):
torch_grad = torch.autograd.grad(torch_loss, logits2) torch_grad = torch.autograd.grad(torch_loss, logits2)
torch_grad = torch_grad[0] torch_grad = torch_grad[0]
assert torch.allclose(fast_loss, torch_loss, atol=1e-2, rtol=1e-2) assert torch.allclose(
fast_loss, torch_loss, atol=1e-2, rtol=1e-2
)
assert torch.allclose(fast_grad, torch_grad, atol=1e-2, rtol=1e-2) assert torch.allclose(
fast_grad, torch_grad, atol=1e-2, rtol=1e-2
)
def test_rnnt_loss_smoothed(self): def test_rnnt_loss_smoothed(self):
B = 1 B = 1
...@@ -443,7 +453,7 @@ class TestRnntLoss(unittest.TestCase): ...@@ -443,7 +453,7 @@ class TestRnntLoss(unittest.TestCase):
boundary_[:, 2] = seq_length boundary_[:, 2] = seq_length
boundary_[:, 3] = frames boundary_[:, 3] = frames
for modified in [True, False]: for rnnt_type in ["regular", "modified", "constrained"]:
for device in self.devices: for device in self.devices:
# normal rnnt # normal rnnt
am = am_.to(device) am = am_.to(device)
...@@ -460,12 +470,10 @@ class TestRnntLoss(unittest.TestCase): ...@@ -460,12 +470,10 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols, symbols=symbols,
termination_symbol=terminal_symbol, termination_symbol=terminal_symbol,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
) )
print( print(f"Unpruned rnnt loss with {rnnt_loss} rnnt : {fast_loss}")
f"Unpruned rnnt loss with modified {modified} : {fast_loss}"
)
# pruning # pruning
simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple( simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple(
...@@ -474,7 +482,7 @@ class TestRnntLoss(unittest.TestCase): ...@@ -474,7 +482,7 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols, symbols=symbols,
termination_symbol=terminal_symbol, termination_symbol=terminal_symbol,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
return_grad=True, return_grad=True,
reduction="none", reduction="none",
) )
...@@ -487,7 +495,9 @@ class TestRnntLoss(unittest.TestCase): ...@@ -487,7 +495,9 @@ class TestRnntLoss(unittest.TestCase):
s_range=r, s_range=r,
) )
# (B, T, r, C) # (B, T, r, C)
pruned_am, pruned_lm = fast_rnnt.do_rnnt_pruning(am=am, lm=lm, ranges=ranges) pruned_am, pruned_lm = fast_rnnt.do_rnnt_pruning(
am=am, lm=lm, ranges=ranges
)
logits = pruned_am + pruned_lm logits = pruned_am + pruned_lm
...@@ -500,12 +510,11 @@ class TestRnntLoss(unittest.TestCase): ...@@ -500,12 +510,11 @@ class TestRnntLoss(unittest.TestCase):
ranges=ranges, ranges=ranges,
termination_symbol=terminal_symbol, termination_symbol=terminal_symbol,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
reduction="none", reduction="none",
) )
print(f"Pruning loss with range {r} : {pruned_loss}") print(f"Pruning loss with range {r} : {pruned_loss}")
# Test the sequences that only have small number of symbols, # Test the sequences that only have small number of symbols,
# at this circumstance, the s_range would be greater than S, which will # at this circumstance, the s_range would be greater than S, which will
# raise errors (like, nan or inf loss) in our previous versions. # raise errors (like, nan or inf loss) in our previous versions.
...@@ -531,7 +540,7 @@ class TestRnntLoss(unittest.TestCase): ...@@ -531,7 +540,7 @@ class TestRnntLoss(unittest.TestCase):
print(f"B = {B}, T = {T}, S = {S}, C = {C}") print(f"B = {B}, T = {T}, S = {S}, C = {C}")
for modified in [True, False]: for rnnt_type in ["regular", "modified", "constrained"]:
for device in self.devices: for device in self.devices:
# normal rnnt # normal rnnt
am = am_.to(device) am = am_.to(device)
...@@ -550,13 +559,11 @@ class TestRnntLoss(unittest.TestCase): ...@@ -550,13 +559,11 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols, symbols=symbols,
termination_symbol=terminal_symbol, termination_symbol=terminal_symbol,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
reduction="none", reduction="none",
) )
print( print(f"Unpruned rnnt loss with {rnnt_type} rnnt : {loss}")
f"Unpruned rnnt loss with modified {modified} : {loss}"
)
# pruning # pruning
simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple( simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple(
...@@ -565,13 +572,13 @@ class TestRnntLoss(unittest.TestCase): ...@@ -565,13 +572,13 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols, symbols=symbols,
termination_symbol=terminal_symbol, termination_symbol=terminal_symbol,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
return_grad=True, return_grad=True,
reduction="none", reduction="none",
) )
S0 = 2 S0 = 2
if modified: if rnnt_type == "regular":
S0 = 1 S0 = 1
for r in range(S0, S + 2): for r in range(S0, S + 2):
...@@ -597,10 +604,11 @@ class TestRnntLoss(unittest.TestCase): ...@@ -597,10 +604,11 @@ class TestRnntLoss(unittest.TestCase):
ranges=ranges, ranges=ranges,
termination_symbol=terminal_symbol, termination_symbol=terminal_symbol,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
reduction="none", reduction="none",
) )
print(f"Pruned loss with range {r} : {pruned_loss}") print(f"Pruned loss with range {r} : {pruned_loss}")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment