Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
8ed6deff
Commit
8ed6deff
authored
Jul 29, 2021
by
Daniel Povey
Browse files
More progress, nearly done but not compiled
parent
9f929ab3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
119 additions
and
68 deletions
+119
-68
torch_mutual_information/mutual_information_cuda.cpp
torch_mutual_information/mutual_information_cuda.cpp
+55
-7
torch_mutual_information/mutual_information_cuda_kernel.cu
torch_mutual_information/mutual_information_cuda_kernel.cu
+64
-61
No files found.
torch_mutual_information/mutual_information_cuda.cpp
View file @
8ed6deff
#include <torch/extension.h>
// forward of mutual_information. """... """ comment of `mutual_information`
// in mutual_information.py documents the behavior of this function.
// It is the core recursion in the sequence-to-sequence mutual information
// computation.
// returns 'ans', of dimension B (batch size).
/*
Forward of mutual_information. See also """... """ comment of
`mutual_information` in mutual_information.py. This It is the core recursion
in the sequence-to-sequence mutual information computation.
Args:
px: Tensor of shape [B][S][T + 1]; contains the log-odds ratio of
generating the next x in the sequence, i.e.
xy[b][s][t] is the log of
p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s),
i.e. the log-prob of generating x_s given subsequences of lengths
(s, t), divided by the prior probability of generating x_s. (See
mutual_information.py for more info).
py: The log-odds ratio of generating the next y in the sequence.
Shape [B][S + 1][T]
p: This function writes to p[b][s][t] the mutual information between
sub-sequences of x and y of length s and t respectively, from the
b'th sequences in the batch. Its shape is [B][S + 1][T + 1].
Concretely, this function implements the following recursion,
in the case where s_begin == t_begin == 0:
p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
if s > 0 or t > 0,
treating values with any -1 index as -infinity.
.. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
boundary: If set, a tensor of shape [B][4] of type int64_t, which
contains, where for each batch element b, boundary[b] equals
[s_begin, t_begin, s_end, t_end]
which are the beginning and end (i.e. one-past-the-last) of the
x and y sequences that we should process. If not set, these
default to (0, 0, S, T); and they should not exceed these bounds.
ans: a tensor `ans` of shape [B], where this function will set
ans[b] = p[b][s_end][t_end],
with s_end and t_end being (S, T) if `boundary` was specified,
and (boundary[b][2], boundary[b][3]) otherwise.
`ans` represents the mutual information between each pair of
sequences (i.e. x[b] and y[b], although the sequences are not
supplied directy to this function).
The block-dim and grid-dim must both be 1-dimensional, and the block-dim must
be at least 128.
*/
torch
::
Tensor
mutual_information_cuda
(
torch
::
Tensor
px
,
// [B][S][T+1]
torch
::
Tensor
py
,
// [B][S+1][T]
std
::
optional
<
torch
::
Tensor
>
boundary_info
,
// [B][4], int64_t.
torch
::
Tensor
p
);
// [B][S+1][T+1]; an output
// backward of mutual_information; returns (grad_px, grad_py)
/*
backward of mutual_information; returns (grad_px, grad_py)
if overwrite_ans_grad == true, this function will overwrite ans_grad with a
value that, if the computation worked correctly, should be identical to or
very close to the value of ans_grad at entry. This can be used
to validate the correctness of this code.
*/
std
::
vector
<
torch
::
Tensor
>
mutual_information_backward_cuda
(
torch
::
Tensor
px
,
torch
::
Tensor
py
,
std
::
optional
<
torch
::
Tensor
>
boundary_info
,
torch
::
Tensor
p
,
torch
::
Tensor
ans_grad
);
torch
::
Tensor
ans_grad
,
bool
overwrite_ans_grad
);
...
...
torch_mutual_information/mutual_information_cuda_kernel.cu
View file @
8ed6deff
...
...
@@ -8,7 +8,6 @@
// returns log(exp(x) + exp(y)).
__forceinline__
__device__
double
LogAdd
(
double
x
,
double
y
)
{
double
diff
;
if
(
x
<
y
)
{
diff
=
x
-
y
;
x
=
y
;
...
...
@@ -44,71 +43,59 @@ __forceinline__ __device__ inline float LogAdd(float x, float y) {
/*
Forward of mutual_information. Each thread block computes blocks of the 'p'
array of (s, t) shape equal to (BLOCK_SIZE, BLOCK_SIZE), e.g. (32, 32).
Thread
blocks loop over such blocks, but they might loop only once if there is
Thread
-
blocks loop over such blocks, but they might loop only once if there is
not that much data to process. We sequentially launch thread groups in
such a way that thread-blocks within a group do not depend on each other
(see the "iter" parameter).
(see the "iter" parameter). The blocks of the 'image' (i.e. of the p matrix)
that each group handles are arranged in a diagonal.
Template args:
scalar_t: the floating-point type, e.g. float, double, maybe half.
scalar_t: the floating-point type, e.g. float, double; maybe eventually
half, although I think we don't support LogAdd for half yet.
BLOCK_SIZE: an integer power of two no greater than 32 (this limitation
is because we assume BLOCK_SIZE + 1 <= 64 in some data-loading
code).
Args:
px: log-odds ratio of generating next x in the sequence, i.e.
xy[b][s][t] is the log-odds probability of generating x_t of
the b'th image given subsequences of length (s, t). (See
mutual_information.py for more info). Shape [B][S][T + 1]
py: log-odds ratio of generating next y in the sequence.
px: Tensor of shape [B][S][T + 1]; contains the log-odds ratio of
generating the next x in the sequence, i.e.
xy[b][s][t] is the log of
p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s),
i.e. the log-prob of generating x_s given subsequences of lengths
(s, t), divided by the prior probability of generating x_s. (See
mutual_information.py for more info).
py: The log-odds ratio of generating the next y in the sequence.
Shape [B][S + 1][T]
p: This function writes to p[s][t] the mutual information between
sub-sequences of x and y of length s and t respectively.
Its shape is [B][S + 1][T + 1]. This function implements
the following recursion:
p: This function writes to p[b][s][t] the mutual information between
sub-sequences of x and y of length s and t respectively, from the
b'th sequences in the batch. Its shape is [B][S + 1][T + 1].
Concretely, this function implements the following recursion,
in the case where s_begin == t_begin == 0:
p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
(
if s > 0 or t > 0
)
if s > 0 or t > 0
,
treating values with any -1 index as -infinity.
.. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
boundary: If set, a tensor of shape [B][4] of type int64_t, which
contains, for each batch element, [s_begin, t_begin, s_end, t_end]
which are the beginning and end (one-past-the-last) of the
contains, where for each batch element b, boundary[b] equals
[s_begin, t_begin, s_end, t_end]
which are the beginning and end (i.e. one-past-the-last) of the
x and y sequences that we should process. If not set, these
default to (0, 0, S, T), and they should not exceed these bounds
or be empty (i.e. s_begin <= t_begin or s_end <= t_end).
nput: input image, shape (B, C, T) where B is batch size, C is
the number of channels and T is the time axis. (For more-than-1d
convolution setups, T would really be more than 1 axis, reshaped).
params: of shape (C, N+1) where N is the number of linear regions in the
piecewise linear function; params[c][0] is l which is
a log scale parameter that dictates how far apart
the discontinuities in the piecewise linear function are,
and params[c][n+1] for 0 <= n < N are the derivatives
of the linear parts of the piecewise linear function.
The discontinuities of the function are at:
exp(l) * [ -(N/2 - 1), -(N/2 - 2), ... (N/2 - 1) ]
output: The transformed input, shape (B , C, T)
images_per_thread_block: The number of images processed by each thread
block. The calling code must guarantee that this is a power
of 2, and that EITHER:
THREADS_PER_BLOCK / images_per_thread_block >= T
OR
images_per_thread_block == 1
.. this is used for a small optimization.
This kernel is allocated with `extern_buf` containing enough memory
to store 2*N + 3 values of type scalar_t.
default to (0, 0, S, T); and they should not exceed these bounds.
ans: a tensor `ans` of shape [B], where this function will set
ans[b] = p[b][s_end][t_end],
with s_end and t_end being (S, T) if `boundary` was specified,
and (boundary[b][2], boundary[b][3]) otherwise.
`ans` represents the mutual information between each pair of
sequences (i.e. x[b] and y[b], although the sequences are not
supplied directy to this function).
The block-dim and grid-dim must both be 1-dimensional, and the block-dim must
be at least 128.
*/
*/
template
<
typename
scalar_t
,
int
BLOCK_SIZE
>
// e.g. BLOCK_SIZE == 16 or 32. Note: we require the
// num-threads be at least 128.
int
BLOCK_SIZE
>
// e.g. BLOCK_SIZE == 16 or 32.
__global__
void
mutual_information_kernel
(
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
px
,
// B, S, T + 1, i.e. batch, x_seq_length, y_seq_length + 1
...
...
@@ -450,8 +437,8 @@ void mutual_information_kernel(
epx_grad = px_grad / epx and epy_grad = py_grad / epy, and writing exp(p) for p and so on,
the above becomes
px_grad[b][s][t] / exp(px[b][s][t]) = p_grad[b][s + 1][t] / exp(p[b][s + 1][t] * exp(p[b][s][t])
py_grad[b][s][t] / exp(py[b][s][t]) = p_grad[b][s][t + 1] / exp(p[b][s][t + 1] * exp(p[b][s][t])
px_grad[b][s][t] / exp(px[b][s][t]) = p_grad[b][s + 1][t] / exp(p[b][s + 1][t]
)
* exp(p[b][s][t])
py_grad[b][s][t] / exp(py[b][s][t]) = p_grad[b][s][t + 1] / exp(p[b][s][t + 1]
)
* exp(p[b][s][t])
Rearranging:
px_grad[b][s][t] = p_grad[b][s + 1][t] * exp(p[b][s][t] + px[b][s][t] - p[b][s + 1][t]) (eq. 3a)
py_grad[b][s][t] = p_grad[b][s][t + 1] * exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1]) (eq. 3b)
...
...
@@ -485,15 +472,21 @@ void mutual_information_backward_kernel(
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
py
,
// B, S + 1, T.
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
p
,
// B, S + 1, T + 1. Produced in forward pass.
torch
::
PackedTensorAccessor32
<
scalar_t
,
1
>
ans_grad
,
// [B]. This is an input.
torch
::
PackedTensorAccessor32
<
scalar_t
,
1
>
ans_grad_compare
,
// [B]. A value will be written to here which
// should ideally equal ans_grad.
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
p_grad
,
// B, S + 1, T + 1. This is a temporary.
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
px_grad
,
// B, S, T + 1.
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
py_grad
,
// B, S + 1, T.
torch
::
PackedTensorAccessor32
<
int64_t
,
2
>
boundary
,
// B, 4; or 0, 0 if boundaries are the defaults (0, 0, S, T)
int
iter
)
{
// This kernel is sequentially called with 'iter' = num_iters
// - 1, num_iters - 2, .. 0, where num_iters can be taken to
// be any sufficiently large number but will actually be:
// num_s_blocks + num_t_blocks - 1 where num_s_blocks = S /
// BLOCK_SIZE + 1 and num_t_blocks = T / BLOCK_SIZE + 1
int
iter
,
// This kernel is sequentially called with 'iter' = num_iters
// - 1, num_iters - 2, .. 0, where num_iters can be taken to
// be any sufficiently large number but will actually be:
// num_s_blocks + num_t_blocks - 1 where num_s_blocks = S /
// BLOCK_SIZE + 1 and num_t_blocks = T / BLOCK_SIZE + 1
bool
overwrite_ans_grad
)
{
// If true, overwrite ans_grad with a value
// which, if everything is working correctly,
// should be identical or very close to the
// value of ans_grad that was passed in.
const
int
B
=
px
.
size
(
0
),
S
=
px
.
size
(
1
),
T
=
py
.
size
(
2
);
...
...
@@ -715,14 +708,13 @@ void mutual_information_backward_kernel(
}
if
(
threadIdx
.
x
==
0
&&
s_block_begin
==
s_begin
&&
t_block_
end
==
t_
en
d
)
t_block_
begin
==
t_
begin
&&
overwrite_ans_gra
d
)
ans_grad
[
b
]
=
p_buf
[
0
][
0
];
}
}
// forward of mutual_information. See """... """ comment of `mutual_information` in
// mutual_information.py for documentation of the behavior of this function.
torch
::
Tensor
mutual_information_cuda
(
torch
::
Tensor
px
,
...
...
@@ -752,6 +744,9 @@ torch::Tensor mutual_information_cuda(torch::Tensor px,
num_blocks
=
128
,
BLOCK_SIZE
=
32
;
// The blocks cover the 'p' matrix, which is of size (B, S+1, T+1),
// so dividing by BLOCK_SIZE rounding up we get e.g.
// (S+1 + BLOCK_SIZE-1) / BLOCK_SIZE == S / BLOCK_SIZE + 1
const
int
num_s_blocks
=
S
/
BLOCK_SIZE
+
1
,
num_t_blocks
=
T
/
BLOCK_SIZE
+
1
,
num_iters
=
num_s_blocks
+
num_t_blocks
-
1
;
...
...
@@ -777,11 +772,15 @@ torch::Tensor mutual_information_cuda(torch::Tensor px,
// backward of mutual_information; returns (grad_px, grad_py)
// If overwrite_ans_grad == true, will overwrite ans_grad with a value which
// should be identical to the original ans_grad if the computation worked
// as it should.
torch
::
Tensor
mutual_information_backward_cuda
(
torch
::
Tensor
px
,
torch
::
Tensor
py
,
std
::
optional
<
torch
::
Tensor
>
optional_boundary
,
torch
::
Tensor
p
,
torch
::
Tensor
ans_grad
)
{
torch
::
Tensor
ans_grad
,
bool
overwrite_ans_grad
)
{
TORCH_CHECK
(
px
.
dim
()
==
3
,
"px must be 3-dimensional"
);
TORCH_CHECK
(
py
.
dim
()
==
3
,
"py must be 3-dimensional."
);
TORCH_CHECK
(
p
.
dim
()
==
3
,
"p must be 3-dimensional."
);
...
...
@@ -813,6 +812,9 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
num_blocks
=
128
,
BLOCK_SIZE
=
32
;
// The blocks cover the 'p' matrix, which is of size (B, S+1, T+1),
// so dividing by BLOCK_SIZE rounding up we get e.g.
// (S+1 + BLOCK_SIZE-1) / BLOCK_SIZE == S / BLOCK_SIZE + 1
const
int
num_s_blocks
=
S
/
BLOCK_SIZE
+
1
,
num_t_blocks
=
T
/
BLOCK_SIZE
+
1
,
num_iters
=
num_s_blocks
+
num_t_blocks
-
1
;
...
...
@@ -833,7 +835,8 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
px_grad
.
packed_accessor32
<
scalar_t
,
3
>
(),
py_grad
.
packed_accessor32
<
scalar_t
,
3
>
(),
optional_boundary
.
value
().
packed_accessor32
<
int64_t
,
2
>
(),
iter
);
iter
,
overwrite_ans_grad
);
}
return
std
::
vector
<
torch
::
Tensor
>
({
px_grad
,
py_grad
});
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment