"examples/vscode:/vscode.git/clone" did not exist on "9e9a9488f1fb421c92f89da518f30518b57e057e"
Unverified Commit 83e4637f authored by durson's avatar durson Committed by GitHub
Browse files

added return_grad for all types of rnnt loss (#29)

* added return_grad for all types of rnnt loss

* lifted T >= S for regular case

* black reformat

* black -l80 reformat

* fixed s_range adjustment rule
parent 6a4b834f
Pipeline #2954 canceled with stages
...@@ -22,6 +22,26 @@ from typing import Optional, Tuple, Union ...@@ -22,6 +22,26 @@ from typing import Optional, Tuple, Union
from .mutual_information import mutual_information_recursion from .mutual_information import mutual_information_recursion
def validate_st_lengths(
S: int,
T: int,
is_rnnt_type_regular: bool,
boundary: Optional[Tensor] = None,
):
if boundary is None:
assert S >= 1, S
assert (
is_rnnt_type_regular or T >= S
), f"Modified transducer requires T >= S, but got T={T} and S={S}"
else:
Ss = boundary[:, 2]
Ts = boundary[:, 3]
assert (Ss >= 1).all(), Ss
assert (
is_rnnt_type_regular or (Ts >= Ss).all()
), f"Modified transducer requires T >= S, but got T={Ts} and S={Ss}"
def fix_for_boundary(px: Tensor, boundary: Optional[Tensor] = None) -> Tensor: def fix_for_boundary(px: Tensor, boundary: Optional[Tensor] = None) -> Tensor:
""" """
Insert -inf's into `px` in appropriate places if `boundary` is not Insert -inf's into `px` in appropriate places if `boundary` is not
...@@ -145,8 +165,8 @@ def get_rnnt_logprobs( ...@@ -145,8 +165,8 @@ def get_rnnt_logprobs(
(B, T, C) = am.shape (B, T, C) = am.shape
S = lm.shape[1] - 1 S = lm.shape[1] - 1
assert symbols.shape == (B, S), symbols.shape assert symbols.shape == (B, S), symbols.shape
assert S >= 1, S
assert T >= S, (T, S) validate_st_lengths(S, T, rnnt_type == "regular", boundary)
assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type
# subtracting am_max and lm_max is to ensure the probs are in a good range # subtracting am_max and lm_max is to ensure the probs are in a good range
...@@ -391,8 +411,8 @@ def get_rnnt_logprobs_joint( ...@@ -391,8 +411,8 @@ def get_rnnt_logprobs_joint(
(B, T, S1, C) = logits.shape (B, T, S1, C) = logits.shape
S = S1 - 1 S = S1 - 1
assert symbols.shape == (B, S), symbols.shape assert symbols.shape == (B, S), symbols.shape
assert S >= 1, S
assert T >= S, (T, S) validate_st_lengths(S, T, rnnt_type == "regular", boundary)
assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type
normalizers = torch.logsumexp(logits, dim=3) normalizers = torch.logsumexp(logits, dim=3)
...@@ -437,6 +457,7 @@ def rnnt_loss( ...@@ -437,6 +457,7 @@ def rnnt_loss(
rnnt_type: str = "regular", rnnt_type: str = "regular",
delay_penalty: float = 0.0, delay_penalty: float = 0.0,
reduction: Optional[str] = "mean", reduction: Optional[str] = "mean",
return_grad: bool = False,
) -> Tensor: ) -> Tensor:
"""A normal RNN-T loss, which uses a 'joiner' network output as input, """A normal RNN-T loss, which uses a 'joiner' network output as input,
i.e. a 4 dimensions tensor. i.e. a 4 dimensions tensor.
...@@ -509,20 +530,24 @@ def rnnt_loss( ...@@ -509,20 +530,24 @@ def rnnt_loss(
penalty = penalty * delay_penalty penalty = penalty * delay_penalty
px += penalty.to(px.dtype) px += penalty.to(px.dtype)
negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary) scores_and_grads = mutual_information_recursion(
px=px, py=py, boundary=boundary, return_grad=return_grad
)
negated_loss = scores_and_grads[0] if return_grad else scores_and_grads
if reduction == "none": if reduction == "none":
return -negated_loss loss = -negated_loss
elif reduction == "mean": elif reduction == "mean":
return -torch.mean(negated_loss) loss = -torch.mean(negated_loss)
elif reduction == "sum": elif reduction == "sum":
return -torch.sum(negated_loss) loss = -torch.sum(negated_loss)
else: else:
raise ValueError( raise ValueError(
f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}" f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}"
) )
return (loss, scores_and_grads[1]) if return_grad else loss
def _monotonic_lower_bound(x: torch.Tensor) -> torch.Tensor: def _monotonic_lower_bound(x: Tensor) -> Tensor:
"""Compute a monotonically increasing lower bound of the tensor `x` on the """Compute a monotonically increasing lower bound of the tensor `x` on the
last dimension. The basic idea is: we traverse the tensor in reverse order, last dimension. The basic idea is: we traverse the tensor in reverse order,
and update current element with the following statement, and update current element with the following statement,
...@@ -556,9 +581,7 @@ def _monotonic_lower_bound(x: torch.Tensor) -> torch.Tensor: ...@@ -556,9 +581,7 @@ def _monotonic_lower_bound(x: torch.Tensor) -> torch.Tensor:
return x return x
def _adjust_pruning_lower_bound( def _adjust_pruning_lower_bound(s_begin: Tensor, s_range: int) -> Tensor:
s_begin: torch.Tensor, s_range: int
) -> torch.Tensor:
"""Adjust s_begin (pruning lower bounds) to make it satisfy the following """Adjust s_begin (pruning lower bounds) to make it satisfy the following
constraints constraints
...@@ -613,11 +636,11 @@ def _adjust_pruning_lower_bound( ...@@ -613,11 +636,11 @@ def _adjust_pruning_lower_bound(
# chapter 3.2 (Pruning bounds) of our Pruned RNN-T paper # chapter 3.2 (Pruning bounds) of our Pruned RNN-T paper
# (https://arxiv.org/pdf/2206.13236.pdf) # (https://arxiv.org/pdf/2206.13236.pdf)
def get_rnnt_prune_ranges( def get_rnnt_prune_ranges(
px_grad: torch.Tensor, px_grad: Tensor,
py_grad: torch.Tensor, py_grad: Tensor,
boundary: torch.Tensor, boundary: Tensor,
s_range: int, s_range: int,
) -> torch.Tensor: ) -> Tensor:
"""Get the pruning ranges of normal rnnt loss according to the grads """Get the pruning ranges of normal rnnt loss according to the grads
of px and py returned by mutual_information_recursion. of px and py returned by mutual_information_recursion.
...@@ -661,28 +684,44 @@ def get_rnnt_prune_ranges( ...@@ -661,28 +684,44 @@ def get_rnnt_prune_ranges(
""" """
(B, S, T1) = px_grad.shape (B, S, T1) = px_grad.shape
T = py_grad.shape[-1] T = py_grad.shape[-1]
is_regular = T1 != T
assert T1 in [T, T + 1], T1 assert T1 in [T, T + 1], T1
S1 = S + 1 S1 = S + 1
assert py_grad.shape == (B, S + 1, T), py_grad.shape assert py_grad.shape == (B, S + 1, T), py_grad.shape
assert boundary.shape == (B, 4), boundary.shape assert boundary.shape == (B, 4), boundary.shape
assert S >= 1, S validate_st_lengths(S, T, is_regular, boundary)
assert T >= S, (T, S)
# in regular case s_range should be no less than
# a minimum integer satisfying `(s_range - 1) * t + 1 >= s + 1`
if is_regular:
Ss = boundary[:, 2]
Ts = boundary[:, 3]
s_range_min = (
Ss.sub(1).div(Ts, rounding_mode="trunc").add(2).max().item()
)
if s_range < s_range_min:
print(
f"Warning: get_rnnt_prune_ranges - got s_range={s_range} "
f"for boundaries S={Ss}, T={Ts}. Adjusting to {s_range_min}"
)
s_range = s_range_min
# s_range > S means we won't prune out any symbols. To make indexing with # s_range > S means we won't prune out any symbols. To make indexing with
# ranges run normally, s_range should be equal to or less than ``S + 1``. # ranges run normally, s_range should be equal to or less than ``S + 1``.
if s_range > S: if s_range > S:
s_range = S + 1 s_range = S + 1
if T1 == T: if is_regular:
assert (
s_range >= 1
), "Pruning range for modified RNN-T should be equal to or greater than 1, or no valid paths could survive pruning."
else:
assert ( assert (
s_range >= 2 s_range >= 2
), "Pruning range for standard RNN-T should be equal to or greater than 2, or no valid paths could survive pruning." ), "Pruning range for standard RNN-T should be equal to or greater than 2, or no valid paths could survive pruning."
else:
assert (
s_range >= 1
), "Pruning range for modified RNN-T should be equal to or greater than 1, or no valid paths could survive pruning."
(B_stride, S_stride, T_stride) = py_grad.stride() (B_stride, S_stride, T_stride) = py_grad.stride()
blk_grad = torch.as_strided( blk_grad = torch.as_strided(
...@@ -739,8 +778,8 @@ def get_rnnt_prune_ranges( ...@@ -739,8 +778,8 @@ def get_rnnt_prune_ranges(
def do_rnnt_pruning( def do_rnnt_pruning(
am: torch.Tensor, lm: torch.Tensor, ranges: torch.Tensor am: Tensor, lm: Tensor, ranges: Tensor
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[Tensor, Tensor]:
"""Prune the output of encoder(am) and prediction network(lm) with ranges """Prune the output of encoder(am) and prediction network(lm) with ranges
generated by `get_rnnt_prune_ranges`. generated by `get_rnnt_prune_ranges`.
...@@ -779,7 +818,7 @@ def do_rnnt_pruning( ...@@ -779,7 +818,7 @@ def do_rnnt_pruning(
return am_pruning, lm_pruning return am_pruning, lm_pruning
def _roll_by_shifts(src: torch.Tensor, shifts: torch.LongTensor): def _roll_by_shifts(src: Tensor, shifts: torch.LongTensor):
"""Roll tensor with different shifts for each row. """Roll tensor with different shifts for each row.
Note: Note:
...@@ -819,7 +858,7 @@ def get_rnnt_logprobs_pruned( ...@@ -819,7 +858,7 @@ def get_rnnt_logprobs_pruned(
symbols: Tensor, symbols: Tensor,
ranges: Tensor, ranges: Tensor,
termination_symbol: int, termination_symbol: int,
boundary: Tensor, boundary: Optional[Tensor] = None,
rnnt_type: str = "regular", rnnt_type: str = "regular",
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
"""Construct px, py for mutual_information_recursion with pruned output. """Construct px, py for mutual_information_recursion with pruned output.
...@@ -888,10 +927,14 @@ def get_rnnt_logprobs_pruned( ...@@ -888,10 +927,14 @@ def get_rnnt_logprobs_pruned(
# ranges (B, T, s_range) # ranges (B, T, s_range)
assert logits.ndim == 4, logits.ndim assert logits.ndim == 4, logits.ndim
(B, T, s_range, C) = logits.shape (B, T, s_range, C) = logits.shape
assert ranges.shape == (B, T, s_range), ranges.shape assert ranges.shape == (
B,
T,
s_range,
), f"{ranges.shape} == ({B}, {T}, {s_range})"
(B, S) = symbols.shape (B, S) = symbols.shape
assert S >= 1, S
assert T >= S, (T, S) validate_st_lengths(S, T, rnnt_type == "regular", boundary)
assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type
normalizers = torch.logsumexp(logits, dim=3) normalizers = torch.logsumexp(logits, dim=3)
...@@ -986,10 +1029,11 @@ def rnnt_loss_pruned( ...@@ -986,10 +1029,11 @@ def rnnt_loss_pruned(
symbols: Tensor, symbols: Tensor,
ranges: Tensor, ranges: Tensor,
termination_symbol: int, termination_symbol: int,
boundary: Tensor = None, boundary: Optional[Tensor] = None,
rnnt_type: str = "regular", rnnt_type: str = "regular",
delay_penalty: float = 0.0, delay_penalty: float = 0.0,
reduction: Optional[str] = "mean", reduction: Optional[str] = "mean",
return_grad: bool = False,
) -> Tensor: ) -> Tensor:
"""A RNN-T loss with pruning, which uses the output of a pruned 'joiner' """A RNN-T loss with pruning, which uses the output of a pruned 'joiner'
network as input, i.e. a 4 dimensions tensor with shape (B, T, s_range, C), network as input, i.e. a 4 dimensions tensor with shape (B, T, s_range, C),
...@@ -1071,17 +1115,21 @@ def rnnt_loss_pruned( ...@@ -1071,17 +1115,21 @@ def rnnt_loss_pruned(
penalty = penalty * delay_penalty penalty = penalty * delay_penalty
px += penalty.to(px.dtype) px += penalty.to(px.dtype)
negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary) scores_and_grads = mutual_information_recursion(
px=px, py=py, boundary=boundary, return_grad=return_grad
)
negated_loss = scores_and_grads[0] if return_grad else scores_and_grads
if reduction == "none": if reduction == "none":
return -negated_loss loss = -negated_loss
elif reduction == "mean": elif reduction == "mean":
return -torch.mean(negated_loss) loss = -torch.mean(negated_loss)
elif reduction == "sum": elif reduction == "sum":
return -torch.sum(negated_loss) loss = -torch.sum(negated_loss)
else: else:
raise ValueError( raise ValueError(
f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}" f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}"
) )
return (loss, scores_and_grads[1]) if return_grad else loss
def get_rnnt_logprobs_smoothed( def get_rnnt_logprobs_smoothed(
...@@ -1202,8 +1250,8 @@ def get_rnnt_logprobs_smoothed( ...@@ -1202,8 +1250,8 @@ def get_rnnt_logprobs_smoothed(
(B, T, C) = am.shape (B, T, C) = am.shape
S = lm.shape[1] - 1 S = lm.shape[1] - 1
assert symbols.shape == (B, S), symbols.shape assert symbols.shape == (B, S), symbols.shape
assert S >= 1, S
assert T >= S, (T, S) validate_st_lengths(S, T, rnnt_type == "regular", boundary)
assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type
# Caution: some parts of this code are a little less clear than they could # Caution: some parts of this code are a little less clear than they could
......
...@@ -206,7 +206,6 @@ class TestMutualInformation(unittest.TestCase): ...@@ -206,7 +206,6 @@ class TestMutualInformation(unittest.TestCase):
for dtype in self.dtypes: for dtype in self.dtypes:
for device in self.devices: for device in self.devices:
if random_boundary: if random_boundary:
def get_boundary_row(): def get_boundary_row():
......
...@@ -343,7 +343,6 @@ class TestRnntLoss(unittest.TestCase): ...@@ -343,7 +343,6 @@ class TestRnntLoss(unittest.TestCase):
boundary_[:, 3] = frames boundary_[:, 3] = frames
for device in self.devices: for device in self.devices:
# lm: [B][S+1][C] # lm: [B][S+1][C]
lm = lm_.to(device) lm = lm_.to(device)
# am: [B][T][C] # am: [B][T][C]
...@@ -609,6 +608,102 @@ class TestRnntLoss(unittest.TestCase): ...@@ -609,6 +608,102 @@ class TestRnntLoss(unittest.TestCase):
) )
print(f"Pruned loss with range {r} : {pruned_loss}") print(f"Pruned loss with range {r} : {pruned_loss}")
# Test low s_range values with large S and small T,
# at this circumstance, the s_range would not be enough
# to cover the whole sequence length (in regular rnnt mode)
# and would result in inf loss
def test_rnnt_loss_pruned_small_s_range(self):
B = 2
T = 2
S = 10
C = 10
frames = torch.randint(1, T, (B,))
seq_lengths = torch.randint(1, S, (B,))
T = torch.max(frames)
S = torch.max(seq_lengths)
am_ = torch.randn((B, T, C), dtype=torch.float64)
lm_ = torch.randn((B, S + 1, C), dtype=torch.float64)
symbols_ = torch.randint(0, C, (B, S))
terminal_symbol = C - 1
boundary_ = torch.zeros((B, 4), dtype=torch.int64)
boundary_[:, 2] = seq_lengths
boundary_[:, 3] = frames
print(f"B = {B}, T = {T}, S = {S}, C = {C}")
for rnnt_type in ["regular"]:
for device in self.devices:
# normal rnnt
am = am_.to(device)
lm = lm_.to(device)
symbols = symbols_.to(device)
boundary = boundary_.to(device)
logits = am.unsqueeze(2) + lm.unsqueeze(1)
logits = logits.float()
# nonlinear transform
logits = torch.sigmoid(logits)
loss = fast_rnnt.rnnt_loss(
logits=logits,
symbols=symbols,
termination_symbol=terminal_symbol,
boundary=boundary,
rnnt_type=rnnt_type,
reduction="none",
)
print(f"Unpruned rnnt loss with {rnnt_type} rnnt : {loss}")
# pruning
simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=terminal_symbol,
boundary=boundary,
rnnt_type=rnnt_type,
return_grad=True,
reduction="none",
)
S0 = 2
for r in range(S0, S + 2):
ranges = fast_rnnt.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=r,
)
# (B, T, r, C)
pruned_am, pruned_lm = fast_rnnt.do_rnnt_pruning(
am=am, lm=lm, ranges=ranges
)
logits = pruned_am + pruned_lm
# nonlinear transform
logits = torch.sigmoid(logits)
pruned_loss = fast_rnnt.rnnt_loss_pruned(
logits=logits,
symbols=symbols,
ranges=ranges,
termination_symbol=terminal_symbol,
boundary=boundary,
rnnt_type=rnnt_type,
reduction="none",
)
assert (
not pruned_loss.isinf().any()
), f"Pruned loss is inf for r={r}, S={S}, T={T}: {pruned_loss}"
print(f"Pruned loss with range {r} : {pruned_loss}")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment