Commit b0ed23ef authored by pkufool's avatar pkufool
Browse files

Add constrained rnnt

parent dc35168d
......@@ -26,13 +26,13 @@ from .mutual_information import mutual_information_recursion
def fix_for_boundary(px: Tensor, boundary: Optional[Tensor] = None) -> Tensor:
"""
Insert -inf's into `px` in appropriate places if `boundary` is not
None. If boundary == None and modified == False, px[:,:,-1] will
None. If boundary == None and rnnt_type == "regular", px[:,:,-1] will
be -infinity, but if boundary is specified, we need px[b,:,boundary[b,3]]
to be -infinity.
Args:
px: a Tensor of of shape [B][S][T+1] (this function is only
called if modified == False, see other docs for `modified`)
called if rnnt_type == "regular", see other docs for `rnnt_type`)
px is modified in-place and returned.
boundary: None, or a Tensor of shape [B][3] containing
[s_begin, t_begin, s_end, t_end]; we need only t_end.
......@@ -49,8 +49,8 @@ def get_rnnt_logprobs(
am: Tensor,
symbols: Tensor,
termination_symbol: int,
rnnt_type: str = "regular",
boundary: Optional[Tensor] = None,
modified: bool = False,
) -> Tuple[Tensor, Tensor]:
"""
Reduces RNN-T problem (the simple case, where joiner network is just
......@@ -97,20 +97,32 @@ def get_rnnt_logprobs(
[0, 0, S, T]
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
rnnt_type:
Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
Returns:
(px, py) (the names are quite arbitrary).
px: logprobs, of shape [B][S][T+1] if !modified, [B][S][T] if modified.
px: logprobs, of shape [B][S][T+1] if rnnt_type is regular,
[B][S][T] if rnnt_type is not regular.
py: logprobs, of shape [B][S+1][T]
in the recursion::
p[b,0,0] = 0.0
if !modified:
if rnnt_type == "regular":
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
if modified:
if rnnt_type != "regular":
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1])
.. where p[b][s][t] is the "joint score" of the pair of subsequences
......@@ -121,21 +133,22 @@ def get_rnnt_logprobs(
(s,t) by one in the t direction,
i.e. of emitting the termination/next-frame symbol.
if !modified, px[:,:,T] equals -infinity, meaning on the
if rnnt_type == "regular", px[:,:,T] equals -infinity, meaning on the
"one-past-the-last" frame we cannot emit any symbols.
This is simply a way of incorporating
the probability of the termination symbol on the last frame.
"""
assert lm.ndim == 3
assert am.ndim == 3
assert lm.shape[0] == am.shape[0]
assert lm.shape[2] == am.shape[2]
assert lm.ndim == 3, lm.ndim
assert am.ndim == 3, am.ndim
assert lm.shape[0] == am.shape[0], (lm.shape[0], am.shape[0])
assert lm.shape[2] == am.shape[2], (lm.shape[2], am.shape[2])
(B, T, C) = am.shape
S = lm.shape[1] - 1
assert symbols.shape == (B, S)
assert S >= 1
assert T >= S
assert symbols.shape == (B, S), symbols.shape
assert S >= 1, S
assert T >= S, (T, S)
assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type
# subtracting am_max and lm_max is to ensure the probs are in a good range
# to do exp() without causing underflow or overflow.
......@@ -162,7 +175,7 @@ def get_rnnt_logprobs(
-1
) # [B][S][T]
if not modified:
if rnnt_type == "regular":
px_am = torch.cat(
(
px_am,
......@@ -189,8 +202,10 @@ def get_rnnt_logprobs(
py_lm = lm[:, :, termination_symbol].unsqueeze(2) # [B][S+1][1]
py = py_am + py_lm - normalizers
if not modified:
if rnnt_type == "regular":
px = fix_for_boundary(px, boundary)
elif rnnt_type == "constrained":
px += py[:, 1:, :]
return (px, py)
......@@ -201,7 +216,7 @@ def rnnt_loss_simple(
symbols: Tensor,
termination_symbol: int,
boundary: Optional[Tensor] = None,
modified: bool = False,
rnnt_type: str = "regular",
delay_penalty: float = 0.0,
reduction: Optional[str] = "mean",
return_grad: bool = False,
......@@ -227,8 +242,19 @@ def rnnt_loss_simple(
[0, 0, S, T]
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
rnnt_type:
Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
delay_penalty: A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time-masking
encouraging the network to delay symbols.
......@@ -260,12 +286,12 @@ def rnnt_loss_simple(
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
)
if delay_penalty > 0.0:
B, S, T0 = px.shape
T = T0 if modified else T0 - 1
T = T0 if rnnt_type != "regular" else T0 - 1
if boundary is None:
offset = torch.tensor(
(T - 1) / 2, dtype=px.dtype, device=px.device,
......@@ -289,9 +315,9 @@ def rnnt_loss_simple(
elif reduction == "sum":
loss = -torch.sum(negated_loss)
else:
assert (
False
), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}"
raise ValueError(
f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}"
)
return (loss, scores_and_grads[1]) if return_grad else loss
......@@ -300,7 +326,7 @@ def get_rnnt_logprobs_joint(
symbols: Tensor,
termination_symbol: int,
boundary: Optional[Tensor] = None,
modified: bool = False,
rnnt_type: str = "regular",
) -> Tuple[Tensor, Tensor]:
"""Reduces RNN-T problem to a compact, standard form that can then be given
(with boundaries) to mutual_information_recursion().
......@@ -321,21 +347,33 @@ def get_rnnt_logprobs_joint(
[0, 0, S, T]
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
rnnt_type:
Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
Returns:
(px, py) (the names are quite arbitrary)::
px: logprobs, of shape [B][S][T+1]
px: logprobs, of shape [B][S][T+1] if rnnt_type is regular,
[B][S][T] if rnnt_type is not regular.
py: logprobs, of shape [B][S+1][T]
in the recursion::
p[b,0,0] = 0.0
if !modified:
if rnnt_type == "regular":
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
if modified:
if rnnt_type != "regular":
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1])
.. where p[b][s][t] is the "joint score" of the pair of subsequences of
......@@ -345,17 +383,18 @@ def get_rnnt_logprobs_joint(
of extending the subsequences of length (s,t) by one in the t direction,
i.e. of emitting the termination/next-frame symbol.
if !modified, px[:,:,T] equals -infinity, meaning on the
if rnnt_type == "regular", px[:,:,T] equals -infinity, meaning on the
"one-past-the-last" frame we cannot emit any symbols.
This is simply a way of incorporating
the probability of the termination symbol on the last frame.
"""
assert logits.ndim == 4
assert logits.ndim == 4, logits.ndim
(B, T, S1, C) = logits.shape
S = S1 - 1
assert symbols.shape == (B, S)
assert S >= 1
assert T >= S
assert symbols.shape == (B, S), symbols.shape
assert S >= 1, S
assert T >= S, (T, S)
assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type
normalizers = torch.logsumexp(logits, dim=3)
normalizers = normalizers.permute((0, 2, 1))
......@@ -365,7 +404,7 @@ def get_rnnt_logprobs_joint(
).squeeze(-1)
px = px.permute((0, 2, 1))
if not modified:
if rnnt_type == "regular":
px = torch.cat(
(
px,
......@@ -383,8 +422,10 @@ def get_rnnt_logprobs_joint(
) # [B][S+1][T]
py -= normalizers
if not modified:
if rnnt_type == "regular":
px = fix_for_boundary(px, boundary)
elif rnnt_type == "constrained":
px += py[:, 1:, :]
return (px, py)
......@@ -394,7 +435,7 @@ def rnnt_loss(
symbols: Tensor,
termination_symbol: int,
boundary: Optional[Tensor] = None,
modified: bool = False,
rnnt_type: str = "regular",
delay_penalty: float = 0.0,
reduction: Optional[str] = "mean",
) -> Tensor:
......@@ -415,8 +456,19 @@ def rnnt_loss(
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T] if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
rnnt_type:
Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
delay_penalty: A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time-masking
encouraging the network to delay symbols.
......@@ -438,11 +490,12 @@ def rnnt_loss(
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
)
if delay_penalty > 0.0:
B, S, T0 = px.shape
T = T0 if modified else T0 - 1
T = T0 if rnnt_type != "regular" else T0 - 1
if boundary is None:
offset = torch.tensor(
(T - 1) / 2, dtype=px.dtype, device=px.device,
......@@ -454,6 +507,7 @@ def rnnt_loss(
).reshape(1, 1, T0)
penalty = penalty * delay_penalty
px += penalty.to(px.dtype)
negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary)
if reduction == "none":
return -negated_loss
......@@ -462,30 +516,30 @@ def rnnt_loss(
elif reduction == "sum":
return -torch.sum(negated_loss)
else:
assert (
False
), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}"
raise ValueError(
f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}"
)
def _adjust_pruning_lower_bound(
s_begin: torch.Tensor, s_range: int
) -> torch.Tensor:
"""Adjust s_begin (pruning lower bound) to make it satisfied the following
constrains
"""Adjust s_begin (pruning lower bounds) to make it satisfy the following
constraints
- monotonic increasing, i.e. s_begin[i] <= s_begin[i + 1]
- start with symbol 0 at first frame.
- s_begin[i + 1] - s_begin[i] < s_range, whicn means that we can't skip
- s_begin[i + 1] - s_begin[i] < s_range, which means that we can't skip
any symbols.
To make it monotonic increasing, we can use `monotonic_lower_bound` function
in k2, which guarantee `s_begin[i] <= s_begin[i + 1]`. The main idea is:
in k2, which guarantees `s_begin[i] <= s_begin[i + 1]`. The main idea is:
traverse the array in reverse order and update the elements by
`min_value = min(a_begin[i], min_value)`, the initial `min_value` set to
`min_value = min(a_begin[i], min_value)`, the initial `min_value` is set to
`inf`.
The method we used to realize `s_begin[i + 1] - s_begin[i] < s_range`
constrain is a little tricky. We first transform `s_begin` with
constraint is a little tricky. We first transform `s_begin` with
`s_begin = -(s_begin - (s_range - 1) * torch.arange(0,T))`
then we make the transformed `s_begin` monotonic increasing, after that,
we transform back `s_begin` with the same formula as the previous
......@@ -551,9 +605,9 @@ def get_rnnt_prune_ranges(
Note:
For the generated tensor ranges(assuming batch size is 1), ranges[:, 0]
is a monotonic increasing tensor from 0 to `len(symbols)` and it satisfies
`ranges[t+1, 0] - ranges[t, 0] < s_range` which means we won't skip any
symbols.
is a monotonic increasing tensor from 0 to `len(symbols) - s_range` and
it satisfies `ranges[t+1, 0] - ranges[t, 0] < s_range` which means we
won't skip any symbols.
Args:
px_grad:
......@@ -568,21 +622,21 @@ def get_rnnt_prune_ranges(
s_range:
How many symbols to keep for each frame.
Returns:
A tensor contains the kept symbols indexes for each frame, with shape
(B, T, s_range).
A tensor with the shape of (B, T, s_range) containing the indexes of the
kept symbols for each frame.
"""
(B, S, T1) = px_grad.shape
T = py_grad.shape[-1]
assert T1 in [T, T + 1]
assert T1 in [T, T + 1], T1
S1 = S + 1
assert py_grad.shape == (B, S + 1, T)
assert boundary.shape == (B, 4)
assert py_grad.shape == (B, S + 1, T), py_grad.shape
assert boundary.shape == (B, 4), boundary.shape
assert S >= 1
assert T >= S
assert S >= 1, S
assert T >= S, (T, S)
# s_range > S means we won't prune out any symbols. To make indexing with
# ranges runs normally, s_range should be equal to or less than ``S + 1``.
# ranges run normally, s_range should be equal to or less than ``S + 1``.
if s_range > S:
s_range = S + 1
......@@ -630,16 +684,17 @@ def get_rnnt_prune_ranges(
mask = mask < boundary[:, 3].reshape(B, 1) - 1
s_begin_padding = boundary[:, 2].reshape(B, 1) - s_range + 1
# handle the cases when `len(symbols) < s_range`
# handle the cases where `len(symbols) < s_range`
s_begin_padding = torch.clamp(s_begin_padding, min=0)
s_begin = torch.where(mask, s_begin, s_begin_padding)
# adjusting lower bound to make it satisfied some constrains, see docs in
# `_adjust_pruning_lower_bound` for more details of these constrains.
# T1 == T here means we are using the modified version of transducer,
# the third constrain becomes `s_begin[i + 1] - s_begin[i] < 2`, because
# it only emits one symbol per frame.
# adjusting lower bound to make it satisfy some constraints, see docs in
# `_adjust_pruning_lower_bound` for more details of these constraints.
# T1 == T here means we are using the non-regular(i.e. modified rnnt or
# constrained rnnt) version of transducer, the third constraint becomes
# `s_begin[i + 1] - s_begin[i] < 2`, because it only emits one symbol per
# frame.
s_begin = _adjust_pruning_lower_bound(s_begin, 2 if T1 == T else s_range)
ranges = s_begin.reshape((B, T, 1)).expand((B, T, s_range)) + torch.arange(
......@@ -652,8 +707,8 @@ def get_rnnt_prune_ranges(
def do_rnnt_pruning(
am: torch.Tensor, lm: torch.Tensor, ranges: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Prune the output of encoder(am) output and prediction network(lm)
output of RNNT.
"""Prune the output of encoder(am) and prediction network(lm) with ranges
generated by `get_rnnt_prune_ranges`.
Args:
am:
......@@ -671,9 +726,9 @@ def do_rnnt_pruning(
# am (B, T, C)
# lm (B, S + 1, C)
# ranges (B, T, s_range)
assert ranges.shape[0] == am.shape[0]
assert ranges.shape[0] == lm.shape[0]
assert am.shape[1] == ranges.shape[1]
assert ranges.shape[0] == am.shape[0], (ranges.shape[0], am.shape[0])
assert ranges.shape[0] == lm.shape[0], (ranges.shape[0], lm.shape[0])
assert am.shape[1] == ranges.shape[1], (am.shape[1], ranges.shape[1])
(B, T, s_range) = ranges.shape
(B, S1, C) = lm.shape
S = S1 - 1
......@@ -711,9 +766,9 @@ def _roll_by_shifts(src: torch.Tensor, shifts: torch.LongTensor):
[ 8, 9, 5, 6, 7],
[12, 13, 14, 10, 11]]])
"""
assert src.dim() == 3
assert src.dim() == 3, src.dim()
(B, T, S) = src.shape
assert shifts.shape == (B, T)
assert shifts.shape == (B, T), shifts.shape
index = (
torch.arange(S, device=src.device)
......@@ -731,7 +786,7 @@ def get_rnnt_logprobs_pruned(
ranges: Tensor,
termination_symbol: int,
boundary: Tensor,
modified: bool = False,
rnnt_type: str = "regular",
) -> Tuple[Tensor, Tensor]:
"""Construct px, py for mutual_information_recursion with pruned output.
......@@ -751,21 +806,53 @@ def get_rnnt_logprobs_pruned(
[0, 0, S, T]
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
rnnt_type:
Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
Returns:
Return the px (B, S, T) if modified else (B, S, T + 1) and
py (B, S + 1, T) needed by mutual_information_recursion.
(px, py) (the names are quite arbitrary)::
px: logprobs, of shape [B][S][T+1] if rnnt_type is regular,
[B][S][T] if rnnt_type is not regular.
py: logprobs, of shape [B][S+1][T]
in the recursion::
p[b,0,0] = 0.0
if rnnt_type == "regular":
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
if rnnt_type != "regular":
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1])
.. where p[b][s][t] is the "joint score" of the pair of subsequences of
length s and t respectively. px[b][s][t] represents the probability of
extending the subsequences of length (s,t) by one in the s direction,
given the particular symbol, and py[b][s][t] represents the probability
of extending the subsequences of length (s,t) by one in the t direction,
i.e. of emitting the termination/next-frame symbol.
if `rnnt_type == "regular"`, px[:,:,T] equals -infinity, meaning on the
"one-past-the-last" frame we cannot emit any symbols.
This is simply a way of incorporating
the probability of the termination symbol on the last frame.
"""
# logits (B, T, s_range, C)
# symbols (B, S)
# ranges (B, T, s_range)
assert logits.ndim == 4
assert logits.ndim == 4, logits.ndim
(B, T, s_range, C) = logits.shape
assert ranges.shape == (B, T, s_range)
assert ranges.shape == (B, T, s_range), ranges.shape
(B, S) = symbols.shape
assert S >= 1
assert T >= S
assert S >= 1, S
assert T >= S, (T, S)
assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type
normalizers = torch.logsumexp(logits, dim=3)
......@@ -813,7 +900,7 @@ def get_rnnt_logprobs_pruned(
px = px.permute((0, 2, 1))
if not modified:
if rnnt_type == "regular":
px = torch.cat(
(
px,
......@@ -846,8 +933,10 @@ def get_rnnt_logprobs_pruned(
# (B, S + 1, T)
py = py.permute((0, 2, 1))
if not modified:
if rnnt_type == "regular":
px = fix_for_boundary(px, boundary)
elif rnnt_type == "constrained":
px += py[:, 1:, :]
return (px, py)
......@@ -858,13 +947,13 @@ def rnnt_loss_pruned(
ranges: Tensor,
termination_symbol: int,
boundary: Tensor = None,
modified: bool = False,
rnnt_type: str = "regular",
delay_penalty: float = 0.0,
reduction: Optional[str] = "mean",
) -> Tensor:
"""A RNN-T loss with pruning, which uses a pruned 'joiner' network output
as input, i.e. a 4 dimensions tensor with shape (B, T, s_range, C),
s_range means the symbols number kept for each frame.
"""A RNN-T loss with pruning, which uses the output of a pruned 'joiner'
network as input, i.e. a 4 dimensions tensor with shape (B, T, s_range, C),
s_range means the number of symbols kept for each frame.
Args:
logits:
......@@ -882,8 +971,19 @@ def rnnt_loss_pruned(
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T] if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
rnnt_type:
Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
delay_penalty: A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time-masking
encouraging the network to delay symbols.
......@@ -895,8 +995,8 @@ def rnnt_loss_pruned(
`sum`: the output will be summed.
Default: `mean`
Returns:
If recursion is `none`, returns a tensor of shape (B,), containing the
total RNN-T loss values for each element of the batch, otherwise a scalar
If reduction is `none`, returns a tensor of shape (B,), containing the
total RNN-T loss values for each sequence of the batch, otherwise a scalar
with the reduction applied.
"""
px, py = get_rnnt_logprobs_pruned(
......@@ -905,11 +1005,12 @@ def rnnt_loss_pruned(
ranges=ranges,
termination_symbol=termination_symbol,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
)
if delay_penalty > 0.0:
B, S, T0 = px.shape
T = T0 if modified else T0 - 1
T = T0 if rnnt_type != "regular" else T0 - 1
if boundary is None:
offset = torch.tensor(
(T - 1) / 2, dtype=px.dtype, device=px.device,
......@@ -921,6 +1022,7 @@ def rnnt_loss_pruned(
).reshape(1, 1, T0)
penalty = penalty * delay_penalty
px += penalty.to(px.dtype)
negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary)
if reduction == "none":
return -negated_loss
......@@ -929,9 +1031,9 @@ def rnnt_loss_pruned(
elif reduction == "sum":
return -torch.sum(negated_loss)
else:
assert (
False
), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}"
raise ValueError(
f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}"
)
def get_rnnt_logprobs_smoothed(
......@@ -942,7 +1044,7 @@ def get_rnnt_logprobs_smoothed(
lm_only_scale: float = 0.1,
am_only_scale: float = 0.1,
boundary: Optional[Tensor] = None,
modified: bool = False,
rnnt_type: str = "regular",
) -> Tuple[Tensor, Tensor]:
"""
Reduces RNN-T problem (the simple case, where joiner network is just
......@@ -1005,18 +1107,32 @@ def get_rnnt_logprobs_smoothed(
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
rnnt_type:
Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
Returns:
(px, py) (the names are quite arbitrary).
px: logprobs, of shape [B][S][T+1] if !modified, [B][S][T] if modified.
px: logprobs, of shape [B][S][T+1] if rnnt_type == "regular",
[B][S][T] if rnnt_type != "regular".
py: logprobs, of shape [B][S+1][T]
in the recursion::
p[b,0,0] = 0.0
if !modified:
if rnnt_type == "regular":
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
if modified:
if rnnt_type != "regular":
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1])
.. where p[b][s][t] is the "joint score" of the pair of subsequences
......@@ -1031,15 +1147,16 @@ def get_rnnt_logprobs_smoothed(
we cannot emit any symbols. This is simply a way of incorporating
the probability of the termination symbol on the last frame.
"""
assert lm.ndim == 3
assert am.ndim == 3
assert lm.shape[0] == am.shape[0]
assert lm.shape[2] == am.shape[2]
assert lm.ndim == 3, lm.ndim
assert am.ndim == 3, am.ndim
assert lm.shape[0] == am.shape[0], (lm.shape[0], am.shape[0])
assert lm.shape[2] == am.shape[2], (lm.shape[2], am.shape[2])
(B, T, C) = am.shape
S = lm.shape[1] - 1
assert symbols.shape == (B, S)
assert S >= 1
assert T >= S
assert symbols.shape == (B, S), symbols.shape
assert S >= 1, S
assert T >= S, (T, S)
assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type
# Caution: some parts of this code are a little less clear than they could
# be due to optimizations. In particular it may not be totally obvious that
......@@ -1091,7 +1208,7 @@ def get_rnnt_logprobs_smoothed(
-1
) # [B][S][T]
if not modified:
if rnnt_type == "regular":
px_am = torch.cat(
(
px_am,
......@@ -1150,8 +1267,10 @@ def get_rnnt_logprobs_smoothed(
+ py_amonly * am_only_scale
)
if not modified:
if rnnt_type == "regular":
px_interp = fix_for_boundary(px_interp, boundary)
elif rnnt_type == "constrained":
px_interp += py_interp[:, 1:, :]
return (px_interp, py_interp)
......@@ -1164,7 +1283,7 @@ def rnnt_loss_smoothed(
lm_only_scale: float = 0.1,
am_only_scale: float = 0.1,
boundary: Optional[Tensor] = None,
modified: bool = False,
rnnt_type: str = "regular",
delay_penalty: float = 0.0,
reduction: Optional[str] = "mean",
return_grad: bool = False,
......@@ -1197,8 +1316,19 @@ def rnnt_loss_smoothed(
[0, 0, S, T]
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
rnnt_type:
Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
delay_penalty: A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time-masking
encouraging the network to delay symbols.
......@@ -1233,11 +1363,12 @@ def rnnt_loss_smoothed(
lm_only_scale=lm_only_scale,
am_only_scale=am_only_scale,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
)
if delay_penalty > 0.0:
B, S, T0 = px.shape
T = T0 if modified else T0 - 1
T = T0 if rnnt_type != "regular" else T0 - 1
if boundary is None:
offset = torch.tensor(
(T - 1) / 2, dtype=px.dtype, device=px.device,
......@@ -1249,6 +1380,7 @@ def rnnt_loss_smoothed(
).reshape(1, 1, T0)
penalty = penalty * delay_penalty
px += penalty.to(px.dtype)
scores_and_grads = mutual_information_recursion(
px=px, py=py, boundary=boundary, return_grad=return_grad
)
......@@ -1260,7 +1392,7 @@ def rnnt_loss_smoothed(
elif reduction == "sum":
loss = -torch.sum(negated_loss)
else:
assert (
False
), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}"
raise ValueError(
f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}"
)
return (loss, scores_and_grads[1]) if return_grad else loss
......@@ -90,7 +90,9 @@ class TestRnntLoss(unittest.TestCase):
assert px.shape == (B, S, T + 1)
assert py.shape == (B, S + 1, T)
assert symbols.shape == (B, S)
m = fast_rnnt.mutual_information_recursion(px=px, py=py, boundary=None)
m = fast_rnnt.mutual_information_recursion(
px=px, py=py, boundary=None
)
if device == torch.device("cpu"):
expected = -m
......@@ -205,7 +207,7 @@ class TestRnntLoss(unittest.TestCase):
boundary_[:, 2] = seq_length
boundary_[:, 3] = frames
for modified in [True, False]:
for rnnt_type in ["regular", "modified", "constrained"]:
for device in self.devices:
# lm: [B][S+1][C]
lm = lm_.to(device)
......@@ -220,9 +222,13 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
)
assert (
px.shape == (B, S, T)
if rnnt_type != "regular"
else (B, S, T + 1)
)
assert px.shape == (B, S, T) if modified else (B, S, T + 1)
assert py.shape == (B, S + 1, T)
assert symbols.shape == (B, S)
m = fast_rnnt.mutual_information_recursion(
......@@ -239,7 +245,7 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
)
assert torch.allclose(m, expected.to(device))
......@@ -251,7 +257,7 @@ class TestRnntLoss(unittest.TestCase):
lm_only_scale=0.0,
am_only_scale=0.0,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
)
assert torch.allclose(m, expected.to(device))
......@@ -261,12 +267,12 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
)
assert torch.allclose(m, expected.to(device))
# compare with torchaudio rnnt_loss
if self.has_torch_rnnt_loss and not modified:
if self.has_torch_rnnt_loss and rnnt_type == "regular":
import torchaudio.functional
m = torchaudio.functional.rnnt_loss(
......@@ -288,7 +294,7 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
)
assert torch.allclose(m, expected.to(device))
......@@ -298,7 +304,7 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
)
assert torch.allclose(m, expected.to(device))
......@@ -310,7 +316,7 @@ class TestRnntLoss(unittest.TestCase):
lm_only_scale=0.0,
am_only_scale=0.0,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
)
assert torch.allclose(m, expected.to(device))
......@@ -368,9 +374,13 @@ class TestRnntLoss(unittest.TestCase):
torch_grad = torch.autograd.grad(torch_loss, logits2)
torch_grad = torch_grad[0]
assert torch.allclose(fast_loss, torch_loss, atol=1e-2, rtol=1e-2)
assert torch.allclose(
fast_loss, torch_loss, atol=1e-2, rtol=1e-2
)
assert torch.allclose(fast_grad, torch_grad, atol=1e-2, rtol=1e-2)
assert torch.allclose(
fast_grad, torch_grad, atol=1e-2, rtol=1e-2
)
def test_rnnt_loss_smoothed(self):
B = 1
......@@ -443,7 +453,7 @@ class TestRnntLoss(unittest.TestCase):
boundary_[:, 2] = seq_length
boundary_[:, 3] = frames
for modified in [True, False]:
for rnnt_type in ["regular", "modified", "constrained"]:
for device in self.devices:
# normal rnnt
am = am_.to(device)
......@@ -460,12 +470,10 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols,
termination_symbol=terminal_symbol,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
)
print(
f"Unpruned rnnt loss with modified {modified} : {fast_loss}"
)
print(f"Unpruned rnnt loss with {rnnt_loss} rnnt : {fast_loss}")
# pruning
simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple(
......@@ -474,7 +482,7 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols,
termination_symbol=terminal_symbol,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
return_grad=True,
reduction="none",
)
......@@ -487,7 +495,9 @@ class TestRnntLoss(unittest.TestCase):
s_range=r,
)
# (B, T, r, C)
pruned_am, pruned_lm = fast_rnnt.do_rnnt_pruning(am=am, lm=lm, ranges=ranges)
pruned_am, pruned_lm = fast_rnnt.do_rnnt_pruning(
am=am, lm=lm, ranges=ranges
)
logits = pruned_am + pruned_lm
......@@ -500,12 +510,11 @@ class TestRnntLoss(unittest.TestCase):
ranges=ranges,
termination_symbol=terminal_symbol,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
reduction="none",
)
print(f"Pruning loss with range {r} : {pruned_loss}")
# Test the sequences that only have small number of symbols,
# at this circumstance, the s_range would be greater than S, which will
# raise errors (like, nan or inf loss) in our previous versions.
......@@ -531,7 +540,7 @@ class TestRnntLoss(unittest.TestCase):
print(f"B = {B}, T = {T}, S = {S}, C = {C}")
for modified in [True, False]:
for rnnt_type in ["regular", "modified", "constrained"]:
for device in self.devices:
# normal rnnt
am = am_.to(device)
......@@ -550,13 +559,11 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols,
termination_symbol=terminal_symbol,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
reduction="none",
)
print(
f"Unpruned rnnt loss with modified {modified} : {loss}"
)
print(f"Unpruned rnnt loss with {rnnt_type} rnnt : {loss}")
# pruning
simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple(
......@@ -565,13 +572,13 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols,
termination_symbol=terminal_symbol,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
return_grad=True,
reduction="none",
)
S0 = 2
if modified:
if rnnt_type == "regular":
S0 = 1
for r in range(S0, S + 2):
......@@ -597,10 +604,11 @@ class TestRnntLoss(unittest.TestCase):
ranges=ranges,
termination_symbol=terminal_symbol,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
reduction="none",
)
print(f"Pruned loss with range {r} : {pruned_loss}")
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment