Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
8ed6deff
Commit
8ed6deff
authored
Jul 29, 2021
by
Daniel Povey
Browse files
More progress, nearly done but not compiled
parent
9f929ab3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
119 additions
and
68 deletions
+119
-68
torch_mutual_information/mutual_information_cuda.cpp
torch_mutual_information/mutual_information_cuda.cpp
+55
-7
torch_mutual_information/mutual_information_cuda_kernel.cu
torch_mutual_information/mutual_information_cuda_kernel.cu
+64
-61
No files found.
torch_mutual_information/mutual_information_cuda.cpp
View file @
8ed6deff
#include <torch/extension.h>
#include <torch/extension.h>
// forward of mutual_information. """... """ comment of `mutual_information`
// in mutual_information.py documents the behavior of this function.
/*
// It is the core recursion in the sequence-to-sequence mutual information
Forward of mutual_information. See also """... """ comment of
// computation.
`mutual_information` in mutual_information.py. This It is the core recursion
// returns 'ans', of dimension B (batch size).
in the sequence-to-sequence mutual information computation.
Args:
px: Tensor of shape [B][S][T + 1]; contains the log-odds ratio of
generating the next x in the sequence, i.e.
xy[b][s][t] is the log of
p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s),
i.e. the log-prob of generating x_s given subsequences of lengths
(s, t), divided by the prior probability of generating x_s. (See
mutual_information.py for more info).
py: The log-odds ratio of generating the next y in the sequence.
Shape [B][S + 1][T]
p: This function writes to p[b][s][t] the mutual information between
sub-sequences of x and y of length s and t respectively, from the
b'th sequences in the batch. Its shape is [B][S + 1][T + 1].
Concretely, this function implements the following recursion,
in the case where s_begin == t_begin == 0:
p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
if s > 0 or t > 0,
treating values with any -1 index as -infinity.
.. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
boundary: If set, a tensor of shape [B][4] of type int64_t, which
contains, where for each batch element b, boundary[b] equals
[s_begin, t_begin, s_end, t_end]
which are the beginning and end (i.e. one-past-the-last) of the
x and y sequences that we should process. If not set, these
default to (0, 0, S, T); and they should not exceed these bounds.
ans: a tensor `ans` of shape [B], where this function will set
ans[b] = p[b][s_end][t_end],
with s_end and t_end being (S, T) if `boundary` was specified,
and (boundary[b][2], boundary[b][3]) otherwise.
`ans` represents the mutual information between each pair of
sequences (i.e. x[b] and y[b], although the sequences are not
supplied directy to this function).
The block-dim and grid-dim must both be 1-dimensional, and the block-dim must
be at least 128.
*/
torch
::
Tensor
mutual_information_cuda
(
torch
::
Tensor
px
,
// [B][S][T+1]
torch
::
Tensor
mutual_information_cuda
(
torch
::
Tensor
px
,
// [B][S][T+1]
torch
::
Tensor
py
,
// [B][S+1][T]
torch
::
Tensor
py
,
// [B][S+1][T]
std
::
optional
<
torch
::
Tensor
>
boundary_info
,
// [B][4], int64_t.
std
::
optional
<
torch
::
Tensor
>
boundary_info
,
// [B][4], int64_t.
torch
::
Tensor
p
);
// [B][S+1][T+1]; an output
torch
::
Tensor
p
);
// [B][S+1][T+1]; an output
// backward of mutual_information; returns (grad_px, grad_py)
/*
backward of mutual_information; returns (grad_px, grad_py)
if overwrite_ans_grad == true, this function will overwrite ans_grad with a
value that, if the computation worked correctly, should be identical to or
very close to the value of ans_grad at entry. This can be used
to validate the correctness of this code.
*/
std
::
vector
<
torch
::
Tensor
>
mutual_information_backward_cuda
(
std
::
vector
<
torch
::
Tensor
>
mutual_information_backward_cuda
(
torch
::
Tensor
px
,
torch
::
Tensor
px
,
torch
::
Tensor
py
,
torch
::
Tensor
py
,
std
::
optional
<
torch
::
Tensor
>
boundary_info
,
std
::
optional
<
torch
::
Tensor
>
boundary_info
,
torch
::
Tensor
p
,
torch
::
Tensor
p
,
torch
::
Tensor
ans_grad
);
torch
::
Tensor
ans_grad
,
bool
overwrite_ans_grad
);
...
...
torch_mutual_information/mutual_information_cuda_kernel.cu
View file @
8ed6deff
...
@@ -8,7 +8,6 @@
...
@@ -8,7 +8,6 @@
// returns log(exp(x) + exp(y)).
// returns log(exp(x) + exp(y)).
__forceinline__
__device__
double
LogAdd
(
double
x
,
double
y
)
{
__forceinline__
__device__
double
LogAdd
(
double
x
,
double
y
)
{
double
diff
;
double
diff
;
if
(
x
<
y
)
{
if
(
x
<
y
)
{
diff
=
x
-
y
;
diff
=
x
-
y
;
x
=
y
;
x
=
y
;
...
@@ -44,71 +43,59 @@ __forceinline__ __device__ inline float LogAdd(float x, float y) {
...
@@ -44,71 +43,59 @@ __forceinline__ __device__ inline float LogAdd(float x, float y) {
/*
/*
Forward of mutual_information. Each thread block computes blocks of the 'p'
Forward of mutual_information. Each thread block computes blocks of the 'p'
array of (s, t) shape equal to (BLOCK_SIZE, BLOCK_SIZE), e.g. (32, 32).
array of (s, t) shape equal to (BLOCK_SIZE, BLOCK_SIZE), e.g. (32, 32).
Thread
blocks loop over such blocks, but they might loop only once if there is
Thread
-
blocks loop over such blocks, but they might loop only once if there is
not that much data to process. We sequentially launch thread groups in
not that much data to process. We sequentially launch thread groups in
such a way that thread-blocks within a group do not depend on each other
such a way that thread-blocks within a group do not depend on each other
(see the "iter" parameter).
(see the "iter" parameter). The blocks of the 'image' (i.e. of the p matrix)
that each group handles are arranged in a diagonal.
Template args:
Template args:
scalar_t: the floating-point type, e.g. float, double, maybe half.
scalar_t: the floating-point type, e.g. float, double; maybe eventually
half, although I think we don't support LogAdd for half yet.
BLOCK_SIZE: an integer power of two no greater than 32 (this limitation
is because we assume BLOCK_SIZE + 1 <= 64 in some data-loading
code).
Args:
Args:
px: log-odds ratio of generating next x in the sequence, i.e.
px: Tensor of shape [B][S][T + 1]; contains the log-odds ratio of
xy[b][s][t] is the log-odds probability of generating x_t of
generating the next x in the sequence, i.e.
the b'th image given subsequences of length (s, t). (See
xy[b][s][t] is the log of
mutual_information.py for more info). Shape [B][S][T + 1]
p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s),
py: log-odds ratio of generating next y in the sequence.
i.e. the log-prob of generating x_s given subsequences of lengths
(s, t), divided by the prior probability of generating x_s. (See
mutual_information.py for more info).
py: The log-odds ratio of generating the next y in the sequence.
Shape [B][S + 1][T]
Shape [B][S + 1][T]
p: This function writes to p[s][t] the mutual information between
p: This function writes to p[b][s][t] the mutual information between
sub-sequences of x and y of length s and t respectively.
sub-sequences of x and y of length s and t respectively, from the
Its shape is [B][S + 1][T + 1]. This function implements
b'th sequences in the batch. Its shape is [B][S + 1][T + 1].
the following recursion:
Concretely, this function implements the following recursion,
in the case where s_begin == t_begin == 0:
p[b,0,0] = 0.0
p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
p[b,s,t-1] + py[b,s,t-1])
(
if s > 0 or t > 0
)
if s > 0 or t > 0
,
treating values with any -1 index as -infinity.
.. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
boundary: If set, a tensor of shape [B][4] of type int64_t, which
boundary: If set, a tensor of shape [B][4] of type int64_t, which
contains, for each batch element, [s_begin, t_begin, s_end, t_end]
contains, where for each batch element b, boundary[b] equals
which are the beginning and end (one-past-the-last) of the
[s_begin, t_begin, s_end, t_end]
which are the beginning and end (i.e. one-past-the-last) of the
x and y sequences that we should process. If not set, these
x and y sequences that we should process. If not set, these
default to (0, 0, S, T), and they should not exceed these bounds
default to (0, 0, S, T); and they should not exceed these bounds.
or be empty (i.e. s_begin <= t_begin or s_end <= t_end).
ans: a tensor `ans` of shape [B], where this function will set
ans[b] = p[b][s_end][t_end],
with s_end and t_end being (S, T) if `boundary` was specified,
nput: input image, shape (B, C, T) where B is batch size, C is
and (boundary[b][2], boundary[b][3]) otherwise.
the number of channels and T is the time axis. (For more-than-1d
`ans` represents the mutual information between each pair of
convolution setups, T would really be more than 1 axis, reshaped).
sequences (i.e. x[b] and y[b], although the sequences are not
params: of shape (C, N+1) where N is the number of linear regions in the
supplied directy to this function).
piecewise linear function; params[c][0] is l which is
a log scale parameter that dictates how far apart
the discontinuities in the piecewise linear function are,
and params[c][n+1] for 0 <= n < N are the derivatives
of the linear parts of the piecewise linear function.
The discontinuities of the function are at:
exp(l) * [ -(N/2 - 1), -(N/2 - 2), ... (N/2 - 1) ]
output: The transformed input, shape (B , C, T)
images_per_thread_block: The number of images processed by each thread
block. The calling code must guarantee that this is a power
of 2, and that EITHER:
THREADS_PER_BLOCK / images_per_thread_block >= T
OR
images_per_thread_block == 1
.. this is used for a small optimization.
This kernel is allocated with `extern_buf` containing enough memory
to store 2*N + 3 values of type scalar_t.
The block-dim and grid-dim must both be 1-dimensional, and the block-dim must
The block-dim and grid-dim must both be 1-dimensional, and the block-dim must
be at least 128.
be at least 128.
*/
*/
template
<
typename
scalar_t
,
template
<
typename
scalar_t
,
int
BLOCK_SIZE
>
// e.g. BLOCK_SIZE == 16 or 32. Note: we require the
int
BLOCK_SIZE
>
// e.g. BLOCK_SIZE == 16 or 32.
// num-threads be at least 128.
__global__
__global__
void
mutual_information_kernel
(
void
mutual_information_kernel
(
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
px
,
// B, S, T + 1, i.e. batch, x_seq_length, y_seq_length + 1
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
px
,
// B, S, T + 1, i.e. batch, x_seq_length, y_seq_length + 1
...
@@ -450,8 +437,8 @@ void mutual_information_kernel(
...
@@ -450,8 +437,8 @@ void mutual_information_kernel(
epx_grad = px_grad / epx and epy_grad = py_grad / epy, and writing exp(p) for p and so on,
epx_grad = px_grad / epx and epy_grad = py_grad / epy, and writing exp(p) for p and so on,
the above becomes
the above becomes
px_grad[b][s][t] / exp(px[b][s][t]) = p_grad[b][s + 1][t] / exp(p[b][s + 1][t] * exp(p[b][s][t])
px_grad[b][s][t] / exp(px[b][s][t]) = p_grad[b][s + 1][t] / exp(p[b][s + 1][t]
)
* exp(p[b][s][t])
py_grad[b][s][t] / exp(py[b][s][t]) = p_grad[b][s][t + 1] / exp(p[b][s][t + 1] * exp(p[b][s][t])
py_grad[b][s][t] / exp(py[b][s][t]) = p_grad[b][s][t + 1] / exp(p[b][s][t + 1]
)
* exp(p[b][s][t])
Rearranging:
Rearranging:
px_grad[b][s][t] = p_grad[b][s + 1][t] * exp(p[b][s][t] + px[b][s][t] - p[b][s + 1][t]) (eq. 3a)
px_grad[b][s][t] = p_grad[b][s + 1][t] * exp(p[b][s][t] + px[b][s][t] - p[b][s + 1][t]) (eq. 3a)
py_grad[b][s][t] = p_grad[b][s][t + 1] * exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1]) (eq. 3b)
py_grad[b][s][t] = p_grad[b][s][t + 1] * exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1]) (eq. 3b)
...
@@ -485,15 +472,21 @@ void mutual_information_backward_kernel(
...
@@ -485,15 +472,21 @@ void mutual_information_backward_kernel(
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
py
,
// B, S + 1, T.
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
py
,
// B, S + 1, T.
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
p
,
// B, S + 1, T + 1. Produced in forward pass.
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
p
,
// B, S + 1, T + 1. Produced in forward pass.
torch
::
PackedTensorAccessor32
<
scalar_t
,
1
>
ans_grad
,
// [B]. This is an input.
torch
::
PackedTensorAccessor32
<
scalar_t
,
1
>
ans_grad
,
// [B]. This is an input.
torch
::
PackedTensorAccessor32
<
scalar_t
,
1
>
ans_grad_compare
,
// [B]. A value will be written to here which
// should ideally equal ans_grad.
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
p_grad
,
// B, S + 1, T + 1. This is a temporary.
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
p_grad
,
// B, S + 1, T + 1. This is a temporary.
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
px_grad
,
// B, S, T + 1.
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
px_grad
,
// B, S, T + 1.
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
py_grad
,
// B, S + 1, T.
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
py_grad
,
// B, S + 1, T.
torch
::
PackedTensorAccessor32
<
int64_t
,
2
>
boundary
,
// B, 4; or 0, 0 if boundaries are the defaults (0, 0, S, T)
torch
::
PackedTensorAccessor32
<
int64_t
,
2
>
boundary
,
// B, 4; or 0, 0 if boundaries are the defaults (0, 0, S, T)
int
iter
)
{
// This kernel is sequentially called with 'iter' = num_iters
int
iter
,
// This kernel is sequentially called with 'iter' = num_iters
// - 1, num_iters - 2, .. 0, where num_iters can be taken to
// - 1, num_iters - 2, .. 0, where num_iters can be taken to
// be any sufficiently large number but will actually be:
// be any sufficiently large number but will actually be:
// num_s_blocks + num_t_blocks - 1 where num_s_blocks = S /
// num_s_blocks + num_t_blocks - 1 where num_s_blocks = S /
// BLOCK_SIZE + 1 and num_t_blocks = T / BLOCK_SIZE + 1
// BLOCK_SIZE + 1 and num_t_blocks = T / BLOCK_SIZE + 1
bool
overwrite_ans_grad
)
{
// If true, overwrite ans_grad with a value
// which, if everything is working correctly,
// should be identical or very close to the
// value of ans_grad that was passed in.
const
int
B
=
px
.
size
(
0
),
const
int
B
=
px
.
size
(
0
),
S
=
px
.
size
(
1
),
S
=
px
.
size
(
1
),
T
=
py
.
size
(
2
);
T
=
py
.
size
(
2
);
...
@@ -715,14 +708,13 @@ void mutual_information_backward_kernel(
...
@@ -715,14 +708,13 @@ void mutual_information_backward_kernel(
}
}
if
(
threadIdx
.
x
==
0
&&
s_block_begin
==
s_begin
&&
if
(
threadIdx
.
x
==
0
&&
s_block_begin
==
s_begin
&&
t_block_
end
==
t_
en
d
)
t_block_
begin
==
t_
begin
&&
overwrite_ans_gra
d
)
ans_grad
[
b
]
=
p_buf
[
0
][
0
];
ans_grad
[
b
]
=
p_buf
[
0
][
0
];
}
}
}
}
// forward of mutual_information. See """... """ comment of `mutual_information` in
// forward of mutual_information. See """... """ comment of `mutual_information` in
// mutual_information.py for documentation of the behavior of this function.
// mutual_information.py for documentation of the behavior of this function.
torch
::
Tensor
mutual_information_cuda
(
torch
::
Tensor
px
,
torch
::
Tensor
mutual_information_cuda
(
torch
::
Tensor
px
,
...
@@ -752,6 +744,9 @@ torch::Tensor mutual_information_cuda(torch::Tensor px,
...
@@ -752,6 +744,9 @@ torch::Tensor mutual_information_cuda(torch::Tensor px,
num_blocks
=
128
,
num_blocks
=
128
,
BLOCK_SIZE
=
32
;
BLOCK_SIZE
=
32
;
// The blocks cover the 'p' matrix, which is of size (B, S+1, T+1),
// so dividing by BLOCK_SIZE rounding up we get e.g.
// (S+1 + BLOCK_SIZE-1) / BLOCK_SIZE == S / BLOCK_SIZE + 1
const
int
num_s_blocks
=
S
/
BLOCK_SIZE
+
1
,
const
int
num_s_blocks
=
S
/
BLOCK_SIZE
+
1
,
num_t_blocks
=
T
/
BLOCK_SIZE
+
1
,
num_t_blocks
=
T
/
BLOCK_SIZE
+
1
,
num_iters
=
num_s_blocks
+
num_t_blocks
-
1
;
num_iters
=
num_s_blocks
+
num_t_blocks
-
1
;
...
@@ -777,11 +772,15 @@ torch::Tensor mutual_information_cuda(torch::Tensor px,
...
@@ -777,11 +772,15 @@ torch::Tensor mutual_information_cuda(torch::Tensor px,
// backward of mutual_information; returns (grad_px, grad_py)
// backward of mutual_information; returns (grad_px, grad_py)
// If overwrite_ans_grad == true, will overwrite ans_grad with a value which
// should be identical to the original ans_grad if the computation worked
// as it should.
torch
::
Tensor
mutual_information_backward_cuda
(
torch
::
Tensor
px
,
torch
::
Tensor
mutual_information_backward_cuda
(
torch
::
Tensor
px
,
torch
::
Tensor
py
,
torch
::
Tensor
py
,
std
::
optional
<
torch
::
Tensor
>
optional_boundary
,
std
::
optional
<
torch
::
Tensor
>
optional_boundary
,
torch
::
Tensor
p
,
torch
::
Tensor
p
,
torch
::
Tensor
ans_grad
)
{
torch
::
Tensor
ans_grad
,
bool
overwrite_ans_grad
)
{
TORCH_CHECK
(
px
.
dim
()
==
3
,
"px must be 3-dimensional"
);
TORCH_CHECK
(
px
.
dim
()
==
3
,
"px must be 3-dimensional"
);
TORCH_CHECK
(
py
.
dim
()
==
3
,
"py must be 3-dimensional."
);
TORCH_CHECK
(
py
.
dim
()
==
3
,
"py must be 3-dimensional."
);
TORCH_CHECK
(
p
.
dim
()
==
3
,
"p must be 3-dimensional."
);
TORCH_CHECK
(
p
.
dim
()
==
3
,
"p must be 3-dimensional."
);
...
@@ -813,6 +812,9 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
...
@@ -813,6 +812,9 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
num_blocks
=
128
,
num_blocks
=
128
,
BLOCK_SIZE
=
32
;
BLOCK_SIZE
=
32
;
// The blocks cover the 'p' matrix, which is of size (B, S+1, T+1),
// so dividing by BLOCK_SIZE rounding up we get e.g.
// (S+1 + BLOCK_SIZE-1) / BLOCK_SIZE == S / BLOCK_SIZE + 1
const
int
num_s_blocks
=
S
/
BLOCK_SIZE
+
1
,
const
int
num_s_blocks
=
S
/
BLOCK_SIZE
+
1
,
num_t_blocks
=
T
/
BLOCK_SIZE
+
1
,
num_t_blocks
=
T
/
BLOCK_SIZE
+
1
,
num_iters
=
num_s_blocks
+
num_t_blocks
-
1
;
num_iters
=
num_s_blocks
+
num_t_blocks
-
1
;
...
@@ -833,7 +835,8 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
...
@@ -833,7 +835,8 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
px_grad
.
packed_accessor32
<
scalar_t
,
3
>
(),
px_grad
.
packed_accessor32
<
scalar_t
,
3
>
(),
py_grad
.
packed_accessor32
<
scalar_t
,
3
>
(),
py_grad
.
packed_accessor32
<
scalar_t
,
3
>
(),
optional_boundary
.
value
().
packed_accessor32
<
int64_t
,
2
>
(),
optional_boundary
.
value
().
packed_accessor32
<
int64_t
,
2
>
(),
iter
);
iter
,
overwrite_ans_grad
);
}
}
return
std
::
vector
<
torch
::
Tensor
>
({
px_grad
,
py_grad
});
return
std
::
vector
<
torch
::
Tensor
>
({
px_grad
,
py_grad
});
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment