Commit 3c1ec347 authored by Daniel Povey's avatar Daniel Povey
Browse files

Get it to a stage where it looks like it might compile

parent 8ed6deff
......@@ -44,66 +44,72 @@ except ImportError:
def _mutual_information_forward_dispatcher(px: torch.Tensor, py: torch.Tensor,
boundaries: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
boundaries: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
if input.is_cuda:
if torch_mutual_information_cuda is None:
raise EnvironmentError(f'Failed to load native CUDA module')
return torch_mutual_information_cuda.mutual_information_cuda(
px, py, boundaries, q)
px, py, boundaries, p)
else:
return torch_mutual_information_cpu.mutual_information_cpu(
px, py, boundaries, q)
px, py, boundaries, p)
def _mutual_information_backward_dispatcher(px: torch.Tensor, py: torch.Tensor,
boundaries: torch.Tensor, q: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
boundaries: torch.Tensor, p: torch.Tensor,
ans_grad: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if px.is_cuda:
if torch_mutual_information_cuda is None:
raise EnvironmentError(f'Failed to load native CUDA module')
return tuple(torch_mutual_information_cuda.mutual_information_backward_cuda(
px, py, boundaries, q))
overwrite_ans_grad = True
if overwrite_ans_grad:
ans_grad_copy = ans_grad.clone()
ans = tuple(torch_mutual_information_cuda.mutual_information_backward_cuda(
px, py, boundaries, p, ans_grad_copy, overwrite_ans_grad))
if overwrite_ans_grad:
if not torch.allclose(ans_grad, ans_grad_copy, rtol=1.0e-02):
print(f"Warning: possible excsssive roundoff in mutual information backward "
"recursion: {ans_grad} vs. {ans_grad_copy}");
return ans
else:
return tuple(torch_mutual_information_cpu.mutual_information_backward_cpu(
px, py, boundaries, q))
px, py, boundaries, p, ans_grad))
class MutualInformationRecursionFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, px: torch.Tensor, py: torch.Tensor, boundaries: torch.Tensor) -> torch.Tensor:
(B, S, T) = px.shape
def forward(ctx, px: torch.Tensor, py: torch.Tensor, boundaries: Optional[torch.Tensor]) -> torch.Tensor:
(B, S, T1) = px.shape
T = T1 - 1;
assert py.shape == (B, S + 1, T)
if boundaries is not None:
assert boundaries.shape == (B, 4)
# p is a tensor of shape (B, S + 1, T + 1) were p[s][t] is related to
# p is a tensor of shape (B, S + 1, T + 1) were p[s][t] is the
# the mutual information of the pair of subsequences of x and y that are of
# length s and t respectively. p[0][0] will be 0.0 and p[S][T] is
# the mutual information of the entire pair of sequences, i.e. of lengths
# S and T respectively.
# q is a rearrangement of a tensor p which is of shape (B,S,T),
# using p[b,s,t] == q[b,s+t,t]. The reason for working with this
# representation is that each row of q depends only on the previous row,
# so we can access the rows sequenctially and this leads to
# better memory access patterns. We are assuming that most likely
# T < S, which means that q should not require much more memory than p.
#
# Actually we access q beginning from 0 indexes even if `boundaries`
# has t_begin > 0 or s_begin > 0, i.e. we really access q as
# q[b, s-s_begin + t-t_begin, t-t_begin];
# note, rows of `boundaries` are [s_begin, t_begin, s_end, t_end].
# It is computed as follows (in C++ and CUDA):
# p[b,0,0] = 0.0
# p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
# p[b,s,t-1] + py[b,s,t-1])
# if s > 0 or t > 0,
# treating values with any -1 index as -infinity.
# .. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
p = torch.empty(B, S + 1, T + 1, device=px.device, dtype=px.dtype)
ans = _mutual_information_forward_dispatcher(px, py, boundaries, p)
if px.requires_grad or py.requires_grad:
ctx.save_for_backward(px, py, boundaries, q)
ctx.save_for_backward(px, py, boundaries, p)
@staticmethod
def backward(ctx, ans_grad: Tensor) -> Tuple[torch.Tensor, torch.Tensor, None]:
(px, py, boundaries, q) = ctx.saved_tensors
(px_grad, py_grad) = _mutual_information_backward_dispatcher(px, py, boundaries, q)
(px, py, boundaries, p) = ctx.saved_tensors
(px_grad, py_grad) = _mutual_information_backward_dispatcher(px, py, boundaries, p, ans_grad)
return (px_grad, py_grad, None)
......
......@@ -151,11 +151,16 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T);
TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
torch::Tensor p_grad = torch::zeros({B, S + 1, T + 1}, opts);
bool has_boundary = (bool)optional_boundary;
torch::Tensor p_grad = torch::zeros({B, S + 1, T + 1}, opts),
px_grad = (has_boundary ? torch::zeros({B, S, T + 1}, opts) :
torch::empty({B, S, T + 1}, opts)),
py_grad = (has_boundary ? torch::zeros({B, S + 1, T}, opts) :
torch::empty({B, S + 1, T}, opts));
auto long_opts = torch::TensorOptions().dtype(torch::kInt64).device(px.device());
bool has_boundary = (bool)optional_boundary;
if (!has_boundary)
optional_boundary = torch::empty({0, 0}, long_opts);
......@@ -166,7 +171,9 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
auto px_a = px.packed_accessor32<scalar_t, 3>(),
py_a = py.packed_accessor32<scalar_t, 3>(),
p_a = p.packed_accessor32<scalar_t, 3>(),
p_grad_a = p.packed_accessor32<scalar_t, 3>();
p_grad_a = p_grad.packed_accessor32<scalar_t, 3>(),
px_grad_a = px_grad.packed_accessor32<scalar_t, 3>(),
py_grad_a = py_grad.packed_accessor32<scalar_t, 3>();
auto ans_grad_a = ans_grad.packed_accessor32<scalar_t, 1>();
......@@ -196,19 +203,17 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
// .. which obtains p_a[b][s][t - 1] from a register.
scalar_t term1 = p_a[b][s - 1][t] + px_a[b][s - 1][t],
// term2 = p_a[b][s][t - 1] + py_a[b][s][t - 1], <-- not
// actually needed..
total = p_a[b][s][t],
term1_deriv = exp(term1 - total),
term2_deriv = 1.0 - term1_deriv,
grad = p_grad_a[b][s][t],
term1_grad = term1_deriv * grad,
term2_grad = term2_deriv * grad;
// We can assign to px_grad_a here rather than add, because we
// know it's currently zero.
TORCH_CHECK(px_grad_a[b][s - 1][t] == 0);
px_grad_a[b][s - 1][t] = term1_grad;
TORCH_CHECK(p_grad_a[b][s - 1][t] == 0.0); // likewise..
p_grad_a[b][s - 1][t] = term1_grad
py_grad_a[b][s][t - 1] += term2_grad;
p_grad_a[b][s - 1][t] = term1_grad;
py_grad_a[b][s][t - 1] = term2_grad;
p_grad_a[b][s][t - 1] += term2_grad;
}
}
......@@ -239,111 +244,9 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
}
}
}));
return ans;
}
TORCH_CHECK(input.dim() == 3, "input must be 3-dimensional");
TORCH_CHECK(params.dim() == 2, "params must be 2-dimensional.");
TORCH_CHECK(params.size(1) >= 3 &&
((params.size(1) - 1) & (params.size(1) - 2)) == 0,
"params.size(1) has invalid value, must be a power of 2 plus 1.");
TORCH_CHECK(params.size(0) == input.size(1),
"params vs input channels mismatch");
TORCH_CHECK(input.sizes() == output_grad.sizes(),
"Output-grad vs. input sizes mismatch.");
TORCH_CHECK(input.device().is_cpu(), "Input must be a CPU tensor");
TORCH_CHECK(params.device().is_cpu(), "Params must be a CPU tensor");
TORCH_CHECK(output_grad.device().is_cpu(), "Output-grad must be a CPU tensor");
const int B = input.size(0),
C = input.size(1),
T = input.size(2),
N = params.size(1) - 1,
K = N / 2;
auto scalar_t = input.scalar_type();
auto opts = torch::TensorOptions().dtype(scalar_t).device(input.device());
torch::Tensor y_vals = torch::empty({C, N}, opts),
y_vals_grad = torch::zeros({C, N}, opts),
params_grad = torch::zeros({C, N + 1}, opts),
input_grad = torch::zeros({B, C, T}, opts);
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "mutual_information_backward_cpu_loop", ([&] {
auto params_a = params.accessor<scalar_t, 2>(),
params_grad_a = params_grad.accessor<scalar_t, 2>(),
y_vals_a = y_vals.accessor<scalar_t, 2>(),
y_vals_grad_a = y_vals_grad.accessor<scalar_t, 2>();
for (int c = 0; c < C; c++) {
scalar_t sum_negative = 0.0,
sum_positive = 0.0,
scale = exp(params_a[c][0]);
for (int i = 0; i < K; i++) {
scalar_t pos_scaled_param = params_a[c][1 + K + i] * scale,
neg_scaled_param = params_a[c][K - i] * scale;
y_vals_a[c][K + i] = sum_positive - pos_scaled_param * i;
sum_positive += pos_scaled_param;
sum_negative -= neg_scaled_param;
y_vals_a[c][K - i - 1] = sum_negative + neg_scaled_param * (i + 1);
}
}
auto input_a = input.accessor<scalar_t, 3>(),
output_grad_a = output_grad.accessor<scalar_t, 3>(),
input_grad_a = input_grad.accessor<scalar_t, 3>();
for (int b = 0; b < B; b++) {
for (int c = 0; c < C; c++) {
scalar_t inv_scale = exp(-params_a[c][0]);
for (int t = 0; t < T; t++) {
scalar_t input = input_a[b][c][t],
x = input * inv_scale + K,
output_grad = output_grad_a[b][c][t];
if (x < 0) x = 0;
else if (x >= N) x = N - 1;
// C++ rounds toward zero.
int n = (int) x;
// OK, at this point, 0 <= n < 2*K.
// backprop for:
// output_a[b][c][t] = input * params_a[c][n + 1] + y_vals_a[c][n];
params_grad_a[c][n + 1] += output_grad * input;
y_vals_grad_a[c][n] += output_grad;
input_grad_a[b][c][t] = output_grad * params_a[c][n + 1];
}
}
}
// Now do the backprop for the loop above where we set y_vals_a.
for (int c = 0; c < C; c++) {
scalar_t scale = exp(params_a[c][0]),
scale_grad = 0.0,
sum_negative_grad = 0.0,
sum_positive_grad = 0.0;
for (int i = K - 1; i >= 0; i--) {
// Backprop for: y_vals_a[c][K - i - 1] = sum_negative + neg_scaled_param * (i + 1):
scalar_t y_grad_neg = y_vals_grad_a[c][K - i - 1];
sum_negative_grad += y_grad_neg;
scalar_t neg_scaled_param_grad = y_grad_neg * (i + 1);
// Backprop for: sum_negative -= neg_scaled_param;
neg_scaled_param_grad -= sum_negative_grad;
// Backprop for: sum_positive += pos_scaled_param;
scalar_t pos_scaled_param_grad = sum_positive_grad;
// Backprop for: y_vals_a[c][K + i] = sum_positive - pos_scaled_param * i;
scalar_t y_grad_pos = y_vals_grad_a[c][K + i];
pos_scaled_param_grad -= i * y_grad_pos;
sum_positive_grad += y_grad_pos;
// Backprop for: pos_scaled_param = params_a[c][1 + K + i] * scale,
// and: neg_scaled_param = params_a[c][K - i] * scale;
params_grad_a[c][1 + K + i] += pos_scaled_param_grad * scale;
params_grad_a[c][K - i] += neg_scaled_param_grad * scale;
scale_grad += (pos_scaled_param_grad * params_a[c][1 + K + i] +
neg_scaled_param_grad * params_a[c][K - i]);
}
// Backprop for: scale = exp(params_a[c][0]),
params_grad_a[c][0] += scale * scale_grad;
}
}));
return std::vector<torch::Tensor>({input_grad, params_grad});
std::cout << "p_grad = " << p_grad;
return std::vector<torch::Tensor>({px_grad, py_grad});
}
......
......@@ -73,7 +73,7 @@ __forceinline__ __device__ inline float LogAdd(float x, float y) {
p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
p[b,s,t-1] + py[b,s,t-1]) (eq. 0)
if s > 0 or t > 0,
treating values with any -1 index as -infinity.
.. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
......@@ -122,32 +122,33 @@ void mutual_information_kernel(
num_t_blocks = T / BLOCK_SIZE + 1;
// num_blocks_this_iter is an upper bound on the number of blocks of size
// (BLOCK_SIZE by BLOCK_SIZE) that might be active on this iteration. We go
// from the bottom left of the image so that on iter == 0 we process only one
// block with block-index (0, 0) then on iter == 1 we process block-indexes
// (1, 0) and (0, 1); and then on iter==2 we process (2, 0), (1, 1) and (0,
// 2); and so on. We also will never have more than `num_s_blocks` blocks
// (We'll never have more than num_t_blocks either, but the numbering we use
// corresponds to s and not t, so if we hit the num_t_blocks limit, the
// lowest-numbered blocks on s would just not be active and we'll 'continue'
// below).
// (BLOCK_SIZE by BLOCK_SIZE) that might be active on this iteration (`iter`).
// These iterations start from the bottom left of the image so that on iter ==
// 0 we process only one block with block-index (0, 0) then on iter == 1 we
// process block-indexes (1, 0) and (0, 1); and then on iter==2 we process (2,
// 0), (1, 1) and (0, 2); and so on. We also will never have more than
// `num_s_blocks` blocks (We'll never have more than num_t_blocks either, but
// the numbering we use corresponds to s and not t, so when we hit the
// num_t_blocks limit, the blocks with the lowest s indexes would just not be
// active and we'll 'continue' in the loop below).
int num_blocks_this_iter = min(iter + 1, num_s_blocks);
// For the block with s_block_begin == 0 and t_block_begin == 0 (for
// easy illustration), px_buf[s][t] will contain exp(px[s - 1][t]); or 0
// for out-of-range indexes.
// for out-of-range indexes into px.
// Likewise, py_buf[s][t] will contain exp(py[s][t - 1]).
__shared__ scalar_t px_buf[BLOCK_SIZE][BLOCK_SIZE],
py_buf[BLOCK_SIZE][BLOCK_SIZE];
// 1st row/col of p_buf correspond to the previous blocks, or to an edge case.
// So, again for this origin block, p_buf[s][t] corresponds to exp(p[s - 1][t
// - 1] - normalizer); or 0 for out-of-range values.
// p_buf[s][t] == exp(p[s+s_block_begin-1][t+t_block_begin-1] - normalizer).
// 1st row/col of p_buf correspond to the previously computed blocks (lower
// `iter`), or to negative indexes into p. So, for the origin block,
// p_buf[s][t] corresponds to exp(p[s - 1][t - 1] - normalizer); or 0 for
// out-of-range values.
__shared__ scalar_t p_buf[BLOCK_SIZE + 1][BLOCK_SIZE + 1];
// boundary_buf will be used to store the b'th row of `boundary` if we have
// boundary information supplied.
// boundary information supplied; or (0, 0, S, T) otherwise.
__shared__ int64_t boundary_buf[4];
if (threadIdx.x == 0) {
......@@ -157,69 +158,70 @@ void mutual_information_kernel(
boundary_buf[3] = T;
}
// batch_block_iter iterates over both batch elements (index b), and block
// indexes in the range [0..num_blocks_this_iter-1]
// batch_block_iter iterates over batch elements (index b) and block
// indexes in the range [0..num_blocks_this_iter-1], combining both
// batch and block indexes.
for (int batch_block_iter = blockIdx.x;
batch_block_iter < B * num_blocks_this_iter;
batch_block_iter += gridDim.x) {
int b = batch_block_iter % B,
block = batch_block_iter / B;
int s_block_begin = block * BLOCK_S_SIZE,
t_block_begin = (iter - block) * BLOCK_T_SIZE;
int block = batch_block_iter / B,
b = batch_block_iter % B; // b is the index into the batch
// Note: `block` can be no greater than `iter` because num_blocks_this_iter
// <= iter + 1, so iter - block >= 0.
int s_block_begin = block * BLOCK_SIZE,
t_block_begin = (iter - block) * BLOCK_SIZE;
bool is_origin_block = (s_block_begin * t_block_begin == 0);
if (threadDim.x < 4 && boundary.size(0) != 0)
if (boundary.size(0) != 0 && threadIdx.x < 4)
boundary_buf[threadDim.x] = boundary[b][threadDim.x];
__syncthreads();
int s_begin = boundary_buf[0],
t_begin = boundary_buf[1],
s_end = boundary_buf[2],
t_end = boundary_buf[3];
s_block_begin += s_begin;
t_block_begin += t_begin;
// block_S and block_T are the actual sizes of this block, no greater than
// (BLOCK_SIZE, BLOCK_SIZE) but possibly less than that if we are towards
// the end of the sequence.
// The last element of the output matrix p we write is (s_end, t_end),
// i.e. the one-past-the-end index is (s_end + 1, t_end + 1).
// block_S and block_T are the actual sizes of this block (the block of `p`
// that we will write), no greater than (BLOCK_SIZE, BLOCK_SIZE) but
// possibly less than that if we are towards the end of the sequence. The
// last element in the output matrix p that we need to write is (s_end,
// t_end), i.e. the one-past-the-end index is (s_end + 1, t_end + 1).
int block_S = min(BLOCK_SIZE, s_end + 1 - s_block_begin),
block_T = min(BLOCK_SIZE, t_end + 1 - t_block_begin);
if (block_S <= 0 || block_T <= 0)
continue;
bool is_origin_block = (s_block_begin * t_block_begin == 0);
// Load px_buf and py_buf. We exponentiate; the assumption is that they most likely
// won't overflow or underflow, but if they do overflow we'll detect it later; we'll
// also detect certain kinds of underflow.
for (int i = threadDim.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) {
// Load px_buf and py_buf. We exponentiate; the assumption is that they
// most likely won't overflow or underflow, but if they do overflow we'll
// detect it later; we'll also detect certain kinds of underflow.
for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) {
int s_in_block = i / BLOCK_SIZE,
t_in_block = i % BLOCK_SIZE,
s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin;
// the comparisons with S and T below just make sure we don't access
// out-of-memory regions; they do not guarantee we are in the range given
// by s_begin, s_end and so on. Note: comparing as unsigned int makes sure
// the index is nonnegative.
// comparing as unsigned int makes sure the index is nonnegative.
scalar_t this_px = 0.0;
if (static_cast<unsigned int>(s - 1) < static_cast<unsigned int>(S) &&
t <= T)
if (static_cast<unsigned int>(s - 1) < static_cast<unsigned int>(s_end) &&
t <= t_end)
this_px = exp(px[b][s - 1][t]);
px_buf[s_in_block][t_in_block] = this_px;
scalar_t this_py = 0.0;
if (static_cast<unsigned int>(t - 1) < static_cast<unsigned int>(T) &&
s <= S)
if (static_cast<unsigned int>(t - 1) < static_cast<unsigned int>(t_end) &&
s <= s_end)
this_py = exp(py[b][s][t - 1]);
py_buf[s_in_block][t_in_block] = this_py;
}
// Load the 1st row and column of p_buf (except element[0][0] is not needed).
// Remember: p_buf[s][t] corresponds to exp(p[s + s_block_begin - 1][t + t_block_begin - 1] - normalizer.
// Load the 1st row and 1st column of p_buf (except element[0][0] is not
// needed). This is the context from previously computed blocks of the
// image. Remember: p_buf[s][t] will correspond to exp(p[s + s_block_begin -
// 1][t + t_block_begin - 1] - normalizer.
if (threadIdx.x < 64) { // 64 == warp size. First half of threads...
if (threadIdx.x <= BLOCK_SIZE) {
// s_in_p_buf are simply the indexes into p_buf
......@@ -227,16 +229,14 @@ void mutual_information_kernel(
t_in_p_buf = 0,
s = s_in_p_buf + s_block_begin - 1,
t = t_in_p_buf + t_block_begin - 1;
// The if-statement below just guards against out-of-range memory
// accesses, it does not guarantee that we really need these values.
scalar_t this_p = -INFINITY;
if (static_cast<unsigned int>(s) < static_cast<unsigned int>(S) &&
static_cast<unsigned int>(t) < static_cast<unsigned int>(T))
this_p = p[s + s_block_begin][s + t_block_begin];
if (static_cast<unsigned int>(s) <= static_cast<unsigned int>(s_end) &&
static_cast<unsigned int>(t) <= static_cast<unsigned int>(t_end))
this_p = p[b][s][t];
p_buf[threadIdx.x][0] = this_p;
}
} else { // Another warp handles the other leg
if (threadIdx.x - 64 <= BLOCK_SIZE) {
if (int(threadIdx.x) - 64 <= BLOCK_SIZE) {
int s_in_p_buf = 0,
t_in_p_buf = threadIdx.x - 64,
s = s_in_p_buf + s_block_begin - 1,
......@@ -244,19 +244,19 @@ void mutual_information_kernel(
// The if-statement below just guards against out-of-range memory
// accesses, it does not guarantee that we really need these values.
scalar_t this_p = -INFINITY;
if (static_cast<unsigned int>(s) < static_cast<unsigned int>(S) &&
static_cast<unsigned int>(t) < static_cast<unsigned int>(T))
this_p = p[s + s_block_begin][s + t_block_begin];
if (static_cast<unsigned int>(s) <= static_cast<unsigned int>(s_end) &&
static_cast<unsigned int>(t) <= static_cast<unsigned int>(t_end))
this_p = p[b][s][t];
p_buf[threadIdx.x][0] = this_p;
}
}
__syncthreads();
// We read p_buf in log-space; subtract 'normalizer', which mathematically
// could be any finite number, to get in a reasonable range of probabilities,
// and then exponentiate. We'll do everything in non-log space, and later
// take a log before we write out the data.
// We read p_buf in log-space; we now subtract 'normalizer', which
// mathematically could be any finite number, to get it in a range close to
// zero, and then exponentiate. We'll do everything in non-log space, for
// speed, and later take a log before we write out the data.
scalar_t normalizer = (is_origin_block ? 0.0 :
max(px_buf[0][1], px_buf[1][0]));
......@@ -265,50 +265,55 @@ void mutual_information_kernel(
// and we'll overwrite with 1.0 if there is a panic situation due to
// overflow.
if (threadIdx.x <= BLOCK_SIZE) {
if (threadIdx.x == 0) {
// p_buf[0][0] is never used for its normal purpose; we set it to zero.
// p_buf[0][0] is never used for its normal purpose; we set it to zero
// p_buf[0][0] = 0.0; <-- for search purposes.
// We'll later write an infinity there if something goes wrong, as a
// 'panic' indicator.
p_buf[threadIdx.x][0] = (threadIdx.x == 0 ? 0.0 :
exp(p_buf[threadIdx.x][0] - normalizer));
}
} else if ((int)threadIdx.x - 64 < BLOCK_SIZE) {
} else if (int(threadIdx.x) - 64 < BLOCK_SIZE) {
// this happens in a different warp so can be in parallel to the code above.
p_buf[0][threadIdx.x + 1] = exp(p_buf[0][threadIdx.x + 1] - normalizer);
}
if (threadIdx.x == 0) {
if (threadIdx.x == 0 && is_origin_block) {
// This if-statement is an optimization and modification of the loop below
// for the value i == 0, i.e. inner-iteration == 0. The modification
// is to use 0.0 if this is the "origin block" with s_block_begin == 0 and
// t_block_begin == 0. This corresponds to the probability of the pair of
// sequences of length (0, 0).
p_buf[1][1] = (is_origin_block ? 0.0 :
// for the value i == 0, i.e. inner-iteration == 0. The modification is
// to set p_buf to 1.0 = exp(0.0) if this is the "origin block",
// i.e. s == s_begin, t == t_begin. This corresponds to the
// probability of the pair of sequences of length (0, 0).
p_buf[1][1] = (is_origin_block ? 1.0 :
p_buf[0][1] * px_buf[0][0] +
p_buf[1][0] * py_buf[0][0]);
}
scalar_t p_buf_s1_t; // This is for an optimization.
if (i < BLOCK_SIZE) {
scalar_t p_buf_s1_t; // This is for an optimization to avoid one
// shared-memory read/write in the loop below. it
// represents p_buf[s + 1][t]; the first time we
// access this, it will be for t == 0, except for
// thread 0 when we first need it for t == 1.
if (threadIdx.x < BLOCK_SIZE) {
int s = threadIdx.x;
p_buf_s1_t = p_buf[s + 1][0];
}
for (int i = 1; i < block_S + block_T; i++) {
// i is the inner iteration, which corresponds to the (s + t) indexes of the
// elements within the block that we write. So i == 0 writes positions
// (s, t) == (0, 0); i == 1 writes (0, 1) and (1, 0); i == 2 writes
// (0, 2), (1, 1) and (2, 1); and so on.
// Note: not many threads participate in this part, only up to BLOCK_SIZE
// at most. Unfortunately we couldn't figure out a very meaningful way
// for more threads to do work, that looked like it would really spead
// things up.
p_buf_s1_t = p_buf[s + 1][threadIdx.x == 0 ? 1 : 0];
}
for (int i = 1; i < block_S + block_T - 1; ++i) {
// i is the inner iteration, which corresponds to the (s + t) indexes of
// the elements within the block that we write. So i == 0 writes
// positions (s, t) == (0, 0) (but we treated i == 0 as a special case
// above); i == 1 writes (0, 1) and (1, 0); i == 2 writes (0, 2), (1, 1)
// and (2, 1); and so on. Note: not many threads participate in this
// part, only up to BLOCK_SIZE at most. Unfortunately we couldn't figure
// out a very meaningful way for more threads to do work, that looked like
// it would really spead things up.
// So this kernel does (2 * BLOCK_SIZE) iterations, which may seem a lot,
// but we do at least do the I/O in an efficient way and keep the
// inner loop simple and fast (e.g. no exp() or log()).
int s = threadIdx.x,
t = i - s;
if (t >= 0) {
if (static_cast<unsigned int>(t) < static_cast<unsigned int>(block_T)) {
// p_buf is indexed by s + 1 and t + 1 because it has an extra initial
// row and column for context from previous blocks. Taking into account
// the way these buffers relate to the tensors p, px and py, and
......@@ -320,7 +325,7 @@ void mutual_information_kernel(
//
// where you can see that apart from the offsets of tbb and sbb, this is
// the same as the recursion defined for p in
// mutual_information.py:mutual_information_recursion().
// mutual_information.py:mutual_information_recursion(); and (eq. 0) above.
#if 1
p_buf[s + 1][t + 1] = p_buf[s][t + 1] * px_buf[s][t] + p_buf[s + 1][t] * py_buf[s][t];
#else
......@@ -328,18 +333,30 @@ void mutual_information_kernel(
// this #if/#else) where we keep p_buf[s + 1][t] in a register to avoid
// the need for a load from shared memory.
p_buf_s1_t = p_buf[s][t + 1] * px_buf[s][t] + p_buf_s1_t * py_buf[s][t];
// The next time this thread reads p_buf_s1_t, t will be one greater,
// so p_buf_s1_t will contain p_buf[s + 1][t]. The first time this
// thread uses p_buf_s1_t is when t == 0, except for thread 0 where
// the 1st item accessed is for s == 0, t == 1.
p_buf[s + 1][t + 1] = p_buf_s1_t;
#endif
// We don't need to do __syncthreads() in this loop because all the
// threads that are active are in the same warp. (However, in future,
// if NVidia changes some things, we might need to sync here).
}
__syncthreads();
}
// Write out the data.
for (int i = threadDim.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) {
int t = i % BLOCK_SIZE, s = i / BLOCK_SIZE;
if (s < block_S && t < block_T) {
float this_p = p_buf[s + 1][t + 1];
p[b][s + s_block_begin][t + t_block_begin] = normalizer + log(this_p);
// Write out the data to p; check that nothing has gone out of numerical
// range, and write 'panic' flag if it has.
for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) {
int s_in_block = i / BLOCK_SIZE,
t_in_block = i % BLOCK_SIZE,
s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin;
if (s_in_block < block_S && t_in_block < block_T) {
float this_p = p_buf[s_in_block + 1][t_in_block + 1];
p[b][s][t] = normalizer + log(this_p);
// If this_p is infinity, NaN or zero...
if (this_p - this_p != 0 || this_p == 0)
p_buf[0][0] = 1.0; // This is a "panic" flag.
}
......@@ -351,27 +368,31 @@ void mutual_information_kernel(
// Write `ans`, if this is the final (top-right) block in its sequence
// Logically, the following equation corresponds to:
// ans[b] = p[b][s_end][t_end]
if (s_block_begin + S > s_end && t_block_begin + T > t_end)
ans[b] = normalizer + log(p_buf[s_end - s_block_begin + 1][t_end - t_block_begin + 1]);
if (s_block_begin + block_S - 1 == s_end &&
t_block_begin + block_T - 1 == t_end) {
// you could read block_S below as block_S - 1 + 1, meaning,
// it's the last index in a block of size block_S, but the indexes into
// p_buf have a "+ 1". Likewise for block_T.
ans[b] = normalizer + log(p_buf[block_S][block_T]);
}
}
if (p_buf[0][0] != 0.0) {
// "panic" flag set. We need to re-do the computation using log-add.
// The "panic" flag is set. We need to re-do the computation using log-add.
// This time we won't use the buffers, we'll just load and save from main
// memory. This code should very rarely be reached; and anyway, caching
// should help us quite a bit.
for (int i = 0; i < 2 * BLOCK_SIZE; i++) {
int block_s = threadIdx.x,
block_t = i - block_s;
if (static_cast<unsigned int>(t) < static_cast<unsigned int>(block_T) &&
block_s < block_S) {
int s = block_s + s_block_begin,
t = block_t + t_block_begin;
for (int i = 0; i < block_S + block_T - 1; ++i) {
int s_in_block = threadIdx.x,
t_in_block = i - block_s;
if (s_in_block < block_S &&
static_cast<unsigned int>(t_in_block) < static_cast<unsigned int>(block_T)) {
int s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin;
float p_s1 = (s == 0 ? -INFINITY : p[b][s - 1][t]),
this_px = (s == 0 ? -INFINITY : px[b][s - 1][t]),
p_t1 = (t == 0 ? -INFINITY : p[b][s][t - 1]),
this_px = px[b][s][t], this_py = py[b][s][t];
this_py = (t == 0 ? -INFINITY : py[b][s][t - 1]);
float this_p = LogAdd(p_s1 + this_px,
p_t1 + this_py);
if (i == 0 && is_origin_block)
......@@ -382,7 +403,8 @@ void mutual_information_kernel(
if (threadIdx.x == 0) {
// Write `ans`, if this is the final (top-right) block in its sequence.
// This is only reached in the 'panic situation' where we had overflow.
if (s_block_begin + S > s_end && t_block_begin + T > t_end)
if (s_block_begin + block_S - 1 == s_end &&
t_block_begin + block_T - 1 == t_end)
ans[b] = p[b][s_end][t_end];
}
}
......@@ -402,8 +424,8 @@ void mutual_information_kernel(
ep[b][s][t - 1] * epy[b][s][t - 1]. (eq. 1)
(A)
First we consider the part that involves recursion, i.e. the part involving only gradients of
ep. The backprop involving ep only would be:
First we consider the part of the backprop that requires recursion or iteration,
i.e. the part involving only gradients of ep. This is:
ep_grad[b][s - 1][t] += ep_grad[b][s][t] * epx[b][s - 1][t]
ep_grad[b][s][t - 1] += ep_grad[b][s][t] * epy[b][s][t - 1].
.. and if we add 1 to the s index of the first equation above and 1 to the
......@@ -412,7 +434,8 @@ void mutual_information_kernel(
ep_grad[b][s][t] = ep_grad[b][s + 1][t] * epx[b][s][t] +
ep_grad[b][s][t + 1] * epy[b][s][t].
Now, if ep = exp(p), then ep_grad == dy/dep == dy/dp dp/dep == dy/dp / (dep/dp) == dy/dp / exp(p)
Now, if ep = exp(p), and y is the loss function we are backprop'ing,
then ep_grad == dy/dep == dy/dp dp/dep == dy/dp / (dep/dp) == dy/dp / exp(p)
== dy/dp / ep. == p_grad / ep.
I.e. ep_grad = p_grad / ep.
So we can write the above as:
......@@ -435,7 +458,7 @@ void mutual_information_kernel(
Using, similar to the above, ep_grad = p_grad / ep, and similarly,
epx_grad = px_grad / epx and epy_grad = py_grad / epy, and writing exp(p) for p and so on,
the above becomes
the above becomes:
px_grad[b][s][t] / exp(px[b][s][t]) = p_grad[b][s + 1][t] / exp(p[b][s + 1][t]) * exp(p[b][s][t])
py_grad[b][s][t] / exp(py[b][s][t]) = p_grad[b][s][t + 1] / exp(p[b][s][t + 1]) * exp(p[b][s][t])
......@@ -450,11 +473,11 @@ void mutual_information_kernel(
yderiv[b][s][t] := exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1]) (eq. 5)
.. and note that these quantities are <= 1 so there is no problem doing
the exponentiation. So the recursion can be simplified as:
the exponentiation. So the recursion can be simplified as from eqs. (2, 3a, 3b), as:
p_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] +
p_grad[b][s][t + 1] * yderiv[b][s][t] (eq. 6)
px_grad[b][s][t] = p_grad[b][s + 1][t] * yderiv[b][s][t] (eq. 7)
px_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] (eq. 7)
py_grad[b][s][t] = p_grad[b][s][t + 1] * yderiv[b][s][t] (eq. 8)
(It might seem like we could just reuse px_grad and py_grad for (eq. 6), but it's
......@@ -462,8 +485,9 @@ void mutual_information_kernel(
write to shared memory within the loop that's the limiting factor.)
The backward pass will be slightly different from the forward pass in terms of
how we store p (and p_grad), because for writing a particular block of p_grad, we
need context on the top and right instead of the bottom and left.
how we store and index p (and p_grad), because for writing a particular block
of p_grad, we need context on the top and right instead of the bottom and
left. So there are offsets of 1.
*/
template <typename scalar_t>
__global__
......@@ -472,8 +496,6 @@ void mutual_information_backward_kernel(
torch::PackedTensorAccessor32<scalar_t, 3> py, // B, S + 1, T.
torch::PackedTensorAccessor32<scalar_t, 3> p, // B, S + 1, T + 1. Produced in forward pass.
torch::PackedTensorAccessor32<scalar_t, 1> ans_grad, // [B]. This is an input.
torch::PackedTensorAccessor32<scalar_t, 1> ans_grad_compare, // [B]. A value will be written to here which
// should ideally equal ans_grad.
torch::PackedTensorAccessor32<scalar_t, 3> p_grad, // B, S + 1, T + 1. This is a temporary.
torch::PackedTensorAccessor32<scalar_t, 3> px_grad, // B, S, T + 1.
torch::PackedTensorAccessor32<scalar_t, 3> py_grad, // B, S + 1, T.
......@@ -483,16 +505,18 @@ void mutual_information_backward_kernel(
// be any sufficiently large number but will actually be:
// num_s_blocks + num_t_blocks - 1 where num_s_blocks = S /
// BLOCK_SIZE + 1 and num_t_blocks = T / BLOCK_SIZE + 1
bool overwrite_ans_grad) { // If true, overwrite ans_grad with a value
// which, if everything is working correctly,
// should be identical or very close to the
// value of ans_grad that was passed in.
bool overwrite_ans_grad) { // If overwite_ans_grad == true, this function
// will overwrite ans_grad with a value which,
// if everything is working correctly, should be
// identical or very close to the value of
// ans_grad that was passed in.
const int B = px.size(0),
S = px.size(1),
T = py.size(2);
// For statements that are the same as the forward pass, we are omitting some comments
// what we made there. We'll focus, in the comments, on differences from the forward pass.
// For statements that are the same as the forward pass, we are omitting some
// comments. We'll focus, in the comments, on differences from the forward
// pass.
const int num_s_blocks = S / BLOCK_SIZE + 1,
num_t_blocks = T / BLOCK_SIZE + 1,
num_blocks_this_iter = min(iter + 1, num_s_blocks);
......@@ -502,29 +526,33 @@ void mutual_information_backward_kernel(
// but then modified to store the "xderiv" and "yderiv" values defined
// in (eq. 5) and (eq. 6) above. For out-of-range values, we'll write 0.0
// here.
// Initially (before xderiv/yderiv are written):
// px_buf[s][t] contains px[s+s_block_begin][t+t_block_begin];
// py_buf[s][t] contains py[s+s_block_begin][t+t_block_begin].
// Later (see eq. 4 and eq. 5):
// px_buf[s][t] contains exp(p[b][ss][tt] + px[b][ss][tt] - p[b][ss + 1][tt]),
// py_buf[s][t] contains exp(p[b][ss][tt] + py[b][ss][tt] - p[b][ss][tt + 1]
// where ss == s + s_block_begin, tt = t + t_block_begin.
// Unlike in the forward code, there is no offset of 1 in the indexes.
__shared__ scalar_t px_buf[BLOCK_SIZE][BLOCK_SIZE],
py_buf[BLOCK_SIZE][BLOCK_SIZE];
// p_buf is initially used to store p, and then (after we are done putting
// xderiv and yderiv into px_buf and py_buf) it is repurposed to store
// p_grad.
//
// Unlike in the forward pass, p_buf has the same numbering as px_buf and
// py_buf not offset by 1: e.g., for the origin block, p_buf[0][0] refers
// to p[0][0] and not p[-1][-1]. The p_buf block is larger by 1 than
// py_buf, it's not offset by 1: e.g., for the origin block, p_buf[0][0]
// refers to p[0][0] and not p[-1][-1]. The p_buf block is larger by 1 than
// the block for px_buf and py_buf; unlike in the forward pass, we store
// context on the top right, not the bottom left, i.e. the elements at
// context on the top and right, not the bottom and left, i.e. the elements at
// (one past the largest indexes in the block).
//
// For out-of-range elements of p_buf, we'll put zero.
__shared__ scalar_t p_buf[BLOCK_SIZE + 1][BLOCK_SIZE + 1];
// boundary_buf will be used to store the b'th row of `boundary` if we have
// boundary information supplied.
// boundary information supplied; or (0, 0, S, T) if not.
__shared__ int64_t boundary_buf[4];
if (threadIdx.x == 0) {
......@@ -541,13 +569,13 @@ void mutual_information_backward_kernel(
for (int batch_block_iter = blockIdx.x;
batch_block_iter < B * num_blocks_this_iter;
batch_block_iter += gridDim.x) {
int b = batch_block_iter % B,
block = batch_block_iter / B;
int s_block_begin = block * BLOCK_S_SIZE,
t_block_begin = (iter - block) * BLOCK_T_SIZE;
int block = batch_block_iter / B,
b = batch_block_iter % B;
int s_block_begin = block * BLOCK_SIZE,
t_block_begin = (iter - block) * BLOCK_SIZE;
if (threadDim.x < 4 && boundary.size(0) != 0)
boundary_buf[threadDim.x] = boundary[b][threadDim.x];
if (threadIdx.x < 4 && boundary.size(0) != 0)
boundary_buf[threadIdx.x] = boundary[b][threadIdx.x];
__syncthreads();
int s_begin = boundary_buf[0],
......@@ -560,68 +588,69 @@ void mutual_information_backward_kernel(
// block_S and block_T are the actual sizes of this block, no greater than
// (BLOCK_SIZE, BLOCK_SIZE) but possibly less than that if we are towards
// the end of the sequence.
// The last element of the output matrix p we write is (s_end, t_end),
// i.e. the one-past-the-end index is (s_end + 1, t_end + 1).
// The last element of the output matrix p_grad we write is (s_end, t_end),
// i.e. the one-past-the-end index of p_grad is (s_end + 1, t_end + 1).
int block_S = min(BLOCK_SIZE, s_end + 1 - s_block_begin),
block_T = min(BLOCK_SIZE, t_end + 1 - t_block_begin);
if (block_S <= 0 || block_T <= 0)
continue;
// Load px_buf and py_buf. At this point they just contain px and py
// Load px_buf and py_buf. At this point we just set them to the px and py
// for this block.
for (int i = threadDim.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) {
for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) {
int s_in_block = i / BLOCK_SIZE,
t_in_block = i % BLOCK_SIZE,
s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin;
// We let ps and py default to -infinity if they are out of range, which will
// We let px and py default to -infinity if they are out of range, which will
// cause xderiv and yderiv for out-of-range values to be zero, and cause
// correct behavior in edge cases (for the top and right blocks).
// The issue is that p and p_grad are of larger size than px and py.
scalar_t this_px = -INFINITY;
if (s < s_end && t <= t_end)
this_px = px[b][s - 1][t];
this_px = px[b][s][t];
px_buf[s_in_block][t_in_block] = this_px;
scalar_t this_py = -INFINITY;
if (s <= s_end && t < t_end)
this_py = py[b][s][t - 1];
this_py = py[b][s][t];
py_buf[s_in_block][t_in_block] = this_py;
}
// load p. This time we loop over the exact indexes we need. Above
// we looped to BLOCK_SIZE * BLOCK_SIZE rather than block_S and block_T
// because having power-of-2 arrangement of threads may be helpful
// for aligned reads, but here the loop is up to (BLOCK_SIZE + 1) * (BLOCK_SIZE + 1)
// which is not a power of 2, so that is not a concern here.
for (int i = threadDim.x; i < (BLOCK_SIZE + 1) * (BLOCK_SIZE + 1); i += blockDim.x) {
int s_in_block = i / (BLOCK_SIZE + 1), // 0 <= s_in_block <= block_S
t_in_block = i % (BLOCK_SIZE + 1), // 0 <= t_in_block <= block_T
// load p. We could use BLOCK_SIZE + 1 here, but we use + 8 to hopefully keep
// reads more aligned.
for (int i = threadIdx.x; i < (BLOCK_SIZE + 1) * (BLOCK_SIZE + 1); i += blockDim.x) {
int s_in_block = i / (BLOCK_SIZE + 1),
t_in_block = i % (BLOCK_SIZE + 1),
s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin;
// Setting 0.0 for out-of-bounds elements, together with setting
// -INFINITY for out-of-bounds elements of px_buf and py_buf, will
// ensure that we do the right thing in top and right edge cases,
// i.e. that no derivatives will be propagated from out-of-bounds points.
p_buf[s_in_block][t_in_block] = (s <= s_end && t <= t_end ?
p[b][s][t] : 0.0);
// i.e. that no derivatives will be propagated from out-of-bounds points
// because the corresponding xderiv and yderiv values will be zero.
scalar_t this_p = 0.0;
if (s <= s_end && t <= t_end)
this_p = p[b][s][t];
p_buf[s_in_block][t_in_block] = this_p;
}
// Set xderiv and yderiv; see (eq. 4) and (eq. 5).
for (int i = threadDim.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) {
for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) {
// We can apply this formula to the entire block even if we are processing
// a partial block; elements outside the partial block will not be used so
// their values don't matter, and elements just out
int t = i % BLOCK_SIZE, s = i / BLOCK_SIZE;
// a partial block; we have ensured that x_buf and y_buf contain -infinity,
// and p contains 0, for out-of-range elements, so we'll get x_buf and y_buf
// containing 0 after applying the followin formulas.
int s = i / BLOCK_SIZE,
t = i % BLOCK_SIZE;
// Mathematically the following is doing:
// xderiv[b][s][t] := exp(p[b][s][t] + px[b][s][t] - p[b][s + 1][t])
// (with an offset on the s and t indexes)
px_buf[s][t] = exp(px_buf[s][t] + px_buf[s][t] - p_buf[s + 1][t]);
px_buf[s][t] = exp(p_buf[s][t] + px_buf[s][t] - p_buf[s + 1][t]);
// Mathematically the following is doing:
// yderiv[b][s][t] := exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1])
// (with an offset on the s and t indexes)
py_buf[s][t] = exp(px_buf[s][t] + py_buf[s][t] - p_buf[s][t + 1]);
py_buf[s][t] = exp(p_buf[s][t] + py_buf[s][t] - p_buf[s][t + 1]);
}
// Load p_grad for the top and right elements in p_buf: i.e. for elements
......@@ -630,7 +659,8 @@ void mutual_information_backward_kernel(
// never be accessed.
// These are the p_grad values computed by previous instances of this kernel
// If this is one of the top or right blocks, some or all of the p_grad
// values we'd be reading here will be out of range, and we use zeros.
// values we'd be reading here will be out of range, and we use zeros
// to ensure no gradient gets propagated from those positions.
if (threadIdx.x < block_S) {
int s_in_block = threadIdx.x,
t_in_block = block_T,
......@@ -638,34 +668,33 @@ void mutual_information_backward_kernel(
t = t_in_block + t_block_begin;
p_buf[s_in_block][t_in_block] = (
s <= s_end && t <= t_end ? p_grad[s][t] : 0.0);
} else if (static_cast<unsigned int>(threadIdx.x - 64) <
} else if (static_cast<unsigned int>((int)threadIdx.x - 64) <
static_cast<unsigned int>(block_T)) {
// casting to unsigned before the comparison tests for both negative and
// out-of-range values of (int)threadIdx.x - 64.
int s_in_block = block_S,
t_in_block = threadIdx.x - 64,
t_in_block = (int)threadIdx.x - 64,
s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin;
p_buf[s_in_block][t_in_block] = (
s <= s_end && t <= t_end ? p_grad[s][t] : 0.0);
}
// The number of inner iterations, i.e. iterations inside this
// kernel, is this_num_inner_iters. The highest iteration,
// corresponding to the highest-indexed value of p_buf that
// we need to set,
// corresponds to p_buf[block_S - 1][block_T - 1],
// and the iteration number is the sum of these indexes, i.e.
// (block_S - 1) + (block_T - 1).
// The highest-numbered value in p_buf that we need (corresponding,
// of course, to p_grad), is:
// p_buf[block_S - 1][block_T - 1],
// and the inner iteration number (i) on which we set this is the sum of
// these indexes, i.e. (block_S - 1) + (block_T - 1).
bool is_final_block = (s_block_begin + block_S == s_end + 1 &&
t_block_begin + block_T == t_end + 1);
int first_iter = block_S + block_T - 2;
if (is_final_block) {
// The following statement, mathematically, corresponds to:
// p_grad[b][s_end][t_end] = ans_grad[b] Normally this element of p_buf
// would be set by the first iteration of the loop below, so if it's set
// this way we have to decrement first_iter to prevent it being
// overwritten.
// The following statement corresponds to:
// p_grad[b][s_end][t_end] = ans_grad[b]
// Normally this element of p_buf would be set by the first iteration of
// the loop below, so if it's set this way we have to decrement first_iter
// to prevent it from being overwritten.
p_buf[block_S - 1][block_T - 1] = ans_grad[b];
--first_iter;
}
......@@ -675,7 +704,8 @@ void mutual_information_backward_kernel(
t = i - threadIdx.x;
if (t >= 0) {
// The following statement is really operating on the gradients;
// it corresponds to (eq. 6) defined above, i.e.:
// it corresponds, with offsets of s_block_begin and t_block_begin
// on the indexes, to (eq. 6) defined above, i.e.:
// p_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] +
// p_grad[b][s][t + 1] * yderiv[b][s][t]
p_buf[s][t] = (p_buf[s + 1][t] * px_buf[s][t] +
......@@ -684,17 +714,19 @@ void mutual_information_backward_kernel(
}
// Write out p_grad, px_grad and py_grad.
for (int i = threadDim.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) {
int t_in_block = i % BLOCK_SIZE,
s_in_block = i / BLOCK_SIZE,
for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) {
int s_in_block = i / BLOCK_SIZE,
t_in_block = i % BLOCK_SIZE,
s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin;
// s_end and t_end are the one-past-the-end of the (x,y) sequences, but
// the one-past-the-end element of p_grad would be (s_end + 1, t_end + 1).
if (t <= t_end && s <= s_end) {
p_grad[b][s][t] = p_buf[s_in_block][t_in_block];
if (s < s_end) { // write px_grad, which is of shape [B][S][T + 1]
// From (eq. 7):
// px_grad[b][s][t] = p_grad[b][s + 1][t] * yderiv[b][s][t]
// px_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t]
px_grad[b][s][t] = (p_buf[s_in_block + 1][t_in_block] *
px_buf[s_in_block][t_in_block]);
}
......@@ -741,7 +773,7 @@ torch::Tensor mutual_information_cuda(torch::Tensor px,
// num_threads and num_blocks and BLOCK_SIZE can be tuned.
// (however, num_threads may not be less than 128).
int num_threads = 128,
num_blocks = 128,
num_blocks = 256,
BLOCK_SIZE = 32;
// The blocks cover the 'p' matrix, which is of size (B, S+1, T+1),
......@@ -802,14 +834,18 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
TORCH_CHECK(ans_grad.size(0) == b);
bool has_boundary = (bool)optional_boundary;
torch::Tensor p_grad = torch::empty({B, S + 1, T + 1}, opts),
px_grad = torch::empty({B, S, T + 1}, opts),
py_grad = torch::empty({B, S + 1, T}, opts),
px_grad = (has_boundary ? torch::zeros({B, S, T + 1}, opts) :
torch::empty({B, S, T + 1}, opts)),
py_grad = (has_boundary ? torch::zeros({B, S, T + 1}, opts) :
torch::empty({B, S + 1, T}, opts));
// num_threads and num_blocks and BLOCK_SIZE can be tuned.
// (however, num_threads may not be less than 128).
const int num_threads = 128,
num_blocks = 128,
num_blocks = 256,
BLOCK_SIZE = 32;
// The blocks cover the 'p' matrix, which is of size (B, S+1, T+1),
......@@ -819,7 +855,7 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
num_t_blocks = T / BLOCK_SIZE + 1,
num_iters = num_s_blocks + num_t_blocks - 1;
if ((bool)optional_boundary)
if (has_boundary)
TORCH_CHECK(optional_boundary.value().device().is_cuda(),
"boundary information must be in CUDA tensor");
else
......@@ -838,5 +874,6 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
iter,
overwrite_ans_grad);
}
std::cout << "p_grad = " << p_grad;
return std::vector<torch::Tensor>({px_grad, py_grad});
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment