Commit 8ed6deff authored by Daniel Povey's avatar Daniel Povey
Browse files

More progress, nearly done but not compiled

parent 9f929ab3
#include <torch/extension.h>
// forward of mutual_information. """... """ comment of `mutual_information`
// in mutual_information.py documents the behavior of this function.
// It is the core recursion in the sequence-to-sequence mutual information
// computation.
// returns 'ans', of dimension B (batch size).
/*
Forward of mutual_information. See also """... """ comment of
`mutual_information` in mutual_information.py. This It is the core recursion
in the sequence-to-sequence mutual information computation.
Args:
px: Tensor of shape [B][S][T + 1]; contains the log-odds ratio of
generating the next x in the sequence, i.e.
xy[b][s][t] is the log of
p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s),
i.e. the log-prob of generating x_s given subsequences of lengths
(s, t), divided by the prior probability of generating x_s. (See
mutual_information.py for more info).
py: The log-odds ratio of generating the next y in the sequence.
Shape [B][S + 1][T]
p: This function writes to p[b][s][t] the mutual information between
sub-sequences of x and y of length s and t respectively, from the
b'th sequences in the batch. Its shape is [B][S + 1][T + 1].
Concretely, this function implements the following recursion,
in the case where s_begin == t_begin == 0:
p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
if s > 0 or t > 0,
treating values with any -1 index as -infinity.
.. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
boundary: If set, a tensor of shape [B][4] of type int64_t, which
contains, where for each batch element b, boundary[b] equals
[s_begin, t_begin, s_end, t_end]
which are the beginning and end (i.e. one-past-the-last) of the
x and y sequences that we should process. If not set, these
default to (0, 0, S, T); and they should not exceed these bounds.
ans: a tensor `ans` of shape [B], where this function will set
ans[b] = p[b][s_end][t_end],
with s_end and t_end being (S, T) if `boundary` was specified,
and (boundary[b][2], boundary[b][3]) otherwise.
`ans` represents the mutual information between each pair of
sequences (i.e. x[b] and y[b], although the sequences are not
supplied directy to this function).
The block-dim and grid-dim must both be 1-dimensional, and the block-dim must
be at least 128.
*/
torch::Tensor mutual_information_cuda(torch::Tensor px, // [B][S][T+1]
torch::Tensor py, // [B][S+1][T]
std::optional<torch::Tensor> boundary_info, // [B][4], int64_t.
torch::Tensor p); // [B][S+1][T+1]; an output
// backward of mutual_information; returns (grad_px, grad_py)
/*
backward of mutual_information; returns (grad_px, grad_py)
if overwrite_ans_grad == true, this function will overwrite ans_grad with a
value that, if the computation worked correctly, should be identical to or
very close to the value of ans_grad at entry. This can be used
to validate the correctness of this code.
*/
std::vector<torch::Tensor> mutual_information_backward_cuda(
torch::Tensor px,
torch::Tensor py,
std::optional<torch::Tensor> boundary_info,
torch::Tensor p,
torch::Tensor ans_grad);
torch::Tensor ans_grad,
bool overwrite_ans_grad);
......
......@@ -8,7 +8,6 @@
// returns log(exp(x) + exp(y)).
__forceinline__ __device__ double LogAdd(double x, double y) {
double diff;
if (x < y) {
diff = x - y;
x = y;
......@@ -44,71 +43,59 @@ __forceinline__ __device__ inline float LogAdd(float x, float y) {
/*
Forward of mutual_information. Each thread block computes blocks of the 'p'
array of (s, t) shape equal to (BLOCK_SIZE, BLOCK_SIZE), e.g. (32, 32).
Thread blocks loop over such blocks, but they might loop only once if there is
Thread-blocks loop over such blocks, but they might loop only once if there is
not that much data to process. We sequentially launch thread groups in
such a way that thread-blocks within a group do not depend on each other
(see the "iter" parameter).
(see the "iter" parameter). The blocks of the 'image' (i.e. of the p matrix)
that each group handles are arranged in a diagonal.
Template args:
scalar_t: the floating-point type, e.g. float, double, maybe half.
scalar_t: the floating-point type, e.g. float, double; maybe eventually
half, although I think we don't support LogAdd for half yet.
BLOCK_SIZE: an integer power of two no greater than 32 (this limitation
is because we assume BLOCK_SIZE + 1 <= 64 in some data-loading
code).
Args:
px: log-odds ratio of generating next x in the sequence, i.e.
xy[b][s][t] is the log-odds probability of generating x_t of
the b'th image given subsequences of length (s, t). (See
mutual_information.py for more info). Shape [B][S][T + 1]
py: log-odds ratio of generating next y in the sequence.
px: Tensor of shape [B][S][T + 1]; contains the log-odds ratio of
generating the next x in the sequence, i.e.
xy[b][s][t] is the log of
p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s),
i.e. the log-prob of generating x_s given subsequences of lengths
(s, t), divided by the prior probability of generating x_s. (See
mutual_information.py for more info).
py: The log-odds ratio of generating the next y in the sequence.
Shape [B][S + 1][T]
p: This function writes to p[s][t] the mutual information between
sub-sequences of x and y of length s and t respectively.
Its shape is [B][S + 1][T + 1]. This function implements
the following recursion:
p: This function writes to p[b][s][t] the mutual information between
sub-sequences of x and y of length s and t respectively, from the
b'th sequences in the batch. Its shape is [B][S + 1][T + 1].
Concretely, this function implements the following recursion,
in the case where s_begin == t_begin == 0:
p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
(if s > 0 or t > 0)
if s > 0 or t > 0,
treating values with any -1 index as -infinity.
.. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
boundary: If set, a tensor of shape [B][4] of type int64_t, which
contains, for each batch element, [s_begin, t_begin, s_end, t_end]
which are the beginning and end (one-past-the-last) of the
contains, where for each batch element b, boundary[b] equals
[s_begin, t_begin, s_end, t_end]
which are the beginning and end (i.e. one-past-the-last) of the
x and y sequences that we should process. If not set, these
default to (0, 0, S, T), and they should not exceed these bounds
or be empty (i.e. s_begin <= t_begin or s_end <= t_end).
nput: input image, shape (B, C, T) where B is batch size, C is
the number of channels and T is the time axis. (For more-than-1d
convolution setups, T would really be more than 1 axis, reshaped).
params: of shape (C, N+1) where N is the number of linear regions in the
piecewise linear function; params[c][0] is l which is
a log scale parameter that dictates how far apart
the discontinuities in the piecewise linear function are,
and params[c][n+1] for 0 <= n < N are the derivatives
of the linear parts of the piecewise linear function.
The discontinuities of the function are at:
exp(l) * [ -(N/2 - 1), -(N/2 - 2), ... (N/2 - 1) ]
output: The transformed input, shape (B , C, T)
images_per_thread_block: The number of images processed by each thread
block. The calling code must guarantee that this is a power
of 2, and that EITHER:
THREADS_PER_BLOCK / images_per_thread_block >= T
OR
images_per_thread_block == 1
.. this is used for a small optimization.
This kernel is allocated with `extern_buf` containing enough memory
to store 2*N + 3 values of type scalar_t.
default to (0, 0, S, T); and they should not exceed these bounds.
ans: a tensor `ans` of shape [B], where this function will set
ans[b] = p[b][s_end][t_end],
with s_end and t_end being (S, T) if `boundary` was specified,
and (boundary[b][2], boundary[b][3]) otherwise.
`ans` represents the mutual information between each pair of
sequences (i.e. x[b] and y[b], although the sequences are not
supplied directy to this function).
The block-dim and grid-dim must both be 1-dimensional, and the block-dim must
be at least 128.
*/
*/
template <typename scalar_t,
int BLOCK_SIZE> // e.g. BLOCK_SIZE == 16 or 32. Note: we require the
// num-threads be at least 128.
int BLOCK_SIZE> // e.g. BLOCK_SIZE == 16 or 32.
__global__
void mutual_information_kernel(
torch::PackedTensorAccessor32<scalar_t, 3> px, // B, S, T + 1, i.e. batch, x_seq_length, y_seq_length + 1
......@@ -450,8 +437,8 @@ void mutual_information_kernel(
epx_grad = px_grad / epx and epy_grad = py_grad / epy, and writing exp(p) for p and so on,
the above becomes
px_grad[b][s][t] / exp(px[b][s][t]) = p_grad[b][s + 1][t] / exp(p[b][s + 1][t] * exp(p[b][s][t])
py_grad[b][s][t] / exp(py[b][s][t]) = p_grad[b][s][t + 1] / exp(p[b][s][t + 1] * exp(p[b][s][t])
px_grad[b][s][t] / exp(px[b][s][t]) = p_grad[b][s + 1][t] / exp(p[b][s + 1][t]) * exp(p[b][s][t])
py_grad[b][s][t] / exp(py[b][s][t]) = p_grad[b][s][t + 1] / exp(p[b][s][t + 1]) * exp(p[b][s][t])
Rearranging:
px_grad[b][s][t] = p_grad[b][s + 1][t] * exp(p[b][s][t] + px[b][s][t] - p[b][s + 1][t]) (eq. 3a)
py_grad[b][s][t] = p_grad[b][s][t + 1] * exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1]) (eq. 3b)
......@@ -485,15 +472,21 @@ void mutual_information_backward_kernel(
torch::PackedTensorAccessor32<scalar_t, 3> py, // B, S + 1, T.
torch::PackedTensorAccessor32<scalar_t, 3> p, // B, S + 1, T + 1. Produced in forward pass.
torch::PackedTensorAccessor32<scalar_t, 1> ans_grad, // [B]. This is an input.
torch::PackedTensorAccessor32<scalar_t, 1> ans_grad_compare, // [B]. A value will be written to here which
// should ideally equal ans_grad.
torch::PackedTensorAccessor32<scalar_t, 3> p_grad, // B, S + 1, T + 1. This is a temporary.
torch::PackedTensorAccessor32<scalar_t, 3> px_grad, // B, S, T + 1.
torch::PackedTensorAccessor32<scalar_t, 3> py_grad, // B, S + 1, T.
torch::PackedTensorAccessor32<int64_t, 2> boundary, // B, 4; or 0, 0 if boundaries are the defaults (0, 0, S, T)
int iter) { // This kernel is sequentially called with 'iter' = num_iters
// - 1, num_iters - 2, .. 0, where num_iters can be taken to
// be any sufficiently large number but will actually be:
// num_s_blocks + num_t_blocks - 1 where num_s_blocks = S /
// BLOCK_SIZE + 1 and num_t_blocks = T / BLOCK_SIZE + 1
int iter, // This kernel is sequentially called with 'iter' = num_iters
// - 1, num_iters - 2, .. 0, where num_iters can be taken to
// be any sufficiently large number but will actually be:
// num_s_blocks + num_t_blocks - 1 where num_s_blocks = S /
// BLOCK_SIZE + 1 and num_t_blocks = T / BLOCK_SIZE + 1
bool overwrite_ans_grad) { // If true, overwrite ans_grad with a value
// which, if everything is working correctly,
// should be identical or very close to the
// value of ans_grad that was passed in.
const int B = px.size(0),
S = px.size(1),
T = py.size(2);
......@@ -715,14 +708,13 @@ void mutual_information_backward_kernel(
}
if (threadIdx.x == 0 && s_block_begin == s_begin &&
t_block_end == t_end)
t_block_begin == t_begin && overwrite_ans_grad)
ans_grad[b] = p_buf[0][0];
}
}
// forward of mutual_information. See """... """ comment of `mutual_information` in
// mutual_information.py for documentation of the behavior of this function.
torch::Tensor mutual_information_cuda(torch::Tensor px,
......@@ -752,6 +744,9 @@ torch::Tensor mutual_information_cuda(torch::Tensor px,
num_blocks = 128,
BLOCK_SIZE = 32;
// The blocks cover the 'p' matrix, which is of size (B, S+1, T+1),
// so dividing by BLOCK_SIZE rounding up we get e.g.
// (S+1 + BLOCK_SIZE-1) / BLOCK_SIZE == S / BLOCK_SIZE + 1
const int num_s_blocks = S / BLOCK_SIZE + 1,
num_t_blocks = T / BLOCK_SIZE + 1,
num_iters = num_s_blocks + num_t_blocks - 1;
......@@ -777,11 +772,15 @@ torch::Tensor mutual_information_cuda(torch::Tensor px,
// backward of mutual_information; returns (grad_px, grad_py)
// If overwrite_ans_grad == true, will overwrite ans_grad with a value which
// should be identical to the original ans_grad if the computation worked
// as it should.
torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
torch::Tensor py,
std::optional<torch::Tensor> optional_boundary,
torch::Tensor p,
torch::Tensor ans_grad) {
torch::Tensor ans_grad,
bool overwrite_ans_grad) {
TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional");
TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional.");
TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional.");
......@@ -813,6 +812,9 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
num_blocks = 128,
BLOCK_SIZE = 32;
// The blocks cover the 'p' matrix, which is of size (B, S+1, T+1),
// so dividing by BLOCK_SIZE rounding up we get e.g.
// (S+1 + BLOCK_SIZE-1) / BLOCK_SIZE == S / BLOCK_SIZE + 1
const int num_s_blocks = S / BLOCK_SIZE + 1,
num_t_blocks = T / BLOCK_SIZE + 1,
num_iters = num_s_blocks + num_t_blocks - 1;
......@@ -833,7 +835,8 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
px_grad.packed_accessor32<scalar_t, 3>(),
py_grad.packed_accessor32<scalar_t, 3>(),
optional_boundary.value().packed_accessor32<int64_t, 2>(),
iter);
iter,
overwrite_ans_grad);
}
return std::vector<torch::Tensor>({px_grad, py_grad});
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment