Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
dc35168d
Commit
dc35168d
authored
Oct 30, 2022
by
pkufool
Browse files
Add delay penalty
parent
23b841cf
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
79 additions
and
0 deletions
+79
-0
fast_rnnt/python/fast_rnnt/rnnt_loss.py
fast_rnnt/python/fast_rnnt/rnnt_loss.py
+79
-0
No files found.
fast_rnnt/python/fast_rnnt/rnnt_loss.py
View file @
dc35168d
...
...
@@ -202,6 +202,7 @@ def rnnt_loss_simple(
termination_symbol
:
int
,
boundary
:
Optional
[
Tensor
]
=
None
,
modified
:
bool
=
False
,
delay_penalty
:
float
=
0.0
,
reduction
:
Optional
[
str
]
=
"mean"
,
return_grad
:
bool
=
False
,
)
->
Union
[
Tensor
,
Tuple
[
Tensor
,
Tuple
[
Tensor
,
Tensor
]]]:
...
...
@@ -228,6 +229,10 @@ def rnnt_loss_simple(
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
delay_penalty: A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time-masking
encouraging the network to delay symbols.
See https://github.com/k2-fsa/k2/issues/955 for more details.
reduction:
Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
`none`: no reduction will be applied.
...
...
@@ -257,6 +262,22 @@ def rnnt_loss_simple(
boundary
=
boundary
,
modified
=
modified
,
)
if
delay_penalty
>
0.0
:
B
,
S
,
T0
=
px
.
shape
T
=
T0
if
modified
else
T0
-
1
if
boundary
is
None
:
offset
=
torch
.
tensor
(
(
T
-
1
)
/
2
,
dtype
=
px
.
dtype
,
device
=
px
.
device
,
).
expand
(
B
,
1
,
1
)
else
:
offset
=
(
boundary
[:,
3
]
-
1
)
/
2
penalty
=
offset
.
reshape
(
B
,
1
,
1
)
-
torch
.
arange
(
T0
,
device
=
px
.
device
).
reshape
(
1
,
1
,
T0
)
penalty
=
penalty
*
delay_penalty
px
+=
penalty
.
to
(
px
.
dtype
)
scores_and_grads
=
mutual_information_recursion
(
px
=
px
,
py
=
py
,
boundary
=
boundary
,
return_grad
=
return_grad
)
...
...
@@ -374,6 +395,7 @@ def rnnt_loss(
termination_symbol
:
int
,
boundary
:
Optional
[
Tensor
]
=
None
,
modified
:
bool
=
False
,
delay_penalty
:
float
=
0.0
,
reduction
:
Optional
[
str
]
=
"mean"
,
)
->
Tensor
:
"""A normal RNN-T loss, which uses a 'joiner' network output as input,
...
...
@@ -395,6 +417,10 @@ def rnnt_loss(
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
delay_penalty: A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time-masking
encouraging the network to delay symbols.
See https://github.com/k2-fsa/k2/issues/955 for more details.
reduction:
Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
`none`: no reduction will be applied.
...
...
@@ -414,6 +440,20 @@ def rnnt_loss(
boundary
=
boundary
,
modified
=
modified
,
)
if
delay_penalty
>
0.0
:
B
,
S
,
T0
=
px
.
shape
T
=
T0
if
modified
else
T0
-
1
if
boundary
is
None
:
offset
=
torch
.
tensor
(
(
T
-
1
)
/
2
,
dtype
=
px
.
dtype
,
device
=
px
.
device
,
).
expand
(
B
,
1
,
1
)
else
:
offset
=
(
boundary
[:,
3
]
-
1
)
/
2
penalty
=
offset
.
reshape
(
B
,
1
,
1
)
-
torch
.
arange
(
T0
,
device
=
px
.
device
).
reshape
(
1
,
1
,
T0
)
penalty
=
penalty
*
delay_penalty
px
+=
penalty
.
to
(
px
.
dtype
)
negated_loss
=
mutual_information_recursion
(
px
=
px
,
py
=
py
,
boundary
=
boundary
)
if
reduction
==
"none"
:
return
-
negated_loss
...
...
@@ -480,6 +520,7 @@ def _adjust_pruning_lower_bound(
)
return
s_begin
# To get more insight of how we calculate pruning bounds, please read
# chapter 3.2 (Pruning bounds) of our Pruned RNN-T paper
# (https://arxiv.org/pdf/2206.13236.pdf)
...
...
@@ -818,6 +859,7 @@ def rnnt_loss_pruned(
termination_symbol
:
int
,
boundary
:
Tensor
=
None
,
modified
:
bool
=
False
,
delay_penalty
:
float
=
0.0
,
reduction
:
Optional
[
str
]
=
"mean"
,
)
->
Tensor
:
"""A RNN-T loss with pruning, which uses a pruned 'joiner' network output
...
...
@@ -842,6 +884,10 @@ def rnnt_loss_pruned(
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
delay_penalty: A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time-masking
encouraging the network to delay symbols.
See https://github.com/k2-fsa/k2/issues/955 for more details.
reduction:
Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
`none`: no reduction will be applied.
...
...
@@ -861,6 +907,20 @@ def rnnt_loss_pruned(
boundary
=
boundary
,
modified
=
modified
,
)
if
delay_penalty
>
0.0
:
B
,
S
,
T0
=
px
.
shape
T
=
T0
if
modified
else
T0
-
1
if
boundary
is
None
:
offset
=
torch
.
tensor
(
(
T
-
1
)
/
2
,
dtype
=
px
.
dtype
,
device
=
px
.
device
,
).
expand
(
B
,
1
,
1
)
else
:
offset
=
(
boundary
[:,
3
]
-
1
)
/
2
penalty
=
offset
.
reshape
(
B
,
1
,
1
)
-
torch
.
arange
(
T0
,
device
=
px
.
device
).
reshape
(
1
,
1
,
T0
)
penalty
=
penalty
*
delay_penalty
px
+=
penalty
.
to
(
px
.
dtype
)
negated_loss
=
mutual_information_recursion
(
px
=
px
,
py
=
py
,
boundary
=
boundary
)
if
reduction
==
"none"
:
return
-
negated_loss
...
...
@@ -1105,6 +1165,7 @@ def rnnt_loss_smoothed(
am_only_scale
:
float
=
0.1
,
boundary
:
Optional
[
Tensor
]
=
None
,
modified
:
bool
=
False
,
delay_penalty
:
float
=
0.0
,
reduction
:
Optional
[
str
]
=
"mean"
,
return_grad
:
bool
=
False
,
)
->
Union
[
Tuple
[
Tensor
,
Tuple
[
Tensor
,
Tensor
]],
Tensor
]:
...
...
@@ -1138,6 +1199,10 @@ def rnnt_loss_smoothed(
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
delay_penalty: A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time-masking
encouraging the network to delay symbols.
See https://github.com/k2-fsa/k2/issues/955 for more details.
reduction:
Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
`none`: no reduction will be applied.
...
...
@@ -1170,6 +1235,20 @@ def rnnt_loss_smoothed(
boundary
=
boundary
,
modified
=
modified
,
)
if
delay_penalty
>
0.0
:
B
,
S
,
T0
=
px
.
shape
T
=
T0
if
modified
else
T0
-
1
if
boundary
is
None
:
offset
=
torch
.
tensor
(
(
T
-
1
)
/
2
,
dtype
=
px
.
dtype
,
device
=
px
.
device
,
).
expand
(
B
,
1
,
1
)
else
:
offset
=
(
boundary
[:,
3
]
-
1
)
/
2
penalty
=
offset
.
reshape
(
B
,
1
,
1
)
-
torch
.
arange
(
T0
,
device
=
px
.
device
).
reshape
(
1
,
1
,
T0
)
penalty
=
penalty
*
delay_penalty
px
+=
penalty
.
to
(
px
.
dtype
)
scores_and_grads
=
mutual_information_recursion
(
px
=
px
,
py
=
py
,
boundary
=
boundary
,
return_grad
=
return_grad
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment