Commit 3c1ec347 authored by Daniel Povey's avatar Daniel Povey
Browse files

Get it to a stage where it looks like it might compile

parent 8ed6deff
......@@ -44,66 +44,72 @@ except ImportError:
def _mutual_information_forward_dispatcher(px: torch.Tensor, py: torch.Tensor,
boundaries: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
boundaries: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
if input.is_cuda:
if torch_mutual_information_cuda is None:
raise EnvironmentError(f'Failed to load native CUDA module')
return torch_mutual_information_cuda.mutual_information_cuda(
px, py, boundaries, q)
px, py, boundaries, p)
else:
return torch_mutual_information_cpu.mutual_information_cpu(
px, py, boundaries, q)
px, py, boundaries, p)
def _mutual_information_backward_dispatcher(px: torch.Tensor, py: torch.Tensor,
boundaries: torch.Tensor, q: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
boundaries: torch.Tensor, p: torch.Tensor,
ans_grad: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if px.is_cuda:
if torch_mutual_information_cuda is None:
raise EnvironmentError(f'Failed to load native CUDA module')
return tuple(torch_mutual_information_cuda.mutual_information_backward_cuda(
px, py, boundaries, q))
overwrite_ans_grad = True
if overwrite_ans_grad:
ans_grad_copy = ans_grad.clone()
ans = tuple(torch_mutual_information_cuda.mutual_information_backward_cuda(
px, py, boundaries, p, ans_grad_copy, overwrite_ans_grad))
if overwrite_ans_grad:
if not torch.allclose(ans_grad, ans_grad_copy, rtol=1.0e-02):
print(f"Warning: possible excsssive roundoff in mutual information backward "
"recursion: {ans_grad} vs. {ans_grad_copy}");
return ans
else:
return tuple(torch_mutual_information_cpu.mutual_information_backward_cpu(
px, py, boundaries, q))
px, py, boundaries, p, ans_grad))
class MutualInformationRecursionFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, px: torch.Tensor, py: torch.Tensor, boundaries: torch.Tensor) -> torch.Tensor:
(B, S, T) = px.shape
def forward(ctx, px: torch.Tensor, py: torch.Tensor, boundaries: Optional[torch.Tensor]) -> torch.Tensor:
(B, S, T1) = px.shape
T = T1 - 1;
assert py.shape == (B, S + 1, T)
if boundaries is not None:
assert boundaries.shape == (B, 4)
# p is a tensor of shape (B, S + 1, T + 1) were p[s][t] is related to
# p is a tensor of shape (B, S + 1, T + 1) were p[s][t] is the
# the mutual information of the pair of subsequences of x and y that are of
# length s and t respectively. p[0][0] will be 0.0 and p[S][T] is
# the mutual information of the entire pair of sequences, i.e. of lengths
# S and T respectively.
# q is a rearrangement of a tensor p which is of shape (B,S,T),
# using p[b,s,t] == q[b,s+t,t]. The reason for working with this
# representation is that each row of q depends only on the previous row,
# so we can access the rows sequenctially and this leads to
# better memory access patterns. We are assuming that most likely
# T < S, which means that q should not require much more memory than p.
#
# Actually we access q beginning from 0 indexes even if `boundaries`
# has t_begin > 0 or s_begin > 0, i.e. we really access q as
# q[b, s-s_begin + t-t_begin, t-t_begin];
# note, rows of `boundaries` are [s_begin, t_begin, s_end, t_end].
# It is computed as follows (in C++ and CUDA):
# p[b,0,0] = 0.0
# p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
# p[b,s,t-1] + py[b,s,t-1])
# if s > 0 or t > 0,
# treating values with any -1 index as -infinity.
# .. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
p = torch.empty(B, S + 1, T + 1, device=px.device, dtype=px.dtype)
ans = _mutual_information_forward_dispatcher(px, py, boundaries, p)
if px.requires_grad or py.requires_grad:
ctx.save_for_backward(px, py, boundaries, q)
ctx.save_for_backward(px, py, boundaries, p)
@staticmethod
def backward(ctx, ans_grad: Tensor) -> Tuple[torch.Tensor, torch.Tensor, None]:
(px, py, boundaries, q) = ctx.saved_tensors
(px_grad, py_grad) = _mutual_information_backward_dispatcher(px, py, boundaries, q)
(px, py, boundaries, p) = ctx.saved_tensors
(px_grad, py_grad) = _mutual_information_backward_dispatcher(px, py, boundaries, p, ans_grad)
return (px_grad, py_grad, None)
......
......@@ -151,11 +151,16 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T);
TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
torch::Tensor p_grad = torch::zeros({B, S + 1, T + 1}, opts);
bool has_boundary = (bool)optional_boundary;
torch::Tensor p_grad = torch::zeros({B, S + 1, T + 1}, opts),
px_grad = (has_boundary ? torch::zeros({B, S, T + 1}, opts) :
torch::empty({B, S, T + 1}, opts)),
py_grad = (has_boundary ? torch::zeros({B, S + 1, T}, opts) :
torch::empty({B, S + 1, T}, opts));
auto long_opts = torch::TensorOptions().dtype(torch::kInt64).device(px.device());
bool has_boundary = (bool)optional_boundary;
if (!has_boundary)
optional_boundary = torch::empty({0, 0}, long_opts);
......@@ -166,7 +171,9 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
auto px_a = px.packed_accessor32<scalar_t, 3>(),
py_a = py.packed_accessor32<scalar_t, 3>(),
p_a = p.packed_accessor32<scalar_t, 3>(),
p_grad_a = p.packed_accessor32<scalar_t, 3>();
p_grad_a = p_grad.packed_accessor32<scalar_t, 3>(),
px_grad_a = px_grad.packed_accessor32<scalar_t, 3>(),
py_grad_a = py_grad.packed_accessor32<scalar_t, 3>();
auto ans_grad_a = ans_grad.packed_accessor32<scalar_t, 1>();
......@@ -196,19 +203,17 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
// .. which obtains p_a[b][s][t - 1] from a register.
scalar_t term1 = p_a[b][s - 1][t] + px_a[b][s - 1][t],
// term2 = p_a[b][s][t - 1] + py_a[b][s][t - 1], <-- not
// actually needed..
total = p_a[b][s][t],
term1_deriv = exp(term1 - total),
term2_deriv = 1.0 - term1_deriv,
grad = p_grad_a[b][s][t],
term1_grad = term1_deriv * grad,
term2_grad = term2_deriv * grad;
// We can assign to px_grad_a here rather than add, because we
// know it's currently zero.
TORCH_CHECK(px_grad_a[b][s - 1][t] == 0);
px_grad_a[b][s - 1][t] = term1_grad;
TORCH_CHECK(p_grad_a[b][s - 1][t] == 0.0); // likewise..
p_grad_a[b][s - 1][t] = term1_grad
py_grad_a[b][s][t - 1] += term2_grad;
p_grad_a[b][s - 1][t] = term1_grad;
py_grad_a[b][s][t - 1] = term2_grad;
p_grad_a[b][s][t - 1] += term2_grad;
}
}
......@@ -239,111 +244,9 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
}
}
}));
return ans;
}
TORCH_CHECK(input.dim() == 3, "input must be 3-dimensional");
TORCH_CHECK(params.dim() == 2, "params must be 2-dimensional.");
TORCH_CHECK(params.size(1) >= 3 &&
((params.size(1) - 1) & (params.size(1) - 2)) == 0,
"params.size(1) has invalid value, must be a power of 2 plus 1.");
TORCH_CHECK(params.size(0) == input.size(1),
"params vs input channels mismatch");
TORCH_CHECK(input.sizes() == output_grad.sizes(),
"Output-grad vs. input sizes mismatch.");
TORCH_CHECK(input.device().is_cpu(), "Input must be a CPU tensor");
TORCH_CHECK(params.device().is_cpu(), "Params must be a CPU tensor");
TORCH_CHECK(output_grad.device().is_cpu(), "Output-grad must be a CPU tensor");
const int B = input.size(0),
C = input.size(1),
T = input.size(2),
N = params.size(1) - 1,
K = N / 2;
auto scalar_t = input.scalar_type();
auto opts = torch::TensorOptions().dtype(scalar_t).device(input.device());
torch::Tensor y_vals = torch::empty({C, N}, opts),
y_vals_grad = torch::zeros({C, N}, opts),
params_grad = torch::zeros({C, N + 1}, opts),
input_grad = torch::zeros({B, C, T}, opts);
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "mutual_information_backward_cpu_loop", ([&] {
auto params_a = params.accessor<scalar_t, 2>(),
params_grad_a = params_grad.accessor<scalar_t, 2>(),
y_vals_a = y_vals.accessor<scalar_t, 2>(),
y_vals_grad_a = y_vals_grad.accessor<scalar_t, 2>();
for (int c = 0; c < C; c++) {
scalar_t sum_negative = 0.0,
sum_positive = 0.0,
scale = exp(params_a[c][0]);
for (int i = 0; i < K; i++) {
scalar_t pos_scaled_param = params_a[c][1 + K + i] * scale,
neg_scaled_param = params_a[c][K - i] * scale;
y_vals_a[c][K + i] = sum_positive - pos_scaled_param * i;
sum_positive += pos_scaled_param;
sum_negative -= neg_scaled_param;
y_vals_a[c][K - i - 1] = sum_negative + neg_scaled_param * (i + 1);
}
}
auto input_a = input.accessor<scalar_t, 3>(),
output_grad_a = output_grad.accessor<scalar_t, 3>(),
input_grad_a = input_grad.accessor<scalar_t, 3>();
for (int b = 0; b < B; b++) {
for (int c = 0; c < C; c++) {
scalar_t inv_scale = exp(-params_a[c][0]);
for (int t = 0; t < T; t++) {
scalar_t input = input_a[b][c][t],
x = input * inv_scale + K,
output_grad = output_grad_a[b][c][t];
if (x < 0) x = 0;
else if (x >= N) x = N - 1;
// C++ rounds toward zero.
int n = (int) x;
// OK, at this point, 0 <= n < 2*K.
// backprop for:
// output_a[b][c][t] = input * params_a[c][n + 1] + y_vals_a[c][n];
params_grad_a[c][n + 1] += output_grad * input;
y_vals_grad_a[c][n] += output_grad;
input_grad_a[b][c][t] = output_grad * params_a[c][n + 1];
}
}
}
// Now do the backprop for the loop above where we set y_vals_a.
for (int c = 0; c < C; c++) {
scalar_t scale = exp(params_a[c][0]),
scale_grad = 0.0,
sum_negative_grad = 0.0,
sum_positive_grad = 0.0;
for (int i = K - 1; i >= 0; i--) {
// Backprop for: y_vals_a[c][K - i - 1] = sum_negative + neg_scaled_param * (i + 1):
scalar_t y_grad_neg = y_vals_grad_a[c][K - i - 1];
sum_negative_grad += y_grad_neg;
scalar_t neg_scaled_param_grad = y_grad_neg * (i + 1);
// Backprop for: sum_negative -= neg_scaled_param;
neg_scaled_param_grad -= sum_negative_grad;
// Backprop for: sum_positive += pos_scaled_param;
scalar_t pos_scaled_param_grad = sum_positive_grad;
// Backprop for: y_vals_a[c][K + i] = sum_positive - pos_scaled_param * i;
scalar_t y_grad_pos = y_vals_grad_a[c][K + i];
pos_scaled_param_grad -= i * y_grad_pos;
sum_positive_grad += y_grad_pos;
// Backprop for: pos_scaled_param = params_a[c][1 + K + i] * scale,
// and: neg_scaled_param = params_a[c][K - i] * scale;
params_grad_a[c][1 + K + i] += pos_scaled_param_grad * scale;
params_grad_a[c][K - i] += neg_scaled_param_grad * scale;
scale_grad += (pos_scaled_param_grad * params_a[c][1 + K + i] +
neg_scaled_param_grad * params_a[c][K - i]);
}
// Backprop for: scale = exp(params_a[c][0]),
params_grad_a[c][0] += scale * scale_grad;
}
}));
return std::vector<torch::Tensor>({input_grad, params_grad});
std::cout << "p_grad = " << p_grad;
return std::vector<torch::Tensor>({px_grad, py_grad});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment