Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
b0ed23ef
"docs/source/vscode:/vscode.git/clone" did not exist on "40afdf7cd5cb332127c28a53e1d729bb4e97b6c7"
Commit
b0ed23ef
authored
Oct 30, 2022
by
pkufool
Browse files
Add constrained rnnt
parent
dc35168d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
292 additions
and
152 deletions
+292
-152
fast_rnnt/python/fast_rnnt/rnnt_loss.py
fast_rnnt/python/fast_rnnt/rnnt_loss.py
+254
-122
fast_rnnt/python/tests/rnnt_loss_test.py
fast_rnnt/python/tests/rnnt_loss_test.py
+38
-30
No files found.
fast_rnnt/python/fast_rnnt/rnnt_loss.py
View file @
b0ed23ef
...
@@ -26,13 +26,13 @@ from .mutual_information import mutual_information_recursion
...
@@ -26,13 +26,13 @@ from .mutual_information import mutual_information_recursion
def
fix_for_boundary
(
px
:
Tensor
,
boundary
:
Optional
[
Tensor
]
=
None
)
->
Tensor
:
def
fix_for_boundary
(
px
:
Tensor
,
boundary
:
Optional
[
Tensor
]
=
None
)
->
Tensor
:
"""
"""
Insert -inf's into `px` in appropriate places if `boundary` is not
Insert -inf's into `px` in appropriate places if `boundary` is not
None. If boundary == None and
modified == False
, px[:,:,-1] will
None. If boundary == None and
rnnt_type == "regular"
, px[:,:,-1] will
be -infinity, but if boundary is specified, we need px[b,:,boundary[b,3]]
be -infinity, but if boundary is specified, we need px[b,:,boundary[b,3]]
to be -infinity.
to be -infinity.
Args:
Args:
px: a Tensor of of shape [B][S][T+1] (this function is only
px: a Tensor of of shape [B][S][T+1] (this function is only
called if
modified == False
, see other docs for `
modified
`)
called if
rnnt_type == "regular"
, see other docs for `
rnnt_type
`)
px is modified in-place and returned.
px is modified in-place and returned.
boundary: None, or a Tensor of shape [B][3] containing
boundary: None, or a Tensor of shape [B][3] containing
[s_begin, t_begin, s_end, t_end]; we need only t_end.
[s_begin, t_begin, s_end, t_end]; we need only t_end.
...
@@ -49,8 +49,8 @@ def get_rnnt_logprobs(
...
@@ -49,8 +49,8 @@ def get_rnnt_logprobs(
am
:
Tensor
,
am
:
Tensor
,
symbols
:
Tensor
,
symbols
:
Tensor
,
termination_symbol
:
int
,
termination_symbol
:
int
,
rnnt_type
:
str
=
"regular"
,
boundary
:
Optional
[
Tensor
]
=
None
,
boundary
:
Optional
[
Tensor
]
=
None
,
modified
:
bool
=
False
,
)
->
Tuple
[
Tensor
,
Tensor
]:
)
->
Tuple
[
Tensor
,
Tensor
]:
"""
"""
Reduces RNN-T problem (the simple case, where joiner network is just
Reduces RNN-T problem (the simple case, where joiner network is just
...
@@ -97,20 +97,32 @@ def get_rnnt_logprobs(
...
@@ -97,20 +97,32 @@ def get_rnnt_logprobs(
[0, 0, S, T]
[0, 0, S, T]
if boundary is not supplied.
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
rnnt_type:
also be consumed, so at most 1 symbol can appear per frame.
Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
Returns:
Returns:
(px, py) (the names are quite arbitrary).
(px, py) (the names are quite arbitrary).
px: logprobs, of shape [B][S][T+1] if !modified, [B][S][T] if modified.
px: logprobs, of shape [B][S][T+1] if rnnt_type is regular,
[B][S][T] if rnnt_type is not regular.
py: logprobs, of shape [B][S+1][T]
py: logprobs, of shape [B][S+1][T]
in the recursion::
in the recursion::
p[b,0,0] = 0.0
p[b,0,0] = 0.0
if
!modified
:
if
rnnt_type == "regular"
:
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
p[b,s,t-1] + py[b,s,t-1])
if
modified
:
if
rnnt_type != "regular"
:
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1])
p[b,s,t-1] + py[b,s,t-1])
.. where p[b][s][t] is the "joint score" of the pair of subsequences
.. where p[b][s][t] is the "joint score" of the pair of subsequences
...
@@ -121,21 +133,22 @@ def get_rnnt_logprobs(
...
@@ -121,21 +133,22 @@ def get_rnnt_logprobs(
(s,t) by one in the t direction,
(s,t) by one in the t direction,
i.e. of emitting the termination/next-frame symbol.
i.e. of emitting the termination/next-frame symbol.
if
!modified
, px[:,:,T] equals -infinity, meaning on the
if
rnnt_type == "regular"
, px[:,:,T] equals -infinity, meaning on the
"one-past-the-last" frame we cannot emit any symbols.
"one-past-the-last" frame we cannot emit any symbols.
This is simply a way of incorporating
This is simply a way of incorporating
the probability of the termination symbol on the last frame.
the probability of the termination symbol on the last frame.
"""
"""
assert
lm
.
ndim
==
3
assert
lm
.
ndim
==
3
,
lm
.
ndim
assert
am
.
ndim
==
3
assert
am
.
ndim
==
3
,
am
.
ndim
assert
lm
.
shape
[
0
]
==
am
.
shape
[
0
]
assert
lm
.
shape
[
0
]
==
am
.
shape
[
0
]
,
(
lm
.
shape
[
0
],
am
.
shape
[
0
])
assert
lm
.
shape
[
2
]
==
am
.
shape
[
2
]
assert
lm
.
shape
[
2
]
==
am
.
shape
[
2
]
,
(
lm
.
shape
[
2
],
am
.
shape
[
2
])
(
B
,
T
,
C
)
=
am
.
shape
(
B
,
T
,
C
)
=
am
.
shape
S
=
lm
.
shape
[
1
]
-
1
S
=
lm
.
shape
[
1
]
-
1
assert
symbols
.
shape
==
(
B
,
S
)
assert
symbols
.
shape
==
(
B
,
S
),
symbols
.
shape
assert
S
>=
1
assert
S
>=
1
,
S
assert
T
>=
S
assert
T
>=
S
,
(
T
,
S
)
assert
rnnt_type
in
[
"regular"
,
"modified"
,
"constrained"
],
rnnt_type
# subtracting am_max and lm_max is to ensure the probs are in a good range
# subtracting am_max and lm_max is to ensure the probs are in a good range
# to do exp() without causing underflow or overflow.
# to do exp() without causing underflow or overflow.
...
@@ -162,7 +175,7 @@ def get_rnnt_logprobs(
...
@@ -162,7 +175,7 @@ def get_rnnt_logprobs(
-
1
-
1
)
# [B][S][T]
)
# [B][S][T]
if
not
modified
:
if
rnnt_type
==
"regular"
:
px_am
=
torch
.
cat
(
px_am
=
torch
.
cat
(
(
(
px_am
,
px_am
,
...
@@ -189,8 +202,10 @@ def get_rnnt_logprobs(
...
@@ -189,8 +202,10 @@ def get_rnnt_logprobs(
py_lm
=
lm
[:,
:,
termination_symbol
].
unsqueeze
(
2
)
# [B][S+1][1]
py_lm
=
lm
[:,
:,
termination_symbol
].
unsqueeze
(
2
)
# [B][S+1][1]
py
=
py_am
+
py_lm
-
normalizers
py
=
py_am
+
py_lm
-
normalizers
if
not
modified
:
if
rnnt_type
==
"regular"
:
px
=
fix_for_boundary
(
px
,
boundary
)
px
=
fix_for_boundary
(
px
,
boundary
)
elif
rnnt_type
==
"constrained"
:
px
+=
py
[:,
1
:,
:]
return
(
px
,
py
)
return
(
px
,
py
)
...
@@ -201,7 +216,7 @@ def rnnt_loss_simple(
...
@@ -201,7 +216,7 @@ def rnnt_loss_simple(
symbols
:
Tensor
,
symbols
:
Tensor
,
termination_symbol
:
int
,
termination_symbol
:
int
,
boundary
:
Optional
[
Tensor
]
=
None
,
boundary
:
Optional
[
Tensor
]
=
None
,
modified
:
bool
=
False
,
rnnt_type
:
str
=
"regular"
,
delay_penalty
:
float
=
0.0
,
delay_penalty
:
float
=
0.0
,
reduction
:
Optional
[
str
]
=
"mean"
,
reduction
:
Optional
[
str
]
=
"mean"
,
return_grad
:
bool
=
False
,
return_grad
:
bool
=
False
,
...
@@ -227,8 +242,19 @@ def rnnt_loss_simple(
...
@@ -227,8 +242,19 @@ def rnnt_loss_simple(
[0, 0, S, T]
[0, 0, S, T]
if boundary is not supplied.
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
rnnt_type:
also be consumed, so at most 1 symbol can appear per frame.
Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
delay_penalty: A constant value to penalize symbol delay, this may be
delay_penalty: A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time-masking
needed when training with time masking, to avoid the time-masking
encouraging the network to delay symbols.
encouraging the network to delay symbols.
...
@@ -260,12 +286,12 @@ def rnnt_loss_simple(
...
@@ -260,12 +286,12 @@ def rnnt_loss_simple(
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
termination_symbol
=
termination_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
)
)
if
delay_penalty
>
0.0
:
if
delay_penalty
>
0.0
:
B
,
S
,
T0
=
px
.
shape
B
,
S
,
T0
=
px
.
shape
T
=
T0
if
modified
else
T0
-
1
T
=
T0
if
rnnt_type
!=
"regular"
else
T0
-
1
if
boundary
is
None
:
if
boundary
is
None
:
offset
=
torch
.
tensor
(
offset
=
torch
.
tensor
(
(
T
-
1
)
/
2
,
dtype
=
px
.
dtype
,
device
=
px
.
device
,
(
T
-
1
)
/
2
,
dtype
=
px
.
dtype
,
device
=
px
.
device
,
...
@@ -289,9 +315,9 @@ def rnnt_loss_simple(
...
@@ -289,9 +315,9 @@ def rnnt_loss_simple(
elif
reduction
==
"sum"
:
elif
reduction
==
"sum"
:
loss
=
-
torch
.
sum
(
negated_loss
)
loss
=
-
torch
.
sum
(
negated_loss
)
else
:
else
:
assert
(
raise
ValueError
(
False
f
"reduction should be ('none' | 'mean' | 'sum'), given
{
reduction
}
"
)
,
f
"reduction should be ('none' | 'mean' | 'sum'), given
{
reduction
}
"
)
return
(
loss
,
scores_and_grads
[
1
])
if
return_grad
else
loss
return
(
loss
,
scores_and_grads
[
1
])
if
return_grad
else
loss
...
@@ -300,7 +326,7 @@ def get_rnnt_logprobs_joint(
...
@@ -300,7 +326,7 @@ def get_rnnt_logprobs_joint(
symbols
:
Tensor
,
symbols
:
Tensor
,
termination_symbol
:
int
,
termination_symbol
:
int
,
boundary
:
Optional
[
Tensor
]
=
None
,
boundary
:
Optional
[
Tensor
]
=
None
,
modified
:
bool
=
False
,
rnnt_type
:
str
=
"regular"
,
)
->
Tuple
[
Tensor
,
Tensor
]:
)
->
Tuple
[
Tensor
,
Tensor
]:
"""Reduces RNN-T problem to a compact, standard form that can then be given
"""Reduces RNN-T problem to a compact, standard form that can then be given
(with boundaries) to mutual_information_recursion().
(with boundaries) to mutual_information_recursion().
...
@@ -321,21 +347,33 @@ def get_rnnt_logprobs_joint(
...
@@ -321,21 +347,33 @@ def get_rnnt_logprobs_joint(
[0, 0, S, T]
[0, 0, S, T]
if boundary is not supplied.
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
rnnt_type:
also be consumed, so at most 1 symbol can appear per frame.
Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
Returns:
Returns:
(px, py) (the names are quite arbitrary)::
(px, py) (the names are quite arbitrary)::
px: logprobs, of shape [B][S][T+1]
px: logprobs, of shape [B][S][T+1] if rnnt_type is regular,
[B][S][T] if rnnt_type is not regular.
py: logprobs, of shape [B][S+1][T]
py: logprobs, of shape [B][S+1][T]
in the recursion::
in the recursion::
p[b,0,0] = 0.0
p[b,0,0] = 0.0
if
!modified
:
if
rnnt_type == "regular"
:
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
p[b,s,t-1] + py[b,s,t-1])
if
modified
:
if
rnnt_type != "regular"
:
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1])
p[b,s,t-1] + py[b,s,t-1])
.. where p[b][s][t] is the "joint score" of the pair of subsequences of
.. where p[b][s][t] is the "joint score" of the pair of subsequences of
...
@@ -345,17 +383,18 @@ def get_rnnt_logprobs_joint(
...
@@ -345,17 +383,18 @@ def get_rnnt_logprobs_joint(
of extending the subsequences of length (s,t) by one in the t direction,
of extending the subsequences of length (s,t) by one in the t direction,
i.e. of emitting the termination/next-frame symbol.
i.e. of emitting the termination/next-frame symbol.
if
!modified
, px[:,:,T] equals -infinity, meaning on the
if
rnnt_type == "regular"
, px[:,:,T] equals -infinity, meaning on the
"one-past-the-last" frame we cannot emit any symbols.
"one-past-the-last" frame we cannot emit any symbols.
This is simply a way of incorporating
This is simply a way of incorporating
the probability of the termination symbol on the last frame.
the probability of the termination symbol on the last frame.
"""
"""
assert
logits
.
ndim
==
4
assert
logits
.
ndim
==
4
,
logits
.
ndim
(
B
,
T
,
S1
,
C
)
=
logits
.
shape
(
B
,
T
,
S1
,
C
)
=
logits
.
shape
S
=
S1
-
1
S
=
S1
-
1
assert
symbols
.
shape
==
(
B
,
S
)
assert
symbols
.
shape
==
(
B
,
S
),
symbols
.
shape
assert
S
>=
1
assert
S
>=
1
,
S
assert
T
>=
S
assert
T
>=
S
,
(
T
,
S
)
assert
rnnt_type
in
[
"regular"
,
"modified"
,
"constrained"
],
rnnt_type
normalizers
=
torch
.
logsumexp
(
logits
,
dim
=
3
)
normalizers
=
torch
.
logsumexp
(
logits
,
dim
=
3
)
normalizers
=
normalizers
.
permute
((
0
,
2
,
1
))
normalizers
=
normalizers
.
permute
((
0
,
2
,
1
))
...
@@ -365,7 +404,7 @@ def get_rnnt_logprobs_joint(
...
@@ -365,7 +404,7 @@ def get_rnnt_logprobs_joint(
).
squeeze
(
-
1
)
).
squeeze
(
-
1
)
px
=
px
.
permute
((
0
,
2
,
1
))
px
=
px
.
permute
((
0
,
2
,
1
))
if
not
modified
:
if
rnnt_type
==
"regular"
:
px
=
torch
.
cat
(
px
=
torch
.
cat
(
(
(
px
,
px
,
...
@@ -383,8 +422,10 @@ def get_rnnt_logprobs_joint(
...
@@ -383,8 +422,10 @@ def get_rnnt_logprobs_joint(
)
# [B][S+1][T]
)
# [B][S+1][T]
py
-=
normalizers
py
-=
normalizers
if
not
modified
:
if
rnnt_type
==
"regular"
:
px
=
fix_for_boundary
(
px
,
boundary
)
px
=
fix_for_boundary
(
px
,
boundary
)
elif
rnnt_type
==
"constrained"
:
px
+=
py
[:,
1
:,
:]
return
(
px
,
py
)
return
(
px
,
py
)
...
@@ -394,7 +435,7 @@ def rnnt_loss(
...
@@ -394,7 +435,7 @@ def rnnt_loss(
symbols
:
Tensor
,
symbols
:
Tensor
,
termination_symbol
:
int
,
termination_symbol
:
int
,
boundary
:
Optional
[
Tensor
]
=
None
,
boundary
:
Optional
[
Tensor
]
=
None
,
modified
:
bool
=
False
,
rnnt_type
:
str
=
"regular"
,
delay_penalty
:
float
=
0.0
,
delay_penalty
:
float
=
0.0
,
reduction
:
Optional
[
str
]
=
"mean"
,
reduction
:
Optional
[
str
]
=
"mean"
,
)
->
Tensor
:
)
->
Tensor
:
...
@@ -415,8 +456,19 @@ def rnnt_loss(
...
@@ -415,8 +456,19 @@ def rnnt_loss(
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T] if boundary is not supplied.
[0, 0, S, T] if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
rnnt_type:
also be consumed, so at most 1 symbol can appear per frame.
Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
delay_penalty: A constant value to penalize symbol delay, this may be
delay_penalty: A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time-masking
needed when training with time masking, to avoid the time-masking
encouraging the network to delay symbols.
encouraging the network to delay symbols.
...
@@ -438,11 +490,12 @@ def rnnt_loss(
...
@@ -438,11 +490,12 @@ def rnnt_loss(
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
termination_symbol
=
termination_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
)
)
if
delay_penalty
>
0.0
:
if
delay_penalty
>
0.0
:
B
,
S
,
T0
=
px
.
shape
B
,
S
,
T0
=
px
.
shape
T
=
T0
if
modified
else
T0
-
1
T
=
T0
if
rnnt_type
!=
"regular"
else
T0
-
1
if
boundary
is
None
:
if
boundary
is
None
:
offset
=
torch
.
tensor
(
offset
=
torch
.
tensor
(
(
T
-
1
)
/
2
,
dtype
=
px
.
dtype
,
device
=
px
.
device
,
(
T
-
1
)
/
2
,
dtype
=
px
.
dtype
,
device
=
px
.
device
,
...
@@ -454,6 +507,7 @@ def rnnt_loss(
...
@@ -454,6 +507,7 @@ def rnnt_loss(
).
reshape
(
1
,
1
,
T0
)
).
reshape
(
1
,
1
,
T0
)
penalty
=
penalty
*
delay_penalty
penalty
=
penalty
*
delay_penalty
px
+=
penalty
.
to
(
px
.
dtype
)
px
+=
penalty
.
to
(
px
.
dtype
)
negated_loss
=
mutual_information_recursion
(
px
=
px
,
py
=
py
,
boundary
=
boundary
)
negated_loss
=
mutual_information_recursion
(
px
=
px
,
py
=
py
,
boundary
=
boundary
)
if
reduction
==
"none"
:
if
reduction
==
"none"
:
return
-
negated_loss
return
-
negated_loss
...
@@ -462,30 +516,30 @@ def rnnt_loss(
...
@@ -462,30 +516,30 @@ def rnnt_loss(
elif
reduction
==
"sum"
:
elif
reduction
==
"sum"
:
return
-
torch
.
sum
(
negated_loss
)
return
-
torch
.
sum
(
negated_loss
)
else
:
else
:
assert
(
raise
ValueError
(
False
f
"reduction should be ('none' | 'mean' | 'sum'), given
{
reduction
}
"
)
,
f
"reduction should be ('none' | 'mean' | 'sum'), given
{
reduction
}
"
)
def
_adjust_pruning_lower_bound
(
def
_adjust_pruning_lower_bound
(
s_begin
:
torch
.
Tensor
,
s_range
:
int
s_begin
:
torch
.
Tensor
,
s_range
:
int
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Adjust s_begin (pruning lower bound) to make it satisf
ied
the following
"""Adjust s_begin (pruning lower bound
s
) to make it satisf
y
the following
constrains
constrain
t
s
- monotonic increasing, i.e. s_begin[i] <= s_begin[i + 1]
- monotonic increasing, i.e. s_begin[i] <= s_begin[i + 1]
- start with symbol 0 at first frame.
- start with symbol 0 at first frame.
- s_begin[i + 1] - s_begin[i] < s_range, whic
n
means that we can't skip
- s_begin[i + 1] - s_begin[i] < s_range, whic
h
means that we can't skip
any symbols.
any symbols.
To make it monotonic increasing, we can use `monotonic_lower_bound` function
To make it monotonic increasing, we can use `monotonic_lower_bound` function
in k2, which guarantee `s_begin[i] <= s_begin[i + 1]`. The main idea is:
in k2, which guarantee
s
`s_begin[i] <= s_begin[i + 1]`. The main idea is:
traverse the array in reverse order and update the elements by
traverse the array in reverse order and update the elements by
`min_value = min(a_begin[i], min_value)`, the initial `min_value` set to
`min_value = min(a_begin[i], min_value)`, the initial `min_value`
is
set to
`inf`.
`inf`.
The method we used to realize `s_begin[i + 1] - s_begin[i] < s_range`
The method we used to realize `s_begin[i + 1] - s_begin[i] < s_range`
constrain is a little tricky. We first transform `s_begin` with
constrain
t
is a little tricky. We first transform `s_begin` with
`s_begin = -(s_begin - (s_range - 1) * torch.arange(0,T))`
`s_begin = -(s_begin - (s_range - 1) * torch.arange(0,T))`
then we make the transformed `s_begin` monotonic increasing, after that,
then we make the transformed `s_begin` monotonic increasing, after that,
we transform back `s_begin` with the same formula as the previous
we transform back `s_begin` with the same formula as the previous
...
@@ -551,9 +605,9 @@ def get_rnnt_prune_ranges(
...
@@ -551,9 +605,9 @@ def get_rnnt_prune_ranges(
Note:
Note:
For the generated tensor ranges(assuming batch size is 1), ranges[:, 0]
For the generated tensor ranges(assuming batch size is 1), ranges[:, 0]
is a monotonic increasing tensor from 0 to `len(symbols)
` and it satisfies
is a monotonic increasing tensor from 0 to `len(symbols)
- s_range` and
`ranges[t+1, 0] - ranges[t, 0] < s_range` which means we
won't skip any
it satisfies
`ranges[t+1, 0] - ranges[t, 0] < s_range` which means we
symbols.
won't skip any
symbols.
Args:
Args:
px_grad:
px_grad:
...
@@ -568,21 +622,21 @@ def get_rnnt_prune_ranges(
...
@@ -568,21 +622,21 @@ def get_rnnt_prune_ranges(
s_range:
s_range:
How many symbols to keep for each frame.
How many symbols to keep for each frame.
Returns:
Returns:
A tensor contain
s
the
kept symbols indexes for each frame, with shap
e
A tensor
with the shape of (B, T, s_range)
contain
ing
the
indexes of th
e
(B, T, s_range)
.
kept symbols for each frame
.
"""
"""
(
B
,
S
,
T1
)
=
px_grad
.
shape
(
B
,
S
,
T1
)
=
px_grad
.
shape
T
=
py_grad
.
shape
[
-
1
]
T
=
py_grad
.
shape
[
-
1
]
assert
T1
in
[
T
,
T
+
1
]
assert
T1
in
[
T
,
T
+
1
]
,
T1
S1
=
S
+
1
S1
=
S
+
1
assert
py_grad
.
shape
==
(
B
,
S
+
1
,
T
)
assert
py_grad
.
shape
==
(
B
,
S
+
1
,
T
)
,
py_grad
.
shape
assert
boundary
.
shape
==
(
B
,
4
)
assert
boundary
.
shape
==
(
B
,
4
)
,
boundary
.
shape
assert
S
>=
1
assert
S
>=
1
,
S
assert
T
>=
S
assert
T
>=
S
,
(
T
,
S
)
# s_range > S means we won't prune out any symbols. To make indexing with
# s_range > S means we won't prune out any symbols. To make indexing with
# ranges run
s
normally, s_range should be equal to or less than ``S + 1``.
# ranges run normally, s_range should be equal to or less than ``S + 1``.
if
s_range
>
S
:
if
s_range
>
S
:
s_range
=
S
+
1
s_range
=
S
+
1
...
@@ -630,16 +684,17 @@ def get_rnnt_prune_ranges(
...
@@ -630,16 +684,17 @@ def get_rnnt_prune_ranges(
mask
=
mask
<
boundary
[:,
3
].
reshape
(
B
,
1
)
-
1
mask
=
mask
<
boundary
[:,
3
].
reshape
(
B
,
1
)
-
1
s_begin_padding
=
boundary
[:,
2
].
reshape
(
B
,
1
)
-
s_range
+
1
s_begin_padding
=
boundary
[:,
2
].
reshape
(
B
,
1
)
-
s_range
+
1
# handle the cases whe
n
`len(symbols) < s_range`
# handle the cases whe
re
`len(symbols) < s_range`
s_begin_padding
=
torch
.
clamp
(
s_begin_padding
,
min
=
0
)
s_begin_padding
=
torch
.
clamp
(
s_begin_padding
,
min
=
0
)
s_begin
=
torch
.
where
(
mask
,
s_begin
,
s_begin_padding
)
s_begin
=
torch
.
where
(
mask
,
s_begin
,
s_begin_padding
)
# adjusting lower bound to make it satisfied some constrains, see docs in
# adjusting lower bound to make it satisfy some constraints, see docs in
# `_adjust_pruning_lower_bound` for more details of these constrains.
# `_adjust_pruning_lower_bound` for more details of these constraints.
# T1 == T here means we are using the modified version of transducer,
# T1 == T here means we are using the non-regular(i.e. modified rnnt or
# the third constrain becomes `s_begin[i + 1] - s_begin[i] < 2`, because
# constrained rnnt) version of transducer, the third constraint becomes
# it only emits one symbol per frame.
# `s_begin[i + 1] - s_begin[i] < 2`, because it only emits one symbol per
# frame.
s_begin
=
_adjust_pruning_lower_bound
(
s_begin
,
2
if
T1
==
T
else
s_range
)
s_begin
=
_adjust_pruning_lower_bound
(
s_begin
,
2
if
T1
==
T
else
s_range
)
ranges
=
s_begin
.
reshape
((
B
,
T
,
1
)).
expand
((
B
,
T
,
s_range
))
+
torch
.
arange
(
ranges
=
s_begin
.
reshape
((
B
,
T
,
1
)).
expand
((
B
,
T
,
s_range
))
+
torch
.
arange
(
...
@@ -652,8 +707,8 @@ def get_rnnt_prune_ranges(
...
@@ -652,8 +707,8 @@ def get_rnnt_prune_ranges(
def
do_rnnt_pruning
(
def
do_rnnt_pruning
(
am
:
torch
.
Tensor
,
lm
:
torch
.
Tensor
,
ranges
:
torch
.
Tensor
am
:
torch
.
Tensor
,
lm
:
torch
.
Tensor
,
ranges
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Prune the output of encoder(am)
output
and prediction network(lm)
"""Prune the output of encoder(am) and prediction network(lm)
with ranges
output of RNNT
.
generated by `get_rnnt_prune_ranges`
.
Args:
Args:
am:
am:
...
@@ -671,9 +726,9 @@ def do_rnnt_pruning(
...
@@ -671,9 +726,9 @@ def do_rnnt_pruning(
# am (B, T, C)
# am (B, T, C)
# lm (B, S + 1, C)
# lm (B, S + 1, C)
# ranges (B, T, s_range)
# ranges (B, T, s_range)
assert
ranges
.
shape
[
0
]
==
am
.
shape
[
0
]
assert
ranges
.
shape
[
0
]
==
am
.
shape
[
0
]
,
(
ranges
.
shape
[
0
],
am
.
shape
[
0
])
assert
ranges
.
shape
[
0
]
==
lm
.
shape
[
0
]
assert
ranges
.
shape
[
0
]
==
lm
.
shape
[
0
]
,
(
ranges
.
shape
[
0
],
lm
.
shape
[
0
])
assert
am
.
shape
[
1
]
==
ranges
.
shape
[
1
]
assert
am
.
shape
[
1
]
==
ranges
.
shape
[
1
]
,
(
am
.
shape
[
1
],
ranges
.
shape
[
1
])
(
B
,
T
,
s_range
)
=
ranges
.
shape
(
B
,
T
,
s_range
)
=
ranges
.
shape
(
B
,
S1
,
C
)
=
lm
.
shape
(
B
,
S1
,
C
)
=
lm
.
shape
S
=
S1
-
1
S
=
S1
-
1
...
@@ -711,9 +766,9 @@ def _roll_by_shifts(src: torch.Tensor, shifts: torch.LongTensor):
...
@@ -711,9 +766,9 @@ def _roll_by_shifts(src: torch.Tensor, shifts: torch.LongTensor):
[ 8, 9, 5, 6, 7],
[ 8, 9, 5, 6, 7],
[12, 13, 14, 10, 11]]])
[12, 13, 14, 10, 11]]])
"""
"""
assert
src
.
dim
()
==
3
assert
src
.
dim
()
==
3
,
src
.
dim
()
(
B
,
T
,
S
)
=
src
.
shape
(
B
,
T
,
S
)
=
src
.
shape
assert
shifts
.
shape
==
(
B
,
T
)
assert
shifts
.
shape
==
(
B
,
T
)
,
shifts
.
shape
index
=
(
index
=
(
torch
.
arange
(
S
,
device
=
src
.
device
)
torch
.
arange
(
S
,
device
=
src
.
device
)
...
@@ -731,7 +786,7 @@ def get_rnnt_logprobs_pruned(
...
@@ -731,7 +786,7 @@ def get_rnnt_logprobs_pruned(
ranges
:
Tensor
,
ranges
:
Tensor
,
termination_symbol
:
int
,
termination_symbol
:
int
,
boundary
:
Tensor
,
boundary
:
Tensor
,
modified
:
bool
=
False
,
rnnt_type
:
str
=
"regular"
,
)
->
Tuple
[
Tensor
,
Tensor
]:
)
->
Tuple
[
Tensor
,
Tensor
]:
"""Construct px, py for mutual_information_recursion with pruned output.
"""Construct px, py for mutual_information_recursion with pruned output.
...
@@ -751,21 +806,53 @@ def get_rnnt_logprobs_pruned(
...
@@ -751,21 +806,53 @@ def get_rnnt_logprobs_pruned(
[0, 0, S, T]
[0, 0, S, T]
if boundary is not supplied.
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
rnnt_type:
also be consumed, so at most 1 symbol can appear per frame.
Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
Returns:
Returns:
Return the px (B, S, T) if modified else (B, S, T + 1) and
(px, py) (the names are quite arbitrary)::
py (B, S + 1, T) needed by mutual_information_recursion.
px: logprobs, of shape [B][S][T+1] if rnnt_type is regular,
[B][S][T] if rnnt_type is not regular.
py: logprobs, of shape [B][S+1][T]
in the recursion::
p[b,0,0] = 0.0
if rnnt_type == "regular":
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
if rnnt_type != "regular":
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1])
.. where p[b][s][t] is the "joint score" of the pair of subsequences of
length s and t respectively. px[b][s][t] represents the probability of
extending the subsequences of length (s,t) by one in the s direction,
given the particular symbol, and py[b][s][t] represents the probability
of extending the subsequences of length (s,t) by one in the t direction,
i.e. of emitting the termination/next-frame symbol.
if `rnnt_type == "regular"`, px[:,:,T] equals -infinity, meaning on the
"one-past-the-last" frame we cannot emit any symbols.
This is simply a way of incorporating
the probability of the termination symbol on the last frame.
"""
"""
# logits (B, T, s_range, C)
# logits (B, T, s_range, C)
# symbols (B, S)
# symbols (B, S)
# ranges (B, T, s_range)
# ranges (B, T, s_range)
assert
logits
.
ndim
==
4
assert
logits
.
ndim
==
4
,
logits
.
ndim
(
B
,
T
,
s_range
,
C
)
=
logits
.
shape
(
B
,
T
,
s_range
,
C
)
=
logits
.
shape
assert
ranges
.
shape
==
(
B
,
T
,
s_range
)
assert
ranges
.
shape
==
(
B
,
T
,
s_range
)
,
ranges
.
shape
(
B
,
S
)
=
symbols
.
shape
(
B
,
S
)
=
symbols
.
shape
assert
S
>=
1
assert
S
>=
1
,
S
assert
T
>=
S
assert
T
>=
S
,
(
T
,
S
)
assert
rnnt_type
in
[
"regular"
,
"modified"
,
"constrained"
],
rnnt_type
normalizers
=
torch
.
logsumexp
(
logits
,
dim
=
3
)
normalizers
=
torch
.
logsumexp
(
logits
,
dim
=
3
)
...
@@ -813,7 +900,7 @@ def get_rnnt_logprobs_pruned(
...
@@ -813,7 +900,7 @@ def get_rnnt_logprobs_pruned(
px
=
px
.
permute
((
0
,
2
,
1
))
px
=
px
.
permute
((
0
,
2
,
1
))
if
not
modified
:
if
rnnt_type
==
"regular"
:
px
=
torch
.
cat
(
px
=
torch
.
cat
(
(
(
px
,
px
,
...
@@ -846,8 +933,10 @@ def get_rnnt_logprobs_pruned(
...
@@ -846,8 +933,10 @@ def get_rnnt_logprobs_pruned(
# (B, S + 1, T)
# (B, S + 1, T)
py
=
py
.
permute
((
0
,
2
,
1
))
py
=
py
.
permute
((
0
,
2
,
1
))
if
not
modified
:
if
rnnt_type
==
"regular"
:
px
=
fix_for_boundary
(
px
,
boundary
)
px
=
fix_for_boundary
(
px
,
boundary
)
elif
rnnt_type
==
"constrained"
:
px
+=
py
[:,
1
:,
:]
return
(
px
,
py
)
return
(
px
,
py
)
...
@@ -858,13 +947,13 @@ def rnnt_loss_pruned(
...
@@ -858,13 +947,13 @@ def rnnt_loss_pruned(
ranges
:
Tensor
,
ranges
:
Tensor
,
termination_symbol
:
int
,
termination_symbol
:
int
,
boundary
:
Tensor
=
None
,
boundary
:
Tensor
=
None
,
modified
:
bool
=
False
,
rnnt_type
:
str
=
"regular"
,
delay_penalty
:
float
=
0.0
,
delay_penalty
:
float
=
0.0
,
reduction
:
Optional
[
str
]
=
"mean"
,
reduction
:
Optional
[
str
]
=
"mean"
,
)
->
Tensor
:
)
->
Tensor
:
"""A RNN-T loss with pruning, which uses a pruned 'joiner'
network output
"""A RNN-T loss with pruning, which uses
the output of
a pruned 'joiner'
as input, i.e. a 4 dimensions tensor with shape (B, T, s_range, C),
network
as input, i.e. a 4 dimensions tensor with shape (B, T, s_range, C),
s_range means the symbols
number
kept for each frame.
s_range means the
number of
symbols kept for each frame.
Args:
Args:
logits:
logits:
...
@@ -882,8 +971,19 @@ def rnnt_loss_pruned(
...
@@ -882,8 +971,19 @@ def rnnt_loss_pruned(
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T] if boundary is not supplied.
[0, 0, S, T] if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
rnnt_type:
also be consumed, so at most 1 symbol can appear per frame.
Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
delay_penalty: A constant value to penalize symbol delay, this may be
delay_penalty: A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time-masking
needed when training with time masking, to avoid the time-masking
encouraging the network to delay symbols.
encouraging the network to delay symbols.
...
@@ -895,8 +995,8 @@ def rnnt_loss_pruned(
...
@@ -895,8 +995,8 @@ def rnnt_loss_pruned(
`sum`: the output will be summed.
`sum`: the output will be summed.
Default: `mean`
Default: `mean`
Returns:
Returns:
If re
curs
ion is `none`, returns a tensor of shape (B,), containing the
If re
duct
ion is `none`, returns a tensor of shape (B,), containing the
total RNN-T loss values for each
element
of the batch, otherwise a scalar
total RNN-T loss values for each
sequence
of the batch, otherwise a scalar
with the reduction applied.
with the reduction applied.
"""
"""
px
,
py
=
get_rnnt_logprobs_pruned
(
px
,
py
=
get_rnnt_logprobs_pruned
(
...
@@ -905,11 +1005,12 @@ def rnnt_loss_pruned(
...
@@ -905,11 +1005,12 @@ def rnnt_loss_pruned(
ranges
=
ranges
,
ranges
=
ranges
,
termination_symbol
=
termination_symbol
,
termination_symbol
=
termination_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
)
)
if
delay_penalty
>
0.0
:
if
delay_penalty
>
0.0
:
B
,
S
,
T0
=
px
.
shape
B
,
S
,
T0
=
px
.
shape
T
=
T0
if
modified
else
T0
-
1
T
=
T0
if
rnnt_type
!=
"regular"
else
T0
-
1
if
boundary
is
None
:
if
boundary
is
None
:
offset
=
torch
.
tensor
(
offset
=
torch
.
tensor
(
(
T
-
1
)
/
2
,
dtype
=
px
.
dtype
,
device
=
px
.
device
,
(
T
-
1
)
/
2
,
dtype
=
px
.
dtype
,
device
=
px
.
device
,
...
@@ -921,6 +1022,7 @@ def rnnt_loss_pruned(
...
@@ -921,6 +1022,7 @@ def rnnt_loss_pruned(
).
reshape
(
1
,
1
,
T0
)
).
reshape
(
1
,
1
,
T0
)
penalty
=
penalty
*
delay_penalty
penalty
=
penalty
*
delay_penalty
px
+=
penalty
.
to
(
px
.
dtype
)
px
+=
penalty
.
to
(
px
.
dtype
)
negated_loss
=
mutual_information_recursion
(
px
=
px
,
py
=
py
,
boundary
=
boundary
)
negated_loss
=
mutual_information_recursion
(
px
=
px
,
py
=
py
,
boundary
=
boundary
)
if
reduction
==
"none"
:
if
reduction
==
"none"
:
return
-
negated_loss
return
-
negated_loss
...
@@ -929,9 +1031,9 @@ def rnnt_loss_pruned(
...
@@ -929,9 +1031,9 @@ def rnnt_loss_pruned(
elif
reduction
==
"sum"
:
elif
reduction
==
"sum"
:
return
-
torch
.
sum
(
negated_loss
)
return
-
torch
.
sum
(
negated_loss
)
else
:
else
:
assert
(
raise
ValueError
(
False
f
"reduction should be ('none' | 'mean' | 'sum'), given
{
reduction
}
"
)
,
f
"reduction should be ('none' | 'mean' | 'sum'), given
{
reduction
}
"
)
def
get_rnnt_logprobs_smoothed
(
def
get_rnnt_logprobs_smoothed
(
...
@@ -942,7 +1044,7 @@ def get_rnnt_logprobs_smoothed(
...
@@ -942,7 +1044,7 @@ def get_rnnt_logprobs_smoothed(
lm_only_scale
:
float
=
0.1
,
lm_only_scale
:
float
=
0.1
,
am_only_scale
:
float
=
0.1
,
am_only_scale
:
float
=
0.1
,
boundary
:
Optional
[
Tensor
]
=
None
,
boundary
:
Optional
[
Tensor
]
=
None
,
modified
:
bool
=
False
,
rnnt_type
:
str
=
"regular"
,
)
->
Tuple
[
Tensor
,
Tensor
]:
)
->
Tuple
[
Tensor
,
Tensor
]:
"""
"""
Reduces RNN-T problem (the simple case, where joiner network is just
Reduces RNN-T problem (the simple case, where joiner network is just
...
@@ -1005,18 +1107,32 @@ def get_rnnt_logprobs_smoothed(
...
@@ -1005,18 +1107,32 @@ def get_rnnt_logprobs_smoothed(
Most likely you will want begin_symbol and begin_frame to be zero.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
also be consumed, so at most 1 symbol can appear per frame.
rnnt_type:
Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
Returns:
Returns:
(px, py) (the names are quite arbitrary).
(px, py) (the names are quite arbitrary).
px: logprobs, of shape [B][S][T+1] if !modified, [B][S][T] if modified.
px: logprobs, of shape [B][S][T+1] if rnnt_type == "regular",
[B][S][T] if rnnt_type != "regular".
py: logprobs, of shape [B][S+1][T]
py: logprobs, of shape [B][S+1][T]
in the recursion::
in the recursion::
p[b,0,0] = 0.0
p[b,0,0] = 0.0
if
!modified
:
if
rnnt_type == "regular"
:
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
p[b,s,t-1] + py[b,s,t-1])
if
modified
:
if
rnnt_type != "regular"
:
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1])
p[b,s,t-1] + py[b,s,t-1])
.. where p[b][s][t] is the "joint score" of the pair of subsequences
.. where p[b][s][t] is the "joint score" of the pair of subsequences
...
@@ -1031,15 +1147,16 @@ def get_rnnt_logprobs_smoothed(
...
@@ -1031,15 +1147,16 @@ def get_rnnt_logprobs_smoothed(
we cannot emit any symbols. This is simply a way of incorporating
we cannot emit any symbols. This is simply a way of incorporating
the probability of the termination symbol on the last frame.
the probability of the termination symbol on the last frame.
"""
"""
assert
lm
.
ndim
==
3
assert
lm
.
ndim
==
3
,
lm
.
ndim
assert
am
.
ndim
==
3
assert
am
.
ndim
==
3
,
am
.
ndim
assert
lm
.
shape
[
0
]
==
am
.
shape
[
0
]
assert
lm
.
shape
[
0
]
==
am
.
shape
[
0
]
,
(
lm
.
shape
[
0
],
am
.
shape
[
0
])
assert
lm
.
shape
[
2
]
==
am
.
shape
[
2
]
assert
lm
.
shape
[
2
]
==
am
.
shape
[
2
]
,
(
lm
.
shape
[
2
],
am
.
shape
[
2
])
(
B
,
T
,
C
)
=
am
.
shape
(
B
,
T
,
C
)
=
am
.
shape
S
=
lm
.
shape
[
1
]
-
1
S
=
lm
.
shape
[
1
]
-
1
assert
symbols
.
shape
==
(
B
,
S
)
assert
symbols
.
shape
==
(
B
,
S
),
symbols
.
shape
assert
S
>=
1
assert
S
>=
1
,
S
assert
T
>=
S
assert
T
>=
S
,
(
T
,
S
)
assert
rnnt_type
in
[
"regular"
,
"modified"
,
"constrained"
],
rnnt_type
# Caution: some parts of this code are a little less clear than they could
# Caution: some parts of this code are a little less clear than they could
# be due to optimizations. In particular it may not be totally obvious that
# be due to optimizations. In particular it may not be totally obvious that
...
@@ -1091,7 +1208,7 @@ def get_rnnt_logprobs_smoothed(
...
@@ -1091,7 +1208,7 @@ def get_rnnt_logprobs_smoothed(
-
1
-
1
)
# [B][S][T]
)
# [B][S][T]
if
not
modified
:
if
rnnt_type
==
"regular"
:
px_am
=
torch
.
cat
(
px_am
=
torch
.
cat
(
(
(
px_am
,
px_am
,
...
@@ -1150,8 +1267,10 @@ def get_rnnt_logprobs_smoothed(
...
@@ -1150,8 +1267,10 @@ def get_rnnt_logprobs_smoothed(
+
py_amonly
*
am_only_scale
+
py_amonly
*
am_only_scale
)
)
if
not
modified
:
if
rnnt_type
==
"regular"
:
px_interp
=
fix_for_boundary
(
px_interp
,
boundary
)
px_interp
=
fix_for_boundary
(
px_interp
,
boundary
)
elif
rnnt_type
==
"constrained"
:
px_interp
+=
py_interp
[:,
1
:,
:]
return
(
px_interp
,
py_interp
)
return
(
px_interp
,
py_interp
)
...
@@ -1164,7 +1283,7 @@ def rnnt_loss_smoothed(
...
@@ -1164,7 +1283,7 @@ def rnnt_loss_smoothed(
lm_only_scale
:
float
=
0.1
,
lm_only_scale
:
float
=
0.1
,
am_only_scale
:
float
=
0.1
,
am_only_scale
:
float
=
0.1
,
boundary
:
Optional
[
Tensor
]
=
None
,
boundary
:
Optional
[
Tensor
]
=
None
,
modified
:
bool
=
False
,
rnnt_type
:
str
=
"regular"
,
delay_penalty
:
float
=
0.0
,
delay_penalty
:
float
=
0.0
,
reduction
:
Optional
[
str
]
=
"mean"
,
reduction
:
Optional
[
str
]
=
"mean"
,
return_grad
:
bool
=
False
,
return_grad
:
bool
=
False
,
...
@@ -1197,8 +1316,19 @@ def rnnt_loss_smoothed(
...
@@ -1197,8 +1316,19 @@ def rnnt_loss_smoothed(
[0, 0, S, T]
[0, 0, S, T]
if boundary is not supplied.
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
rnnt_type:
also be consumed, so at most 1 symbol can appear per frame.
Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
delay_penalty: A constant value to penalize symbol delay, this may be
delay_penalty: A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time-masking
needed when training with time masking, to avoid the time-masking
encouraging the network to delay symbols.
encouraging the network to delay symbols.
...
@@ -1233,11 +1363,12 @@ def rnnt_loss_smoothed(
...
@@ -1233,11 +1363,12 @@ def rnnt_loss_smoothed(
lm_only_scale
=
lm_only_scale
,
lm_only_scale
=
lm_only_scale
,
am_only_scale
=
am_only_scale
,
am_only_scale
=
am_only_scale
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
)
)
if
delay_penalty
>
0.0
:
if
delay_penalty
>
0.0
:
B
,
S
,
T0
=
px
.
shape
B
,
S
,
T0
=
px
.
shape
T
=
T0
if
modified
else
T0
-
1
T
=
T0
if
rnnt_type
!=
"regular"
else
T0
-
1
if
boundary
is
None
:
if
boundary
is
None
:
offset
=
torch
.
tensor
(
offset
=
torch
.
tensor
(
(
T
-
1
)
/
2
,
dtype
=
px
.
dtype
,
device
=
px
.
device
,
(
T
-
1
)
/
2
,
dtype
=
px
.
dtype
,
device
=
px
.
device
,
...
@@ -1249,6 +1380,7 @@ def rnnt_loss_smoothed(
...
@@ -1249,6 +1380,7 @@ def rnnt_loss_smoothed(
).
reshape
(
1
,
1
,
T0
)
).
reshape
(
1
,
1
,
T0
)
penalty
=
penalty
*
delay_penalty
penalty
=
penalty
*
delay_penalty
px
+=
penalty
.
to
(
px
.
dtype
)
px
+=
penalty
.
to
(
px
.
dtype
)
scores_and_grads
=
mutual_information_recursion
(
scores_and_grads
=
mutual_information_recursion
(
px
=
px
,
py
=
py
,
boundary
=
boundary
,
return_grad
=
return_grad
px
=
px
,
py
=
py
,
boundary
=
boundary
,
return_grad
=
return_grad
)
)
...
@@ -1260,7 +1392,7 @@ def rnnt_loss_smoothed(
...
@@ -1260,7 +1392,7 @@ def rnnt_loss_smoothed(
elif
reduction
==
"sum"
:
elif
reduction
==
"sum"
:
loss
=
-
torch
.
sum
(
negated_loss
)
loss
=
-
torch
.
sum
(
negated_loss
)
else
:
else
:
assert
(
raise
ValueError
(
False
f
"reduction should be ('none' | 'mean' | 'sum'), given
{
reduction
}
"
)
,
f
"reduction should be ('none' | 'mean' | 'sum'), given
{
reduction
}
"
)
return
(
loss
,
scores_and_grads
[
1
])
if
return_grad
else
loss
return
(
loss
,
scores_and_grads
[
1
])
if
return_grad
else
loss
fast_rnnt/python/tests/rnnt_loss_test.py
View file @
b0ed23ef
...
@@ -90,7 +90,9 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -90,7 +90,9 @@ class TestRnntLoss(unittest.TestCase):
assert
px
.
shape
==
(
B
,
S
,
T
+
1
)
assert
px
.
shape
==
(
B
,
S
,
T
+
1
)
assert
py
.
shape
==
(
B
,
S
+
1
,
T
)
assert
py
.
shape
==
(
B
,
S
+
1
,
T
)
assert
symbols
.
shape
==
(
B
,
S
)
assert
symbols
.
shape
==
(
B
,
S
)
m
=
fast_rnnt
.
mutual_information_recursion
(
px
=
px
,
py
=
py
,
boundary
=
None
)
m
=
fast_rnnt
.
mutual_information_recursion
(
px
=
px
,
py
=
py
,
boundary
=
None
)
if
device
==
torch
.
device
(
"cpu"
):
if
device
==
torch
.
device
(
"cpu"
):
expected
=
-
m
expected
=
-
m
...
@@ -205,7 +207,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -205,7 +207,7 @@ class TestRnntLoss(unittest.TestCase):
boundary_
[:,
2
]
=
seq_length
boundary_
[:,
2
]
=
seq_length
boundary_
[:,
3
]
=
frames
boundary_
[:,
3
]
=
frames
for
modified
in
[
True
,
False
]:
for
rnnt_type
in
[
"regular"
,
"modified"
,
"constrained"
]:
for
device
in
self
.
devices
:
for
device
in
self
.
devices
:
# lm: [B][S+1][C]
# lm: [B][S+1][C]
lm
=
lm_
.
to
(
device
)
lm
=
lm_
.
to
(
device
)
...
@@ -220,9 +222,13 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -220,9 +222,13 @@ class TestRnntLoss(unittest.TestCase):
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
termination_symbol
=
termination_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
)
assert
(
px
.
shape
==
(
B
,
S
,
T
)
if
rnnt_type
!=
"regular"
else
(
B
,
S
,
T
+
1
)
)
)
assert
px
.
shape
==
(
B
,
S
,
T
)
if
modified
else
(
B
,
S
,
T
+
1
)
assert
py
.
shape
==
(
B
,
S
+
1
,
T
)
assert
py
.
shape
==
(
B
,
S
+
1
,
T
)
assert
symbols
.
shape
==
(
B
,
S
)
assert
symbols
.
shape
==
(
B
,
S
)
m
=
fast_rnnt
.
mutual_information_recursion
(
m
=
fast_rnnt
.
mutual_information_recursion
(
...
@@ -239,7 +245,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -239,7 +245,7 @@ class TestRnntLoss(unittest.TestCase):
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
termination_symbol
=
termination_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
)
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
...
@@ -251,7 +257,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -251,7 +257,7 @@ class TestRnntLoss(unittest.TestCase):
lm_only_scale
=
0.0
,
lm_only_scale
=
0.0
,
am_only_scale
=
0.0
,
am_only_scale
=
0.0
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
)
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
...
@@ -261,12 +267,12 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -261,12 +267,12 @@ class TestRnntLoss(unittest.TestCase):
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
termination_symbol
=
termination_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
)
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
# compare with torchaudio rnnt_loss
# compare with torchaudio rnnt_loss
if
self
.
has_torch_rnnt_loss
and
not
modified
:
if
self
.
has_torch_rnnt_loss
and
rnnt_type
==
"regular"
:
import
torchaudio.functional
import
torchaudio.functional
m
=
torchaudio
.
functional
.
rnnt_loss
(
m
=
torchaudio
.
functional
.
rnnt_loss
(
...
@@ -288,7 +294,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -288,7 +294,7 @@ class TestRnntLoss(unittest.TestCase):
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
termination_symbol
=
termination_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
)
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
...
@@ -298,7 +304,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -298,7 +304,7 @@ class TestRnntLoss(unittest.TestCase):
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
termination_symbol
=
termination_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
)
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
...
@@ -310,7 +316,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -310,7 +316,7 @@ class TestRnntLoss(unittest.TestCase):
lm_only_scale
=
0.0
,
lm_only_scale
=
0.0
,
am_only_scale
=
0.0
,
am_only_scale
=
0.0
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
)
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
...
@@ -368,9 +374,13 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -368,9 +374,13 @@ class TestRnntLoss(unittest.TestCase):
torch_grad
=
torch
.
autograd
.
grad
(
torch_loss
,
logits2
)
torch_grad
=
torch
.
autograd
.
grad
(
torch_loss
,
logits2
)
torch_grad
=
torch_grad
[
0
]
torch_grad
=
torch_grad
[
0
]
assert
torch
.
allclose
(
fast_loss
,
torch_loss
,
atol
=
1e-2
,
rtol
=
1e-2
)
assert
torch
.
allclose
(
fast_loss
,
torch_loss
,
atol
=
1e-2
,
rtol
=
1e-2
)
assert
torch
.
allclose
(
fast_grad
,
torch_grad
,
atol
=
1e-2
,
rtol
=
1e-2
)
assert
torch
.
allclose
(
fast_grad
,
torch_grad
,
atol
=
1e-2
,
rtol
=
1e-2
)
def
test_rnnt_loss_smoothed
(
self
):
def
test_rnnt_loss_smoothed
(
self
):
B
=
1
B
=
1
...
@@ -443,7 +453,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -443,7 +453,7 @@ class TestRnntLoss(unittest.TestCase):
boundary_
[:,
2
]
=
seq_length
boundary_
[:,
2
]
=
seq_length
boundary_
[:,
3
]
=
frames
boundary_
[:,
3
]
=
frames
for
modified
in
[
True
,
False
]:
for
rnnt_type
in
[
"regular"
,
"modified"
,
"constrained"
]:
for
device
in
self
.
devices
:
for
device
in
self
.
devices
:
# normal rnnt
# normal rnnt
am
=
am_
.
to
(
device
)
am
=
am_
.
to
(
device
)
...
@@ -460,12 +470,10 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -460,12 +470,10 @@ class TestRnntLoss(unittest.TestCase):
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
terminal_symbol
,
termination_symbol
=
terminal_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
)
)
print
(
print
(
f
"Unpruned rnnt loss with
{
rnnt_loss
}
rnnt :
{
fast_loss
}
"
)
f
"Unpruned rnnt loss with modified
{
modified
}
:
{
fast_loss
}
"
)
# pruning
# pruning
simple_loss
,
(
px_grad
,
py_grad
)
=
fast_rnnt
.
rnnt_loss_simple
(
simple_loss
,
(
px_grad
,
py_grad
)
=
fast_rnnt
.
rnnt_loss_simple
(
...
@@ -474,7 +482,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -474,7 +482,7 @@ class TestRnntLoss(unittest.TestCase):
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
terminal_symbol
,
termination_symbol
=
terminal_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
return_grad
=
True
,
return_grad
=
True
,
reduction
=
"none"
,
reduction
=
"none"
,
)
)
...
@@ -487,7 +495,9 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -487,7 +495,9 @@ class TestRnntLoss(unittest.TestCase):
s_range
=
r
,
s_range
=
r
,
)
)
# (B, T, r, C)
# (B, T, r, C)
pruned_am
,
pruned_lm
=
fast_rnnt
.
do_rnnt_pruning
(
am
=
am
,
lm
=
lm
,
ranges
=
ranges
)
pruned_am
,
pruned_lm
=
fast_rnnt
.
do_rnnt_pruning
(
am
=
am
,
lm
=
lm
,
ranges
=
ranges
)
logits
=
pruned_am
+
pruned_lm
logits
=
pruned_am
+
pruned_lm
...
@@ -500,12 +510,11 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -500,12 +510,11 @@ class TestRnntLoss(unittest.TestCase):
ranges
=
ranges
,
ranges
=
ranges
,
termination_symbol
=
terminal_symbol
,
termination_symbol
=
terminal_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
reduction
=
"none"
,
reduction
=
"none"
,
)
)
print
(
f
"Pruning loss with range
{
r
}
:
{
pruned_loss
}
"
)
print
(
f
"Pruning loss with range
{
r
}
:
{
pruned_loss
}
"
)
# Test the sequences that only have small number of symbols,
# Test the sequences that only have small number of symbols,
# at this circumstance, the s_range would be greater than S, which will
# at this circumstance, the s_range would be greater than S, which will
# raise errors (like, nan or inf loss) in our previous versions.
# raise errors (like, nan or inf loss) in our previous versions.
...
@@ -531,7 +540,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -531,7 +540,7 @@ class TestRnntLoss(unittest.TestCase):
print
(
f
"B =
{
B
}
, T =
{
T
}
, S =
{
S
}
, C =
{
C
}
"
)
print
(
f
"B =
{
B
}
, T =
{
T
}
, S =
{
S
}
, C =
{
C
}
"
)
for
modified
in
[
True
,
False
]:
for
rnnt_type
in
[
"regular"
,
"modified"
,
"constrained"
]:
for
device
in
self
.
devices
:
for
device
in
self
.
devices
:
# normal rnnt
# normal rnnt
am
=
am_
.
to
(
device
)
am
=
am_
.
to
(
device
)
...
@@ -550,13 +559,11 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -550,13 +559,11 @@ class TestRnntLoss(unittest.TestCase):
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
terminal_symbol
,
termination_symbol
=
terminal_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
reduction
=
"none"
,
reduction
=
"none"
,
)
)
print
(
print
(
f
"Unpruned rnnt loss with
{
rnnt_type
}
rnnt :
{
loss
}
"
)
f
"Unpruned rnnt loss with modified
{
modified
}
:
{
loss
}
"
)
# pruning
# pruning
simple_loss
,
(
px_grad
,
py_grad
)
=
fast_rnnt
.
rnnt_loss_simple
(
simple_loss
,
(
px_grad
,
py_grad
)
=
fast_rnnt
.
rnnt_loss_simple
(
...
@@ -565,13 +572,13 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -565,13 +572,13 @@ class TestRnntLoss(unittest.TestCase):
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
terminal_symbol
,
termination_symbol
=
terminal_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
return_grad
=
True
,
return_grad
=
True
,
reduction
=
"none"
,
reduction
=
"none"
,
)
)
S0
=
2
S0
=
2
if
modified
:
if
rnnt_type
==
"regular"
:
S0
=
1
S0
=
1
for
r
in
range
(
S0
,
S
+
2
):
for
r
in
range
(
S0
,
S
+
2
):
...
@@ -597,10 +604,11 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -597,10 +604,11 @@ class TestRnntLoss(unittest.TestCase):
ranges
=
ranges
,
ranges
=
ranges
,
termination_symbol
=
terminal_symbol
,
termination_symbol
=
terminal_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
reduction
=
"none"
,
reduction
=
"none"
,
)
)
print
(
f
"Pruned loss with range
{
r
}
:
{
pruned_loss
}
"
)
print
(
f
"Pruned loss with range
{
r
}
:
{
pruned_loss
}
"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment