Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
62506226
You need to sign in or sign up before continuing.
Commit
62506226
authored
Jun 05, 2022
by
Guo Liyong
Browse files
fix typo in comments
parent
a283cddf
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
14 additions
and
14 deletions
+14
-14
fast_rnnt/csrc/mutual_information.h
fast_rnnt/csrc/mutual_information.h
+3
-3
fast_rnnt/csrc/mutual_information_cpu.cu
fast_rnnt/csrc/mutual_information_cpu.cu
+1
-1
fast_rnnt/csrc/mutual_information_cuda.cu
fast_rnnt/csrc/mutual_information_cuda.cu
+6
-6
fast_rnnt/python/fast_rnnt/mutual_information.py
fast_rnnt/python/fast_rnnt/mutual_information.py
+2
-2
fast_rnnt/python/fast_rnnt/rnnt_loss.py
fast_rnnt/python/fast_rnnt/rnnt_loss.py
+2
-2
No files found.
fast_rnnt/csrc/mutual_information.h
View file @
62506226
...
@@ -73,7 +73,7 @@ FT_CUDA_HOSTDEV inline float LogAdd(float x, float y) {
...
@@ -73,7 +73,7 @@ FT_CUDA_HOSTDEV inline float LogAdd(float x, float y) {
modified. `modified` can be worked out from this. In not-modified case,
modified. `modified` can be worked out from this. In not-modified case,
it can be thought of as the log-odds ratio of generating the next x in
it can be thought of as the log-odds ratio of generating the next x in
the sequence, i.e.
the sequence, i.e.
x
y
[b][s][t] is the log of
p
x[b][s][t] is the log of
p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s),
p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s),
i.e. the log-prob of generating x_s given subsequences of
i.e. the log-prob of generating x_s given subsequences of
lengths (s, t), divided by the prior probability of generating x_s.
lengths (s, t), divided by the prior probability of generating x_s.
...
@@ -94,7 +94,7 @@ FT_CUDA_HOSTDEV inline float LogAdd(float x, float y) {
...
@@ -94,7 +94,7 @@ FT_CUDA_HOSTDEV inline float LogAdd(float x, float y) {
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1])
p[b,s,t-1] + py[b,s,t-1])
... treating values with any -1 index as -infinity.
... treating values with any -1 index as -infinity.
.. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
.. if `boundary` is set, we start f
r
om p[b,s_begin,t_begin]=0.0.
@param boundary If set, a tensor of shape [B][4] of type int64_t, which
@param boundary If set, a tensor of shape [B][4] of type int64_t, which
contains, where for each batch element b, boundary[b]
contains, where for each batch element b, boundary[b]
equals [s_begin, t_begin, s_end, t_end]
equals [s_begin, t_begin, s_end, t_end]
...
@@ -108,7 +108,7 @@ FT_CUDA_HOSTDEV inline float LogAdd(float x, float y) {
...
@@ -108,7 +108,7 @@ FT_CUDA_HOSTDEV inline float LogAdd(float x, float y) {
and (boundary[b][2], boundary[b][3]) otherwise.
and (boundary[b][2], boundary[b][3]) otherwise.
`ans` represents the mutual information between each pair of
`ans` represents the mutual information between each pair of
sequences (i.e. x[b] and y[b], although the sequences are not
sequences (i.e. x[b] and y[b], although the sequences are not
supplied directy to this function).
supplied direct
l
y to this function).
The block-dim and grid-dim must both be 1-dimensional, and the block-dim must
The block-dim and grid-dim must both be 1-dimensional, and the block-dim must
be at least 128.
be at least 128.
...
...
fast_rnnt/csrc/mutual_information_cpu.cu
View file @
62506226
...
@@ -25,7 +25,7 @@ namespace fast_rnnt {
...
@@ -25,7 +25,7 @@ namespace fast_rnnt {
// forward of mutual_information. See """... """ comment of
// forward of mutual_information. See """... """ comment of
// `mutual_information_recursion` in
// `mutual_information_recursion` in
// in
k2/
python/
k2
/mutual_information.py for documentation of the
// in python/
fast_rnnt
/mutual_information.py for documentation of the
// behavior of this function.
// behavior of this function.
// px: of shape [B, S, T+1] if !modified, else [B, S, T] <-- work out
// px: of shape [B, S, T+1] if !modified, else [B, S, T] <-- work out
...
...
fast_rnnt/csrc/mutual_information_cuda.cu
View file @
62506226
...
@@ -43,7 +43,7 @@ namespace fast_rnnt {
...
@@ -43,7 +43,7 @@ namespace fast_rnnt {
px: Tensor of shape [B][S][T + 1], if !modified; [B][S][T] if modified;
px: Tensor of shape [B][S][T + 1], if !modified; [B][S][T] if modified;
may be interpreted as the log-odds ratio of
may be interpreted as the log-odds ratio of
generating the next x in the sequence, i.e.
generating the next x in the sequence, i.e.
x
y
[b][s][t] is the log of
p
x[b][s][t] is the log of
p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s),
p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s),
i.e. the log-prob of generating x_s given subsequences of lengths
i.e. the log-prob of generating x_s given subsequences of lengths
(s, t), divided by the prior probability of generating x_s. (See
(s, t), divided by the prior probability of generating x_s. (See
...
@@ -65,7 +65,7 @@ namespace fast_rnnt {
...
@@ -65,7 +65,7 @@ namespace fast_rnnt {
p[b,s,t-1] + py[b,s,t-1]) (eq. 0)
p[b,s,t-1] + py[b,s,t-1]) (eq. 0)
treating values with any -1 index as -infinity.
treating values with any -1 index as -infinity.
.. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
.. if `boundary` is set, we start f
r
om p[b,s_begin,t_begin]=0.0.
boundary: If set, a tensor of shape [B][4] of type int64_t, which
boundary: If set, a tensor of shape [B][4] of type int64_t, which
contains, where for each batch element b, boundary[b] equals
contains, where for each batch element b, boundary[b] equals
[s_begin, t_begin, s_end, t_end]
[s_begin, t_begin, s_end, t_end]
...
@@ -79,7 +79,7 @@ namespace fast_rnnt {
...
@@ -79,7 +79,7 @@ namespace fast_rnnt {
and (boundary[b][2], boundary[b][3]) otherwise.
and (boundary[b][2], boundary[b][3]) otherwise.
`ans` represents the mutual information between each pair of
`ans` represents the mutual information between each pair of
sequences (i.e. x[b] and y[b], although the sequences are not
sequences (i.e. x[b] and y[b], although the sequences are not
supplied directy to this function).
supplied direct
l
y to this function).
The block-dim and grid-dim must both be 1-dimensional, and the block-dim must
The block-dim and grid-dim must both be 1-dimensional, and the block-dim must
be at least 128.
be at least 128.
...
@@ -274,7 +274,7 @@ __global__ void mutual_information_kernel(
...
@@ -274,7 +274,7 @@ __global__ void mutual_information_kernel(
// and (2, 1); and so on. Note: not many threads participate in this
// and (2, 1); and so on. Note: not many threads participate in this
// part, only up to BLOCK_SIZE at most. Unfortunately we couldn't figure
// part, only up to BLOCK_SIZE at most. Unfortunately we couldn't figure
// out a very meaningful way for more threads to do work, that looked like
// out a very meaningful way for more threads to do work, that looked like
// it would really spe
a
d things up.
// it would really spe
e
d things up.
// So this kernel does (2 * BLOCK_SIZE) iterations, which may seem a lot,
// So this kernel does (2 * BLOCK_SIZE) iterations, which may seem a lot,
// but we do at least do the I/O in an efficient way and keep the
// but we do at least do the I/O in an efficient way and keep the
// inner loop simple and fast (e.g. no exp() or log()).
// inner loop simple and fast (e.g. no exp() or log()).
...
@@ -418,7 +418,7 @@ __global__ void mutual_information_backward_kernel(
...
@@ -418,7 +418,7 @@ __global__ void mutual_information_backward_kernel(
// be any sufficiently large number but will actually be:
// be any sufficiently large number but will actually be:
// num_s_blocks + num_t_blocks - 1 where num_s_blocks = S /
// num_s_blocks + num_t_blocks - 1 where num_s_blocks = S /
// BLOCK_SIZE + 1 and num_t_blocks = T / BLOCK_SIZE + 1
// BLOCK_SIZE + 1 and num_t_blocks = T / BLOCK_SIZE + 1
bool
overwrite_ans_grad
)
{
// If overwite_ans_grad == true, this function
bool
overwrite_ans_grad
)
{
// If overw
r
ite_ans_grad == true, this function
// will overwrite ans_grad with a value which,
// will overwrite ans_grad with a value which,
// if everything is working correctly, should be
// if everything is working correctly, should be
// identical or very close to the value of
// identical or very close to the value of
...
@@ -554,7 +554,7 @@ __global__ void mutual_information_backward_kernel(
...
@@ -554,7 +554,7 @@ __global__ void mutual_information_backward_kernel(
// We can apply this formula to the entire block even if we are processing
// We can apply this formula to the entire block even if we are processing
// a partial block; we have ensured that x_buf and y_buf contain
// a partial block; we have ensured that x_buf and y_buf contain
// -infinity, and p contains 0, for out-of-range elements, so we'll get
// -infinity, and p contains 0, for out-of-range elements, so we'll get
// x_buf and y_buf containing 0 after applying the followin formulas.
// x_buf and y_buf containing 0 after applying the followin
g
formulas.
int
s
=
i
/
BLOCK_SIZE
,
t
=
i
%
BLOCK_SIZE
;
int
s
=
i
/
BLOCK_SIZE
,
t
=
i
%
BLOCK_SIZE
;
// Mathematically the following is doing:
// Mathematically the following is doing:
// term1(b,s,t) = exp(p[b,s,t] + px[b,s,t] - p[b,s+1,t-t_offset]) (4a)
// term1(b,s,t) = exp(p[b,s,t] + px[b,s,t] - p[b,s+1,t-t_offset]) (4a)
...
...
fast_rnnt/python/fast_rnnt/mutual_information.py
View file @
62506226
...
@@ -44,7 +44,7 @@ class MutualInformationRecursionFunction(torch.autograd.Function):
...
@@ -44,7 +44,7 @@ class MutualInformationRecursionFunction(torch.autograd.Function):
A torch.Tensor of some floating point type, with shape
A torch.Tensor of some floating point type, with shape
``[B][S][T+1]`` where ``B`` is the batch size, ``S`` is the
``[B][S][T+1]`` where ``B`` is the batch size, ``S`` is the
length of the ``x`` sequence (including representations of
length of the ``x`` sequence (including representations of
``EOS`` symbols but not ``BOS`` symbols), and ``
S
`` is the
``EOS`` symbols but not ``BOS`` symbols), and ``
T
`` is the
length of the ``y`` sequence (including representations of
length of the ``y`` sequence (including representations of
``EOS`` symbols but not ``BOS`` symbols). In the mutual
``EOS`` symbols but not ``BOS`` symbols). In the mutual
information application, ``px[b][s][t]`` would represent the
information application, ``px[b][s][t]`` would represent the
...
@@ -199,7 +199,7 @@ def mutual_information_recursion(
...
@@ -199,7 +199,7 @@ def mutual_information_recursion(
A torch.Tensor of some floating point type, with shape ``[B][S][T+1]``,
A torch.Tensor of some floating point type, with shape ``[B][S][T+1]``,
where ``B`` is the batch size, ``S`` is the length of the ``x`` sequence
where ``B`` is the batch size, ``S`` is the length of the ``x`` sequence
(including representations of ``EOS`` symbols but not ``BOS`` symbols),
(including representations of ``EOS`` symbols but not ``BOS`` symbols),
and ``
S
`` is the length of the ``y`` sequence (including representations
and ``
T
`` is the length of the ``y`` sequence (including representations
of ``EOS`` symbols but not ``BOS`` symbols). In the mutual information
of ``EOS`` symbols but not ``BOS`` symbols). In the mutual information
application, ``px[b][s][t]`` would represent the following log odds
application, ``px[b][s][t]`` would represent the following log odds
ratio; ignoring the b index on the right to make the notation more
ratio; ignoring the b index on the right to make the notation more
...
...
fast_rnnt/python/fast_rnnt/rnnt_loss.py
View file @
62506226
...
@@ -1042,8 +1042,8 @@ def get_rnnt_logprobs_smoothed(
...
@@ -1042,8 +1042,8 @@ def get_rnnt_logprobs_smoothed(
py
=
py_am
+
py_lm
-
normalizers
py
=
py_am
+
py_lm
-
normalizers
py_lm_unigram
=
unigram_lm
[
0
][
0
][
termination_symbol
]
# scalar, normalized..
py_lm_unigram
=
unigram_lm
[
0
][
0
][
termination_symbol
]
# scalar, normalized..
py_amonly
=
py_am
+
py_lm_unigram
-
amonly_normalizers
# [B][
S+
1][T]
py_amonly
=
py_am
+
py_lm_unigram
-
amonly_normalizers
# [B][1][T]
py_lmonly
=
py_lm
-
lmonly_normalizers
# [B][S+1][
T
]
py_lmonly
=
py_lm
-
lmonly_normalizers
# [B][S+1][
1
]
combined_scale
=
1.0
-
lm_only_scale
-
am_only_scale
combined_scale
=
1.0
-
lm_only_scale
-
am_only_scale
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment