Commit 2b90f668 authored by Daniel Povey's avatar Daniel Povey
Browse files

Pretty close to finishing all the core code, but need to check through it.

parent 77eed83f
......@@ -5,8 +5,6 @@
// returns log(exp(x) + exp(y)).
__forceinline__ __device__ double LogAdd(double x, double y) {
double diff;
......@@ -27,7 +25,6 @@ __forceinline__ __device__ double LogAdd(double x, double y) {
// returns log(exp(x) + exp(y)).
__forceinline__ __device__ inline float LogAdd(float x, float y) {
float diff;
if (x < y) {
diff = x - y;
x = y;
......@@ -118,9 +115,12 @@ void mutual_information_kernel(
torch::PackedTensorAccessor32<scalar_t, 3> p, // B, S + 1, T + 1. This is an output.
torch::PackedTensorAccessor32<int64_t, 2> boundary, // B, 4; or 0, 0 if boundaries are the defaults (0, 0, S, T)
torch::PackedTensorAccessor32<scalar_t, 1> ans, // [B]
int iter) { // This kernel is sequentially called with 'iter' = 0, 1, 2 and so on, up to:
// (S+BLOCK_S_SIZE-1)/BLOCK_S_SIZE + (T+BLOCK_T_SIZE-1)/BLOCK_T_SIZE - 1
// so that each group depends on the previous group...
int iter) { // This kernel is sequentially called with 'iter' = 0, 1, 2 and so on,
// up to num_iters - 1 where
// num_iters = num_s_blocks + num_t_blocks - 1
// num_s_blocks = S / BLOCK_SIZE + 1
// num_t_blocks = T / BLOCK_SIZE + 1
// so that each group depends on the previous group...
const int B = px.size(0),
S = px.size(1),
T = py.size(2);
......@@ -180,35 +180,36 @@ void mutual_information_kernel(
int s_block_begin = block * BLOCK_S_SIZE,
t_block_begin = (iter - block) * BLOCK_T_SIZE;
bool is_origin_block = (s_block_begin * t_block_begin == 0);
int s_end, t_end; // s_end and t_end are the end points (last-plus-one) of the entire sequence.
if (threadDim.x < 4 && boundary.size(0) != 0)
boundary_buf[threadDim.x] = boundary[b][threadDim.x];
__syncthreads();
int s_begin = boundary_buf[0],
t_begin = boundary_buf[1];
s_end = boundary_buf[2];
t_end = boundary_buf[3];
t_begin = boundary_buf[1],
s_end = boundary_buf[2],
t_end = boundary_buf[3];
s_block_begin += s_begin;
t_block_begin += t_begin;
// block_S and block_T are the actual sizes of this block, no greater than
// (BLOCK_SIZE, BLOCK_SIZE) but possibly less than that if we are towards
// the end of the sequence.
int block_S = min(BLOCK_SIZE, s_end - s_block_begin),
block_T = min(BLOCK_SIZE, t_end - t_block_begin);
// The last element of the output matrix p we write is (s_end, t_end),
// i.e. the one-past-the-end index is (s_end + 1, t_end + 1).
int block_S = min(BLOCK_SIZE, s_end + 1 - s_block_begin),
block_T = min(BLOCK_SIZE, t_end + 1 - t_block_begin);
if (block_S <= 0 || block_T <= 0)
continue;
bool is_origin_block = (s_block_begin * t_block_begin == 0);
// Load px_buf and py_buf. We exponentiate; the assumption is that they most likely
// won't overflow or underflow, but if they do overflow we'll detect it later; we'll
// also detect certain kinds of underflow.
for (int i = threadDim.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) {
int t_in_block = i % BLOCK_SIZE,
s_in_block = i / BLOCK_SIZE,
int s_in_block = i / BLOCK_SIZE,
t_in_block = i % BLOCK_SIZE,
s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin;
......@@ -305,7 +306,7 @@ void mutual_information_kernel(
p_buf_s1_t = p_buf[s + 1][0];
}
for (int i = 1; i < 2 * BLOCK_SIZE; i++) {
for (int i = 1; i < block_S + block_T; i++) {
// i is the inner iteration, which corresponds to the (s + t) indexes of the
// elements within the block that we write. So i == 0 writes positions
// (s, t) == (0, 0); i == 1 writes (0, 1) and (1, 0); i == 2 writes
......@@ -402,44 +403,6 @@ void mutual_information_kernel(
/*
Summing reduction within a one-dimensional thread block, but with a
stride of N, so that we separately sum up the values of all threads with
threadIdx.x % N == 0, with threadIdx.x % N == 1, and so on. At the end,
threads with 0 <= threadIdx.x < N contain the sums.
So this is like tiled summing reduction except that the tiles are
interspersed with each other.
Args:
N: The number we sum modulo (must be a power of 2 with
1 <= N <= blockDim.x), i.e. all threads with
threadIdx.x % N == n for some 0 <= n < N have `val` summed.
buf: Pointer to the start of a __shared__ buffer of size
blockDim.x, to be used as a temporary within this function.
val: The value to be summed
Return:
Threads where threadIdx.x < N will return the sums (over the threads with
the same value of threadIdx.x % N);
the return value in other threads is undefined.
*/
template <typename scalar_t>
__forceinline__ __device__ scalar_t strided_reduce_sum(int N,
__volatile__ scalar_t *buf,
scalar_t val) {
// Each iteration halves the number of active threads
// Each thread adds its partial sum[i] to sum[lane+i]
for (int i = blockDim.x / 2; i >= N; i /= 2) {
buf[threadIdx.x] = val;
__syncthreads();
if (threadIdx.x < i)
val += buf[threadIdx.x + i];
}
return val; // Only threads with threadIdx.x < N will return the full sums of
// their groups.
}
/*
Backward of mutual_information.
......@@ -503,7 +466,7 @@ __forceinline__ __device__ scalar_t strided_reduce_sum(int N,
p_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] +
p_grad[b][s][t + 1] * yderiv[b][s][t] (eq. 6)
px_grad[b][s][t] = p_grad[b][s][t + 1] * yderiv[b][s][t] (eq. 7)
px_grad[b][s][t] = p_grad[b][s + 1][t] * yderiv[b][s][t] (eq. 7)
py_grad[b][s][t] = p_grad[b][s][t + 1] * yderiv[b][s][t] (eq. 8)
(It might seem like we could just reuse px_grad and py_grad for (eq. 6), but it's
......@@ -525,9 +488,11 @@ void mutual_information_backward_kernel(
torch::PackedTensorAccessor32<scalar_t, 3> px_grad, // B, S, T + 1.
torch::PackedTensorAccessor32<scalar_t, 3> py_grad, // B, S + 1, T.
torch::PackedTensorAccessor32<int64_t, 2> boundary, // B, 4; or 0, 0 if boundaries are the defaults (0, 0, S, T)
int iter) { // This kernel is sequentially called with 'iter' = num_iters - 1, num_iters - 2, .. 0,
// where num_iters can be taken to be any sufficiently large number but will actually be:
// (S+BLOCK_S_SIZE-1)/BLOCK_S_SIZE + (T+BLOCK_T_SIZE-1)/BLOCK_T_SIZE - 1
int iter) { // This kernel is sequentially called with 'iter' = num_iters
// - 1, num_iters - 2, .. 0, where num_iters can be taken to
// be any sufficiently large number but will actually be:
// num_s_blocks + num_t_blocks - 1 where num_s_blocks = S /
// BLOCK_SIZE + 1 and num_t_blocks = T / BLOCK_SIZE + 1
const int B = px.size(0),
S = px.size(1),
T = py.size(2);
......@@ -543,6 +508,9 @@ void mutual_information_backward_kernel(
// but then modified to store the "xderiv" and "yderiv" values defined
// in (eq. 5) and (eq. 6) above. For out-of-range values, we'll write 0.0
// here.
// px_buf[s][t] contains px[s+s_block_begin][t+t_block_begin];
// py_buf[s][t] contains py[s+s_block_begin][t+t_block_begin].
// Unlike in the forward code, there is no offset of 1 in the indexes.
__shared__ scalar_t px_buf[BLOCK_SIZE][BLOCK_SIZE],
py_buf[BLOCK_SIZE][BLOCK_SIZE];
......@@ -565,278 +533,195 @@ void mutual_information_backward_kernel(
// boundary information supplied.
__shared__ int64_t boundary_buf[4];
boundary_buf[0] = 0;
boundary_buf[1] = 0;
boundary_buf[2] = S;
boundary_buf[3] = T;
const int B = input.size(0),
C = input.size(1),
T = input.size(2),
N = params.size(1) - 1,
K = N / 2; // Note: N and K are powers fo 2, with K >= 1.
const int c = blockIdx.x; // c is channel index
scalar_t *y_vals = (scalar_t*) extern_buf, // [N], actually there are three
// spaces between here and
// `params_buf` for storing scale
// and inv_scale and l == params[c][0].
*params_buf = (scalar_t*) y_vals + 3 + N; // [N]. Contains parameters (not times scale!)
// Caution: contains params[c][1] through params[c][N],
// i.e. numbering is off by 1 versus params.
// params_buf[-1] contains params[c][0] == log of scale;
// params_buf[-2] and params_buf[-3] contain scale and inv_scale.
__shared__ scalar_t input_buf[THREADS_PER_BLOCK]; // input sequence
__shared__ scalar_t output_grad_buf[THREADS_PER_BLOCK];
__shared__ char n_buf[THREADS_PER_BLOCK]; // for each input in `input_buf`,
// this stores the integer value 0
// <= n < N which determines which
// piece of the piecewise linear
// function we are in.
// Load parameters
if (threadIdx.x <= N)
params_buf[threadIdx.x - 1] = params[c][threadIdx.x];
__syncthreads();
if (threadIdx.x == 0) {
scalar_t scale = exp(params_buf[-1]);
params_buf[-2] = scale;
params_buf[-3] = 1.0 / scale;
boundary_buf[0] = 0;
boundary_buf[1] = 0;
boundary_buf[2] = S;
boundary_buf[3] = T;
}
__syncthreads();
if (threadIdx.x == 0) {
scalar_t scale = params_buf[-2],
sum_positive = 0.0;
for (int i = 0; i < K; i++) {
// params_buf is indexed with an index one less than params.
scalar_t pos_scaled_param = params_buf[K + i] * scale;
y_vals[K + i] = sum_positive - pos_scaled_param * i;
sum_positive += pos_scaled_param;
}
} else if (threadIdx.x == 64) {
scalar_t scale = params_buf[-2],
sum_negative = 0.0;
for (int i = 0; i < K; i++) {
scalar_t neg_scaled_param = params_buf[K - i - 1] * scale;
sum_negative -= neg_scaled_param;
y_vals[K - i - 1] = sum_negative + neg_scaled_param * (i + 1);
}
}
__syncthreads();
// this_param_grad and this_y_grad pertain to the 'n' value (i.e. the n'th
// linear interval) corresponding to n == threadIdx.x % N. For example, if
// threadIdx.x == 0, this thread's gradient corresponds to the left-most
// linear interval.
scalar_t this_param_grad = 0.0,
this_y_vals_grad = 0.0;
scalar_t inv_scale = params_buf[-3];
int T_inc = THREADS_PER_BLOCK / images_per_thread_block,
b_offset = threadIdx.x / T_inc; // offset within batch
for (int b = blockIdx.y * images_per_thread_block + b_offset; b < B;
b += gridDim.y * images_per_thread_block) {
// The following will loop just once if images_per_thread_block > 1. If
// images_per_thread_block == 1 and T > THREADS_PER_BLOCK, we will loop
// multiple times. We want to keep all threads active so that output_grad
// will be set to zero for excess threads, and thus won't contribute to
// this_params_grad or this_y_vals_grad.
for (int t_offset = 0; t_offset < T; t_offset += THREADS_PER_BLOCK) {
// The following is equivalent to:
// int t = (threadIdx.x % T_inc) + t_offset;
// given that T_inc is a power of 2 and t_offset >= THREADS_PER_BLOCK >= T_inc.
int t = (threadIdx.x & (T_inc - 1)) | t_offset;
scalar_t this_input = 0.0, this_output_grad;
if (t < T) {
this_output_grad = output_grad[b][c][t];
this_input = input[b][c][t];
input_buf[threadIdx.x] = this_input;
output_grad_buf[threadIdx.x] = this_output_grad;
}
scalar_t x = this_input * inv_scale + K;
if (x < 0) x = 0;
else if (x >= N) x = N - 1;
// The forward code did:
// output[b][c][t] = this_input * params_buf[n] + y_vals[n];
// We get the derivative for params and y_vals later.
if (t < T) {
int n = (int)x; // C++ rounds toward zero.
n_buf[threadIdx.x] = (char)n;
input_grad[b][c][t] = this_output_grad * params_buf[n];
} else {
n_buf[threadIdx.x] = 255;
}
// batch_block_iter iterates over both batch elements (index b), and block
// indexes in the range [0..num_blocks_this_iter-1]. The order here
// doesn't matter, since there are no interdependencies between these
// blocks (they are on a diagonal).
for (int batch_block_iter = blockIdx.x;
batch_block_iter < B * num_blocks_this_iter;
batch_block_iter += gridDim.x) {
int b = batch_block_iter % B,
block = batch_block_iter / B;
int s_block_begin = block * BLOCK_S_SIZE,
t_block_begin = (iter - block) * BLOCK_T_SIZE;
int this_block_start = threadIdx.x & ~(N-1), // == N * (threadIdx.x / N),
// since N is power of 2
this_n = threadIdx.x & (N-1); // == threadIdx.x % N.
// this_n is the n value that this thread accumulates gradients for;
// it is responsible for output_grads in the block of threads
// from this_block_start to this_block_start+N-1.
// __syncthreads(); // <- not really needed.
// At this point there is an implicit within-warp
// synchronization (Note: implicit warp synchronization is not considered
// future-proof). Threads above have written to n_buf, and threads below
// will read from it; but we don't need to explicitly synchronize for now
// because the reads/writes are among threads in a group of N threads with
// (4 <= N <= 16); and 16 is less than the warp size which is 32 or 64.
// src_indexes will contain up to 16 16-bit numbers, stored starting in its
// least significant bits. It will store all the offsets within this
// block of N threads, whose chosen 'n' value equals this_n.
uint64_t src_indexes = 0;
// num_src is the number of numbers in `src_indexes`. We need to store a
// separate counter because zero is a valid index and if we are to support
// N == 16 we don't have bits to spare in src_indexes to store some kind
// of marker.
int num_src = 0;
// This loop always does at least N statements, but they should be
// relatively fast ones since the computation per n value is minimal and
// there is little I/O. We are figuring out the subset of our block of N
// elements, which this particular thread value is responsible for
// (because they have n == this_n), and storing them in `src_indexes` and
// `num_src`.
for (int i = 0; i < N; i += 4) {
uint32_t n_block_of_4 = *reinterpret_cast<uint32_t*>(n_buf + this_block_start + i);
#pragma unroll
for (int j = 0; j < 4; ++j) {
// CUDA is little endian
char n = (char)(n_block_of_4 >> (8*j));
if (n == this_n) {
// We require that N <= 16, so 4 bits is enough to store src_idx.
src_indexes = (src_indexes << 4) | (i + j);
++num_src;
}
// Note: if, for out-of-range threads, we had values not in [0..N-1] in
// n_buf they won't end up mattering even though they are read here,
// because they won't equal this_n. For values 0 <= n < N originating
// in out-of-range threads, the value won't matter because the
// corresponding value in output_grad_buf will be zero.
}
}
if (threadDim.x < 4 && boundary.size(0) != 0)
boundary_buf[threadDim.x] = boundary[b][threadDim.x];
__syncthreads();
// While num_src could theoretically be as large as N, the hope is that no
// thread in any given warp actually loops that many times. Once all
// threads in the warp are finished looping, we can continue. It is OK
// for different warps to get out of sync here; we could be looping over a
// number of images, and the hope is that different warps will reach the
// end of the outer loop at around the same time because their variations
// in speed will average out.
for (; num_src > 0; --num_src, (src_indexes >>= 4)) {
int src_thread = this_block_start | (src_indexes & 0xF);
scalar_t src_output_grad = output_grad_buf[src_thread],
src_input = input_buf[src_thread];
assert(n_buf[src_thread] == this_n);
n_buf[src_thread] = 0;
// Backprop for: output = input * params_buf[n] + y_vals[n].
// Here, n == this_n; this is how we selected these `src_idx` values.
this_param_grad += src_output_grad * src_input;
this_y_vals_grad += src_output_grad;
}
int s_begin = boundary_buf[0],
t_begin = boundary_buf[1],
s_end = boundary_buf[2],
t_end = boundary_buf[3];
s_block_begin += s_begin;
t_block_begin += t_begin;
// TODO: remove the next lines
assert(n_buf[threadIdx.x] == 0 || (unsigned char)n_buf[threadIdx.x] == 255);
output_grad_buf[threadIdx.x] = 0.0;
}
}
// block_S and block_T are the actual sizes of this block, no greater than
// (BLOCK_SIZE, BLOCK_SIZE) but possibly less than that if we are towards
// the end of the sequence.
// The last element of the output matrix p we write is (s_end, t_end),
// i.e. the one-past-the-end index is (s_end + 1, t_end + 1).
int block_S = min(BLOCK_SIZE, s_end + 1 - s_block_begin),
block_T = min(BLOCK_SIZE, t_end + 1 - t_block_begin);
__syncthreads(); // sync threads because we are about to re-use
// output_grad_buf for reduction, and, later, input_buf.
if (block_S <= 0 || block_T <= 0)
continue;
this_param_grad = strided_reduce_sum(N, output_grad_buf, this_param_grad);
__syncthreads();
this_y_vals_grad = strided_reduce_sum(N, output_grad_buf, this_y_vals_grad);
// Load px_buf and py_buf. At this point they just contain px and py
// for this block.
for (int i = threadDim.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) {
int s_in_block = i / BLOCK_SIZE,
t_in_block = i % BLOCK_SIZE,
s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin;
// We let ps and py default to -infinity if they are out of range, which will
// cause xderiv and yderiv for out-of-range values to be zero, and cause
// correct behavior in edge cases (for the top and right blocks).
// The issue is that p and p_grad are of larger size than px and py.
scalar_t this_px = -INFINITY;
if (s < s_end && t <= t_end)
this_px = px[b][s - 1][t];
px_buf[s_in_block][t_in_block] = this_px;
scalar_t this_py = -INFINITY;
if (s <= s_end && t < t_end)
this_py = py[b][s][t - 1];
py_buf[s_in_block][t_in_block] = this_py;
}
__syncthreads(); // sync threads because we are about to re-use
// output_grad_buf as y_vals_grad_buf.
// Re-use some buffers..
scalar_t *params_grad_buf = input_buf + 1, // [N] ... but element [-1] will have deriv of scale.
*y_vals_grad_buf = output_grad_buf; // [N]
// load p. This time we loop over the exact indexes we need. Above
// we looped to BLOCK_SIZE * BLOCK_SIZE rather than block_S and block_T
// because having power-of-2 arrangement of threads may be helpful
// for aligned reads, but here the loop is up to (BLOCK_SIZE + 1) * (BLOCK_SIZE + 1)
// which is not a power of 2, so that is not a concern here.
for (int i = threadDim.x; i < (BLOCK_SIZE + 1) * (BLOCK_SIZE + 1); i += blockDim.x) {
int s_in_block = i / (BLOCK_SIZE + 1), // 0 <= s_in_block <= block_S
t_in_block = i % (BLOCK_SIZE + 1), // 0 <= t_in_block <= block_T
s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin;
// Setting 0.0 for out-of-bounds elements, together with setting
// -INFINITY for out-of-bounds elements of px_buf and py_buf, will
// ensure that we do the right thing in top and right edge cases,
// i.e. that no derivatives will be propagated from out-of-bounds points.
p_buf[s_in_block][t_in_block] = (s <= s_end && t <= t_end ?
p[b][s][t] : 0.0);
}
if (threadIdx.x < N) {
params_grad_buf[threadIdx.x] = this_param_grad;
y_vals_grad_buf[threadIdx.x] = this_y_vals_grad;
}
__syncthreads(); // other threads are about to read params_grad_buf and
// y_vals_grad_buf.
// Set xderiv and yderiv; see (eq. 4) and (eq. 5).
for (int i = threadDim.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) {
// We can apply this formula to the entire block even if we are processing
// a partial block; elements outside the partial block will not be used so
// their values don't matter, and elements just out
int t = i % BLOCK_SIZE, s = i / BLOCK_SIZE;
// Mathematically the following is doing:
// xderiv[b][s][t] := exp(p[b][s][t] + px[b][s][t] - p[b][s + 1][t])
// (with an offset on the s and t indexes)
px_buf[s][t] = exp(px_buf[s][t] + px_buf[s][t] - p_buf[s + 1][t]);
// Mathematically the following is doing:
// yderiv[b][s][t] := exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1])
// (with an offset on the s and t indexes)
py_buf[s][t] = exp(px_buf[s][t] + py_buf[s][t] - p_buf[s][t + 1]);
}
// This next block does backprop relating to `y_vals`. Comparing with the CPU
// version (call this the "reference code") is the best way to understand this
// (this code is just a modification of that). The main difference is we
// modify the indexes into params and params_grad by -1, so the index
// corresponds to the 'n' value; and element -1 of params_grad_buf will have
// the deriv of the log scale.
// Load p_grad for the top and right elements in p_buf: i.e. for elements
// p_buf[s][t] where s == block_S (exclusive-or) t == block_T. We don't
// need to load the top-right corner [block_S][block_T]; that location will
// never be accessed.
// These are the p_grad values computed by previous instances of this kernel
// If this is one of the top or right blocks, some or all of the p_grad
// values we'd be reading here will be out of range, and we use zeros.
if (threadIdx.x < block_S) {
int s_in_block = threadIdx.x,
t_in_block = block_T,
s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin;
p_buf[s_in_block][t_in_block] = (
s <= s_end && t <= t_end ? p_grad[s][t] : 0.0);
} else if (static_cast<unsigned int>(threadIdx.x - 64) <
static_cast<unsigned int>(block_T)) {
int s_in_block = block_S,
t_in_block = threadIdx.x - 64,
s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin;
p_buf[s_in_block][t_in_block] = (
s <= s_end && t <= t_end ? p_grad[s][t] : 0.0);
}
scalar_t l_grad;
if (threadIdx.x == 0) {
// Now do the backprop for the loop above where we set y_vals_a. This could
// be further optimized to replace the loop with a raking, but I doubt this
// will have a huge effect on the runtime since K will be fairly small,
// e.g. 4.
scalar_t scale = params_buf[-2],
scale_grad = 0.0,
sum_positive_grad = 0.0;
for (int i = K - 1; i >= 0; i--) {
// Backprop for: sum_positive += pos_scaled_param;
scalar_t pos_scaled_param_grad = sum_positive_grad;
// Backprop for: y_vals[K + i] = sum_positive - pos_scaled_param * i;
scalar_t y_grad_pos = y_vals_grad_buf[K + i];
pos_scaled_param_grad -= i * y_grad_pos;
sum_positive_grad += y_grad_pos;
// Backprop for: pos_scaled_param = params_buf[K + i] * scale,
params_grad_buf[K + i] += pos_scaled_param_grad * scale;
scale_grad += pos_scaled_param_grad * params_buf[K + i];
// The number of inner iterations, i.e. iterations inside this
// kernel, is this_num_inner_iters. The highest iteration,
// corresponding to the highest-indexed value of p_buf that
// we need to set,
// corresponds to p_buf[block_S - 1][block_T - 1],
// and the iteration number is the sum of these indexes, i.e.
// (block_S - 1) + (block_T - 1).
bool is_final_block = (s_block_begin + block_S == s_end + 1 &&
t_block_begin + block_T == t_end + 1);
int first_iter = block_S + block_T - 2;
if (is_final_block) {
// The following statement, mathematically, corresponds to:
// p_grad[b][s_end][t_end] = ans_grad[b] Normally this element of p_buf
// would be set by the first iteration of the loop below, so if it's set
// this way we have to decrement first_iter to prevent it being
// overwritten.
p_buf[block_S - 1][block_T - 1] = ans_grad[b];
--first_iter;
}
// Backprop for: scale = exp(l), where l = params[c][0].
l_grad = scale * scale_grad;
} else if (threadIdx.x == 64) {
// Now do the backprop for the loop above where we set y_vals.
// Make this one threadIdx.x == 0 so it's possibly quicker to test
//
scalar_t scale = params_buf[-2],
scale_grad = 0.0,
sum_negative_grad = 0.0;
for (int i = K - 1; i >= 0; i--) {
// Backprop for: y_vals[K - i - 1] = sum_negative + neg_scaled_param * (i + 1):
scalar_t y_grad_neg = y_vals_grad_buf[K - i - 1];
sum_negative_grad += y_grad_neg;
scalar_t neg_scaled_param_grad = y_grad_neg * (i + 1);
// Backprop for: sum_negative -= neg_scaled_param;
neg_scaled_param_grad -= sum_negative_grad;
// Backprop for: neg_scaled_param = params_buf[K - i - 1] * scale;
params_grad_buf[K - i - 1] += neg_scaled_param_grad * scale;
scale_grad += neg_scaled_param_grad * params_buf[K - i - 1];
for (int i = first_iter; i >= 0; --i) {
int s = i,
t = i - threadIdx.x;
if (t >= 0) {
// The following statement is really operating on the gradients;
// it corresponds to (eq. 6) defined above, i.e.:
// p_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] +
// p_grad[b][s][t + 1] * yderiv[b][s][t]
p_buf[s][t] = (p_buf[s + 1][t] * px_buf[s][t] +
p_buf[s][t + 1] * py_buf[s][t]);
}
}
params_grad_buf[-1] = scale * scale_grad;
}
__syncthreads();
if (threadIdx.x == 0) {
params_grad_buf[-1] += l_grad; // contribution to l grad from the "negative" branch
}
__syncthreads();
if (threadIdx.x <= N) {
params_grad[blockIdx.y][c][threadIdx.x] = params_grad_buf[threadIdx.x - 1];
// Write out p_grad, px_grad and py_grad.
for (int i = threadDim.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) {
int t_in_block = i % BLOCK_SIZE,
s_in_block = i / BLOCK_SIZE,
s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin;
if (t <= t_end && s <= s_end) {
p_grad[b][s][t] = p_buf[s_in_block][t_in_block];
if (s < s_end) { // write px_grad, which is of shape [B][S][T + 1]
// From (eq. 7):
// px_grad[b][s][t] = p_grad[b][s + 1][t] * yderiv[b][s][t]
px_grad[b][s][t] = (p_buf[s_in_block + 1][t_in_block] *
px_buf[s_in_block][t_in_block]);
}
if (t < t_end) { // write py_grad, which is of shape [B][S + 1][T]
// from (eq. 8):
// py_grad[b][s][t] = p_grad[b][s][t + 1] * yderiv[b][s][t]
py_grad[b][s][t] = (p_buf[s_in_block][t_in_block + 1] *
py_buf[s_in_block][t_in_block]);
}
}
}
if (threadIdx.x == 0 && s_block_begin == s_begin &&
t_block_end == t_end)
ans_grad[b] = p_buf[0][0];
}
}
// forward of mutual_information. See """... """ comment of `mutual_information` in
// mutual_information.py for documentation of the behavior of this function.
torch::Tensor mutual_information_cuda(torch::Tensor px,
......@@ -861,18 +746,19 @@ torch::Tensor mutual_information_cuda(torch::Tensor px,
torch::Tensor ans = torch::empty({B}, opts);
int num_threads = 128,
num_blocks = 128;
num_blocks = 128,
BLOCK_SIZE = 32;
const int num_s_blocks = S / BLOCK_SIZE + 1,
num_t_blocks = T / BLOCK_SIZE + 1,
num_iters = std::max<int>(num_s_blocks, num_t_blocks);
num_iters = num_s_blocks + num_t_blocks - 1;
bool has_boundary = (bool)optional_boundary;
if (!has_boundary)
optional_boundary = torch::empty({0, 0}, long_opts);
for (int iter = 0; iter < num_iters; iter++) {
mutual_information_kernel<scalar_t, 32><<<num_blocks, num_threads>>>(
for (int iter = 0; iter < num_iters; ++iter) {
mutual_information_kernel<scalar_t, BLOCK_SIZE><<<num_blocks, num_threads>>>(
px.packed_accessor32<scalar_t, 3>(),
py.packed_accessor32<scalar_t, 3>(),
p.packed_accessor32<scalar_t, 3>(),
......@@ -880,141 +766,64 @@ torch::Tensor mutual_information_cuda(torch::Tensor px,
ans.packed_accessor32<scalar_t, 1>(),
iter);
}
int grid_dim_y = 1;
// If the number of channels is quite small (<128) we can launch more thread
// groups, splitting on the batch index.
while (C * grid_dim_y < 128)
grid_dim_y *= 2;
// B_reduced is the max number of thread-groups per channel that would have
// any work to do. If grid_dim_y is more than this, we reduce it to avoid
// launching kernels with nothing to do.
int B_reduced = (B + images_per_thread_block - 1) / images_per_thread_block;
if (grid_dim_y > B_reduced)
grid_dim_y = B_reduced;
int shared_mem_numel = 2 * N + 3;
if (false)
std::cout << "C,B,T,N = " << C << "," << B << "," << T << "," << N
<< ", images_per_thread_block = " << images_per_thread_block
<< ", grid_dim_y = " << grid_dim_y
<< "\n";
TORCH_CHECK(THREADS_PER_BLOCK / images_per_thread_block >= T ||
images_per_thread_block == 1,
"Code error");
TORCH_CHECK(N + 1 <= THREADS_PER_BLOCK,
"Values of N this large are not supported.");
dim3 gridDim(C, grid_dim_y, 1);
// blockDim is scalar, just THREADS_PER_BLOCK.
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "mutual_information_kernel", ([&] {
mutual_information_kernel<scalar_t><<<gridDim, THREADS_PER_BLOCK, sizeof(scalar_t) * shared_mem_numel, at::cuda::getCurrentCUDAStream()>>>(
input.packed_accessor32<scalar_t, 3>(),
params.packed_accessor32<scalar_t, 2>(),
output.packed_accessor32<scalar_t, 3>(),
images_per_thread_block);
}));
return output;
return ans;
}
std::vector<torch::Tensor> mutual_information_backward_cuda(torch::Tensor input,
torch::Tensor params,
torch::Tensor output_grad) {
TORCH_CHECK(input.dim() == 3, "input must be 3-dimensional");
TORCH_CHECK(params.dim() == 2, "params must be 2-dimensional.");
TORCH_CHECK(params.size(1) >= 3 &&
((params.size(1) - 1) & (params.size(1) - 2)) == 0,
"params.size(1) has invalid value, must be a power of 2 plus 1.");
TORCH_CHECK(params.size(0) == input.size(1),
"params vs input channels mismatch");
TORCH_CHECK(output_grad.dim() == 3 && output_grad.size(0) == input.size(0) &&
output_grad.size(1) == input.size(1) &&
output_grad.size(2) == input.size(2),
"output_grad and input have mismatched dim.");
TORCH_CHECK(input.device().is_cuda(), "Input must be a CUDA tensor");
TORCH_CHECK(output_grad.device().is_cuda(), "output_grad must be a CUDA tensor");
TORCH_CHECK(params.device().is_cuda(), "Params must be a CUDA tensor");
const int B = input.size(0),
C = input.size(1),
T = input.size(2),
N = params.size(1) - 1;
TORCH_CHECK(N >= 4, "This backward code requires N >= 4");
TORCH_CHECK(N <= 16, "This backward code currently requires N <= 16");
TORCH_CHECK((N & (N-1)) == 0, "N must be a power of 2")
auto scalar_t = input.scalar_type();
auto opts = torch::TensorOptions().dtype(scalar_t).device(input.device());
torch::Tensor input_grad = torch::empty({B, C, T}, opts);
if (C * B * T == 0) {
return std::vector<torch::Tensor>({input_grad,
torch::empty({C, N + 1})});
}
int images_per_thread_block = 1;
while (images_per_thread_block * 2 * T <= THREADS_PER_BLOCK &&
images_per_thread_block * 2 * N <= THREADS_PER_BLOCK)
images_per_thread_block *= 2;
int grid_dim_y = 1;
// If the number of channels is quite small (<128) we can launch more thread
// groups, splitting on the batch index.
while (C * grid_dim_y < 128)
grid_dim_y *= 2;
// B_reduced is the max number of thread-groups per channel that would have
// any work to do. If grid_dim_y is more than this, we reduce it to avoid
// launching kernels with nothing to do.
int B_reduced = (B + images_per_thread_block - 1) / images_per_thread_block;
if (grid_dim_y > B_reduced)
grid_dim_y = B_reduced;
// backward of mutual_information; returns (grad_px, grad_py)
torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
torch::Tensor py,
std::optional<torch::Tensor> optional_boundary,
torch::Tensor p,
torch::Tensor ans_grad) {
TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional");
TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional.");
TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional.");
TORCH_CHECK(ans_grad.dim() == 1, "ans_grad must be 1-dimensional.");
int shared_mem_numel = 2 * N + 3;
TORCH_CHECK(px.device().is_cuda() && py.device().is_cuda() &&
p.device().is_cuda() && ans_grad.device().is_cuda() &&
"inputs must be CUDA tensors");
auto scalar_t = px.scalar_type();
auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device());
if (false)
std::cout << "C,B,T,N = " << C << "," << B << "," << T << "," << N
<< ", images_per_thread_block = " << images_per_thread_block
<< ", grid_dim_y = " << grid_dim_y
<< "\n";
const int B = px.size(0),
S = px.size(1),
T = px.size(2) - 1;
TORCH_CHECK(THREADS_PER_BLOCK / images_per_thread_block >= T ||
images_per_thread_block == 1,
"Code error");
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T);
TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
TORCH_CHECK(THREADS_PER_BLOCK / images_per_thread_block >= N);
torch::Tensor p_grad = torch::empty({B, S + 1, T + 1}, opts),
px_grad = torch::empty({B, S, T + 1}, opts),
py_grad = torch::empty({B, S + 1, T}, opts),
torch::Tensor params_grad = torch::zeros({grid_dim_y, C, N + 1}, opts);
const int num_threads = 128,
num_blocks = 128,
BLOCK_SIZE = 32;
dim3 gridDim(C, grid_dim_y, 1);
const int num_s_blocks = S / BLOCK_SIZE + 1,
num_t_blocks = T / BLOCK_SIZE + 1,
num_iters = num_s_blocks + num_t_blocks - 1;
// blockDim is scalar, just THREADS_PER_BLOCK.
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "mutual_information_backward_kernel", ([&] {
mutual_information_backward_kernel<scalar_t><<<gridDim, THREADS_PER_BLOCK, sizeof(scalar_t) * shared_mem_numel, at::cuda::getCurrentCUDAStream()>>>(
input.packed_accessor32<scalar_t, 3>(),
params.packed_accessor32<scalar_t, 2>(),
output_grad.packed_accessor32<scalar_t, 3>(),
input_grad.packed_accessor32<scalar_t, 3>(),
params_grad.packed_accessor32<scalar_t, 3>(),
images_per_thread_block);
}));
bool has_boundary = (bool)optional_boundary;
if (!has_boundary)
optional_boundary = torch::empty({0, 0}, long_opts);
params_grad = at::sum(params_grad, {0});
return std::vector<torch::Tensor>({input_grad, params_grad});
for (int iter = num_iters - 1; iter >= 0; --iter) {
mutual_information_backward_kernel<scalar_t, BLOCK_SIZE><<<num_blocks, num_threads>>>(
px.packed_accessor32<scalar_t, 3>(),
py.packed_accessor32<scalar_t, 3>(),
p.packed_accessor32<scalar_t, 3>(),
ans_grad.packed_accessor32<scalar_t, 1>,
p_grad.packed_accessor32<scalar_t, 3>(),
px_grad.packed_accessor32<scalar_t, 3>(),
py_grad.packed_accessor32<scalar_t, 3>(),
optional_boundary.value().packed_accessor32<int64_t, 2>(),
iter);
}
return std::vector<torch::Tensor>({px_grad, py_grad});
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment