Commit 8ed6deff authored by Daniel Povey's avatar Daniel Povey
Browse files

More progress, nearly done but not compiled

parent 9f929ab3
#include <torch/extension.h> #include <torch/extension.h>
// forward of mutual_information. """... """ comment of `mutual_information`
// in mutual_information.py documents the behavior of this function. /*
// It is the core recursion in the sequence-to-sequence mutual information Forward of mutual_information. See also """... """ comment of
// computation. `mutual_information` in mutual_information.py. This It is the core recursion
// returns 'ans', of dimension B (batch size). in the sequence-to-sequence mutual information computation.
Args:
px: Tensor of shape [B][S][T + 1]; contains the log-odds ratio of
generating the next x in the sequence, i.e.
xy[b][s][t] is the log of
p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s),
i.e. the log-prob of generating x_s given subsequences of lengths
(s, t), divided by the prior probability of generating x_s. (See
mutual_information.py for more info).
py: The log-odds ratio of generating the next y in the sequence.
Shape [B][S + 1][T]
p: This function writes to p[b][s][t] the mutual information between
sub-sequences of x and y of length s and t respectively, from the
b'th sequences in the batch. Its shape is [B][S + 1][T + 1].
Concretely, this function implements the following recursion,
in the case where s_begin == t_begin == 0:
p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
if s > 0 or t > 0,
treating values with any -1 index as -infinity.
.. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
boundary: If set, a tensor of shape [B][4] of type int64_t, which
contains, where for each batch element b, boundary[b] equals
[s_begin, t_begin, s_end, t_end]
which are the beginning and end (i.e. one-past-the-last) of the
x and y sequences that we should process. If not set, these
default to (0, 0, S, T); and they should not exceed these bounds.
ans: a tensor `ans` of shape [B], where this function will set
ans[b] = p[b][s_end][t_end],
with s_end and t_end being (S, T) if `boundary` was specified,
and (boundary[b][2], boundary[b][3]) otherwise.
`ans` represents the mutual information between each pair of
sequences (i.e. x[b] and y[b], although the sequences are not
supplied directy to this function).
The block-dim and grid-dim must both be 1-dimensional, and the block-dim must
be at least 128.
*/
torch::Tensor mutual_information_cuda(torch::Tensor px, // [B][S][T+1] torch::Tensor mutual_information_cuda(torch::Tensor px, // [B][S][T+1]
torch::Tensor py, // [B][S+1][T] torch::Tensor py, // [B][S+1][T]
std::optional<torch::Tensor> boundary_info, // [B][4], int64_t. std::optional<torch::Tensor> boundary_info, // [B][4], int64_t.
torch::Tensor p); // [B][S+1][T+1]; an output torch::Tensor p); // [B][S+1][T+1]; an output
// backward of mutual_information; returns (grad_px, grad_py) /*
backward of mutual_information; returns (grad_px, grad_py)
if overwrite_ans_grad == true, this function will overwrite ans_grad with a
value that, if the computation worked correctly, should be identical to or
very close to the value of ans_grad at entry. This can be used
to validate the correctness of this code.
*/
std::vector<torch::Tensor> mutual_information_backward_cuda( std::vector<torch::Tensor> mutual_information_backward_cuda(
torch::Tensor px, torch::Tensor px,
torch::Tensor py, torch::Tensor py,
std::optional<torch::Tensor> boundary_info, std::optional<torch::Tensor> boundary_info,
torch::Tensor p, torch::Tensor p,
torch::Tensor ans_grad); torch::Tensor ans_grad,
bool overwrite_ans_grad);
......
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
// returns log(exp(x) + exp(y)). // returns log(exp(x) + exp(y)).
__forceinline__ __device__ double LogAdd(double x, double y) { __forceinline__ __device__ double LogAdd(double x, double y) {
double diff; double diff;
if (x < y) { if (x < y) {
diff = x - y; diff = x - y;
x = y; x = y;
...@@ -44,71 +43,59 @@ __forceinline__ __device__ inline float LogAdd(float x, float y) { ...@@ -44,71 +43,59 @@ __forceinline__ __device__ inline float LogAdd(float x, float y) {
/* /*
Forward of mutual_information. Each thread block computes blocks of the 'p' Forward of mutual_information. Each thread block computes blocks of the 'p'
array of (s, t) shape equal to (BLOCK_SIZE, BLOCK_SIZE), e.g. (32, 32). array of (s, t) shape equal to (BLOCK_SIZE, BLOCK_SIZE), e.g. (32, 32).
Thread blocks loop over such blocks, but they might loop only once if there is Thread-blocks loop over such blocks, but they might loop only once if there is
not that much data to process. We sequentially launch thread groups in not that much data to process. We sequentially launch thread groups in
such a way that thread-blocks within a group do not depend on each other such a way that thread-blocks within a group do not depend on each other
(see the "iter" parameter). (see the "iter" parameter). The blocks of the 'image' (i.e. of the p matrix)
that each group handles are arranged in a diagonal.
Template args: Template args:
scalar_t: the floating-point type, e.g. float, double, maybe half. scalar_t: the floating-point type, e.g. float, double; maybe eventually
half, although I think we don't support LogAdd for half yet.
BLOCK_SIZE: an integer power of two no greater than 32 (this limitation
is because we assume BLOCK_SIZE + 1 <= 64 in some data-loading
code).
Args: Args:
px: log-odds ratio of generating next x in the sequence, i.e. px: Tensor of shape [B][S][T + 1]; contains the log-odds ratio of
xy[b][s][t] is the log-odds probability of generating x_t of generating the next x in the sequence, i.e.
the b'th image given subsequences of length (s, t). (See xy[b][s][t] is the log of
mutual_information.py for more info). Shape [B][S][T + 1] p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s),
py: log-odds ratio of generating next y in the sequence. i.e. the log-prob of generating x_s given subsequences of lengths
(s, t), divided by the prior probability of generating x_s. (See
mutual_information.py for more info).
py: The log-odds ratio of generating the next y in the sequence.
Shape [B][S + 1][T] Shape [B][S + 1][T]
p: This function writes to p[s][t] the mutual information between p: This function writes to p[b][s][t] the mutual information between
sub-sequences of x and y of length s and t respectively. sub-sequences of x and y of length s and t respectively, from the
Its shape is [B][S + 1][T + 1]. This function implements b'th sequences in the batch. Its shape is [B][S + 1][T + 1].
the following recursion: Concretely, this function implements the following recursion,
in the case where s_begin == t_begin == 0:
p[b,0,0] = 0.0 p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1]) p[b,s,t-1] + py[b,s,t-1])
(if s > 0 or t > 0) if s > 0 or t > 0,
treating values with any -1 index as -infinity.
.. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
boundary: If set, a tensor of shape [B][4] of type int64_t, which boundary: If set, a tensor of shape [B][4] of type int64_t, which
contains, for each batch element, [s_begin, t_begin, s_end, t_end] contains, where for each batch element b, boundary[b] equals
which are the beginning and end (one-past-the-last) of the [s_begin, t_begin, s_end, t_end]
which are the beginning and end (i.e. one-past-the-last) of the
x and y sequences that we should process. If not set, these x and y sequences that we should process. If not set, these
default to (0, 0, S, T), and they should not exceed these bounds default to (0, 0, S, T); and they should not exceed these bounds.
or be empty (i.e. s_begin <= t_begin or s_end <= t_end). ans: a tensor `ans` of shape [B], where this function will set
ans[b] = p[b][s_end][t_end],
with s_end and t_end being (S, T) if `boundary` was specified,
nput: input image, shape (B, C, T) where B is batch size, C is and (boundary[b][2], boundary[b][3]) otherwise.
the number of channels and T is the time axis. (For more-than-1d `ans` represents the mutual information between each pair of
convolution setups, T would really be more than 1 axis, reshaped). sequences (i.e. x[b] and y[b], although the sequences are not
params: of shape (C, N+1) where N is the number of linear regions in the supplied directy to this function).
piecewise linear function; params[c][0] is l which is
a log scale parameter that dictates how far apart
the discontinuities in the piecewise linear function are,
and params[c][n+1] for 0 <= n < N are the derivatives
of the linear parts of the piecewise linear function.
The discontinuities of the function are at:
exp(l) * [ -(N/2 - 1), -(N/2 - 2), ... (N/2 - 1) ]
output: The transformed input, shape (B , C, T)
images_per_thread_block: The number of images processed by each thread
block. The calling code must guarantee that this is a power
of 2, and that EITHER:
THREADS_PER_BLOCK / images_per_thread_block >= T
OR
images_per_thread_block == 1
.. this is used for a small optimization.
This kernel is allocated with `extern_buf` containing enough memory
to store 2*N + 3 values of type scalar_t.
The block-dim and grid-dim must both be 1-dimensional, and the block-dim must The block-dim and grid-dim must both be 1-dimensional, and the block-dim must
be at least 128. be at least 128.
*/ */
template <typename scalar_t, template <typename scalar_t,
int BLOCK_SIZE> // e.g. BLOCK_SIZE == 16 or 32. Note: we require the int BLOCK_SIZE> // e.g. BLOCK_SIZE == 16 or 32.
// num-threads be at least 128.
__global__ __global__
void mutual_information_kernel( void mutual_information_kernel(
torch::PackedTensorAccessor32<scalar_t, 3> px, // B, S, T + 1, i.e. batch, x_seq_length, y_seq_length + 1 torch::PackedTensorAccessor32<scalar_t, 3> px, // B, S, T + 1, i.e. batch, x_seq_length, y_seq_length + 1
...@@ -450,8 +437,8 @@ void mutual_information_kernel( ...@@ -450,8 +437,8 @@ void mutual_information_kernel(
epx_grad = px_grad / epx and epy_grad = py_grad / epy, and writing exp(p) for p and so on, epx_grad = px_grad / epx and epy_grad = py_grad / epy, and writing exp(p) for p and so on,
the above becomes the above becomes
px_grad[b][s][t] / exp(px[b][s][t]) = p_grad[b][s + 1][t] / exp(p[b][s + 1][t] * exp(p[b][s][t]) px_grad[b][s][t] / exp(px[b][s][t]) = p_grad[b][s + 1][t] / exp(p[b][s + 1][t]) * exp(p[b][s][t])
py_grad[b][s][t] / exp(py[b][s][t]) = p_grad[b][s][t + 1] / exp(p[b][s][t + 1] * exp(p[b][s][t]) py_grad[b][s][t] / exp(py[b][s][t]) = p_grad[b][s][t + 1] / exp(p[b][s][t + 1]) * exp(p[b][s][t])
Rearranging: Rearranging:
px_grad[b][s][t] = p_grad[b][s + 1][t] * exp(p[b][s][t] + px[b][s][t] - p[b][s + 1][t]) (eq. 3a) px_grad[b][s][t] = p_grad[b][s + 1][t] * exp(p[b][s][t] + px[b][s][t] - p[b][s + 1][t]) (eq. 3a)
py_grad[b][s][t] = p_grad[b][s][t + 1] * exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1]) (eq. 3b) py_grad[b][s][t] = p_grad[b][s][t + 1] * exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1]) (eq. 3b)
...@@ -485,15 +472,21 @@ void mutual_information_backward_kernel( ...@@ -485,15 +472,21 @@ void mutual_information_backward_kernel(
torch::PackedTensorAccessor32<scalar_t, 3> py, // B, S + 1, T. torch::PackedTensorAccessor32<scalar_t, 3> py, // B, S + 1, T.
torch::PackedTensorAccessor32<scalar_t, 3> p, // B, S + 1, T + 1. Produced in forward pass. torch::PackedTensorAccessor32<scalar_t, 3> p, // B, S + 1, T + 1. Produced in forward pass.
torch::PackedTensorAccessor32<scalar_t, 1> ans_grad, // [B]. This is an input. torch::PackedTensorAccessor32<scalar_t, 1> ans_grad, // [B]. This is an input.
torch::PackedTensorAccessor32<scalar_t, 1> ans_grad_compare, // [B]. A value will be written to here which
// should ideally equal ans_grad.
torch::PackedTensorAccessor32<scalar_t, 3> p_grad, // B, S + 1, T + 1. This is a temporary. torch::PackedTensorAccessor32<scalar_t, 3> p_grad, // B, S + 1, T + 1. This is a temporary.
torch::PackedTensorAccessor32<scalar_t, 3> px_grad, // B, S, T + 1. torch::PackedTensorAccessor32<scalar_t, 3> px_grad, // B, S, T + 1.
torch::PackedTensorAccessor32<scalar_t, 3> py_grad, // B, S + 1, T. torch::PackedTensorAccessor32<scalar_t, 3> py_grad, // B, S + 1, T.
torch::PackedTensorAccessor32<int64_t, 2> boundary, // B, 4; or 0, 0 if boundaries are the defaults (0, 0, S, T) torch::PackedTensorAccessor32<int64_t, 2> boundary, // B, 4; or 0, 0 if boundaries are the defaults (0, 0, S, T)
int iter) { // This kernel is sequentially called with 'iter' = num_iters int iter, // This kernel is sequentially called with 'iter' = num_iters
// - 1, num_iters - 2, .. 0, where num_iters can be taken to // - 1, num_iters - 2, .. 0, where num_iters can be taken to
// be any sufficiently large number but will actually be: // be any sufficiently large number but will actually be:
// num_s_blocks + num_t_blocks - 1 where num_s_blocks = S / // num_s_blocks + num_t_blocks - 1 where num_s_blocks = S /
// BLOCK_SIZE + 1 and num_t_blocks = T / BLOCK_SIZE + 1 // BLOCK_SIZE + 1 and num_t_blocks = T / BLOCK_SIZE + 1
bool overwrite_ans_grad) { // If true, overwrite ans_grad with a value
// which, if everything is working correctly,
// should be identical or very close to the
// value of ans_grad that was passed in.
const int B = px.size(0), const int B = px.size(0),
S = px.size(1), S = px.size(1),
T = py.size(2); T = py.size(2);
...@@ -715,14 +708,13 @@ void mutual_information_backward_kernel( ...@@ -715,14 +708,13 @@ void mutual_information_backward_kernel(
} }
if (threadIdx.x == 0 && s_block_begin == s_begin && if (threadIdx.x == 0 && s_block_begin == s_begin &&
t_block_end == t_end) t_block_begin == t_begin && overwrite_ans_grad)
ans_grad[b] = p_buf[0][0]; ans_grad[b] = p_buf[0][0];
} }
} }
// forward of mutual_information. See """... """ comment of `mutual_information` in // forward of mutual_information. See """... """ comment of `mutual_information` in
// mutual_information.py for documentation of the behavior of this function. // mutual_information.py for documentation of the behavior of this function.
torch::Tensor mutual_information_cuda(torch::Tensor px, torch::Tensor mutual_information_cuda(torch::Tensor px,
...@@ -752,6 +744,9 @@ torch::Tensor mutual_information_cuda(torch::Tensor px, ...@@ -752,6 +744,9 @@ torch::Tensor mutual_information_cuda(torch::Tensor px,
num_blocks = 128, num_blocks = 128,
BLOCK_SIZE = 32; BLOCK_SIZE = 32;
// The blocks cover the 'p' matrix, which is of size (B, S+1, T+1),
// so dividing by BLOCK_SIZE rounding up we get e.g.
// (S+1 + BLOCK_SIZE-1) / BLOCK_SIZE == S / BLOCK_SIZE + 1
const int num_s_blocks = S / BLOCK_SIZE + 1, const int num_s_blocks = S / BLOCK_SIZE + 1,
num_t_blocks = T / BLOCK_SIZE + 1, num_t_blocks = T / BLOCK_SIZE + 1,
num_iters = num_s_blocks + num_t_blocks - 1; num_iters = num_s_blocks + num_t_blocks - 1;
...@@ -777,11 +772,15 @@ torch::Tensor mutual_information_cuda(torch::Tensor px, ...@@ -777,11 +772,15 @@ torch::Tensor mutual_information_cuda(torch::Tensor px,
// backward of mutual_information; returns (grad_px, grad_py) // backward of mutual_information; returns (grad_px, grad_py)
// If overwrite_ans_grad == true, will overwrite ans_grad with a value which
// should be identical to the original ans_grad if the computation worked
// as it should.
torch::Tensor mutual_information_backward_cuda(torch::Tensor px, torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
torch::Tensor py, torch::Tensor py,
std::optional<torch::Tensor> optional_boundary, std::optional<torch::Tensor> optional_boundary,
torch::Tensor p, torch::Tensor p,
torch::Tensor ans_grad) { torch::Tensor ans_grad,
bool overwrite_ans_grad) {
TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional"); TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional");
TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional."); TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional.");
TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional."); TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional.");
...@@ -813,6 +812,9 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px, ...@@ -813,6 +812,9 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
num_blocks = 128, num_blocks = 128,
BLOCK_SIZE = 32; BLOCK_SIZE = 32;
// The blocks cover the 'p' matrix, which is of size (B, S+1, T+1),
// so dividing by BLOCK_SIZE rounding up we get e.g.
// (S+1 + BLOCK_SIZE-1) / BLOCK_SIZE == S / BLOCK_SIZE + 1
const int num_s_blocks = S / BLOCK_SIZE + 1, const int num_s_blocks = S / BLOCK_SIZE + 1,
num_t_blocks = T / BLOCK_SIZE + 1, num_t_blocks = T / BLOCK_SIZE + 1,
num_iters = num_s_blocks + num_t_blocks - 1; num_iters = num_s_blocks + num_t_blocks - 1;
...@@ -833,7 +835,8 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px, ...@@ -833,7 +835,8 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
px_grad.packed_accessor32<scalar_t, 3>(), px_grad.packed_accessor32<scalar_t, 3>(),
py_grad.packed_accessor32<scalar_t, 3>(), py_grad.packed_accessor32<scalar_t, 3>(),
optional_boundary.value().packed_accessor32<int64_t, 2>(), optional_boundary.value().packed_accessor32<int64_t, 2>(),
iter); iter,
overwrite_ans_grad);
} }
return std::vector<torch::Tensor>({px_grad, py_grad}); return std::vector<torch::Tensor>({px_grad, py_grad});
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment