"git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "0c263a935f178490bda78c046e220e84b095447d"
Commit 2b90f668 authored by Daniel Povey's avatar Daniel Povey
Browse files

Pretty close to finishing all the core code, but need to check through it.

parent 77eed83f
...@@ -5,8 +5,6 @@ ...@@ -5,8 +5,6 @@
// returns log(exp(x) + exp(y)). // returns log(exp(x) + exp(y)).
__forceinline__ __device__ double LogAdd(double x, double y) { __forceinline__ __device__ double LogAdd(double x, double y) {
double diff; double diff;
...@@ -27,7 +25,6 @@ __forceinline__ __device__ double LogAdd(double x, double y) { ...@@ -27,7 +25,6 @@ __forceinline__ __device__ double LogAdd(double x, double y) {
// returns log(exp(x) + exp(y)). // returns log(exp(x) + exp(y)).
__forceinline__ __device__ inline float LogAdd(float x, float y) { __forceinline__ __device__ inline float LogAdd(float x, float y) {
float diff; float diff;
if (x < y) { if (x < y) {
diff = x - y; diff = x - y;
x = y; x = y;
...@@ -118,9 +115,12 @@ void mutual_information_kernel( ...@@ -118,9 +115,12 @@ void mutual_information_kernel(
torch::PackedTensorAccessor32<scalar_t, 3> p, // B, S + 1, T + 1. This is an output. torch::PackedTensorAccessor32<scalar_t, 3> p, // B, S + 1, T + 1. This is an output.
torch::PackedTensorAccessor32<int64_t, 2> boundary, // B, 4; or 0, 0 if boundaries are the defaults (0, 0, S, T) torch::PackedTensorAccessor32<int64_t, 2> boundary, // B, 4; or 0, 0 if boundaries are the defaults (0, 0, S, T)
torch::PackedTensorAccessor32<scalar_t, 1> ans, // [B] torch::PackedTensorAccessor32<scalar_t, 1> ans, // [B]
int iter) { // This kernel is sequentially called with 'iter' = 0, 1, 2 and so on, up to: int iter) { // This kernel is sequentially called with 'iter' = 0, 1, 2 and so on,
// (S+BLOCK_S_SIZE-1)/BLOCK_S_SIZE + (T+BLOCK_T_SIZE-1)/BLOCK_T_SIZE - 1 // up to num_iters - 1 where
// so that each group depends on the previous group... // num_iters = num_s_blocks + num_t_blocks - 1
// num_s_blocks = S / BLOCK_SIZE + 1
// num_t_blocks = T / BLOCK_SIZE + 1
// so that each group depends on the previous group...
const int B = px.size(0), const int B = px.size(0),
S = px.size(1), S = px.size(1),
T = py.size(2); T = py.size(2);
...@@ -180,35 +180,36 @@ void mutual_information_kernel( ...@@ -180,35 +180,36 @@ void mutual_information_kernel(
int s_block_begin = block * BLOCK_S_SIZE, int s_block_begin = block * BLOCK_S_SIZE,
t_block_begin = (iter - block) * BLOCK_T_SIZE; t_block_begin = (iter - block) * BLOCK_T_SIZE;
bool is_origin_block = (s_block_begin * t_block_begin == 0);
int s_end, t_end; // s_end and t_end are the end points (last-plus-one) of the entire sequence.
if (threadDim.x < 4 && boundary.size(0) != 0) if (threadDim.x < 4 && boundary.size(0) != 0)
boundary_buf[threadDim.x] = boundary[b][threadDim.x]; boundary_buf[threadDim.x] = boundary[b][threadDim.x];
__syncthreads(); __syncthreads();
int s_begin = boundary_buf[0], int s_begin = boundary_buf[0],
t_begin = boundary_buf[1]; t_begin = boundary_buf[1],
s_end = boundary_buf[2]; s_end = boundary_buf[2],
t_end = boundary_buf[3]; t_end = boundary_buf[3];
s_block_begin += s_begin; s_block_begin += s_begin;
t_block_begin += t_begin; t_block_begin += t_begin;
// block_S and block_T are the actual sizes of this block, no greater than // block_S and block_T are the actual sizes of this block, no greater than
// (BLOCK_SIZE, BLOCK_SIZE) but possibly less than that if we are towards // (BLOCK_SIZE, BLOCK_SIZE) but possibly less than that if we are towards
// the end of the sequence. // the end of the sequence.
int block_S = min(BLOCK_SIZE, s_end - s_block_begin), // The last element of the output matrix p we write is (s_end, t_end),
block_T = min(BLOCK_SIZE, t_end - t_block_begin); // i.e. the one-past-the-end index is (s_end + 1, t_end + 1).
int block_S = min(BLOCK_SIZE, s_end + 1 - s_block_begin),
block_T = min(BLOCK_SIZE, t_end + 1 - t_block_begin);
if (block_S <= 0 || block_T <= 0) if (block_S <= 0 || block_T <= 0)
continue; continue;
bool is_origin_block = (s_block_begin * t_block_begin == 0);
// Load px_buf and py_buf. We exponentiate; the assumption is that they most likely // Load px_buf and py_buf. We exponentiate; the assumption is that they most likely
// won't overflow or underflow, but if they do overflow we'll detect it later; we'll // won't overflow or underflow, but if they do overflow we'll detect it later; we'll
// also detect certain kinds of underflow. // also detect certain kinds of underflow.
for (int i = threadDim.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) { for (int i = threadDim.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) {
int t_in_block = i % BLOCK_SIZE, int s_in_block = i / BLOCK_SIZE,
s_in_block = i / BLOCK_SIZE, t_in_block = i % BLOCK_SIZE,
s = s_in_block + s_block_begin, s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin; t = t_in_block + t_block_begin;
...@@ -305,7 +306,7 @@ void mutual_information_kernel( ...@@ -305,7 +306,7 @@ void mutual_information_kernel(
p_buf_s1_t = p_buf[s + 1][0]; p_buf_s1_t = p_buf[s + 1][0];
} }
for (int i = 1; i < 2 * BLOCK_SIZE; i++) { for (int i = 1; i < block_S + block_T; i++) {
// i is the inner iteration, which corresponds to the (s + t) indexes of the // i is the inner iteration, which corresponds to the (s + t) indexes of the
// elements within the block that we write. So i == 0 writes positions // elements within the block that we write. So i == 0 writes positions
// (s, t) == (0, 0); i == 1 writes (0, 1) and (1, 0); i == 2 writes // (s, t) == (0, 0); i == 1 writes (0, 1) and (1, 0); i == 2 writes
...@@ -402,44 +403,6 @@ void mutual_information_kernel( ...@@ -402,44 +403,6 @@ void mutual_information_kernel(
/*
Summing reduction within a one-dimensional thread block, but with a
stride of N, so that we separately sum up the values of all threads with
threadIdx.x % N == 0, with threadIdx.x % N == 1, and so on. At the end,
threads with 0 <= threadIdx.x < N contain the sums.
So this is like tiled summing reduction except that the tiles are
interspersed with each other.
Args:
N: The number we sum modulo (must be a power of 2 with
1 <= N <= blockDim.x), i.e. all threads with
threadIdx.x % N == n for some 0 <= n < N have `val` summed.
buf: Pointer to the start of a __shared__ buffer of size
blockDim.x, to be used as a temporary within this function.
val: The value to be summed
Return:
Threads where threadIdx.x < N will return the sums (over the threads with
the same value of threadIdx.x % N);
the return value in other threads is undefined.
*/
template <typename scalar_t>
__forceinline__ __device__ scalar_t strided_reduce_sum(int N,
__volatile__ scalar_t *buf,
scalar_t val) {
// Each iteration halves the number of active threads
// Each thread adds its partial sum[i] to sum[lane+i]
for (int i = blockDim.x / 2; i >= N; i /= 2) {
buf[threadIdx.x] = val;
__syncthreads();
if (threadIdx.x < i)
val += buf[threadIdx.x + i];
}
return val; // Only threads with threadIdx.x < N will return the full sums of
// their groups.
}
/* /*
Backward of mutual_information. Backward of mutual_information.
...@@ -503,7 +466,7 @@ __forceinline__ __device__ scalar_t strided_reduce_sum(int N, ...@@ -503,7 +466,7 @@ __forceinline__ __device__ scalar_t strided_reduce_sum(int N,
p_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] + p_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] +
p_grad[b][s][t + 1] * yderiv[b][s][t] (eq. 6) p_grad[b][s][t + 1] * yderiv[b][s][t] (eq. 6)
px_grad[b][s][t] = p_grad[b][s][t + 1] * yderiv[b][s][t] (eq. 7) px_grad[b][s][t] = p_grad[b][s + 1][t] * yderiv[b][s][t] (eq. 7)
py_grad[b][s][t] = p_grad[b][s][t + 1] * yderiv[b][s][t] (eq. 8) py_grad[b][s][t] = p_grad[b][s][t + 1] * yderiv[b][s][t] (eq. 8)
(It might seem like we could just reuse px_grad and py_grad for (eq. 6), but it's (It might seem like we could just reuse px_grad and py_grad for (eq. 6), but it's
...@@ -525,9 +488,11 @@ void mutual_information_backward_kernel( ...@@ -525,9 +488,11 @@ void mutual_information_backward_kernel(
torch::PackedTensorAccessor32<scalar_t, 3> px_grad, // B, S, T + 1. torch::PackedTensorAccessor32<scalar_t, 3> px_grad, // B, S, T + 1.
torch::PackedTensorAccessor32<scalar_t, 3> py_grad, // B, S + 1, T. torch::PackedTensorAccessor32<scalar_t, 3> py_grad, // B, S + 1, T.
torch::PackedTensorAccessor32<int64_t, 2> boundary, // B, 4; or 0, 0 if boundaries are the defaults (0, 0, S, T) torch::PackedTensorAccessor32<int64_t, 2> boundary, // B, 4; or 0, 0 if boundaries are the defaults (0, 0, S, T)
int iter) { // This kernel is sequentially called with 'iter' = num_iters - 1, num_iters - 2, .. 0, int iter) { // This kernel is sequentially called with 'iter' = num_iters
// where num_iters can be taken to be any sufficiently large number but will actually be: // - 1, num_iters - 2, .. 0, where num_iters can be taken to
// (S+BLOCK_S_SIZE-1)/BLOCK_S_SIZE + (T+BLOCK_T_SIZE-1)/BLOCK_T_SIZE - 1 // be any sufficiently large number but will actually be:
// num_s_blocks + num_t_blocks - 1 where num_s_blocks = S /
// BLOCK_SIZE + 1 and num_t_blocks = T / BLOCK_SIZE + 1
const int B = px.size(0), const int B = px.size(0),
S = px.size(1), S = px.size(1),
T = py.size(2); T = py.size(2);
...@@ -543,6 +508,9 @@ void mutual_information_backward_kernel( ...@@ -543,6 +508,9 @@ void mutual_information_backward_kernel(
// but then modified to store the "xderiv" and "yderiv" values defined // but then modified to store the "xderiv" and "yderiv" values defined
// in (eq. 5) and (eq. 6) above. For out-of-range values, we'll write 0.0 // in (eq. 5) and (eq. 6) above. For out-of-range values, we'll write 0.0
// here. // here.
// px_buf[s][t] contains px[s+s_block_begin][t+t_block_begin];
// py_buf[s][t] contains py[s+s_block_begin][t+t_block_begin].
// Unlike in the forward code, there is no offset of 1 in the indexes.
__shared__ scalar_t px_buf[BLOCK_SIZE][BLOCK_SIZE], __shared__ scalar_t px_buf[BLOCK_SIZE][BLOCK_SIZE],
py_buf[BLOCK_SIZE][BLOCK_SIZE]; py_buf[BLOCK_SIZE][BLOCK_SIZE];
...@@ -565,278 +533,195 @@ void mutual_information_backward_kernel( ...@@ -565,278 +533,195 @@ void mutual_information_backward_kernel(
// boundary information supplied. // boundary information supplied.
__shared__ int64_t boundary_buf[4]; __shared__ int64_t boundary_buf[4];
boundary_buf[0] = 0;
boundary_buf[1] = 0;
boundary_buf[2] = S;
boundary_buf[3] = T;
const int B = input.size(0),
C = input.size(1),
T = input.size(2),
N = params.size(1) - 1,
K = N / 2; // Note: N and K are powers fo 2, with K >= 1.
const int c = blockIdx.x; // c is channel index
scalar_t *y_vals = (scalar_t*) extern_buf, // [N], actually there are three
// spaces between here and
// `params_buf` for storing scale
// and inv_scale and l == params[c][0].
*params_buf = (scalar_t*) y_vals + 3 + N; // [N]. Contains parameters (not times scale!)
// Caution: contains params[c][1] through params[c][N],
// i.e. numbering is off by 1 versus params.
// params_buf[-1] contains params[c][0] == log of scale;
// params_buf[-2] and params_buf[-3] contain scale and inv_scale.
__shared__ scalar_t input_buf[THREADS_PER_BLOCK]; // input sequence
__shared__ scalar_t output_grad_buf[THREADS_PER_BLOCK];
__shared__ char n_buf[THREADS_PER_BLOCK]; // for each input in `input_buf`,
// this stores the integer value 0
// <= n < N which determines which
// piece of the piecewise linear
// function we are in.
// Load parameters
if (threadIdx.x <= N)
params_buf[threadIdx.x - 1] = params[c][threadIdx.x];
__syncthreads();
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
scalar_t scale = exp(params_buf[-1]); boundary_buf[0] = 0;
params_buf[-2] = scale; boundary_buf[1] = 0;
params_buf[-3] = 1.0 / scale; boundary_buf[2] = S;
boundary_buf[3] = T;
} }
__syncthreads();
if (threadIdx.x == 0) { // batch_block_iter iterates over both batch elements (index b), and block
scalar_t scale = params_buf[-2], // indexes in the range [0..num_blocks_this_iter-1]. The order here
sum_positive = 0.0; // doesn't matter, since there are no interdependencies between these
for (int i = 0; i < K; i++) { // blocks (they are on a diagonal).
// params_buf is indexed with an index one less than params. for (int batch_block_iter = blockIdx.x;
scalar_t pos_scaled_param = params_buf[K + i] * scale; batch_block_iter < B * num_blocks_this_iter;
y_vals[K + i] = sum_positive - pos_scaled_param * i; batch_block_iter += gridDim.x) {
sum_positive += pos_scaled_param; int b = batch_block_iter % B,
} block = batch_block_iter / B;
} else if (threadIdx.x == 64) { int s_block_begin = block * BLOCK_S_SIZE,
scalar_t scale = params_buf[-2], t_block_begin = (iter - block) * BLOCK_T_SIZE;
sum_negative = 0.0;
for (int i = 0; i < K; i++) {
scalar_t neg_scaled_param = params_buf[K - i - 1] * scale;
sum_negative -= neg_scaled_param;
y_vals[K - i - 1] = sum_negative + neg_scaled_param * (i + 1);
}
}
__syncthreads();
// this_param_grad and this_y_grad pertain to the 'n' value (i.e. the n'th
// linear interval) corresponding to n == threadIdx.x % N. For example, if
// threadIdx.x == 0, this thread's gradient corresponds to the left-most
// linear interval.
scalar_t this_param_grad = 0.0,
this_y_vals_grad = 0.0;
scalar_t inv_scale = params_buf[-3];
int T_inc = THREADS_PER_BLOCK / images_per_thread_block,
b_offset = threadIdx.x / T_inc; // offset within batch
for (int b = blockIdx.y * images_per_thread_block + b_offset; b < B;
b += gridDim.y * images_per_thread_block) {
// The following will loop just once if images_per_thread_block > 1. If
// images_per_thread_block == 1 and T > THREADS_PER_BLOCK, we will loop
// multiple times. We want to keep all threads active so that output_grad
// will be set to zero for excess threads, and thus won't contribute to
// this_params_grad or this_y_vals_grad.
for (int t_offset = 0; t_offset < T; t_offset += THREADS_PER_BLOCK) {
// The following is equivalent to:
// int t = (threadIdx.x % T_inc) + t_offset;
// given that T_inc is a power of 2 and t_offset >= THREADS_PER_BLOCK >= T_inc.
int t = (threadIdx.x & (T_inc - 1)) | t_offset;
scalar_t this_input = 0.0, this_output_grad;
if (t < T) {
this_output_grad = output_grad[b][c][t];
this_input = input[b][c][t];
input_buf[threadIdx.x] = this_input;
output_grad_buf[threadIdx.x] = this_output_grad;
}
scalar_t x = this_input * inv_scale + K;
if (x < 0) x = 0;
else if (x >= N) x = N - 1;
// The forward code did:
// output[b][c][t] = this_input * params_buf[n] + y_vals[n];
// We get the derivative for params and y_vals later.
if (t < T) {
int n = (int)x; // C++ rounds toward zero.
n_buf[threadIdx.x] = (char)n;
input_grad[b][c][t] = this_output_grad * params_buf[n];
} else {
n_buf[threadIdx.x] = 255;
}
int this_block_start = threadIdx.x & ~(N-1), // == N * (threadIdx.x / N), if (threadDim.x < 4 && boundary.size(0) != 0)
// since N is power of 2 boundary_buf[threadDim.x] = boundary[b][threadDim.x];
this_n = threadIdx.x & (N-1); // == threadIdx.x % N. __syncthreads();
// this_n is the n value that this thread accumulates gradients for;
// it is responsible for output_grads in the block of threads
// from this_block_start to this_block_start+N-1.
// __syncthreads(); // <- not really needed.
// At this point there is an implicit within-warp
// synchronization (Note: implicit warp synchronization is not considered
// future-proof). Threads above have written to n_buf, and threads below
// will read from it; but we don't need to explicitly synchronize for now
// because the reads/writes are among threads in a group of N threads with
// (4 <= N <= 16); and 16 is less than the warp size which is 32 or 64.
// src_indexes will contain up to 16 16-bit numbers, stored starting in its
// least significant bits. It will store all the offsets within this
// block of N threads, whose chosen 'n' value equals this_n.
uint64_t src_indexes = 0;
// num_src is the number of numbers in `src_indexes`. We need to store a
// separate counter because zero is a valid index and if we are to support
// N == 16 we don't have bits to spare in src_indexes to store some kind
// of marker.
int num_src = 0;
// This loop always does at least N statements, but they should be
// relatively fast ones since the computation per n value is minimal and
// there is little I/O. We are figuring out the subset of our block of N
// elements, which this particular thread value is responsible for
// (because they have n == this_n), and storing them in `src_indexes` and
// `num_src`.
for (int i = 0; i < N; i += 4) {
uint32_t n_block_of_4 = *reinterpret_cast<uint32_t*>(n_buf + this_block_start + i);
#pragma unroll
for (int j = 0; j < 4; ++j) {
// CUDA is little endian
char n = (char)(n_block_of_4 >> (8*j));
if (n == this_n) {
// We require that N <= 16, so 4 bits is enough to store src_idx.
src_indexes = (src_indexes << 4) | (i + j);
++num_src;
}
// Note: if, for out-of-range threads, we had values not in [0..N-1] in
// n_buf they won't end up mattering even though they are read here,
// because they won't equal this_n. For values 0 <= n < N originating
// in out-of-range threads, the value won't matter because the
// corresponding value in output_grad_buf will be zero.
}
}
// While num_src could theoretically be as large as N, the hope is that no int s_begin = boundary_buf[0],
// thread in any given warp actually loops that many times. Once all t_begin = boundary_buf[1],
// threads in the warp are finished looping, we can continue. It is OK s_end = boundary_buf[2],
// for different warps to get out of sync here; we could be looping over a t_end = boundary_buf[3];
// number of images, and the hope is that different warps will reach the s_block_begin += s_begin;
// end of the outer loop at around the same time because their variations t_block_begin += t_begin;
// in speed will average out.
for (; num_src > 0; --num_src, (src_indexes >>= 4)) {
int src_thread = this_block_start | (src_indexes & 0xF);
scalar_t src_output_grad = output_grad_buf[src_thread],
src_input = input_buf[src_thread];
assert(n_buf[src_thread] == this_n);
n_buf[src_thread] = 0;
// Backprop for: output = input * params_buf[n] + y_vals[n].
// Here, n == this_n; this is how we selected these `src_idx` values.
this_param_grad += src_output_grad * src_input;
this_y_vals_grad += src_output_grad;
}
// TODO: remove the next lines // block_S and block_T are the actual sizes of this block, no greater than
assert(n_buf[threadIdx.x] == 0 || (unsigned char)n_buf[threadIdx.x] == 255); // (BLOCK_SIZE, BLOCK_SIZE) but possibly less than that if we are towards
output_grad_buf[threadIdx.x] = 0.0; // the end of the sequence.
} // The last element of the output matrix p we write is (s_end, t_end),
} // i.e. the one-past-the-end index is (s_end + 1, t_end + 1).
int block_S = min(BLOCK_SIZE, s_end + 1 - s_block_begin),
block_T = min(BLOCK_SIZE, t_end + 1 - t_block_begin);
__syncthreads(); // sync threads because we are about to re-use if (block_S <= 0 || block_T <= 0)
// output_grad_buf for reduction, and, later, input_buf. continue;
this_param_grad = strided_reduce_sum(N, output_grad_buf, this_param_grad); // Load px_buf and py_buf. At this point they just contain px and py
__syncthreads(); // for this block.
this_y_vals_grad = strided_reduce_sum(N, output_grad_buf, this_y_vals_grad); for (int i = threadDim.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) {
int s_in_block = i / BLOCK_SIZE,
t_in_block = i % BLOCK_SIZE,
s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin;
// We let ps and py default to -infinity if they are out of range, which will
// cause xderiv and yderiv for out-of-range values to be zero, and cause
// correct behavior in edge cases (for the top and right blocks).
// The issue is that p and p_grad are of larger size than px and py.
scalar_t this_px = -INFINITY;
if (s < s_end && t <= t_end)
this_px = px[b][s - 1][t];
px_buf[s_in_block][t_in_block] = this_px;
scalar_t this_py = -INFINITY;
if (s <= s_end && t < t_end)
this_py = py[b][s][t - 1];
py_buf[s_in_block][t_in_block] = this_py;
}
__syncthreads(); // sync threads because we are about to re-use
// output_grad_buf as y_vals_grad_buf.
// Re-use some buffers.. // load p. This time we loop over the exact indexes we need. Above
scalar_t *params_grad_buf = input_buf + 1, // [N] ... but element [-1] will have deriv of scale. // we looped to BLOCK_SIZE * BLOCK_SIZE rather than block_S and block_T
*y_vals_grad_buf = output_grad_buf; // [N] // because having power-of-2 arrangement of threads may be helpful
// for aligned reads, but here the loop is up to (BLOCK_SIZE + 1) * (BLOCK_SIZE + 1)
// which is not a power of 2, so that is not a concern here.
for (int i = threadDim.x; i < (BLOCK_SIZE + 1) * (BLOCK_SIZE + 1); i += blockDim.x) {
int s_in_block = i / (BLOCK_SIZE + 1), // 0 <= s_in_block <= block_S
t_in_block = i % (BLOCK_SIZE + 1), // 0 <= t_in_block <= block_T
s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin;
// Setting 0.0 for out-of-bounds elements, together with setting
// -INFINITY for out-of-bounds elements of px_buf and py_buf, will
// ensure that we do the right thing in top and right edge cases,
// i.e. that no derivatives will be propagated from out-of-bounds points.
p_buf[s_in_block][t_in_block] = (s <= s_end && t <= t_end ?
p[b][s][t] : 0.0);
}
if (threadIdx.x < N) { // Set xderiv and yderiv; see (eq. 4) and (eq. 5).
params_grad_buf[threadIdx.x] = this_param_grad; for (int i = threadDim.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) {
y_vals_grad_buf[threadIdx.x] = this_y_vals_grad; // We can apply this formula to the entire block even if we are processing
} // a partial block; elements outside the partial block will not be used so
__syncthreads(); // other threads are about to read params_grad_buf and // their values don't matter, and elements just out
// y_vals_grad_buf. int t = i % BLOCK_SIZE, s = i / BLOCK_SIZE;
// Mathematically the following is doing:
// xderiv[b][s][t] := exp(p[b][s][t] + px[b][s][t] - p[b][s + 1][t])
// (with an offset on the s and t indexes)
px_buf[s][t] = exp(px_buf[s][t] + px_buf[s][t] - p_buf[s + 1][t]);
// Mathematically the following is doing:
// yderiv[b][s][t] := exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1])
// (with an offset on the s and t indexes)
py_buf[s][t] = exp(px_buf[s][t] + py_buf[s][t] - p_buf[s][t + 1]);
}
// This next block does backprop relating to `y_vals`. Comparing with the CPU // Load p_grad for the top and right elements in p_buf: i.e. for elements
// version (call this the "reference code") is the best way to understand this // p_buf[s][t] where s == block_S (exclusive-or) t == block_T. We don't
// (this code is just a modification of that). The main difference is we // need to load the top-right corner [block_S][block_T]; that location will
// modify the indexes into params and params_grad by -1, so the index // never be accessed.
// corresponds to the 'n' value; and element -1 of params_grad_buf will have // These are the p_grad values computed by previous instances of this kernel
// the deriv of the log scale. // If this is one of the top or right blocks, some or all of the p_grad
// values we'd be reading here will be out of range, and we use zeros.
if (threadIdx.x < block_S) {
int s_in_block = threadIdx.x,
t_in_block = block_T,
s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin;
p_buf[s_in_block][t_in_block] = (
s <= s_end && t <= t_end ? p_grad[s][t] : 0.0);
} else if (static_cast<unsigned int>(threadIdx.x - 64) <
static_cast<unsigned int>(block_T)) {
int s_in_block = block_S,
t_in_block = threadIdx.x - 64,
s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin;
p_buf[s_in_block][t_in_block] = (
s <= s_end && t <= t_end ? p_grad[s][t] : 0.0);
}
scalar_t l_grad; // The number of inner iterations, i.e. iterations inside this
if (threadIdx.x == 0) { // kernel, is this_num_inner_iters. The highest iteration,
// Now do the backprop for the loop above where we set y_vals_a. This could // corresponding to the highest-indexed value of p_buf that
// be further optimized to replace the loop with a raking, but I doubt this // we need to set,
// will have a huge effect on the runtime since K will be fairly small, // corresponds to p_buf[block_S - 1][block_T - 1],
// e.g. 4. // and the iteration number is the sum of these indexes, i.e.
scalar_t scale = params_buf[-2], // (block_S - 1) + (block_T - 1).
scale_grad = 0.0,
sum_positive_grad = 0.0; bool is_final_block = (s_block_begin + block_S == s_end + 1 &&
for (int i = K - 1; i >= 0; i--) { t_block_begin + block_T == t_end + 1);
// Backprop for: sum_positive += pos_scaled_param;
scalar_t pos_scaled_param_grad = sum_positive_grad; int first_iter = block_S + block_T - 2;
// Backprop for: y_vals[K + i] = sum_positive - pos_scaled_param * i; if (is_final_block) {
scalar_t y_grad_pos = y_vals_grad_buf[K + i]; // The following statement, mathematically, corresponds to:
pos_scaled_param_grad -= i * y_grad_pos; // p_grad[b][s_end][t_end] = ans_grad[b] Normally this element of p_buf
sum_positive_grad += y_grad_pos; // would be set by the first iteration of the loop below, so if it's set
// Backprop for: pos_scaled_param = params_buf[K + i] * scale, // this way we have to decrement first_iter to prevent it being
params_grad_buf[K + i] += pos_scaled_param_grad * scale; // overwritten.
scale_grad += pos_scaled_param_grad * params_buf[K + i]; p_buf[block_S - 1][block_T - 1] = ans_grad[b];
--first_iter;
} }
// Backprop for: scale = exp(l), where l = params[c][0].
l_grad = scale * scale_grad; for (int i = first_iter; i >= 0; --i) {
} else if (threadIdx.x == 64) { int s = i,
// Now do the backprop for the loop above where we set y_vals. t = i - threadIdx.x;
// Make this one threadIdx.x == 0 so it's possibly quicker to test if (t >= 0) {
// // The following statement is really operating on the gradients;
scalar_t scale = params_buf[-2], // it corresponds to (eq. 6) defined above, i.e.:
scale_grad = 0.0, // p_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] +
sum_negative_grad = 0.0; // p_grad[b][s][t + 1] * yderiv[b][s][t]
for (int i = K - 1; i >= 0; i--) { p_buf[s][t] = (p_buf[s + 1][t] * px_buf[s][t] +
// Backprop for: y_vals[K - i - 1] = sum_negative + neg_scaled_param * (i + 1): p_buf[s][t + 1] * py_buf[s][t]);
scalar_t y_grad_neg = y_vals_grad_buf[K - i - 1]; }
sum_negative_grad += y_grad_neg;
scalar_t neg_scaled_param_grad = y_grad_neg * (i + 1);
// Backprop for: sum_negative -= neg_scaled_param;
neg_scaled_param_grad -= sum_negative_grad;
// Backprop for: neg_scaled_param = params_buf[K - i - 1] * scale;
params_grad_buf[K - i - 1] += neg_scaled_param_grad * scale;
scale_grad += neg_scaled_param_grad * params_buf[K - i - 1];
} }
params_grad_buf[-1] = scale * scale_grad;
}
__syncthreads();
if (threadIdx.x == 0) { // Write out p_grad, px_grad and py_grad.
params_grad_buf[-1] += l_grad; // contribution to l grad from the "negative" branch for (int i = threadDim.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) {
} int t_in_block = i % BLOCK_SIZE,
__syncthreads(); s_in_block = i / BLOCK_SIZE,
if (threadIdx.x <= N) { s = s_in_block + s_block_begin,
params_grad[blockIdx.y][c][threadIdx.x] = params_grad_buf[threadIdx.x - 1]; t = t_in_block + t_block_begin;
if (t <= t_end && s <= s_end) {
p_grad[b][s][t] = p_buf[s_in_block][t_in_block];
if (s < s_end) { // write px_grad, which is of shape [B][S][T + 1]
// From (eq. 7):
// px_grad[b][s][t] = p_grad[b][s + 1][t] * yderiv[b][s][t]
px_grad[b][s][t] = (p_buf[s_in_block + 1][t_in_block] *
px_buf[s_in_block][t_in_block]);
}
if (t < t_end) { // write py_grad, which is of shape [B][S + 1][T]
// from (eq. 8):
// py_grad[b][s][t] = p_grad[b][s][t + 1] * yderiv[b][s][t]
py_grad[b][s][t] = (p_buf[s_in_block][t_in_block + 1] *
py_buf[s_in_block][t_in_block]);
}
}
}
if (threadIdx.x == 0 && s_block_begin == s_begin &&
t_block_end == t_end)
ans_grad[b] = p_buf[0][0];
} }
} }
// forward of mutual_information. See """... """ comment of `mutual_information` in // forward of mutual_information. See """... """ comment of `mutual_information` in
// mutual_information.py for documentation of the behavior of this function. // mutual_information.py for documentation of the behavior of this function.
torch::Tensor mutual_information_cuda(torch::Tensor px, torch::Tensor mutual_information_cuda(torch::Tensor px,
...@@ -861,18 +746,19 @@ torch::Tensor mutual_information_cuda(torch::Tensor px, ...@@ -861,18 +746,19 @@ torch::Tensor mutual_information_cuda(torch::Tensor px,
torch::Tensor ans = torch::empty({B}, opts); torch::Tensor ans = torch::empty({B}, opts);
int num_threads = 128, int num_threads = 128,
num_blocks = 128; num_blocks = 128,
BLOCK_SIZE = 32;
const int num_s_blocks = S / BLOCK_SIZE + 1, const int num_s_blocks = S / BLOCK_SIZE + 1,
num_t_blocks = T / BLOCK_SIZE + 1, num_t_blocks = T / BLOCK_SIZE + 1,
num_iters = std::max<int>(num_s_blocks, num_t_blocks); num_iters = num_s_blocks + num_t_blocks - 1;
bool has_boundary = (bool)optional_boundary; bool has_boundary = (bool)optional_boundary;
if (!has_boundary) if (!has_boundary)
optional_boundary = torch::empty({0, 0}, long_opts); optional_boundary = torch::empty({0, 0}, long_opts);
for (int iter = 0; iter < num_iters; iter++) { for (int iter = 0; iter < num_iters; ++iter) {
mutual_information_kernel<scalar_t, 32><<<num_blocks, num_threads>>>( mutual_information_kernel<scalar_t, BLOCK_SIZE><<<num_blocks, num_threads>>>(
px.packed_accessor32<scalar_t, 3>(), px.packed_accessor32<scalar_t, 3>(),
py.packed_accessor32<scalar_t, 3>(), py.packed_accessor32<scalar_t, 3>(),
p.packed_accessor32<scalar_t, 3>(), p.packed_accessor32<scalar_t, 3>(),
...@@ -880,141 +766,64 @@ torch::Tensor mutual_information_cuda(torch::Tensor px, ...@@ -880,141 +766,64 @@ torch::Tensor mutual_information_cuda(torch::Tensor px,
ans.packed_accessor32<scalar_t, 1>(), ans.packed_accessor32<scalar_t, 1>(),
iter); iter);
} }
return ans;
int grid_dim_y = 1;
// If the number of channels is quite small (<128) we can launch more thread
// groups, splitting on the batch index.
while (C * grid_dim_y < 128)
grid_dim_y *= 2;
// B_reduced is the max number of thread-groups per channel that would have
// any work to do. If grid_dim_y is more than this, we reduce it to avoid
// launching kernels with nothing to do.
int B_reduced = (B + images_per_thread_block - 1) / images_per_thread_block;
if (grid_dim_y > B_reduced)
grid_dim_y = B_reduced;
int shared_mem_numel = 2 * N + 3;
if (false)
std::cout << "C,B,T,N = " << C << "," << B << "," << T << "," << N
<< ", images_per_thread_block = " << images_per_thread_block
<< ", grid_dim_y = " << grid_dim_y
<< "\n";
TORCH_CHECK(THREADS_PER_BLOCK / images_per_thread_block >= T ||
images_per_thread_block == 1,
"Code error");
TORCH_CHECK(N + 1 <= THREADS_PER_BLOCK,
"Values of N this large are not supported.");
dim3 gridDim(C, grid_dim_y, 1);
// blockDim is scalar, just THREADS_PER_BLOCK.
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "mutual_information_kernel", ([&] {
mutual_information_kernel<scalar_t><<<gridDim, THREADS_PER_BLOCK, sizeof(scalar_t) * shared_mem_numel, at::cuda::getCurrentCUDAStream()>>>(
input.packed_accessor32<scalar_t, 3>(),
params.packed_accessor32<scalar_t, 2>(),
output.packed_accessor32<scalar_t, 3>(),
images_per_thread_block);
}));
return output;
} }
std::vector<torch::Tensor> mutual_information_backward_cuda(torch::Tensor input, // backward of mutual_information; returns (grad_px, grad_py)
torch::Tensor params, torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
torch::Tensor output_grad) { torch::Tensor py,
TORCH_CHECK(input.dim() == 3, "input must be 3-dimensional"); std::optional<torch::Tensor> optional_boundary,
TORCH_CHECK(params.dim() == 2, "params must be 2-dimensional."); torch::Tensor p,
TORCH_CHECK(params.size(1) >= 3 && torch::Tensor ans_grad) {
((params.size(1) - 1) & (params.size(1) - 2)) == 0, TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional");
"params.size(1) has invalid value, must be a power of 2 plus 1."); TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional.");
TORCH_CHECK(params.size(0) == input.size(1), TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional.");
"params vs input channels mismatch"); TORCH_CHECK(ans_grad.dim() == 1, "ans_grad must be 1-dimensional.");
TORCH_CHECK(output_grad.dim() == 3 && output_grad.size(0) == input.size(0) &&
output_grad.size(1) == input.size(1) &&
output_grad.size(2) == input.size(2),
"output_grad and input have mismatched dim.");
TORCH_CHECK(input.device().is_cuda(), "Input must be a CUDA tensor");
TORCH_CHECK(output_grad.device().is_cuda(), "output_grad must be a CUDA tensor");
TORCH_CHECK(params.device().is_cuda(), "Params must be a CUDA tensor");
const int B = input.size(0),
C = input.size(1),
T = input.size(2),
N = params.size(1) - 1;
TORCH_CHECK(N >= 4, "This backward code requires N >= 4");
TORCH_CHECK(N <= 16, "This backward code currently requires N <= 16");
TORCH_CHECK((N & (N-1)) == 0, "N must be a power of 2")
auto scalar_t = input.scalar_type();
auto opts = torch::TensorOptions().dtype(scalar_t).device(input.device());
torch::Tensor input_grad = torch::empty({B, C, T}, opts);
if (C * B * T == 0) {
return std::vector<torch::Tensor>({input_grad,
torch::empty({C, N + 1})});
}
int images_per_thread_block = 1;
while (images_per_thread_block * 2 * T <= THREADS_PER_BLOCK &&
images_per_thread_block * 2 * N <= THREADS_PER_BLOCK)
images_per_thread_block *= 2;
int grid_dim_y = 1;
// If the number of channels is quite small (<128) we can launch more thread
// groups, splitting on the batch index.
while (C * grid_dim_y < 128)
grid_dim_y *= 2;
// B_reduced is the max number of thread-groups per channel that would have
// any work to do. If grid_dim_y is more than this, we reduce it to avoid
// launching kernels with nothing to do.
int B_reduced = (B + images_per_thread_block - 1) / images_per_thread_block;
if (grid_dim_y > B_reduced)
grid_dim_y = B_reduced;
int shared_mem_numel = 2 * N + 3;
TORCH_CHECK(px.device().is_cuda() && py.device().is_cuda() &&
p.device().is_cuda() && ans_grad.device().is_cuda() &&
"inputs must be CUDA tensors");
auto scalar_t = px.scalar_type();
auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device());
if (false) const int B = px.size(0),
std::cout << "C,B,T,N = " << C << "," << B << "," << T << "," << N S = px.size(1),
<< ", images_per_thread_block = " << images_per_thread_block T = px.size(2) - 1;
<< ", grid_dim_y = " << grid_dim_y
<< "\n";
TORCH_CHECK(THREADS_PER_BLOCK / images_per_thread_block >= T || TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T);
images_per_thread_block == 1, TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
"Code error");
TORCH_CHECK(THREADS_PER_BLOCK / images_per_thread_block >= N); torch::Tensor p_grad = torch::empty({B, S + 1, T + 1}, opts),
px_grad = torch::empty({B, S, T + 1}, opts),
py_grad = torch::empty({B, S + 1, T}, opts),
torch::Tensor params_grad = torch::zeros({grid_dim_y, C, N + 1}, opts); const int num_threads = 128,
num_blocks = 128,
BLOCK_SIZE = 32;
dim3 gridDim(C, grid_dim_y, 1); const int num_s_blocks = S / BLOCK_SIZE + 1,
num_t_blocks = T / BLOCK_SIZE + 1,
num_iters = num_s_blocks + num_t_blocks - 1;
// blockDim is scalar, just THREADS_PER_BLOCK. bool has_boundary = (bool)optional_boundary;
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "mutual_information_backward_kernel", ([&] { if (!has_boundary)
mutual_information_backward_kernel<scalar_t><<<gridDim, THREADS_PER_BLOCK, sizeof(scalar_t) * shared_mem_numel, at::cuda::getCurrentCUDAStream()>>>( optional_boundary = torch::empty({0, 0}, long_opts);
input.packed_accessor32<scalar_t, 3>(),
params.packed_accessor32<scalar_t, 2>(),
output_grad.packed_accessor32<scalar_t, 3>(),
input_grad.packed_accessor32<scalar_t, 3>(),
params_grad.packed_accessor32<scalar_t, 3>(),
images_per_thread_block);
}));
params_grad = at::sum(params_grad, {0}); for (int iter = num_iters - 1; iter >= 0; --iter) {
return std::vector<torch::Tensor>({input_grad, params_grad}); mutual_information_backward_kernel<scalar_t, BLOCK_SIZE><<<num_blocks, num_threads>>>(
px.packed_accessor32<scalar_t, 3>(),
py.packed_accessor32<scalar_t, 3>(),
p.packed_accessor32<scalar_t, 3>(),
ans_grad.packed_accessor32<scalar_t, 1>,
p_grad.packed_accessor32<scalar_t, 3>(),
px_grad.packed_accessor32<scalar_t, 3>(),
py_grad.packed_accessor32<scalar_t, 3>(),
optional_boundary.value().packed_accessor32<int64_t, 2>(),
iter);
}
return std::vector<torch::Tensor>({px_grad, py_grad});
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment