Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
62506226
Commit
62506226
authored
Jun 05, 2022
by
Guo Liyong
Browse files
fix typo in comments
parent
a283cddf
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
14 additions
and
14 deletions
+14
-14
fast_rnnt/csrc/mutual_information.h
fast_rnnt/csrc/mutual_information.h
+3
-3
fast_rnnt/csrc/mutual_information_cpu.cu
fast_rnnt/csrc/mutual_information_cpu.cu
+1
-1
fast_rnnt/csrc/mutual_information_cuda.cu
fast_rnnt/csrc/mutual_information_cuda.cu
+6
-6
fast_rnnt/python/fast_rnnt/mutual_information.py
fast_rnnt/python/fast_rnnt/mutual_information.py
+2
-2
fast_rnnt/python/fast_rnnt/rnnt_loss.py
fast_rnnt/python/fast_rnnt/rnnt_loss.py
+2
-2
No files found.
fast_rnnt/csrc/mutual_information.h
View file @
62506226
...
@@ -73,7 +73,7 @@ FT_CUDA_HOSTDEV inline float LogAdd(float x, float y) {
...
@@ -73,7 +73,7 @@ FT_CUDA_HOSTDEV inline float LogAdd(float x, float y) {
modified. `modified` can be worked out from this. In not-modified case,
modified. `modified` can be worked out from this. In not-modified case,
it can be thought of as the log-odds ratio of generating the next x in
it can be thought of as the log-odds ratio of generating the next x in
the sequence, i.e.
the sequence, i.e.
x
y
[b][s][t] is the log of
p
x[b][s][t] is the log of
p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s),
p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s),
i.e. the log-prob of generating x_s given subsequences of
i.e. the log-prob of generating x_s given subsequences of
lengths (s, t), divided by the prior probability of generating x_s.
lengths (s, t), divided by the prior probability of generating x_s.
...
@@ -94,7 +94,7 @@ FT_CUDA_HOSTDEV inline float LogAdd(float x, float y) {
...
@@ -94,7 +94,7 @@ FT_CUDA_HOSTDEV inline float LogAdd(float x, float y) {
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1])
p[b,s,t-1] + py[b,s,t-1])
... treating values with any -1 index as -infinity.
... treating values with any -1 index as -infinity.
.. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
.. if `boundary` is set, we start f
r
om p[b,s_begin,t_begin]=0.0.
@param boundary If set, a tensor of shape [B][4] of type int64_t, which
@param boundary If set, a tensor of shape [B][4] of type int64_t, which
contains, where for each batch element b, boundary[b]
contains, where for each batch element b, boundary[b]
equals [s_begin, t_begin, s_end, t_end]
equals [s_begin, t_begin, s_end, t_end]
...
@@ -108,7 +108,7 @@ FT_CUDA_HOSTDEV inline float LogAdd(float x, float y) {
...
@@ -108,7 +108,7 @@ FT_CUDA_HOSTDEV inline float LogAdd(float x, float y) {
and (boundary[b][2], boundary[b][3]) otherwise.
and (boundary[b][2], boundary[b][3]) otherwise.
`ans` represents the mutual information between each pair of
`ans` represents the mutual information between each pair of
sequences (i.e. x[b] and y[b], although the sequences are not
sequences (i.e. x[b] and y[b], although the sequences are not
supplied directy to this function).
supplied direct
l
y to this function).
The block-dim and grid-dim must both be 1-dimensional, and the block-dim must
The block-dim and grid-dim must both be 1-dimensional, and the block-dim must
be at least 128.
be at least 128.
...
...
fast_rnnt/csrc/mutual_information_cpu.cu
View file @
62506226
...
@@ -25,7 +25,7 @@ namespace fast_rnnt {
...
@@ -25,7 +25,7 @@ namespace fast_rnnt {
// forward of mutual_information. See """... """ comment of
// forward of mutual_information. See """... """ comment of
// `mutual_information_recursion` in
// `mutual_information_recursion` in
// in
k2/
python/
k2
/mutual_information.py for documentation of the
// in python/
fast_rnnt
/mutual_information.py for documentation of the
// behavior of this function.
// behavior of this function.
// px: of shape [B, S, T+1] if !modified, else [B, S, T] <-- work out
// px: of shape [B, S, T+1] if !modified, else [B, S, T] <-- work out
...
...
fast_rnnt/csrc/mutual_information_cuda.cu
View file @
62506226
...
@@ -43,7 +43,7 @@ namespace fast_rnnt {
...
@@ -43,7 +43,7 @@ namespace fast_rnnt {
px: Tensor of shape [B][S][T + 1], if !modified; [B][S][T] if modified;
px: Tensor of shape [B][S][T + 1], if !modified; [B][S][T] if modified;
may be interpreted as the log-odds ratio of
may be interpreted as the log-odds ratio of
generating the next x in the sequence, i.e.
generating the next x in the sequence, i.e.
x
y
[b][s][t] is the log of
p
x[b][s][t] is the log of
p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s),
p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s),
i.e. the log-prob of generating x_s given subsequences of lengths
i.e. the log-prob of generating x_s given subsequences of lengths
(s, t), divided by the prior probability of generating x_s. (See
(s, t), divided by the prior probability of generating x_s. (See
...
@@ -65,7 +65,7 @@ namespace fast_rnnt {
...
@@ -65,7 +65,7 @@ namespace fast_rnnt {
p[b,s,t-1] + py[b,s,t-1]) (eq. 0)
p[b,s,t-1] + py[b,s,t-1]) (eq. 0)
treating values with any -1 index as -infinity.
treating values with any -1 index as -infinity.
.. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
.. if `boundary` is set, we start f
r
om p[b,s_begin,t_begin]=0.0.
boundary: If set, a tensor of shape [B][4] of type int64_t, which
boundary: If set, a tensor of shape [B][4] of type int64_t, which
contains, where for each batch element b, boundary[b] equals
contains, where for each batch element b, boundary[b] equals
[s_begin, t_begin, s_end, t_end]
[s_begin, t_begin, s_end, t_end]
...
@@ -79,7 +79,7 @@ namespace fast_rnnt {
...
@@ -79,7 +79,7 @@ namespace fast_rnnt {
and (boundary[b][2], boundary[b][3]) otherwise.
and (boundary[b][2], boundary[b][3]) otherwise.
`ans` represents the mutual information between each pair of
`ans` represents the mutual information between each pair of
sequences (i.e. x[b] and y[b], although the sequences are not
sequences (i.e. x[b] and y[b], although the sequences are not
supplied directy to this function).
supplied direct
l
y to this function).
The block-dim and grid-dim must both be 1-dimensional, and the block-dim must
The block-dim and grid-dim must both be 1-dimensional, and the block-dim must
be at least 128.
be at least 128.
...
@@ -274,7 +274,7 @@ __global__ void mutual_information_kernel(
...
@@ -274,7 +274,7 @@ __global__ void mutual_information_kernel(
// and (2, 1); and so on. Note: not many threads participate in this
// and (2, 1); and so on. Note: not many threads participate in this
// part, only up to BLOCK_SIZE at most. Unfortunately we couldn't figure
// part, only up to BLOCK_SIZE at most. Unfortunately we couldn't figure
// out a very meaningful way for more threads to do work, that looked like
// out a very meaningful way for more threads to do work, that looked like
// it would really spe
a
d things up.
// it would really spe
e
d things up.
// So this kernel does (2 * BLOCK_SIZE) iterations, which may seem a lot,
// So this kernel does (2 * BLOCK_SIZE) iterations, which may seem a lot,
// but we do at least do the I/O in an efficient way and keep the
// but we do at least do the I/O in an efficient way and keep the
// inner loop simple and fast (e.g. no exp() or log()).
// inner loop simple and fast (e.g. no exp() or log()).
...
@@ -418,7 +418,7 @@ __global__ void mutual_information_backward_kernel(
...
@@ -418,7 +418,7 @@ __global__ void mutual_information_backward_kernel(
// be any sufficiently large number but will actually be:
// be any sufficiently large number but will actually be:
// num_s_blocks + num_t_blocks - 1 where num_s_blocks = S /
// num_s_blocks + num_t_blocks - 1 where num_s_blocks = S /
// BLOCK_SIZE + 1 and num_t_blocks = T / BLOCK_SIZE + 1
// BLOCK_SIZE + 1 and num_t_blocks = T / BLOCK_SIZE + 1
bool
overwrite_ans_grad
)
{
// If overwite_ans_grad == true, this function
bool
overwrite_ans_grad
)
{
// If overw
r
ite_ans_grad == true, this function
// will overwrite ans_grad with a value which,
// will overwrite ans_grad with a value which,
// if everything is working correctly, should be
// if everything is working correctly, should be
// identical or very close to the value of
// identical or very close to the value of
...
@@ -554,7 +554,7 @@ __global__ void mutual_information_backward_kernel(
...
@@ -554,7 +554,7 @@ __global__ void mutual_information_backward_kernel(
// We can apply this formula to the entire block even if we are processing
// We can apply this formula to the entire block even if we are processing
// a partial block; we have ensured that x_buf and y_buf contain
// a partial block; we have ensured that x_buf and y_buf contain
// -infinity, and p contains 0, for out-of-range elements, so we'll get
// -infinity, and p contains 0, for out-of-range elements, so we'll get
// x_buf and y_buf containing 0 after applying the followin formulas.
// x_buf and y_buf containing 0 after applying the followin
g
formulas.
int
s
=
i
/
BLOCK_SIZE
,
t
=
i
%
BLOCK_SIZE
;
int
s
=
i
/
BLOCK_SIZE
,
t
=
i
%
BLOCK_SIZE
;
// Mathematically the following is doing:
// Mathematically the following is doing:
// term1(b,s,t) = exp(p[b,s,t] + px[b,s,t] - p[b,s+1,t-t_offset]) (4a)
// term1(b,s,t) = exp(p[b,s,t] + px[b,s,t] - p[b,s+1,t-t_offset]) (4a)
...
...
fast_rnnt/python/fast_rnnt/mutual_information.py
View file @
62506226
...
@@ -44,7 +44,7 @@ class MutualInformationRecursionFunction(torch.autograd.Function):
...
@@ -44,7 +44,7 @@ class MutualInformationRecursionFunction(torch.autograd.Function):
A torch.Tensor of some floating point type, with shape
A torch.Tensor of some floating point type, with shape
``[B][S][T+1]`` where ``B`` is the batch size, ``S`` is the
``[B][S][T+1]`` where ``B`` is the batch size, ``S`` is the
length of the ``x`` sequence (including representations of
length of the ``x`` sequence (including representations of
``EOS`` symbols but not ``BOS`` symbols), and ``
S
`` is the
``EOS`` symbols but not ``BOS`` symbols), and ``
T
`` is the
length of the ``y`` sequence (including representations of
length of the ``y`` sequence (including representations of
``EOS`` symbols but not ``BOS`` symbols). In the mutual
``EOS`` symbols but not ``BOS`` symbols). In the mutual
information application, ``px[b][s][t]`` would represent the
information application, ``px[b][s][t]`` would represent the
...
@@ -199,7 +199,7 @@ def mutual_information_recursion(
...
@@ -199,7 +199,7 @@ def mutual_information_recursion(
A torch.Tensor of some floating point type, with shape ``[B][S][T+1]``,
A torch.Tensor of some floating point type, with shape ``[B][S][T+1]``,
where ``B`` is the batch size, ``S`` is the length of the ``x`` sequence
where ``B`` is the batch size, ``S`` is the length of the ``x`` sequence
(including representations of ``EOS`` symbols but not ``BOS`` symbols),
(including representations of ``EOS`` symbols but not ``BOS`` symbols),
and ``
S
`` is the length of the ``y`` sequence (including representations
and ``
T
`` is the length of the ``y`` sequence (including representations
of ``EOS`` symbols but not ``BOS`` symbols). In the mutual information
of ``EOS`` symbols but not ``BOS`` symbols). In the mutual information
application, ``px[b][s][t]`` would represent the following log odds
application, ``px[b][s][t]`` would represent the following log odds
ratio; ignoring the b index on the right to make the notation more
ratio; ignoring the b index on the right to make the notation more
...
...
fast_rnnt/python/fast_rnnt/rnnt_loss.py
View file @
62506226
...
@@ -1042,8 +1042,8 @@ def get_rnnt_logprobs_smoothed(
...
@@ -1042,8 +1042,8 @@ def get_rnnt_logprobs_smoothed(
py
=
py_am
+
py_lm
-
normalizers
py
=
py_am
+
py_lm
-
normalizers
py_lm_unigram
=
unigram_lm
[
0
][
0
][
termination_symbol
]
# scalar, normalized..
py_lm_unigram
=
unigram_lm
[
0
][
0
][
termination_symbol
]
# scalar, normalized..
py_amonly
=
py_am
+
py_lm_unigram
-
amonly_normalizers
# [B][
S+
1][T]
py_amonly
=
py_am
+
py_lm_unigram
-
amonly_normalizers
# [B][1][T]
py_lmonly
=
py_lm
-
lmonly_normalizers
# [B][S+1][
T
]
py_lmonly
=
py_lm
-
lmonly_normalizers
# [B][S+1][
1
]
combined_scale
=
1.0
-
lm_only_scale
-
am_only_scale
combined_scale
=
1.0
-
lm_only_scale
-
am_only_scale
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment