Commit 77eed83f authored by Daniel Povey's avatar Daniel Povey
Browse files

Some progress, still drafting.

parent e95d7864
......@@ -73,6 +73,14 @@ class MutualInformationRecursionFunction(torch.autograd.Function):
def forward(ctx, px: torch.Tensor, py: torch.Tensor, boundaries: torch.Tensor) -> torch.Tensor:
(B, S, T) = px.shape
# p is a tensor of shape (B, S + 1, T + 1) were p[s][t] is related to
# the mutual information of the pair of subsequences of x and y that are of
# length s and t respectively. p[0][0] will be 0.0 and p[S][T] is
# the mutual information of the entire pair of sequences, i.e. of lengths
# S and T respectively.
# q is a rearrangement of a tensor p which is of shape (B,S,T),
# using p[b,s,t] == q[b,s+t,t]. The reason for working with this
# representation is that each row of q depends only on the previous row,
......@@ -85,15 +93,9 @@ class MutualInformationRecursionFunction(torch.autograd.Function):
# q[b, s-s_begin + t-t_begin, t-t_begin];
# note, rows of `boundaries` are [s_begin, t_begin, s_end, t_end].
if px.requires_grad or py.requires_grad:
q = torch.empty(B, S, T, device=px.device, dtype=px.dtype)
else:
# We don't need to store q if we are not going to do backprop, but we
# do pass in a temporary with one real row, expanded to have "fake" rows,
# which happens to be convenient for the CPU implementation.
q = torch.empty({1, 1, T}, device=px.device, dtype=px.dtype).expand(B, S + T, T)
p = torch.empty(B, S + 1, T + 1, device=px.device, dtype=px.dtype)
ans = _mutual_information_forward_dispatcher(px, py, boundaries, q)
ans = _mutual_information_forward_dispatcher(px, py, boundaries, p)
if px.requires_grad or py.requires_grad:
ctx.save_for_backward(px, py, boundaries, q)
......@@ -115,7 +117,7 @@ def mutual_information_recursion(input, px, py, boundaries=None):
make use of the formula computed by this function.
Args:
px: A torch.Tensor of some floating point type, with shape [B][S][T],
px: A torch.Tensor of some floating point type, with shape [B][S][T+1],
where B is the batch size, S is the length of the 'x' sequence
(including representations of EOS symbols but not BOS symbols), and S is the
length of the 'y' sequence (including representations of
......@@ -139,13 +141,13 @@ def mutual_information_recursion(input, px, py, boundaries=None):
code assumes for optimization purposes that the T axis has
stride 1.
py: A torch.Tensor of the same dtype as px, with shape [B][S][T],
py: A torch.Tensor of the same dtype as px, with shape [B][S+1][T],
representing
py[b][s][t] = log [ p(y_t | x_{0..s-1}, y_{0..t-1}) / p(y_t) ]
This function does not treat x and y differently; the only difference
is that the implementation assumes for optimization purposes that y
is likely to be the shorter sequence, i.e. that "most of the time T < S",
and it will be faster if you respect this.
is that for optimization purposes we assume the last axis (the t axis)
has stride of 1; this is true if px and py are contiguous.
boundaries: If supplied, a torch.LongTensor of shape [B][4], where each row contains
[s_begin, t_begin, s_end, t_end]. If not supplied, the values
[0, 0, S, T] will be assumed. These are the beginning and
......@@ -155,18 +157,24 @@ def mutual_information_recursion(input, px, py, boundaries=None):
Returns:
Returns a torch.Tensor of shape [B], containing the log of the mutuafl
information between the b'th pair of sequences. This is defined by
the following recursion on p[b,s,t] (where p is of shape [B,S,T]),
the following recursion on p[b,s,t] (where p is of shape [B,S+1,T+1]),
representing a mutual information between sub-sequences of lengths s and t:
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s,t], p[b,s,t-1] + py[b,s,t])
where in the case where boundaries==None: the edge cases are handled
by treating p[b,-1,-1] as 0 and all other quantities with negative
indexes as -infinity; and ans[b] would equal p[S-1,T-1]. The extension to
cases where the boundaries are specified should be obvious.
p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
(if s > 0 or t > 0)
where we handle edge cases by treating quantities with negative indexes
as -infinity. The extension to cases where the boundaries are specified
should be obvious; it just works on shorter sequences with offsets into
px and py.
"""
assert px.ndim == 3 and px.shape == py.shape and px.dtype == py.dtype
assert px.ndim == 3
B, S, T1 = px.shape
T = T1 - 1
assert py.shape == (B, S + 1, T)
assert px.dtype == py.dtype
(B, S, T) = px.shape
if boundaries is not None:
assert boundaries.dtype == torch.LongTensor
......
......@@ -3,6 +3,13 @@
inline double Exp(double x) {
return exp(x);
}
inline double Exp(float x) {
return expf(x);
}
// returns log(exp(x) + exp(y)).
inline double LogAdd(double x, double y) {
double diff;
......@@ -14,8 +21,7 @@ inline double LogAdd(double x, double y) {
diff = y - x;
}
// diff is negative. x is now the larger one.
if (diff >= kMinLogDiffDouble) {
if (diff >= -1000) {
double res;
res = x + log1p(exp(diff));
return res;
......@@ -35,99 +41,208 @@ inline float LogAdd(float x, float y) {
diff = y - x;
}
// diff is negative. x is now the larger one.
if (diff >= kMinLogDiffFloat) {
if (diff >= -200) {
float res;
res = x + log1pf(expf(diff));
return res;
}
return x; // return the larger one.
}
// forward of mutual_information. See """... """ comment of `mutual_information` in
// mutual_information.py for documentation of the behavior of this function.
torch::Tensor mutual_information_cpu(torch::Tensor px,
torch::Tensor py,
std::optional<torch::Tensor> optional_boundary,
torch::Tensor q) {
torch::Tensor p) {
TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional");
TORCH_CHECK(py.dim() == 3, "params must be 3-dimensional.");
TORCH_CHECK(q.dim() == 3, "params must be 3-dimensional.");
TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional.");
TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional.");
TORCH_CHECK(px.device().is_cpu() && py.device().is_cpu() && p.device().is_cpu(),
"inputs must be CPU tensors");
auto scalar_t = px.scalar_type();
auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device());
const int B = px.size(0),
S = px.size(1),
T = px.size(2);
TORCH_CHECK(q.size(0) == B && q.size(1) == S + T && q.size(2) == T);
T = px.size(2) - 1;
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T);
TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
torch::Tensor ans = torch::empty({B}, opts);
auto long_opts = torch::TensorOptiona().dtype(torch::kInt64);
auto long_opts = torch::TensorOptions().dtype(torch::kInt64).device(px.device());
bool has_boundary = (bool)optional_boundary;
if (!has_boundary)
optional_boundary = torch::empty({}, long_opts);
optional_boundary = torch::empty({0, 0}, long_opts);
TORCH_CHECK(optional_boundary.value().device().is_cpu() &&
optional_boundary.value().dtype == torch::kInt64);
AT_DISPATCH_FLOATING_TYPES(px.scalar_type(), "mutual_information_cpu_loop", ([&] {
auto px_a = px.accessor<scalar_t, 3>(),
py_a = py.accessor<scalar_t, 3>();
for (int c = 0; c < C; c++) {
scalar_t sum_negative = 0.0,
sum_positive = 0.0,
scale = exp(params_a[c][0]);
for (int i = 0; i < K; i++) {
scalar_t pos_scaled_param = params_a[c][1 + K + i] * scale,
neg_scaled_param = params_a[c][K - i] * scale;
y_vals_a[c][K + i] = sum_positive - pos_scaled_param * i;
sum_positive += pos_scaled_param;
sum_negative -= neg_scaled_param;
y_vals_a[c][K - i - 1] = sum_negative + neg_scaled_param * (i + 1);
auto px_a = px.packed_accessor32<scalar_t, 3>(),
py_a = py.packed_accessor32<scalar_t, 3>(),
p_a = p.packed_accessor32<scalar_t, 3>();
auto boundary_a = optional_boundary.value().packed_accessor32<int64_t, 2>();
auto ans_a = ans.packed_accessor32<scalar_t, 1>();
for (int b = 0 b < B; b++) {
int s_begin, s_end, t_begin, t_end;
if (has_boundary) {
s_begin = boundary_a[b][0];
t_begin = boundary_a[b][1];
s_end = boundary_a[b][2];
t_end = boundary_a[b][3];
} else {
s_begin = 0;
s_end = S;
t_begin = 0;
t_end = T;
}
p_a[b][s_begin][t_begin] = 0.0;
for (int s = s_begin + 1; s <= s_end; ++s)
p_a[b][s][t_begin] = p_a[b][s - 1][t_begin] + px_a[b][s - 1][t_begin];
for (int t = t_begin + 1; t <= t_end; ++t)
p_a[b][s_begin][t] = p_a[b][s_begin][t - 1] + py_a[b][s_begin][t - 1];
for (int s = s_begin + 1; s <= s_end; ++s) {
scalar_t p_s_t1 = p_a[b][s][t_begin];
for (int t = t_begin + 1; t <= t_end; ++t) {
// The following statement is a small optimization of:
// p_a[b][s][t] = LogAdd(p_a[b][s - 1][t] + px_a[b][s - 1][t],
// p_a[b][s][t - 1] + py_a[b][s][t - 1]);
// .. which obtains p_a[b][s][t - 1] from a register.
p_a[b][s][t] = p_s_t1 = LogAdd(p_a[b][s - 1][t] + px_a[b][s - 1][t],
p_s_t1 + py_a[b][s][t - 1]);
}
}
ans_a[b] = p_a[b][s_end][t_end];
}
}));
return ans;
}
auto input_a = input.accessor<scalar_t, 3>(),
output_a = output.accessor<scalar_t, 3>();
for (int b = 0; b < B; b++) {
for (int c = 0; c < C; c++) {
scalar_t scale = exp(params_a[c][0]),
inv_scale = 1.0 / scale;
for (int t = 0; t < T; t++) {
// `x` is the scaled input x plus an offset so that -K maps to 0.
// Note: the discontinuities in our function are at -(K-1) ... +(K+1),
// so in a sense -K and +K are not special, but we include those
// extra values as an easy way to handle the semi-infinite regions
// that are < -(K-1) and > (K-1)
scalar_t input = input_a[b][c][t],
x = input * inv_scale + K;
if (x < 0) x = 0;
else if (x >= N) x = N - 1;
// C++ rounds toward zero.
int n = (int) x;
// OK, at this point, 0 <= min < 2*K.
output_a[b][c][t] = input * params_a[c][n + 1] + y_vals_a[c][n];
// backward of mutual_information. Returns (px_grad, py_grad).
// p corresponds to what we computed in the forward pass.
std::vector<torch::Tensor> mutual_information_backward_cpu(
torch::Tensor px,
torch::Tensor py,
std::optional<torch::Tensor> optional_boundary,
torch::Tensor p,
torch::Tensor ans_grad) {
TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional");
TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional.");
TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional.");
TORCH_CHECK(ans_grad.dim() == 1, "ans_grad must be 3-dimensional.");
TORCH_CHECK(px.device().is_cpu() && py.device().is_cpu() && p.device().is_cpu()
&& ans_grad.device() == cpu(),
"inputs must be CPU tensors");
auto scalar_t = px.scalar_type();
auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device());
const int B = px.size(0),
S = px.size(1),
T = px.size(2) - 1;
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T);
TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
torch::Tensor p_grad = torch::zeros({B, S + 1, T + 1}, opts);
auto long_opts = torch::TensorOptions().dtype(torch::kInt64).device(px.device());
bool has_boundary = (bool)optional_boundary;
if (!has_boundary)
optional_boundary = torch::empty({0, 0}, long_opts);
TORCH_CHECK(optional_boundary.value().device().is_cpu() &&
optional_boundary.value().dtype == torch::kInt64);
AT_DISPATCH_FLOATING_TYPES(px.scalar_type(), "mutual_information_cpu_backward_loop", ([&] {
auto px_a = px.packed_accessor32<scalar_t, 3>(),
py_a = py.packed_accessor32<scalar_t, 3>(),
p_a = p.packed_accessor32<scalar_t, 3>(),
p_grad_a = p.packed_accessor32<scalar_t, 3>();
auto ans_grad_a = ans_grad.packed_accessor32<scalar_t, 1>();
auto boundary_a = optional_boundary.value().packed_accessor32<int64_t, 2>();
for (int b = 0 b < B; b++) {
int s_begin, s_end, t_begin, t_end;
if (has_boundary) {
s_begin = boundary_a[b][0];
t_begin = boundary_a[b][1];
s_end = boundary_a[b][2];
t_end = boundary_a[b][3];
} else {
s_begin = 0;
s_end = S;
t_begin = 0;
t_end = T;
}
// Backprop for: ans_a[b] = p_a[b][s_end][t_end];
p_grad_a[b][s_end][t_end] = ans_grad_a[b];
for (int s = s_end; s > s_begin; --s) {
for (int t = t_end; t > t_begin; --t) {
// The statement we are backpropagating here is:
// p_a[b][s][t] = LogAdd(p_a[b][s - 1][t] + px_a[b][s - 1][t],
// p_a[b][s][t - 1] + py_a[b][s][t - 1]);
// .. which obtains p_a[b][s][t - 1] from a register.
scalar_t term1 = p_a[b][s - 1][t] + px_a[b][s - 1][t],
total = p_a[b][s][t],
term1_deriv = exp(term1 - total),
term2_deriv = 1.0 - term1_deriv,
grad = p_grad_a[b][s][t],
term1_grad = term1_deriv * grad,
term2_grad = term2_deriv * grad;
// We can assign to px_grad_a here rather than add, because we
// know it's currently zero.
TORCH_CHECK(px_grad_a[b][s - 1][t] == 0);
px_grad_a[b][s - 1][t] = term1_grad;
TORCH_CHECK(p_grad_a[b][s - 1][t] == 0.0); // likewise..
p_grad_a[b][s - 1][t] = term1_grad
py_grad_a[b][s][t - 1] += term2_grad;
p_grad_a[b][s][t - 1] += term2_grad;
}
}
}}));
return output;
for (int t = t_end; t >= t_begin; --t) {
// Backprop for:
// p_a[b][s_begin][t] = p_a[b][s_begin][t - 1] + py_a[b][s_begin][t - 1];
scalar_t this_p_grad = p_grad_a[b][s_begin][t];
p_grad_a[b][s_begin][t - 1] += this_p_grad;
py_grad_a[b][s_begin][t - 1] += this_p_grad;
}
for (int s = s_end; s >= s_begin; --s) {
// Backprop for:
// p_a[b][s][t_begin] = p_a[b][s - 1][t_begin] + px_a[b][s - 1][t_begin];
scalar_t this_p_grad = p_grad_a[b][s][s_begin];
p_a[b][s - 1][t_begin] += this_p_grad;
px_a[b][s - 1][t_begin] += this_p_grad;
}
// There is no backprop for:
// p_a[b][s_begin][t_begin] = 0.0;
// .. but we can use this for a check, that the grad at the beginning
// of the sequence is equal to the grad at the end of the sequence.
if (ans_grad_a[b] != 0.0) {
float grad_ratio = p_a[b][s_begin][t_begin] / ans_grad_a[b];
if (grad_ratio - 1.0 > 0.01) {
printf("Warning: mutual_information backprop: expected these numbers to be the same: %f vs. %f\n",
(float)p_a[b][s_begin][t_begin], (float)ans_grad_a[b]);
}
}
}
}));
return ans;
}
// backward of mutual_information. Returns (input_grad, params_grad)
std::vector<torch::Tensor> mutual_information_backward_cpu(torch::Tensor input,
torch::Tensor params,
torch::Tensor output_grad) {
TORCH_CHECK(input.dim() == 3, "input must be 3-dimensional");
TORCH_CHECK(params.dim() == 2, "params must be 2-dimensional.");
TORCH_CHECK(params.size(1) >= 3 &&
......
......@@ -3,18 +3,27 @@
// forward of mutual_information. """... """ comment of `mutual_information`
// in mutual_information.py documents the behavior of this function.
torch::Tensor mutual_information_cuda(torch::Tensor input,
torch::Tensor params);
// It is the core recursion in the sequence-to-sequence mutual information
// computation.
// returns 'ans', of dimension B (batch size).
torch::Tensor mutual_information_cuda(torch::Tensor px, // [B][S][T+1]
torch::Tensor py, // [B][S+1][T]
std::optional<torch::Tensor> boundary_info, // [B][4], int64_t.
torch::Tensor p); // [B][S+1][T+1]; an output
// backward of mutual_information; returns (grad_input, grad_params).
std::vector<torch::Tensor> mutual_information_backward_cuda(torch::Tensor input,
torch::Tensor params,
torch::Tensor grad_output);
// backward of mutual_information; returns (grad_px, grad_py)
std::vector<torch::Tensor> mutual_information_backward_cuda(
torch::Tensor px,
torch::Tensor py,
std::optional<torch::Tensor> boundary_info,
torch::Tensor p,
torch::Tensor ans_grad);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("mutual_information_cuda", &mutual_information_cuda, "Learned nonlinearity forward function (CUDA)");
m.def("mutual_information_backward_cuda", &mutual_information_backward_cuda, "Learned nonlinearity backward function (CUDA)");
m.def("mutual_information_cuda", &mutual_information_cuda, "Mutual information forward function (CUDA)");
m.def("mutual_information_backward_cuda", &mutual_information_backward_cuda, "Mutual information backward function (CUDA)");
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment