Commit 3c1ec347 authored by Daniel Povey's avatar Daniel Povey
Browse files

Get it to a stage where it looks like it might compile

parent 8ed6deff
...@@ -44,66 +44,72 @@ except ImportError: ...@@ -44,66 +44,72 @@ except ImportError:
def _mutual_information_forward_dispatcher(px: torch.Tensor, py: torch.Tensor, def _mutual_information_forward_dispatcher(px: torch.Tensor, py: torch.Tensor,
boundaries: torch.Tensor, q: torch.Tensor) -> torch.Tensor: boundaries: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
if input.is_cuda: if input.is_cuda:
if torch_mutual_information_cuda is None: if torch_mutual_information_cuda is None:
raise EnvironmentError(f'Failed to load native CUDA module') raise EnvironmentError(f'Failed to load native CUDA module')
return torch_mutual_information_cuda.mutual_information_cuda( return torch_mutual_information_cuda.mutual_information_cuda(
px, py, boundaries, q) px, py, boundaries, p)
else: else:
return torch_mutual_information_cpu.mutual_information_cpu( return torch_mutual_information_cpu.mutual_information_cpu(
px, py, boundaries, q) px, py, boundaries, p)
def _mutual_information_backward_dispatcher(px: torch.Tensor, py: torch.Tensor, def _mutual_information_backward_dispatcher(px: torch.Tensor, py: torch.Tensor,
boundaries: torch.Tensor, q: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: boundaries: torch.Tensor, p: torch.Tensor,
ans_grad: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if px.is_cuda: if px.is_cuda:
if torch_mutual_information_cuda is None: if torch_mutual_information_cuda is None:
raise EnvironmentError(f'Failed to load native CUDA module') raise EnvironmentError(f'Failed to load native CUDA module')
return tuple(torch_mutual_information_cuda.mutual_information_backward_cuda( overwrite_ans_grad = True
px, py, boundaries, q)) if overwrite_ans_grad:
ans_grad_copy = ans_grad.clone()
ans = tuple(torch_mutual_information_cuda.mutual_information_backward_cuda(
px, py, boundaries, p, ans_grad_copy, overwrite_ans_grad))
if overwrite_ans_grad:
if not torch.allclose(ans_grad, ans_grad_copy, rtol=1.0e-02):
print(f"Warning: possible excsssive roundoff in mutual information backward "
"recursion: {ans_grad} vs. {ans_grad_copy}");
return ans
else: else:
return tuple(torch_mutual_information_cpu.mutual_information_backward_cpu( return tuple(torch_mutual_information_cpu.mutual_information_backward_cpu(
px, py, boundaries, q)) px, py, boundaries, p, ans_grad))
class MutualInformationRecursionFunction(torch.autograd.Function): class MutualInformationRecursionFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, px: torch.Tensor, py: torch.Tensor, boundaries: torch.Tensor) -> torch.Tensor: def forward(ctx, px: torch.Tensor, py: torch.Tensor, boundaries: Optional[torch.Tensor]) -> torch.Tensor:
(B, S, T) = px.shape (B, S, T1) = px.shape
T = T1 - 1;
assert py.shape == (B, S + 1, T)
if boundaries is not None:
assert boundaries.shape == (B, 4)
# p is a tensor of shape (B, S + 1, T + 1) were p[s][t] is related to # p is a tensor of shape (B, S + 1, T + 1) were p[s][t] is the
# the mutual information of the pair of subsequences of x and y that are of # the mutual information of the pair of subsequences of x and y that are of
# length s and t respectively. p[0][0] will be 0.0 and p[S][T] is # length s and t respectively. p[0][0] will be 0.0 and p[S][T] is
# the mutual information of the entire pair of sequences, i.e. of lengths # the mutual information of the entire pair of sequences, i.e. of lengths
# S and T respectively. # S and T respectively.
# It is computed as follows (in C++ and CUDA):
# p[b,0,0] = 0.0
# q is a rearrangement of a tensor p which is of shape (B,S,T), # p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
# using p[b,s,t] == q[b,s+t,t]. The reason for working with this # p[b,s,t-1] + py[b,s,t-1])
# representation is that each row of q depends only on the previous row, # if s > 0 or t > 0,
# so we can access the rows sequenctially and this leads to # treating values with any -1 index as -infinity.
# better memory access patterns. We are assuming that most likely # .. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
# T < S, which means that q should not require much more memory than p.
#
# Actually we access q beginning from 0 indexes even if `boundaries`
# has t_begin > 0 or s_begin > 0, i.e. we really access q as
# q[b, s-s_begin + t-t_begin, t-t_begin];
# note, rows of `boundaries` are [s_begin, t_begin, s_end, t_end].
p = torch.empty(B, S + 1, T + 1, device=px.device, dtype=px.dtype) p = torch.empty(B, S + 1, T + 1, device=px.device, dtype=px.dtype)
ans = _mutual_information_forward_dispatcher(px, py, boundaries, p) ans = _mutual_information_forward_dispatcher(px, py, boundaries, p)
if px.requires_grad or py.requires_grad: if px.requires_grad or py.requires_grad:
ctx.save_for_backward(px, py, boundaries, q) ctx.save_for_backward(px, py, boundaries, p)
@staticmethod @staticmethod
def backward(ctx, ans_grad: Tensor) -> Tuple[torch.Tensor, torch.Tensor, None]: def backward(ctx, ans_grad: Tensor) -> Tuple[torch.Tensor, torch.Tensor, None]:
(px, py, boundaries, q) = ctx.saved_tensors (px, py, boundaries, p) = ctx.saved_tensors
(px_grad, py_grad) = _mutual_information_backward_dispatcher(px, py, boundaries, q) (px_grad, py_grad) = _mutual_information_backward_dispatcher(px, py, boundaries, p, ans_grad)
return (px_grad, py_grad, None) return (px_grad, py_grad, None)
......
...@@ -151,11 +151,16 @@ std::vector<torch::Tensor> mutual_information_backward_cpu( ...@@ -151,11 +151,16 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T); TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T);
TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1); TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
torch::Tensor p_grad = torch::zeros({B, S + 1, T + 1}, opts); bool has_boundary = (bool)optional_boundary;
torch::Tensor p_grad = torch::zeros({B, S + 1, T + 1}, opts),
px_grad = (has_boundary ? torch::zeros({B, S, T + 1}, opts) :
torch::empty({B, S, T + 1}, opts)),
py_grad = (has_boundary ? torch::zeros({B, S + 1, T}, opts) :
torch::empty({B, S + 1, T}, opts));
auto long_opts = torch::TensorOptions().dtype(torch::kInt64).device(px.device()); auto long_opts = torch::TensorOptions().dtype(torch::kInt64).device(px.device());
bool has_boundary = (bool)optional_boundary;
if (!has_boundary) if (!has_boundary)
optional_boundary = torch::empty({0, 0}, long_opts); optional_boundary = torch::empty({0, 0}, long_opts);
...@@ -166,7 +171,9 @@ std::vector<torch::Tensor> mutual_information_backward_cpu( ...@@ -166,7 +171,9 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
auto px_a = px.packed_accessor32<scalar_t, 3>(), auto px_a = px.packed_accessor32<scalar_t, 3>(),
py_a = py.packed_accessor32<scalar_t, 3>(), py_a = py.packed_accessor32<scalar_t, 3>(),
p_a = p.packed_accessor32<scalar_t, 3>(), p_a = p.packed_accessor32<scalar_t, 3>(),
p_grad_a = p.packed_accessor32<scalar_t, 3>(); p_grad_a = p_grad.packed_accessor32<scalar_t, 3>(),
px_grad_a = px_grad.packed_accessor32<scalar_t, 3>(),
py_grad_a = py_grad.packed_accessor32<scalar_t, 3>();
auto ans_grad_a = ans_grad.packed_accessor32<scalar_t, 1>(); auto ans_grad_a = ans_grad.packed_accessor32<scalar_t, 1>();
...@@ -196,19 +203,17 @@ std::vector<torch::Tensor> mutual_information_backward_cpu( ...@@ -196,19 +203,17 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
// .. which obtains p_a[b][s][t - 1] from a register. // .. which obtains p_a[b][s][t - 1] from a register.
scalar_t term1 = p_a[b][s - 1][t] + px_a[b][s - 1][t], scalar_t term1 = p_a[b][s - 1][t] + px_a[b][s - 1][t],
// term2 = p_a[b][s][t - 1] + py_a[b][s][t - 1], <-- not
// actually needed..
total = p_a[b][s][t], total = p_a[b][s][t],
term1_deriv = exp(term1 - total), term1_deriv = exp(term1 - total),
term2_deriv = 1.0 - term1_deriv, term2_deriv = 1.0 - term1_deriv,
grad = p_grad_a[b][s][t], grad = p_grad_a[b][s][t],
term1_grad = term1_deriv * grad, term1_grad = term1_deriv * grad,
term2_grad = term2_deriv * grad; term2_grad = term2_deriv * grad;
// We can assign to px_grad_a here rather than add, because we
// know it's currently zero.
TORCH_CHECK(px_grad_a[b][s - 1][t] == 0);
px_grad_a[b][s - 1][t] = term1_grad; px_grad_a[b][s - 1][t] = term1_grad;
TORCH_CHECK(p_grad_a[b][s - 1][t] == 0.0); // likewise.. p_grad_a[b][s - 1][t] = term1_grad;
p_grad_a[b][s - 1][t] = term1_grad py_grad_a[b][s][t - 1] = term2_grad;
py_grad_a[b][s][t - 1] += term2_grad;
p_grad_a[b][s][t - 1] += term2_grad; p_grad_a[b][s][t - 1] += term2_grad;
} }
} }
...@@ -239,111 +244,9 @@ std::vector<torch::Tensor> mutual_information_backward_cpu( ...@@ -239,111 +244,9 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
} }
} }
})); }));
return ans;
}
TORCH_CHECK(input.dim() == 3, "input must be 3-dimensional"); std::cout << "p_grad = " << p_grad;
TORCH_CHECK(params.dim() == 2, "params must be 2-dimensional."); return std::vector<torch::Tensor>({px_grad, py_grad});
TORCH_CHECK(params.size(1) >= 3 &&
((params.size(1) - 1) & (params.size(1) - 2)) == 0,
"params.size(1) has invalid value, must be a power of 2 plus 1.");
TORCH_CHECK(params.size(0) == input.size(1),
"params vs input channels mismatch");
TORCH_CHECK(input.sizes() == output_grad.sizes(),
"Output-grad vs. input sizes mismatch.");
TORCH_CHECK(input.device().is_cpu(), "Input must be a CPU tensor");
TORCH_CHECK(params.device().is_cpu(), "Params must be a CPU tensor");
TORCH_CHECK(output_grad.device().is_cpu(), "Output-grad must be a CPU tensor");
const int B = input.size(0),
C = input.size(1),
T = input.size(2),
N = params.size(1) - 1,
K = N / 2;
auto scalar_t = input.scalar_type();
auto opts = torch::TensorOptions().dtype(scalar_t).device(input.device());
torch::Tensor y_vals = torch::empty({C, N}, opts),
y_vals_grad = torch::zeros({C, N}, opts),
params_grad = torch::zeros({C, N + 1}, opts),
input_grad = torch::zeros({B, C, T}, opts);
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "mutual_information_backward_cpu_loop", ([&] {
auto params_a = params.accessor<scalar_t, 2>(),
params_grad_a = params_grad.accessor<scalar_t, 2>(),
y_vals_a = y_vals.accessor<scalar_t, 2>(),
y_vals_grad_a = y_vals_grad.accessor<scalar_t, 2>();
for (int c = 0; c < C; c++) {
scalar_t sum_negative = 0.0,
sum_positive = 0.0,
scale = exp(params_a[c][0]);
for (int i = 0; i < K; i++) {
scalar_t pos_scaled_param = params_a[c][1 + K + i] * scale,
neg_scaled_param = params_a[c][K - i] * scale;
y_vals_a[c][K + i] = sum_positive - pos_scaled_param * i;
sum_positive += pos_scaled_param;
sum_negative -= neg_scaled_param;
y_vals_a[c][K - i - 1] = sum_negative + neg_scaled_param * (i + 1);
}
}
auto input_a = input.accessor<scalar_t, 3>(),
output_grad_a = output_grad.accessor<scalar_t, 3>(),
input_grad_a = input_grad.accessor<scalar_t, 3>();
for (int b = 0; b < B; b++) {
for (int c = 0; c < C; c++) {
scalar_t inv_scale = exp(-params_a[c][0]);
for (int t = 0; t < T; t++) {
scalar_t input = input_a[b][c][t],
x = input * inv_scale + K,
output_grad = output_grad_a[b][c][t];
if (x < 0) x = 0;
else if (x >= N) x = N - 1;
// C++ rounds toward zero.
int n = (int) x;
// OK, at this point, 0 <= n < 2*K.
// backprop for:
// output_a[b][c][t] = input * params_a[c][n + 1] + y_vals_a[c][n];
params_grad_a[c][n + 1] += output_grad * input;
y_vals_grad_a[c][n] += output_grad;
input_grad_a[b][c][t] = output_grad * params_a[c][n + 1];
}
}
}
// Now do the backprop for the loop above where we set y_vals_a.
for (int c = 0; c < C; c++) {
scalar_t scale = exp(params_a[c][0]),
scale_grad = 0.0,
sum_negative_grad = 0.0,
sum_positive_grad = 0.0;
for (int i = K - 1; i >= 0; i--) {
// Backprop for: y_vals_a[c][K - i - 1] = sum_negative + neg_scaled_param * (i + 1):
scalar_t y_grad_neg = y_vals_grad_a[c][K - i - 1];
sum_negative_grad += y_grad_neg;
scalar_t neg_scaled_param_grad = y_grad_neg * (i + 1);
// Backprop for: sum_negative -= neg_scaled_param;
neg_scaled_param_grad -= sum_negative_grad;
// Backprop for: sum_positive += pos_scaled_param;
scalar_t pos_scaled_param_grad = sum_positive_grad;
// Backprop for: y_vals_a[c][K + i] = sum_positive - pos_scaled_param * i;
scalar_t y_grad_pos = y_vals_grad_a[c][K + i];
pos_scaled_param_grad -= i * y_grad_pos;
sum_positive_grad += y_grad_pos;
// Backprop for: pos_scaled_param = params_a[c][1 + K + i] * scale,
// and: neg_scaled_param = params_a[c][K - i] * scale;
params_grad_a[c][1 + K + i] += pos_scaled_param_grad * scale;
params_grad_a[c][K - i] += neg_scaled_param_grad * scale;
scale_grad += (pos_scaled_param_grad * params_a[c][1 + K + i] +
neg_scaled_param_grad * params_a[c][K - i]);
}
// Backprop for: scale = exp(params_a[c][0]),
params_grad_a[c][0] += scale * scale_grad;
}
}));
return std::vector<torch::Tensor>({input_grad, params_grad});
} }
......
...@@ -73,7 +73,7 @@ __forceinline__ __device__ inline float LogAdd(float x, float y) { ...@@ -73,7 +73,7 @@ __forceinline__ __device__ inline float LogAdd(float x, float y) {
p[b,0,0] = 0.0 p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1]) p[b,s,t-1] + py[b,s,t-1]) (eq. 0)
if s > 0 or t > 0, if s > 0 or t > 0,
treating values with any -1 index as -infinity. treating values with any -1 index as -infinity.
.. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0. .. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
...@@ -122,32 +122,33 @@ void mutual_information_kernel( ...@@ -122,32 +122,33 @@ void mutual_information_kernel(
num_t_blocks = T / BLOCK_SIZE + 1; num_t_blocks = T / BLOCK_SIZE + 1;
// num_blocks_this_iter is an upper bound on the number of blocks of size // num_blocks_this_iter is an upper bound on the number of blocks of size
// (BLOCK_SIZE by BLOCK_SIZE) that might be active on this iteration. We go // (BLOCK_SIZE by BLOCK_SIZE) that might be active on this iteration (`iter`).
// from the bottom left of the image so that on iter == 0 we process only one // These iterations start from the bottom left of the image so that on iter ==
// block with block-index (0, 0) then on iter == 1 we process block-indexes // 0 we process only one block with block-index (0, 0) then on iter == 1 we
// (1, 0) and (0, 1); and then on iter==2 we process (2, 0), (1, 1) and (0, // process block-indexes (1, 0) and (0, 1); and then on iter==2 we process (2,
// 2); and so on. We also will never have more than `num_s_blocks` blocks // 0), (1, 1) and (0, 2); and so on. We also will never have more than
// (We'll never have more than num_t_blocks either, but the numbering we use // `num_s_blocks` blocks (We'll never have more than num_t_blocks either, but
// corresponds to s and not t, so if we hit the num_t_blocks limit, the // the numbering we use corresponds to s and not t, so when we hit the
// lowest-numbered blocks on s would just not be active and we'll 'continue' // num_t_blocks limit, the blocks with the lowest s indexes would just not be
// below). // active and we'll 'continue' in the loop below).
int num_blocks_this_iter = min(iter + 1, num_s_blocks); int num_blocks_this_iter = min(iter + 1, num_s_blocks);
// For the block with s_block_begin == 0 and t_block_begin == 0 (for // For the block with s_block_begin == 0 and t_block_begin == 0 (for
// easy illustration), px_buf[s][t] will contain exp(px[s - 1][t]); or 0 // easy illustration), px_buf[s][t] will contain exp(px[s - 1][t]); or 0
// for out-of-range indexes. // for out-of-range indexes into px.
// Likewise, py_buf[s][t] will contain exp(py[s][t - 1]). // Likewise, py_buf[s][t] will contain exp(py[s][t - 1]).
__shared__ scalar_t px_buf[BLOCK_SIZE][BLOCK_SIZE], __shared__ scalar_t px_buf[BLOCK_SIZE][BLOCK_SIZE],
py_buf[BLOCK_SIZE][BLOCK_SIZE]; py_buf[BLOCK_SIZE][BLOCK_SIZE];
// 1st row/col of p_buf correspond to the previous blocks, or to an edge case. // p_buf[s][t] == exp(p[s+s_block_begin-1][t+t_block_begin-1] - normalizer).
// So, again for this origin block, p_buf[s][t] corresponds to exp(p[s - 1][t // 1st row/col of p_buf correspond to the previously computed blocks (lower
// - 1] - normalizer); or 0 for out-of-range values. // `iter`), or to negative indexes into p. So, for the origin block,
// p_buf[s][t] corresponds to exp(p[s - 1][t - 1] - normalizer); or 0 for
// out-of-range values.
__shared__ scalar_t p_buf[BLOCK_SIZE + 1][BLOCK_SIZE + 1]; __shared__ scalar_t p_buf[BLOCK_SIZE + 1][BLOCK_SIZE + 1];
// boundary_buf will be used to store the b'th row of `boundary` if we have // boundary_buf will be used to store the b'th row of `boundary` if we have
// boundary information supplied. // boundary information supplied; or (0, 0, S, T) otherwise.
__shared__ int64_t boundary_buf[4]; __shared__ int64_t boundary_buf[4];
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
...@@ -157,69 +158,70 @@ void mutual_information_kernel( ...@@ -157,69 +158,70 @@ void mutual_information_kernel(
boundary_buf[3] = T; boundary_buf[3] = T;
} }
// batch_block_iter iterates over both batch elements (index b), and block // batch_block_iter iterates over batch elements (index b) and block
// indexes in the range [0..num_blocks_this_iter-1] // indexes in the range [0..num_blocks_this_iter-1], combining both
// batch and block indexes.
for (int batch_block_iter = blockIdx.x; for (int batch_block_iter = blockIdx.x;
batch_block_iter < B * num_blocks_this_iter; batch_block_iter < B * num_blocks_this_iter;
batch_block_iter += gridDim.x) { batch_block_iter += gridDim.x) {
int b = batch_block_iter % B, int block = batch_block_iter / B,
block = batch_block_iter / B; b = batch_block_iter % B; // b is the index into the batch
int s_block_begin = block * BLOCK_S_SIZE,
t_block_begin = (iter - block) * BLOCK_T_SIZE;
// Note: `block` can be no greater than `iter` because num_blocks_this_iter
// <= iter + 1, so iter - block >= 0.
int s_block_begin = block * BLOCK_SIZE,
t_block_begin = (iter - block) * BLOCK_SIZE;
bool is_origin_block = (s_block_begin * t_block_begin == 0);
if (threadDim.x < 4 && boundary.size(0) != 0) if (boundary.size(0) != 0 && threadIdx.x < 4)
boundary_buf[threadDim.x] = boundary[b][threadDim.x]; boundary_buf[threadDim.x] = boundary[b][threadDim.x];
__syncthreads(); __syncthreads();
int s_begin = boundary_buf[0], int s_begin = boundary_buf[0],
t_begin = boundary_buf[1], t_begin = boundary_buf[1],
s_end = boundary_buf[2], s_end = boundary_buf[2],
t_end = boundary_buf[3]; t_end = boundary_buf[3];
s_block_begin += s_begin; s_block_begin += s_begin;
t_block_begin += t_begin; t_block_begin += t_begin;
// block_S and block_T are the actual sizes of this block, no greater than // block_S and block_T are the actual sizes of this block (the block of `p`
// (BLOCK_SIZE, BLOCK_SIZE) but possibly less than that if we are towards // that we will write), no greater than (BLOCK_SIZE, BLOCK_SIZE) but
// the end of the sequence. // possibly less than that if we are towards the end of the sequence. The
// The last element of the output matrix p we write is (s_end, t_end), // last element in the output matrix p that we need to write is (s_end,
// i.e. the one-past-the-end index is (s_end + 1, t_end + 1). // t_end), i.e. the one-past-the-end index is (s_end + 1, t_end + 1).
int block_S = min(BLOCK_SIZE, s_end + 1 - s_block_begin), int block_S = min(BLOCK_SIZE, s_end + 1 - s_block_begin),
block_T = min(BLOCK_SIZE, t_end + 1 - t_block_begin); block_T = min(BLOCK_SIZE, t_end + 1 - t_block_begin);
if (block_S <= 0 || block_T <= 0) if (block_S <= 0 || block_T <= 0)
continue; continue;
bool is_origin_block = (s_block_begin * t_block_begin == 0); // Load px_buf and py_buf. We exponentiate; the assumption is that they
// most likely won't overflow or underflow, but if they do overflow we'll
// Load px_buf and py_buf. We exponentiate; the assumption is that they most likely // detect it later; we'll also detect certain kinds of underflow.
// won't overflow or underflow, but if they do overflow we'll detect it later; we'll for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) {
// also detect certain kinds of underflow.
for (int i = threadDim.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) {
int s_in_block = i / BLOCK_SIZE, int s_in_block = i / BLOCK_SIZE,
t_in_block = i % BLOCK_SIZE, t_in_block = i % BLOCK_SIZE,
s = s_in_block + s_block_begin, s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin; t = t_in_block + t_block_begin;
// the comparisons with S and T below just make sure we don't access // comparing as unsigned int makes sure the index is nonnegative.
// out-of-memory regions; they do not guarantee we are in the range given
// by s_begin, s_end and so on. Note: comparing as unsigned int makes sure
// the index is nonnegative.
scalar_t this_px = 0.0; scalar_t this_px = 0.0;
if (static_cast<unsigned int>(s - 1) < static_cast<unsigned int>(S) && if (static_cast<unsigned int>(s - 1) < static_cast<unsigned int>(s_end) &&
t <= T) t <= t_end)
this_px = exp(px[b][s - 1][t]); this_px = exp(px[b][s - 1][t]);
px_buf[s_in_block][t_in_block] = this_px; px_buf[s_in_block][t_in_block] = this_px;
scalar_t this_py = 0.0; scalar_t this_py = 0.0;
if (static_cast<unsigned int>(t - 1) < static_cast<unsigned int>(T) && if (static_cast<unsigned int>(t - 1) < static_cast<unsigned int>(t_end) &&
s <= S) s <= s_end)
this_py = exp(py[b][s][t - 1]); this_py = exp(py[b][s][t - 1]);
py_buf[s_in_block][t_in_block] = this_py; py_buf[s_in_block][t_in_block] = this_py;
} }
// Load the 1st row and column of p_buf (except element[0][0] is not needed). // Load the 1st row and 1st column of p_buf (except element[0][0] is not
// Remember: p_buf[s][t] corresponds to exp(p[s + s_block_begin - 1][t + t_block_begin - 1] - normalizer. // needed). This is the context from previously computed blocks of the
// image. Remember: p_buf[s][t] will correspond to exp(p[s + s_block_begin -
// 1][t + t_block_begin - 1] - normalizer.
if (threadIdx.x < 64) { // 64 == warp size. First half of threads... if (threadIdx.x < 64) { // 64 == warp size. First half of threads...
if (threadIdx.x <= BLOCK_SIZE) { if (threadIdx.x <= BLOCK_SIZE) {
// s_in_p_buf are simply the indexes into p_buf // s_in_p_buf are simply the indexes into p_buf
...@@ -227,16 +229,14 @@ void mutual_information_kernel( ...@@ -227,16 +229,14 @@ void mutual_information_kernel(
t_in_p_buf = 0, t_in_p_buf = 0,
s = s_in_p_buf + s_block_begin - 1, s = s_in_p_buf + s_block_begin - 1,
t = t_in_p_buf + t_block_begin - 1; t = t_in_p_buf + t_block_begin - 1;
// The if-statement below just guards against out-of-range memory
// accesses, it does not guarantee that we really need these values.
scalar_t this_p = -INFINITY; scalar_t this_p = -INFINITY;
if (static_cast<unsigned int>(s) < static_cast<unsigned int>(S) && if (static_cast<unsigned int>(s) <= static_cast<unsigned int>(s_end) &&
static_cast<unsigned int>(t) < static_cast<unsigned int>(T)) static_cast<unsigned int>(t) <= static_cast<unsigned int>(t_end))
this_p = p[s + s_block_begin][s + t_block_begin]; this_p = p[b][s][t];
p_buf[threadIdx.x][0] = this_p; p_buf[threadIdx.x][0] = this_p;
} }
} else { // Another warp handles the other leg } else { // Another warp handles the other leg
if (threadIdx.x - 64 <= BLOCK_SIZE) { if (int(threadIdx.x) - 64 <= BLOCK_SIZE) {
int s_in_p_buf = 0, int s_in_p_buf = 0,
t_in_p_buf = threadIdx.x - 64, t_in_p_buf = threadIdx.x - 64,
s = s_in_p_buf + s_block_begin - 1, s = s_in_p_buf + s_block_begin - 1,
...@@ -244,19 +244,19 @@ void mutual_information_kernel( ...@@ -244,19 +244,19 @@ void mutual_information_kernel(
// The if-statement below just guards against out-of-range memory // The if-statement below just guards against out-of-range memory
// accesses, it does not guarantee that we really need these values. // accesses, it does not guarantee that we really need these values.
scalar_t this_p = -INFINITY; scalar_t this_p = -INFINITY;
if (static_cast<unsigned int>(s) < static_cast<unsigned int>(S) && if (static_cast<unsigned int>(s) <= static_cast<unsigned int>(s_end) &&
static_cast<unsigned int>(t) < static_cast<unsigned int>(T)) static_cast<unsigned int>(t) <= static_cast<unsigned int>(t_end))
this_p = p[s + s_block_begin][s + t_block_begin]; this_p = p[b][s][t];
p_buf[threadIdx.x][0] = this_p; p_buf[threadIdx.x][0] = this_p;
} }
} }
__syncthreads(); __syncthreads();
// We read p_buf in log-space; subtract 'normalizer', which mathematically // We read p_buf in log-space; we now subtract 'normalizer', which
// could be any finite number, to get in a reasonable range of probabilities, // mathematically could be any finite number, to get it in a range close to
// and then exponentiate. We'll do everything in non-log space, and later // zero, and then exponentiate. We'll do everything in non-log space, for
// take a log before we write out the data. // speed, and later take a log before we write out the data.
scalar_t normalizer = (is_origin_block ? 0.0 : scalar_t normalizer = (is_origin_block ? 0.0 :
max(px_buf[0][1], px_buf[1][0])); max(px_buf[0][1], px_buf[1][0]));
...@@ -265,50 +265,55 @@ void mutual_information_kernel( ...@@ -265,50 +265,55 @@ void mutual_information_kernel(
// and we'll overwrite with 1.0 if there is a panic situation due to // and we'll overwrite with 1.0 if there is a panic situation due to
// overflow. // overflow.
if (threadIdx.x <= BLOCK_SIZE) { if (threadIdx.x <= BLOCK_SIZE) {
if (threadIdx.x == 0) { // p_buf[0][0] is never used for its normal purpose; we set it to zero
// p_buf[0][0] is never used for its normal purpose; we set it to zero. // p_buf[0][0] = 0.0; <-- for search purposes.
// We'll later write an infinity there if something goes wrong, as a // We'll later write an infinity there if something goes wrong, as a
// 'panic' indicator. // 'panic' indicator.
p_buf[threadIdx.x][0] = (threadIdx.x == 0 ? 0.0 : p_buf[threadIdx.x][0] = (threadIdx.x == 0 ? 0.0 :
exp(p_buf[threadIdx.x][0] - normalizer)); exp(p_buf[threadIdx.x][0] - normalizer));
} } else if (int(threadIdx.x) - 64 < BLOCK_SIZE) {
} else if ((int)threadIdx.x - 64 < BLOCK_SIZE) { // this happens in a different warp so can be in parallel to the code above.
p_buf[0][threadIdx.x + 1] = exp(p_buf[0][threadIdx.x + 1] - normalizer); p_buf[0][threadIdx.x + 1] = exp(p_buf[0][threadIdx.x + 1] - normalizer);
} }
if (threadIdx.x == 0) { if (threadIdx.x == 0 && is_origin_block) {
// This if-statement is an optimization and modification of the loop below // This if-statement is an optimization and modification of the loop below
// for the value i == 0, i.e. inner-iteration == 0. The modification // for the value i == 0, i.e. inner-iteration == 0. The modification is
// is to use 0.0 if this is the "origin block" with s_block_begin == 0 and // to set p_buf to 1.0 = exp(0.0) if this is the "origin block",
// t_block_begin == 0. This corresponds to the probability of the pair of // i.e. s == s_begin, t == t_begin. This corresponds to the
// sequences of length (0, 0). // probability of the pair of sequences of length (0, 0).
p_buf[1][1] = (is_origin_block ? 0.0 : p_buf[1][1] = (is_origin_block ? 1.0 :
p_buf[0][1] * px_buf[0][0] + p_buf[0][1] * px_buf[0][0] +
p_buf[1][0] * py_buf[0][0]); p_buf[1][0] * py_buf[0][0]);
} }
scalar_t p_buf_s1_t; // This is for an optimization. scalar_t p_buf_s1_t; // This is for an optimization to avoid one
if (i < BLOCK_SIZE) { // shared-memory read/write in the loop below. it
// represents p_buf[s + 1][t]; the first time we
// access this, it will be for t == 0, except for
// thread 0 when we first need it for t == 1.
if (threadIdx.x < BLOCK_SIZE) {
int s = threadIdx.x; int s = threadIdx.x;
p_buf_s1_t = p_buf[s + 1][0]; p_buf_s1_t = p_buf[s + 1][threadIdx.x == 0 ? 1 : 0];
} }
for (int i = 1; i < block_S + block_T; i++) { for (int i = 1; i < block_S + block_T - 1; ++i) {
// i is the inner iteration, which corresponds to the (s + t) indexes of the // i is the inner iteration, which corresponds to the (s + t) indexes of
// elements within the block that we write. So i == 0 writes positions // the elements within the block that we write. So i == 0 writes
// (s, t) == (0, 0); i == 1 writes (0, 1) and (1, 0); i == 2 writes // positions (s, t) == (0, 0) (but we treated i == 0 as a special case
// (0, 2), (1, 1) and (2, 1); and so on. // above); i == 1 writes (0, 1) and (1, 0); i == 2 writes (0, 2), (1, 1)
// Note: not many threads participate in this part, only up to BLOCK_SIZE // and (2, 1); and so on. Note: not many threads participate in this
// at most. Unfortunately we couldn't figure out a very meaningful way // part, only up to BLOCK_SIZE at most. Unfortunately we couldn't figure
// for more threads to do work, that looked like it would really spead // out a very meaningful way for more threads to do work, that looked like
// things up. // it would really spead things up.
// So this kernel does (2 * BLOCK_SIZE) iterations, which may seem a lot, // So this kernel does (2 * BLOCK_SIZE) iterations, which may seem a lot,
// but we do at least do the I/O in an efficient way and keep the // but we do at least do the I/O in an efficient way and keep the
// inner loop simple and fast (e.g. no exp() or log()). // inner loop simple and fast (e.g. no exp() or log()).
int s = threadIdx.x, int s = threadIdx.x,
t = i - s; t = i - s;
if (t >= 0) {
if (static_cast<unsigned int>(t) < static_cast<unsigned int>(block_T)) {
// p_buf is indexed by s + 1 and t + 1 because it has an extra initial // p_buf is indexed by s + 1 and t + 1 because it has an extra initial
// row and column for context from previous blocks. Taking into account // row and column for context from previous blocks. Taking into account
// the way these buffers relate to the tensors p, px and py, and // the way these buffers relate to the tensors p, px and py, and
...@@ -320,7 +325,7 @@ void mutual_information_kernel( ...@@ -320,7 +325,7 @@ void mutual_information_kernel(
// //
// where you can see that apart from the offsets of tbb and sbb, this is // where you can see that apart from the offsets of tbb and sbb, this is
// the same as the recursion defined for p in // the same as the recursion defined for p in
// mutual_information.py:mutual_information_recursion(). // mutual_information.py:mutual_information_recursion(); and (eq. 0) above.
#if 1 #if 1
p_buf[s + 1][t + 1] = p_buf[s][t + 1] * px_buf[s][t] + p_buf[s + 1][t] * py_buf[s][t]; p_buf[s + 1][t + 1] = p_buf[s][t + 1] * px_buf[s][t] + p_buf[s + 1][t] * py_buf[s][t];
#else #else
...@@ -328,18 +333,30 @@ void mutual_information_kernel( ...@@ -328,18 +333,30 @@ void mutual_information_kernel(
// this #if/#else) where we keep p_buf[s + 1][t] in a register to avoid // this #if/#else) where we keep p_buf[s + 1][t] in a register to avoid
// the need for a load from shared memory. // the need for a load from shared memory.
p_buf_s1_t = p_buf[s][t + 1] * px_buf[s][t] + p_buf_s1_t * py_buf[s][t]; p_buf_s1_t = p_buf[s][t + 1] * px_buf[s][t] + p_buf_s1_t * py_buf[s][t];
// The next time this thread reads p_buf_s1_t, t will be one greater,
// so p_buf_s1_t will contain p_buf[s + 1][t]. The first time this
// thread uses p_buf_s1_t is when t == 0, except for thread 0 where
// the 1st item accessed is for s == 0, t == 1.
p_buf[s + 1][t + 1] = p_buf_s1_t; p_buf[s + 1][t + 1] = p_buf_s1_t;
#endif #endif
// We don't need to do __syncthreads() in this loop because all the
// threads that are active are in the same warp. (However, in future,
// if NVidia changes some things, we might need to sync here).
} }
__syncthreads(); __syncthreads();
} }
// Write out the data. // Write out the data to p; check that nothing has gone out of numerical
for (int i = threadDim.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) { // range, and write 'panic' flag if it has.
int t = i % BLOCK_SIZE, s = i / BLOCK_SIZE; for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) {
if (s < block_S && t < block_T) { int s_in_block = i / BLOCK_SIZE,
float this_p = p_buf[s + 1][t + 1]; t_in_block = i % BLOCK_SIZE,
p[b][s + s_block_begin][t + t_block_begin] = normalizer + log(this_p); s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin;
if (s_in_block < block_S && t_in_block < block_T) {
float this_p = p_buf[s_in_block + 1][t_in_block + 1];
p[b][s][t] = normalizer + log(this_p);
// If this_p is infinity, NaN or zero...
if (this_p - this_p != 0 || this_p == 0) if (this_p - this_p != 0 || this_p == 0)
p_buf[0][0] = 1.0; // This is a "panic" flag. p_buf[0][0] = 1.0; // This is a "panic" flag.
} }
...@@ -351,27 +368,31 @@ void mutual_information_kernel( ...@@ -351,27 +368,31 @@ void mutual_information_kernel(
// Write `ans`, if this is the final (top-right) block in its sequence // Write `ans`, if this is the final (top-right) block in its sequence
// Logically, the following equation corresponds to: // Logically, the following equation corresponds to:
// ans[b] = p[b][s_end][t_end] // ans[b] = p[b][s_end][t_end]
if (s_block_begin + S > s_end && t_block_begin + T > t_end) if (s_block_begin + block_S - 1 == s_end &&
ans[b] = normalizer + log(p_buf[s_end - s_block_begin + 1][t_end - t_block_begin + 1]); t_block_begin + block_T - 1 == t_end) {
// you could read block_S below as block_S - 1 + 1, meaning,
// it's the last index in a block of size block_S, but the indexes into
// p_buf have a "+ 1". Likewise for block_T.
ans[b] = normalizer + log(p_buf[block_S][block_T]);
}
} }
if (p_buf[0][0] != 0.0) { if (p_buf[0][0] != 0.0) {
// "panic" flag set. We need to re-do the computation using log-add. // The "panic" flag is set. We need to re-do the computation using log-add.
// This time we won't use the buffers, we'll just load and save from main // This time we won't use the buffers, we'll just load and save from main
// memory. This code should very rarely be reached; and anyway, caching // memory. This code should very rarely be reached; and anyway, caching
// should help us quite a bit. // should help us quite a bit.
for (int i = 0; i < 2 * BLOCK_SIZE; i++) { for (int i = 0; i < block_S + block_T - 1; ++i) {
int block_s = threadIdx.x, int s_in_block = threadIdx.x,
block_t = i - block_s; t_in_block = i - block_s;
if (static_cast<unsigned int>(t) < static_cast<unsigned int>(block_T) && if (s_in_block < block_S &&
block_s < block_S) { static_cast<unsigned int>(t_in_block) < static_cast<unsigned int>(block_T)) {
int s = block_s + s_block_begin, int s = s_in_block + s_block_begin,
t = block_t + t_block_begin; t = t_in_block + t_block_begin;
float p_s1 = (s == 0 ? -INFINITY : p[b][s - 1][t]), float p_s1 = (s == 0 ? -INFINITY : p[b][s - 1][t]),
this_px = (s == 0 ? -INFINITY : px[b][s - 1][t]),
p_t1 = (t == 0 ? -INFINITY : p[b][s][t - 1]), p_t1 = (t == 0 ? -INFINITY : p[b][s][t - 1]),
this_px = px[b][s][t], this_py = py[b][s][t]; this_py = (t == 0 ? -INFINITY : py[b][s][t - 1]);
float this_p = LogAdd(p_s1 + this_px, float this_p = LogAdd(p_s1 + this_px,
p_t1 + this_py); p_t1 + this_py);
if (i == 0 && is_origin_block) if (i == 0 && is_origin_block)
...@@ -382,7 +403,8 @@ void mutual_information_kernel( ...@@ -382,7 +403,8 @@ void mutual_information_kernel(
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
// Write `ans`, if this is the final (top-right) block in its sequence. // Write `ans`, if this is the final (top-right) block in its sequence.
// This is only reached in the 'panic situation' where we had overflow. // This is only reached in the 'panic situation' where we had overflow.
if (s_block_begin + S > s_end && t_block_begin + T > t_end) if (s_block_begin + block_S - 1 == s_end &&
t_block_begin + block_T - 1 == t_end)
ans[b] = p[b][s_end][t_end]; ans[b] = p[b][s_end][t_end];
} }
} }
...@@ -402,17 +424,18 @@ void mutual_information_kernel( ...@@ -402,17 +424,18 @@ void mutual_information_kernel(
ep[b][s][t - 1] * epy[b][s][t - 1]. (eq. 1) ep[b][s][t - 1] * epy[b][s][t - 1]. (eq. 1)
(A) (A)
First we consider the part that involves recursion, i.e. the part involving only gradients of First we consider the part of the backprop that requires recursion or iteration,
ep. The backprop involving ep only would be: i.e. the part involving only gradients of ep. This is:
ep_grad[b][s - 1][t] += ep_grad[b][s][t] * epx[b][s - 1][t] ep_grad[b][s - 1][t] += ep_grad[b][s][t] * epx[b][s - 1][t]
ep_grad[b][s][t - 1] += ep_grad[b][s][t] * epy[b][s][t - 1]. ep_grad[b][s][t - 1] += ep_grad[b][s][t] * epy[b][s][t - 1].
.. and if we add 1 to the s index of the first equation above and 1 to the .. and if we add 1 to the s index of the first equation above and 1 to the
t index of the second equation, we can see that: t index of the second equation, we can see that:
ep_grad[b][s][t] = ep_grad[b][s + 1][t] * epx[b][s][t] + ep_grad[b][s][t] = ep_grad[b][s + 1][t] * epx[b][s][t] +
ep_grad[b][s][t + 1] * epy[b][s][t]. ep_grad[b][s][t + 1] * epy[b][s][t].
Now, if ep = exp(p), then ep_grad == dy/dep == dy/dp dp/dep == dy/dp / (dep/dp) == dy/dp / exp(p) Now, if ep = exp(p), and y is the loss function we are backprop'ing,
then ep_grad == dy/dep == dy/dp dp/dep == dy/dp / (dep/dp) == dy/dp / exp(p)
== dy/dp / ep. == p_grad / ep. == dy/dp / ep. == p_grad / ep.
I.e. ep_grad = p_grad / ep. I.e. ep_grad = p_grad / ep.
So we can write the above as: So we can write the above as:
...@@ -425,8 +448,8 @@ void mutual_information_kernel( ...@@ -425,8 +448,8 @@ void mutual_information_kernel(
(B) The following is the backprop for epx and epy from (eq. 1): (B) The following is the backprop for epx and epy from (eq. 1):
epx_grad[b][s - 1][t] += ep_grad[b][s][t] * ep[b][s - 1][t] epx_grad[b][s - 1][t] += ep_grad[b][s][t] * ep[b][s - 1][t]
epy_grad[b][s][t - 1] += ep_grad[b][s][t] * ep[b][s][t - 1] epy_grad[b][s][t - 1] += ep_grad[b][s][t] * ep[b][s][t - 1]
.. adding 1 to the s indexes in the 1st equation and to the t indexes in the 2nd: .. adding 1 to the s indexes in the 1st equation and to the t indexes in the 2nd:
...@@ -435,7 +458,7 @@ void mutual_information_kernel( ...@@ -435,7 +458,7 @@ void mutual_information_kernel(
Using, similar to the above, ep_grad = p_grad / ep, and similarly, Using, similar to the above, ep_grad = p_grad / ep, and similarly,
epx_grad = px_grad / epx and epy_grad = py_grad / epy, and writing exp(p) for p and so on, epx_grad = px_grad / epx and epy_grad = py_grad / epy, and writing exp(p) for p and so on,
the above becomes the above becomes:
px_grad[b][s][t] / exp(px[b][s][t]) = p_grad[b][s + 1][t] / exp(p[b][s + 1][t]) * exp(p[b][s][t]) px_grad[b][s][t] / exp(px[b][s][t]) = p_grad[b][s + 1][t] / exp(p[b][s + 1][t]) * exp(p[b][s][t])
py_grad[b][s][t] / exp(py[b][s][t]) = p_grad[b][s][t + 1] / exp(p[b][s][t + 1]) * exp(p[b][s][t]) py_grad[b][s][t] / exp(py[b][s][t]) = p_grad[b][s][t + 1] / exp(p[b][s][t + 1]) * exp(p[b][s][t])
...@@ -450,11 +473,11 @@ void mutual_information_kernel( ...@@ -450,11 +473,11 @@ void mutual_information_kernel(
yderiv[b][s][t] := exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1]) (eq. 5) yderiv[b][s][t] := exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1]) (eq. 5)
.. and note that these quantities are <= 1 so there is no problem doing .. and note that these quantities are <= 1 so there is no problem doing
the exponentiation. So the recursion can be simplified as: the exponentiation. So the recursion can be simplified as from eqs. (2, 3a, 3b), as:
p_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] + p_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] +
p_grad[b][s][t + 1] * yderiv[b][s][t] (eq. 6) p_grad[b][s][t + 1] * yderiv[b][s][t] (eq. 6)
px_grad[b][s][t] = p_grad[b][s + 1][t] * yderiv[b][s][t] (eq. 7) px_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] (eq. 7)
py_grad[b][s][t] = p_grad[b][s][t + 1] * yderiv[b][s][t] (eq. 8) py_grad[b][s][t] = p_grad[b][s][t + 1] * yderiv[b][s][t] (eq. 8)
(It might seem like we could just reuse px_grad and py_grad for (eq. 6), but it's (It might seem like we could just reuse px_grad and py_grad for (eq. 6), but it's
...@@ -462,8 +485,9 @@ void mutual_information_kernel( ...@@ -462,8 +485,9 @@ void mutual_information_kernel(
write to shared memory within the loop that's the limiting factor.) write to shared memory within the loop that's the limiting factor.)
The backward pass will be slightly different from the forward pass in terms of The backward pass will be slightly different from the forward pass in terms of
how we store p (and p_grad), because for writing a particular block of p_grad, we how we store and index p (and p_grad), because for writing a particular block
need context on the top and right instead of the bottom and left. of p_grad, we need context on the top and right instead of the bottom and
left. So there are offsets of 1.
*/ */
template <typename scalar_t> template <typename scalar_t>
__global__ __global__
...@@ -472,8 +496,6 @@ void mutual_information_backward_kernel( ...@@ -472,8 +496,6 @@ void mutual_information_backward_kernel(
torch::PackedTensorAccessor32<scalar_t, 3> py, // B, S + 1, T. torch::PackedTensorAccessor32<scalar_t, 3> py, // B, S + 1, T.
torch::PackedTensorAccessor32<scalar_t, 3> p, // B, S + 1, T + 1. Produced in forward pass. torch::PackedTensorAccessor32<scalar_t, 3> p, // B, S + 1, T + 1. Produced in forward pass.
torch::PackedTensorAccessor32<scalar_t, 1> ans_grad, // [B]. This is an input. torch::PackedTensorAccessor32<scalar_t, 1> ans_grad, // [B]. This is an input.
torch::PackedTensorAccessor32<scalar_t, 1> ans_grad_compare, // [B]. A value will be written to here which
// should ideally equal ans_grad.
torch::PackedTensorAccessor32<scalar_t, 3> p_grad, // B, S + 1, T + 1. This is a temporary. torch::PackedTensorAccessor32<scalar_t, 3> p_grad, // B, S + 1, T + 1. This is a temporary.
torch::PackedTensorAccessor32<scalar_t, 3> px_grad, // B, S, T + 1. torch::PackedTensorAccessor32<scalar_t, 3> px_grad, // B, S, T + 1.
torch::PackedTensorAccessor32<scalar_t, 3> py_grad, // B, S + 1, T. torch::PackedTensorAccessor32<scalar_t, 3> py_grad, // B, S + 1, T.
...@@ -483,16 +505,18 @@ void mutual_information_backward_kernel( ...@@ -483,16 +505,18 @@ void mutual_information_backward_kernel(
// be any sufficiently large number but will actually be: // be any sufficiently large number but will actually be:
// num_s_blocks + num_t_blocks - 1 where num_s_blocks = S / // num_s_blocks + num_t_blocks - 1 where num_s_blocks = S /
// BLOCK_SIZE + 1 and num_t_blocks = T / BLOCK_SIZE + 1 // BLOCK_SIZE + 1 and num_t_blocks = T / BLOCK_SIZE + 1
bool overwrite_ans_grad) { // If true, overwrite ans_grad with a value bool overwrite_ans_grad) { // If overwite_ans_grad == true, this function
// which, if everything is working correctly, // will overwrite ans_grad with a value which,
// should be identical or very close to the // if everything is working correctly, should be
// value of ans_grad that was passed in. // identical or very close to the value of
// ans_grad that was passed in.
const int B = px.size(0), const int B = px.size(0),
S = px.size(1), S = px.size(1),
T = py.size(2); T = py.size(2);
// For statements that are the same as the forward pass, we are omitting some comments // For statements that are the same as the forward pass, we are omitting some
// what we made there. We'll focus, in the comments, on differences from the forward pass. // comments. We'll focus, in the comments, on differences from the forward
// pass.
const int num_s_blocks = S / BLOCK_SIZE + 1, const int num_s_blocks = S / BLOCK_SIZE + 1,
num_t_blocks = T / BLOCK_SIZE + 1, num_t_blocks = T / BLOCK_SIZE + 1,
num_blocks_this_iter = min(iter + 1, num_s_blocks); num_blocks_this_iter = min(iter + 1, num_s_blocks);
...@@ -502,29 +526,33 @@ void mutual_information_backward_kernel( ...@@ -502,29 +526,33 @@ void mutual_information_backward_kernel(
// but then modified to store the "xderiv" and "yderiv" values defined // but then modified to store the "xderiv" and "yderiv" values defined
// in (eq. 5) and (eq. 6) above. For out-of-range values, we'll write 0.0 // in (eq. 5) and (eq. 6) above. For out-of-range values, we'll write 0.0
// here. // here.
// px_buf[s][t] contains px[s+s_block_begin][t+t_block_begin]; // Initially (before xderiv/yderiv are written):
// py_buf[s][t] contains py[s+s_block_begin][t+t_block_begin]. // px_buf[s][t] contains px[s+s_block_begin][t+t_block_begin];
// py_buf[s][t] contains py[s+s_block_begin][t+t_block_begin].
// Later (see eq. 4 and eq. 5):
// px_buf[s][t] contains exp(p[b][ss][tt] + px[b][ss][tt] - p[b][ss + 1][tt]),
// py_buf[s][t] contains exp(p[b][ss][tt] + py[b][ss][tt] - p[b][ss][tt + 1]
// where ss == s + s_block_begin, tt = t + t_block_begin.
// Unlike in the forward code, there is no offset of 1 in the indexes. // Unlike in the forward code, there is no offset of 1 in the indexes.
__shared__ scalar_t px_buf[BLOCK_SIZE][BLOCK_SIZE], __shared__ scalar_t px_buf[BLOCK_SIZE][BLOCK_SIZE],
py_buf[BLOCK_SIZE][BLOCK_SIZE]; py_buf[BLOCK_SIZE][BLOCK_SIZE];
// p_buf is initially used to store p, and then (after we are done putting // p_buf is initially used to store p, and then (after we are done putting
// xderiv and yderiv into px_buf and py_buf) it is repurposed to store // xderiv and yderiv into px_buf and py_buf) it is repurposed to store
// p_grad. // p_grad.
// //
// Unlike in the forward pass, p_buf has the same numbering as px_buf and // Unlike in the forward pass, p_buf has the same numbering as px_buf and
// py_buf not offset by 1: e.g., for the origin block, p_buf[0][0] refers // py_buf, it's not offset by 1: e.g., for the origin block, p_buf[0][0]
// to p[0][0] and not p[-1][-1]. The p_buf block is larger by 1 than // refers to p[0][0] and not p[-1][-1]. The p_buf block is larger by 1 than
// the block for px_buf and py_buf; unlike in the forward pass, we store // the block for px_buf and py_buf; unlike in the forward pass, we store
// context on the top right, not the bottom left, i.e. the elements at // context on the top and right, not the bottom and left, i.e. the elements at
// (one past the largest indexes in the block). // (one past the largest indexes in the block).
// //
// For out-of-range elements of p_buf, we'll put zero. // For out-of-range elements of p_buf, we'll put zero.
__shared__ scalar_t p_buf[BLOCK_SIZE + 1][BLOCK_SIZE + 1]; __shared__ scalar_t p_buf[BLOCK_SIZE + 1][BLOCK_SIZE + 1];
// boundary_buf will be used to store the b'th row of `boundary` if we have // boundary_buf will be used to store the b'th row of `boundary` if we have
// boundary information supplied. // boundary information supplied; or (0, 0, S, T) if not.
__shared__ int64_t boundary_buf[4]; __shared__ int64_t boundary_buf[4];
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
...@@ -541,13 +569,13 @@ void mutual_information_backward_kernel( ...@@ -541,13 +569,13 @@ void mutual_information_backward_kernel(
for (int batch_block_iter = blockIdx.x; for (int batch_block_iter = blockIdx.x;
batch_block_iter < B * num_blocks_this_iter; batch_block_iter < B * num_blocks_this_iter;
batch_block_iter += gridDim.x) { batch_block_iter += gridDim.x) {
int b = batch_block_iter % B, int block = batch_block_iter / B,
block = batch_block_iter / B; b = batch_block_iter % B;
int s_block_begin = block * BLOCK_S_SIZE, int s_block_begin = block * BLOCK_SIZE,
t_block_begin = (iter - block) * BLOCK_T_SIZE; t_block_begin = (iter - block) * BLOCK_SIZE;
if (threadDim.x < 4 && boundary.size(0) != 0) if (threadIdx.x < 4 && boundary.size(0) != 0)
boundary_buf[threadDim.x] = boundary[b][threadDim.x]; boundary_buf[threadIdx.x] = boundary[b][threadIdx.x];
__syncthreads(); __syncthreads();
int s_begin = boundary_buf[0], int s_begin = boundary_buf[0],
...@@ -560,68 +588,69 @@ void mutual_information_backward_kernel( ...@@ -560,68 +588,69 @@ void mutual_information_backward_kernel(
// block_S and block_T are the actual sizes of this block, no greater than // block_S and block_T are the actual sizes of this block, no greater than
// (BLOCK_SIZE, BLOCK_SIZE) but possibly less than that if we are towards // (BLOCK_SIZE, BLOCK_SIZE) but possibly less than that if we are towards
// the end of the sequence. // the end of the sequence.
// The last element of the output matrix p we write is (s_end, t_end), // The last element of the output matrix p_grad we write is (s_end, t_end),
// i.e. the one-past-the-end index is (s_end + 1, t_end + 1). // i.e. the one-past-the-end index of p_grad is (s_end + 1, t_end + 1).
int block_S = min(BLOCK_SIZE, s_end + 1 - s_block_begin), int block_S = min(BLOCK_SIZE, s_end + 1 - s_block_begin),
block_T = min(BLOCK_SIZE, t_end + 1 - t_block_begin); block_T = min(BLOCK_SIZE, t_end + 1 - t_block_begin);
if (block_S <= 0 || block_T <= 0) if (block_S <= 0 || block_T <= 0)
continue; continue;
// Load px_buf and py_buf. At this point they just contain px and py // Load px_buf and py_buf. At this point we just set them to the px and py
// for this block. // for this block.
for (int i = threadDim.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) { for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) {
int s_in_block = i / BLOCK_SIZE, int s_in_block = i / BLOCK_SIZE,
t_in_block = i % BLOCK_SIZE, t_in_block = i % BLOCK_SIZE,
s = s_in_block + s_block_begin, s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin; t = t_in_block + t_block_begin;
// We let ps and py default to -infinity if they are out of range, which will // We let px and py default to -infinity if they are out of range, which will
// cause xderiv and yderiv for out-of-range values to be zero, and cause // cause xderiv and yderiv for out-of-range values to be zero, and cause
// correct behavior in edge cases (for the top and right blocks). // correct behavior in edge cases (for the top and right blocks).
// The issue is that p and p_grad are of larger size than px and py. // The issue is that p and p_grad are of larger size than px and py.
scalar_t this_px = -INFINITY; scalar_t this_px = -INFINITY;
if (s < s_end && t <= t_end) if (s < s_end && t <= t_end)
this_px = px[b][s - 1][t]; this_px = px[b][s][t];
px_buf[s_in_block][t_in_block] = this_px; px_buf[s_in_block][t_in_block] = this_px;
scalar_t this_py = -INFINITY; scalar_t this_py = -INFINITY;
if (s <= s_end && t < t_end) if (s <= s_end && t < t_end)
this_py = py[b][s][t - 1]; this_py = py[b][s][t];
py_buf[s_in_block][t_in_block] = this_py; py_buf[s_in_block][t_in_block] = this_py;
} }
// load p. We could use BLOCK_SIZE + 1 here, but we use + 8 to hopefully keep
// load p. This time we loop over the exact indexes we need. Above // reads more aligned.
// we looped to BLOCK_SIZE * BLOCK_SIZE rather than block_S and block_T for (int i = threadIdx.x; i < (BLOCK_SIZE + 1) * (BLOCK_SIZE + 1); i += blockDim.x) {
// because having power-of-2 arrangement of threads may be helpful int s_in_block = i / (BLOCK_SIZE + 1),
// for aligned reads, but here the loop is up to (BLOCK_SIZE + 1) * (BLOCK_SIZE + 1) t_in_block = i % (BLOCK_SIZE + 1),
// which is not a power of 2, so that is not a concern here.
for (int i = threadDim.x; i < (BLOCK_SIZE + 1) * (BLOCK_SIZE + 1); i += blockDim.x) {
int s_in_block = i / (BLOCK_SIZE + 1), // 0 <= s_in_block <= block_S
t_in_block = i % (BLOCK_SIZE + 1), // 0 <= t_in_block <= block_T
s = s_in_block + s_block_begin, s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin; t = t_in_block + t_block_begin;
// Setting 0.0 for out-of-bounds elements, together with setting // Setting 0.0 for out-of-bounds elements, together with setting
// -INFINITY for out-of-bounds elements of px_buf and py_buf, will // -INFINITY for out-of-bounds elements of px_buf and py_buf, will
// ensure that we do the right thing in top and right edge cases, // ensure that we do the right thing in top and right edge cases,
// i.e. that no derivatives will be propagated from out-of-bounds points. // i.e. that no derivatives will be propagated from out-of-bounds points
p_buf[s_in_block][t_in_block] = (s <= s_end && t <= t_end ? // because the corresponding xderiv and yderiv values will be zero.
p[b][s][t] : 0.0); scalar_t this_p = 0.0;
if (s <= s_end && t <= t_end)
this_p = p[b][s][t];
p_buf[s_in_block][t_in_block] = this_p;
} }
// Set xderiv and yderiv; see (eq. 4) and (eq. 5). // Set xderiv and yderiv; see (eq. 4) and (eq. 5).
for (int i = threadDim.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) { for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) {
// We can apply this formula to the entire block even if we are processing // We can apply this formula to the entire block even if we are processing
// a partial block; elements outside the partial block will not be used so // a partial block; we have ensured that x_buf and y_buf contain -infinity,
// their values don't matter, and elements just out // and p contains 0, for out-of-range elements, so we'll get x_buf and y_buf
int t = i % BLOCK_SIZE, s = i / BLOCK_SIZE; // containing 0 after applying the followin formulas.
int s = i / BLOCK_SIZE,
t = i % BLOCK_SIZE;
// Mathematically the following is doing: // Mathematically the following is doing:
// xderiv[b][s][t] := exp(p[b][s][t] + px[b][s][t] - p[b][s + 1][t]) // xderiv[b][s][t] := exp(p[b][s][t] + px[b][s][t] - p[b][s + 1][t])
// (with an offset on the s and t indexes) // (with an offset on the s and t indexes)
px_buf[s][t] = exp(px_buf[s][t] + px_buf[s][t] - p_buf[s + 1][t]); px_buf[s][t] = exp(p_buf[s][t] + px_buf[s][t] - p_buf[s + 1][t]);
// Mathematically the following is doing: // Mathematically the following is doing:
// yderiv[b][s][t] := exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1]) // yderiv[b][s][t] := exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1])
// (with an offset on the s and t indexes) // (with an offset on the s and t indexes)
py_buf[s][t] = exp(px_buf[s][t] + py_buf[s][t] - p_buf[s][t + 1]); py_buf[s][t] = exp(p_buf[s][t] + py_buf[s][t] - p_buf[s][t + 1]);
} }
// Load p_grad for the top and right elements in p_buf: i.e. for elements // Load p_grad for the top and right elements in p_buf: i.e. for elements
...@@ -630,7 +659,8 @@ void mutual_information_backward_kernel( ...@@ -630,7 +659,8 @@ void mutual_information_backward_kernel(
// never be accessed. // never be accessed.
// These are the p_grad values computed by previous instances of this kernel // These are the p_grad values computed by previous instances of this kernel
// If this is one of the top or right blocks, some or all of the p_grad // If this is one of the top or right blocks, some or all of the p_grad
// values we'd be reading here will be out of range, and we use zeros. // values we'd be reading here will be out of range, and we use zeros
// to ensure no gradient gets propagated from those positions.
if (threadIdx.x < block_S) { if (threadIdx.x < block_S) {
int s_in_block = threadIdx.x, int s_in_block = threadIdx.x,
t_in_block = block_T, t_in_block = block_T,
...@@ -638,34 +668,33 @@ void mutual_information_backward_kernel( ...@@ -638,34 +668,33 @@ void mutual_information_backward_kernel(
t = t_in_block + t_block_begin; t = t_in_block + t_block_begin;
p_buf[s_in_block][t_in_block] = ( p_buf[s_in_block][t_in_block] = (
s <= s_end && t <= t_end ? p_grad[s][t] : 0.0); s <= s_end && t <= t_end ? p_grad[s][t] : 0.0);
} else if (static_cast<unsigned int>(threadIdx.x - 64) < } else if (static_cast<unsigned int>((int)threadIdx.x - 64) <
static_cast<unsigned int>(block_T)) { static_cast<unsigned int>(block_T)) {
// casting to unsigned before the comparison tests for both negative and
// out-of-range values of (int)threadIdx.x - 64.
int s_in_block = block_S, int s_in_block = block_S,
t_in_block = threadIdx.x - 64, t_in_block = (int)threadIdx.x - 64,
s = s_in_block + s_block_begin, s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin; t = t_in_block + t_block_begin;
p_buf[s_in_block][t_in_block] = ( p_buf[s_in_block][t_in_block] = (
s <= s_end && t <= t_end ? p_grad[s][t] : 0.0); s <= s_end && t <= t_end ? p_grad[s][t] : 0.0);
} }
// The number of inner iterations, i.e. iterations inside this // The highest-numbered value in p_buf that we need (corresponding,
// kernel, is this_num_inner_iters. The highest iteration, // of course, to p_grad), is:
// corresponding to the highest-indexed value of p_buf that // p_buf[block_S - 1][block_T - 1],
// we need to set, // and the inner iteration number (i) on which we set this is the sum of
// corresponds to p_buf[block_S - 1][block_T - 1], // these indexes, i.e. (block_S - 1) + (block_T - 1).
// and the iteration number is the sum of these indexes, i.e.
// (block_S - 1) + (block_T - 1).
bool is_final_block = (s_block_begin + block_S == s_end + 1 && bool is_final_block = (s_block_begin + block_S == s_end + 1 &&
t_block_begin + block_T == t_end + 1); t_block_begin + block_T == t_end + 1);
int first_iter = block_S + block_T - 2; int first_iter = block_S + block_T - 2;
if (is_final_block) { if (is_final_block) {
// The following statement, mathematically, corresponds to: // The following statement corresponds to:
// p_grad[b][s_end][t_end] = ans_grad[b] Normally this element of p_buf // p_grad[b][s_end][t_end] = ans_grad[b]
// would be set by the first iteration of the loop below, so if it's set // Normally this element of p_buf would be set by the first iteration of
// this way we have to decrement first_iter to prevent it being // the loop below, so if it's set this way we have to decrement first_iter
// overwritten. // to prevent it from being overwritten.
p_buf[block_S - 1][block_T - 1] = ans_grad[b]; p_buf[block_S - 1][block_T - 1] = ans_grad[b];
--first_iter; --first_iter;
} }
...@@ -675,7 +704,8 @@ void mutual_information_backward_kernel( ...@@ -675,7 +704,8 @@ void mutual_information_backward_kernel(
t = i - threadIdx.x; t = i - threadIdx.x;
if (t >= 0) { if (t >= 0) {
// The following statement is really operating on the gradients; // The following statement is really operating on the gradients;
// it corresponds to (eq. 6) defined above, i.e.: // it corresponds, with offsets of s_block_begin and t_block_begin
// on the indexes, to (eq. 6) defined above, i.e.:
// p_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] + // p_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] +
// p_grad[b][s][t + 1] * yderiv[b][s][t] // p_grad[b][s][t + 1] * yderiv[b][s][t]
p_buf[s][t] = (p_buf[s + 1][t] * px_buf[s][t] + p_buf[s][t] = (p_buf[s + 1][t] * px_buf[s][t] +
...@@ -684,17 +714,19 @@ void mutual_information_backward_kernel( ...@@ -684,17 +714,19 @@ void mutual_information_backward_kernel(
} }
// Write out p_grad, px_grad and py_grad. // Write out p_grad, px_grad and py_grad.
for (int i = threadDim.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) { for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) {
int t_in_block = i % BLOCK_SIZE, int s_in_block = i / BLOCK_SIZE,
s_in_block = i / BLOCK_SIZE, t_in_block = i % BLOCK_SIZE,
s = s_in_block + s_block_begin, s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin; t = t_in_block + t_block_begin;
// s_end and t_end are the one-past-the-end of the (x,y) sequences, but
// the one-past-the-end element of p_grad would be (s_end + 1, t_end + 1).
if (t <= t_end && s <= s_end) { if (t <= t_end && s <= s_end) {
p_grad[b][s][t] = p_buf[s_in_block][t_in_block]; p_grad[b][s][t] = p_buf[s_in_block][t_in_block];
if (s < s_end) { // write px_grad, which is of shape [B][S][T + 1] if (s < s_end) { // write px_grad, which is of shape [B][S][T + 1]
// From (eq. 7): // From (eq. 7):
// px_grad[b][s][t] = p_grad[b][s + 1][t] * yderiv[b][s][t] // px_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t]
px_grad[b][s][t] = (p_buf[s_in_block + 1][t_in_block] * px_grad[b][s][t] = (p_buf[s_in_block + 1][t_in_block] *
px_buf[s_in_block][t_in_block]); px_buf[s_in_block][t_in_block]);
} }
...@@ -741,7 +773,7 @@ torch::Tensor mutual_information_cuda(torch::Tensor px, ...@@ -741,7 +773,7 @@ torch::Tensor mutual_information_cuda(torch::Tensor px,
// num_threads and num_blocks and BLOCK_SIZE can be tuned. // num_threads and num_blocks and BLOCK_SIZE can be tuned.
// (however, num_threads may not be less than 128). // (however, num_threads may not be less than 128).
int num_threads = 128, int num_threads = 128,
num_blocks = 128, num_blocks = 256,
BLOCK_SIZE = 32; BLOCK_SIZE = 32;
// The blocks cover the 'p' matrix, which is of size (B, S+1, T+1), // The blocks cover the 'p' matrix, which is of size (B, S+1, T+1),
...@@ -802,14 +834,18 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px, ...@@ -802,14 +834,18 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1); TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
TORCH_CHECK(ans_grad.size(0) == b); TORCH_CHECK(ans_grad.size(0) == b);
bool has_boundary = (bool)optional_boundary;
torch::Tensor p_grad = torch::empty({B, S + 1, T + 1}, opts), torch::Tensor p_grad = torch::empty({B, S + 1, T + 1}, opts),
px_grad = torch::empty({B, S, T + 1}, opts), px_grad = (has_boundary ? torch::zeros({B, S, T + 1}, opts) :
py_grad = torch::empty({B, S + 1, T}, opts), torch::empty({B, S, T + 1}, opts)),
py_grad = (has_boundary ? torch::zeros({B, S, T + 1}, opts) :
torch::empty({B, S + 1, T}, opts));
// num_threads and num_blocks and BLOCK_SIZE can be tuned. // num_threads and num_blocks and BLOCK_SIZE can be tuned.
// (however, num_threads may not be less than 128). // (however, num_threads may not be less than 128).
const int num_threads = 128, const int num_threads = 128,
num_blocks = 128, num_blocks = 256,
BLOCK_SIZE = 32; BLOCK_SIZE = 32;
// The blocks cover the 'p' matrix, which is of size (B, S+1, T+1), // The blocks cover the 'p' matrix, which is of size (B, S+1, T+1),
...@@ -819,7 +855,7 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px, ...@@ -819,7 +855,7 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
num_t_blocks = T / BLOCK_SIZE + 1, num_t_blocks = T / BLOCK_SIZE + 1,
num_iters = num_s_blocks + num_t_blocks - 1; num_iters = num_s_blocks + num_t_blocks - 1;
if ((bool)optional_boundary) if (has_boundary)
TORCH_CHECK(optional_boundary.value().device().is_cuda(), TORCH_CHECK(optional_boundary.value().device().is_cuda(),
"boundary information must be in CUDA tensor"); "boundary information must be in CUDA tensor");
else else
...@@ -838,5 +874,6 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px, ...@@ -838,5 +874,6 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
iter, iter,
overwrite_ans_grad); overwrite_ans_grad);
} }
std::cout << "p_grad = " << p_grad;
return std::vector<torch::Tensor>({px_grad, py_grad}); return std::vector<torch::Tensor>({px_grad, py_grad});
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment