Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
b5828e2b
Unverified
Commit
b5828e2b
authored
Dec 03, 2021
by
Daniel Povey
Committed by
GitHub
Dec 03, 2021
Browse files
Merge pull request #1 from danpovey/aux_loss
Add aux version of rnnt loss, allows to have "lm-only" and "am-only" …
parents
bbd073e4
58daa40e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
226 additions
and
6 deletions
+226
-6
torch_mutual_information/__init__.py
torch_mutual_information/__init__.py
+1
-1
torch_mutual_information/rnnt.py
torch_mutual_information/rnnt.py
+180
-3
torch_mutual_information/rnnt_test.py
torch_mutual_information/rnnt_test.py
+45
-2
No files found.
torch_mutual_information/__init__.py
View file @
b5828e2b
from
.mutual_information
import
mutual_information_recursion
,
joint_mutual_information_recursion
from
.rnnt
import
get_rnnt_logprobs
,
rnnt_loss_simple
from
.rnnt
import
get_rnnt_logprobs
,
rnnt_loss_simple
,
rnnt_loss_aux
torch_mutual_information/rnnt.py
View file @
b5828e2b
...
...
@@ -106,7 +106,7 @@ def rnnt_loss_simple(lm: Tensor,
boundary
:
Tensor
=
None
)
->
Tensor
:
"""
A simple case of the RNN-T loss, where the 'joiner' network is just addition.
Returns total loss value.
Returns
negated
total loss value.
Args:
lm: language-model part of unnormalized log-probs of symbols, with shape
...
...
@@ -120,8 +120,185 @@ def rnnt_loss_simple(lm: Tensor,
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
Returns:
a Tensor of shape (B,), containing the total RNN-T loss values
for each element
of the batch (like log-probs of sequences).
a Tensor of shape (B,), containing the
NEGATED
total RNN-T loss values
for each element
of the batch (like log-probs of sequences).
"""
px
,
py
=
get_rnnt_logprobs
(
lm
,
am
,
symbols
,
termination_symbol
)
return
mutual_information_recursion
(
px
,
py
,
boundary
)
def
get_rnnt_logprobs_aux
(
lm
:
Tensor
,
am
:
Tensor
,
symbols
:
Tensor
,
termination_symbol
:
int
,
lm_only_scale
:
float
=
0.1
,
am_only_scale
:
float
=
0.1
)
->
Tuple
[
Tensor
,
Tensor
]:
"""
Reduces RNN-T problem (the simple case, where joiner network is just addition),
to a compact, standard form that can then be given
(with boundaries) to mutual_information_recursion(). This version allows you
to make the loss-function one of the form:
lm_only_scale * lm_probs +
am_only_scale * am_probs +
(1-lm_only_scale-am_only_scale) * combined_probs
where lm_probs and am_probs are the probabilities given the lm and acoustic model
independently.
This function is called from
rnnt_loss_aux(), but may be useful for other purposes.
Args:
lm: Language model part of un-normalized logprobs of symbols, to be added to
acoustic model part before normalizing. Of shape:
[B][S+1][C]
where B is the batch size, S is the maximum sequence length of
the symbol sequence, possibly including the EOS symbol; and
C is size of the symbol vocabulary, including the termination/next-frame
symbol.
Conceptually, lm[b][s] is a vector of length [C] representing the
"language model" part of the un-normalized logprobs of symbols,
given all symbols *earlier than* s in the sequence. The reason
we still need this for position S is that we may still be emitting
the termination/next-frame symbol at this point.
am: Acoustic-model part of un-normalized logprobs of symbols, to be added
to language-model part before normalizing. Of shape:
[B][T][C]
where B is the batch size, T is the maximum sequence length of
the acoustic sequences (in frames); and C is size of the symbol
vocabulary, including the termination/next-frame symbol. It reflects
the "acoustic" part of the probability of any given symbol appearing
next on this frame.
symbols: A LongTensor of shape [B][S], containing the symbols at each position
of the sequence, possibly including EOS
termination_symbol: The identity of the termination symbol, must be
in {0..C-1}
Returns: (px, py) (the names are quite arbitrary).
px: logprobs, of shape [B][S][T+1]
py: logprobs, of shape [B][S+1][T]
in the recursion:
p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
.. where p[b][s][t] is the "joint score" of the pair of subsequences of
length s and t respectively. px[b][s][t] represents the probability of
extending the subsequences of length (s,t) by one in the s direction,
given the particular symbol, and py[b][s][t] represents the probability
of extending the subsequences of length (s,t) by one in the t direction,
i.e. of emitting the termination/next-frame symbol.
px[:,:,T] equals -infinity, meaning on the "one-past-the-last" frame
we cannot emit any symbols. This is simply a way of incorporating
the probability of the termination symbol on the last frame.
"""
assert
lm
.
ndim
==
3
and
am
.
ndim
==
3
and
lm
.
shape
[
0
]
==
am
.
shape
[
0
]
and
lm
.
shape
[
2
]
==
am
.
shape
[
2
]
(
B
,
T
,
C
)
=
am
.
shape
S
=
lm
.
shape
[
1
]
-
1
assert
symbols
.
shape
==
(
B
,
S
)
# Caution: some parts of this code are a little less clear than they could
# be due to optimizations. In particular it may not be totally obvious that
# all of the logprobs here are properly normalized. We test that
# this code is invariant to adding constants in the appropriate ways.
# subtracting am_max and lm_max is to ensure the probs are in a good range to do exp()
# without causing underflow or overflow.
am_max
,
_
=
torch
.
max
(
am
,
dim
=
2
,
keepdim
=
True
)
# am_max: [B][T][1]
lm_max
,
_
=
torch
.
max
(
lm
,
dim
=
2
,
keepdim
=
True
)
# lm_max: [B][S+1][1]
am_probs
=
(
am
-
am_max
).
exp
()
# [B][T][C]
lm_probs
=
(
lm
-
lm_max
).
exp
()
# [B][S+1][C]
# normalizers: [B][S+1][T]
normalizers
=
(
torch
.
matmul
(
lm_probs
,
am_probs
.
transpose
(
1
,
2
))
+
1.0e-20
).
log
()
# normalizer per frame, if we take only the LM probs by themselves
lmonly_normalizers
=
lm_probs
.
sum
(
dim
=
2
,
keepdim
=
True
)
# lmonly_normalizers: [B][S+1][1]
unigram_lm
=
torch
.
mean
(
lm_probs
/
lmonly_normalizers
,
dim
=
(
0
,
1
),
keepdim
=
True
)
+
1.0e-20
# [1][1][C]
amonly_normalizers
=
torch
.
mv
(
am_probs
.
reshape
(
-
1
,
C
),
unigram_lm
.
reshape
(
C
)).
reshape
(
B
,
T
,
1
).
log
()
+
am_max
# [B][T][1]
amonly_normalizers
=
amonly_normalizers
.
transpose
(
1
,
2
)
# [B][1][T]
unigram_lm
=
unigram_lm
.
log
()
lmonly_normalizers
=
lmonly_normalizers
.
log
()
+
lm_max
# [B][S+1][1], log-normalizer, used for LM-only part of prob.
# add lm_max and am_max to normalizers, to make it as if we had not
# subtracted am_max and lm_max above.
normalizers
=
normalizers
+
lm_max
+
am_max
.
transpose
(
1
,
2
)
# [B][S+1][T]
# px is the probs of the actual symbols (not yet normalized)..
px_am
=
torch
.
gather
(
am
.
unsqueeze
(
1
).
expand
(
B
,
S
,
T
,
C
),
dim
=
3
,
index
=
symbols
.
reshape
(
B
,
S
,
1
,
1
).
expand
(
B
,
S
,
T
,
1
)).
squeeze
(
-
1
)
# [B][S][T]
px_am
=
torch
.
cat
((
px_am
,
torch
.
full
((
B
,
S
,
1
),
float
(
'-inf'
),
device
=
px_am
.
device
,
dtype
=
px_am
.
dtype
)),
dim
=
2
)
# now: [B][S][T+1], index [:,:,T] has -inf..
px_lm
=
torch
.
gather
(
lm
[:,:
S
],
dim
=
2
,
index
=
symbols
.
unsqueeze
(
-
1
))
# [B][S][1]
px_lm_unigram
=
torch
.
gather
(
unigram_lm
.
expand
(
B
,
S
,
C
),
dim
=
2
,
index
=
symbols
.
unsqueeze
(
-
1
))
# [B][S][1]
px
=
px_am
+
px_lm
# [B][S][T+1], last slice indexed [:,:,T] is -inf
px
[:,:,:
T
]
-=
normalizers
[:,:
S
,:]
# px: [B][S][T+1]
px_amonly
=
px_am
+
px_lm_unigram
# [B][S][T+1]
px_amonly
[:,:,:
T
]
-=
amonly_normalizers
px_lmonly
=
px_lm
-
lmonly_normalizers
[:,:
S
,:]
# py is the probs of termination symbols, of shape [B][S+1][T]
py_am
=
am
[:,:,
termination_symbol
].
unsqueeze
(
1
)
# [B][1][T]
py_lm
=
lm
[:,:,
termination_symbol
].
unsqueeze
(
2
)
# [B][S+1][1]
py
=
py_am
+
py_lm
-
normalizers
py_lm_unigram
=
unigram_lm
[
0
][
0
][
termination_symbol
]
# scalar, normalized..
py_amonly
=
py_am
+
py_lm_unigram
-
amonly_normalizers
# [B][S+1][T]
py_lmonly
=
py_lm
-
lmonly_normalizers
# [B][S+1][T]
combined_scale
=
1.0
-
lm_only_scale
-
am_only_scale
# We need to avoid exact zeros in the scales because otherwise multiplying -inf
# by zero generates nan.
if
lm_only_scale
==
0.0
:
lm_only_scale
=
1.0e-20
if
am_only_scale
==
0.0
:
am_only_scale
=
1.0e-20
px_interp
=
px
*
combined_scale
+
px_lmonly
*
lm_only_scale
+
px_amonly
*
am_only_scale
py_interp
=
py
*
combined_scale
+
py_lmonly
*
lm_only_scale
+
py_amonly
*
am_only_scale
print
(
"px_interp = "
,
px_interp
)
print
(
"py_interp = "
,
py_interp
)
return
(
px_interp
,
py_interp
)
def
rnnt_loss_aux
(
lm
:
Tensor
,
am
:
Tensor
,
symbols
:
Tensor
,
termination_symbol
:
int
,
lm_only_scale
:
float
=
0.1
,
am_only_scale
:
float
=
0.1
,
boundary
:
Tensor
=
None
)
->
Tensor
:
"""
A simple case of the RNN-T loss, where the 'joiner' network is just addition.
Returns negated total loss value.
Args:
lm: language-model part of unnormalized log-probs of symbols, with shape
(B, S+1, C), i.e. batch, symbol_seq_len+1, num_classes.
These are assumed to be well-normalized, in the sense that we could
use them as probabilities separately from the am scores
am: acoustic-model part of unnormalized log-probs of symbols, with shape
(B, T, C), i.e. batch, frame, num_classes
symbols: the symbol sequences, a LongTensor of shape [B][S], and elements in {0..C-1}.
termination_symbol: the termination symbol, with 0 <= termination_symbol < C
am_only_scale: the scale on the "AM-only" part of the loss, for which we use
an "averaged" LM (averaged over all histories, so effectively unigram).
boundary: a LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as [0, 0, S, T]
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
Returns:
a Tensor of shape (B,), containing the NEGATED total RNN-T loss values
for each element of the batch (like log-probs of sequences).
"""
px
,
py
=
get_rnnt_logprobs_aux
(
lm
,
am
,
symbols
,
termination_symbol
,
lm_only_scale
,
am_only_scale
)
return
mutual_information_recursion
(
px
,
py
,
boundary
)
torch_mutual_information/rnnt_test.py
View file @
b5828e2b
import
random
import
torch
from
torch_mutual_information
import
mutual_information_recursion
,
joint_mutual_information_recursion
,
get_rnnt_logprobs
,
rnnt_loss_simple
from
torch_mutual_information
import
mutual_information_recursion
,
joint_mutual_information_recursion
,
get_rnnt_logprobs
,
rnnt_loss_simple
,
rnnt_loss_aux
def
test_rnnt_logprobs_basic
():
...
...
@@ -43,14 +43,57 @@ def test_rnnt_logprobs_basic():
device
=
torch
.
device
(
'cuda'
)
m3
=
rnnt_loss_simple
(
lm
.
to
(
device
),
am
.
to
(
device
),
symbols
.
to
(
device
),
termination_symbol
,
None
)
print
(
"m3 = "
,
m2
)
print
(
"m3 = "
,
m3
)
device
=
torch
.
device
(
'cuda'
)
m4
=
rnnt_loss_aux
(
lm
.
to
(
device
),
am
.
to
(
device
),
symbols
.
to
(
device
),
termination_symbol
,
lm_only_scale
=
0.0
,
am_only_scale
=
0.0
,
boundary
=
None
)
print
(
"m4 = "
,
m4
)
assert
torch
.
allclose
(
m
,
m2
)
assert
torch
.
allclose
(
m
,
m3
.
to
(
'cpu'
))
assert
torch
.
allclose
(
m
,
m4
.
to
(
'cpu'
))
def
test_rnnt_logprobs_aux
():
print
(
"Running test_rnnt_logprobs_aux()"
)
B
=
1
S
=
3
T
=
4
C
=
3
# lm: [B][S+1][C]
lm
=
torch
.
tensor
([[[
0
,
0
,
1
],
[
0
,
1
,
1
],
[
1
,
0
,
1
],
[
2
,
2
,
0
]]],
dtype
=
torch
.
float
)
# am: [B][T][C]
am
=
torch
.
tensor
([[[
0
,
1
,
2
],
[
0
,
0
,
0
],
[
0
,
2
,
4
],
[
0
,
3
,
3
]]],
dtype
=
torch
.
float
)
termination_symbol
=
2
symbols
=
torch
.
tensor
([[
0
,
1
,
0
]
],
dtype
=
torch
.
long
)
device
=
torch
.
device
(
'cuda'
)
m1
=
rnnt_loss_aux
(
lm
.
to
(
device
),
am
.
to
(
device
),
symbols
.
to
(
device
),
termination_symbol
,
lm_only_scale
=
0.0
,
am_only_scale
=
0.333
,
boundary
=
None
)
print
(
"m1 = "
,
m1
)
# should be invariant to adding a constant for any frame.
lm
+=
torch
.
randn
(
B
,
S
+
1
,
1
)
am
+=
torch
.
randn
(
B
,
T
,
1
)
m2
=
rnnt_loss_aux
(
lm
.
to
(
device
),
am
.
to
(
device
),
symbols
.
to
(
device
),
termination_symbol
,
lm_only_scale
=
0.0
,
am_only_scale
=
0.333
,
boundary
=
None
)
print
(
"m2 = "
,
m2
)
assert
torch
.
allclose
(
m1
,
m2
)
if
__name__
==
"__main__"
:
#torch.set_printoptions(edgeitems=30)
test_rnnt_logprobs_aux
()
test_rnnt_logprobs_basic
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment