Commit 77eed83f authored by Daniel Povey's avatar Daniel Povey
Browse files

Some progress, still drafting.

parent e95d7864
...@@ -73,6 +73,14 @@ class MutualInformationRecursionFunction(torch.autograd.Function): ...@@ -73,6 +73,14 @@ class MutualInformationRecursionFunction(torch.autograd.Function):
def forward(ctx, px: torch.Tensor, py: torch.Tensor, boundaries: torch.Tensor) -> torch.Tensor: def forward(ctx, px: torch.Tensor, py: torch.Tensor, boundaries: torch.Tensor) -> torch.Tensor:
(B, S, T) = px.shape (B, S, T) = px.shape
# p is a tensor of shape (B, S + 1, T + 1) were p[s][t] is related to
# the mutual information of the pair of subsequences of x and y that are of
# length s and t respectively. p[0][0] will be 0.0 and p[S][T] is
# the mutual information of the entire pair of sequences, i.e. of lengths
# S and T respectively.
# q is a rearrangement of a tensor p which is of shape (B,S,T), # q is a rearrangement of a tensor p which is of shape (B,S,T),
# using p[b,s,t] == q[b,s+t,t]. The reason for working with this # using p[b,s,t] == q[b,s+t,t]. The reason for working with this
# representation is that each row of q depends only on the previous row, # representation is that each row of q depends only on the previous row,
...@@ -85,15 +93,9 @@ class MutualInformationRecursionFunction(torch.autograd.Function): ...@@ -85,15 +93,9 @@ class MutualInformationRecursionFunction(torch.autograd.Function):
# q[b, s-s_begin + t-t_begin, t-t_begin]; # q[b, s-s_begin + t-t_begin, t-t_begin];
# note, rows of `boundaries` are [s_begin, t_begin, s_end, t_end]. # note, rows of `boundaries` are [s_begin, t_begin, s_end, t_end].
if px.requires_grad or py.requires_grad: p = torch.empty(B, S + 1, T + 1, device=px.device, dtype=px.dtype)
q = torch.empty(B, S, T, device=px.device, dtype=px.dtype)
else:
# We don't need to store q if we are not going to do backprop, but we
# do pass in a temporary with one real row, expanded to have "fake" rows,
# which happens to be convenient for the CPU implementation.
q = torch.empty({1, 1, T}, device=px.device, dtype=px.dtype).expand(B, S + T, T)
ans = _mutual_information_forward_dispatcher(px, py, boundaries, q) ans = _mutual_information_forward_dispatcher(px, py, boundaries, p)
if px.requires_grad or py.requires_grad: if px.requires_grad or py.requires_grad:
ctx.save_for_backward(px, py, boundaries, q) ctx.save_for_backward(px, py, boundaries, q)
...@@ -115,7 +117,7 @@ def mutual_information_recursion(input, px, py, boundaries=None): ...@@ -115,7 +117,7 @@ def mutual_information_recursion(input, px, py, boundaries=None):
make use of the formula computed by this function. make use of the formula computed by this function.
Args: Args:
px: A torch.Tensor of some floating point type, with shape [B][S][T], px: A torch.Tensor of some floating point type, with shape [B][S][T+1],
where B is the batch size, S is the length of the 'x' sequence where B is the batch size, S is the length of the 'x' sequence
(including representations of EOS symbols but not BOS symbols), and S is the (including representations of EOS symbols but not BOS symbols), and S is the
length of the 'y' sequence (including representations of length of the 'y' sequence (including representations of
...@@ -139,13 +141,13 @@ def mutual_information_recursion(input, px, py, boundaries=None): ...@@ -139,13 +141,13 @@ def mutual_information_recursion(input, px, py, boundaries=None):
code assumes for optimization purposes that the T axis has code assumes for optimization purposes that the T axis has
stride 1. stride 1.
py: A torch.Tensor of the same dtype as px, with shape [B][S][T], py: A torch.Tensor of the same dtype as px, with shape [B][S+1][T],
representing representing
py[b][s][t] = log [ p(y_t | x_{0..s-1}, y_{0..t-1}) / p(y_t) ] py[b][s][t] = log [ p(y_t | x_{0..s-1}, y_{0..t-1}) / p(y_t) ]
This function does not treat x and y differently; the only difference This function does not treat x and y differently; the only difference
is that the implementation assumes for optimization purposes that y is that for optimization purposes we assume the last axis (the t axis)
is likely to be the shorter sequence, i.e. that "most of the time T < S", has stride of 1; this is true if px and py are contiguous.
and it will be faster if you respect this.
boundaries: If supplied, a torch.LongTensor of shape [B][4], where each row contains boundaries: If supplied, a torch.LongTensor of shape [B][4], where each row contains
[s_begin, t_begin, s_end, t_end]. If not supplied, the values [s_begin, t_begin, s_end, t_end]. If not supplied, the values
[0, 0, S, T] will be assumed. These are the beginning and [0, 0, S, T] will be assumed. These are the beginning and
...@@ -155,18 +157,24 @@ def mutual_information_recursion(input, px, py, boundaries=None): ...@@ -155,18 +157,24 @@ def mutual_information_recursion(input, px, py, boundaries=None):
Returns: Returns:
Returns a torch.Tensor of shape [B], containing the log of the mutuafl Returns a torch.Tensor of shape [B], containing the log of the mutuafl
information between the b'th pair of sequences. This is defined by information between the b'th pair of sequences. This is defined by
the following recursion on p[b,s,t] (where p is of shape [B,S,T]), the following recursion on p[b,s,t] (where p is of shape [B,S+1,T+1]),
representing a mutual information between sub-sequences of lengths s and t: representing a mutual information between sub-sequences of lengths s and t:
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s,t], p[b,s,t-1] + py[b,s,t]) p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
where in the case where boundaries==None: the edge cases are handled p[b,s,t-1] + py[b,s,t-1])
by treating p[b,-1,-1] as 0 and all other quantities with negative (if s > 0 or t > 0)
indexes as -infinity; and ans[b] would equal p[S-1,T-1]. The extension to
cases where the boundaries are specified should be obvious.
where we handle edge cases by treating quantities with negative indexes
as -infinity. The extension to cases where the boundaries are specified
should be obvious; it just works on shorter sequences with offsets into
px and py.
""" """
assert px.ndim == 3 and px.shape == py.shape and px.dtype == py.dtype assert px.ndim == 3
B, S, T1 = px.shape
T = T1 - 1
assert py.shape == (B, S + 1, T)
assert px.dtype == py.dtype
(B, S, T) = px.shape (B, S, T) = px.shape
if boundaries is not None: if boundaries is not None:
assert boundaries.dtype == torch.LongTensor assert boundaries.dtype == torch.LongTensor
......
...@@ -3,6 +3,13 @@ ...@@ -3,6 +3,13 @@
inline double Exp(double x) {
return exp(x);
}
inline double Exp(float x) {
return expf(x);
}
// returns log(exp(x) + exp(y)). // returns log(exp(x) + exp(y)).
inline double LogAdd(double x, double y) { inline double LogAdd(double x, double y) {
double diff; double diff;
...@@ -14,8 +21,7 @@ inline double LogAdd(double x, double y) { ...@@ -14,8 +21,7 @@ inline double LogAdd(double x, double y) {
diff = y - x; diff = y - x;
} }
// diff is negative. x is now the larger one. // diff is negative. x is now the larger one.
if (diff >= -1000) {
if (diff >= kMinLogDiffDouble) {
double res; double res;
res = x + log1p(exp(diff)); res = x + log1p(exp(diff));
return res; return res;
...@@ -35,99 +41,208 @@ inline float LogAdd(float x, float y) { ...@@ -35,99 +41,208 @@ inline float LogAdd(float x, float y) {
diff = y - x; diff = y - x;
} }
// diff is negative. x is now the larger one. // diff is negative. x is now the larger one.
if (diff >= -200) {
if (diff >= kMinLogDiffFloat) {
float res; float res;
res = x + log1pf(expf(diff)); res = x + log1pf(expf(diff));
return res; return res;
} }
return x; // return the larger one. return x; // return the larger one.
} }
// forward of mutual_information. See """... """ comment of `mutual_information` in // forward of mutual_information. See """... """ comment of `mutual_information` in
// mutual_information.py for documentation of the behavior of this function. // mutual_information.py for documentation of the behavior of this function.
torch::Tensor mutual_information_cpu(torch::Tensor px, torch::Tensor mutual_information_cpu(torch::Tensor px,
torch::Tensor py, torch::Tensor py,
std::optional<torch::Tensor> optional_boundary, std::optional<torch::Tensor> optional_boundary,
torch::Tensor q) { torch::Tensor p) {
TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional"); TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional");
TORCH_CHECK(py.dim() == 3, "params must be 3-dimensional."); TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional.");
TORCH_CHECK(q.dim() == 3, "params must be 3-dimensional."); TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional.");
TORCH_CHECK(px.device().is_cpu() && py.device().is_cpu() && p.device().is_cpu(),
"inputs must be CPU tensors");
auto scalar_t = px.scalar_type(); auto scalar_t = px.scalar_type();
auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device()); auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device());
const int B = px.size(0), const int B = px.size(0),
S = px.size(1), S = px.size(1),
T = px.size(2); T = px.size(2) - 1;
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T);
TORCH_CHECK(q.size(0) == B && q.size(1) == S + T && q.size(2) == T); TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
torch::Tensor ans = torch::empty({B}, opts);
auto long_opts = torch::TensorOptiona().dtype(torch::kInt64); auto long_opts = torch::TensorOptions().dtype(torch::kInt64).device(px.device());
bool has_boundary = (bool)optional_boundary; bool has_boundary = (bool)optional_boundary;
if (!has_boundary) if (!has_boundary)
optional_boundary = torch::empty({}, long_opts); optional_boundary = torch::empty({0, 0}, long_opts);
TORCH_CHECK(optional_boundary.value().device().is_cpu() &&
optional_boundary.value().dtype == torch::kInt64);
AT_DISPATCH_FLOATING_TYPES(px.scalar_type(), "mutual_information_cpu_loop", ([&] { AT_DISPATCH_FLOATING_TYPES(px.scalar_type(), "mutual_information_cpu_loop", ([&] {
auto px_a = px.accessor<scalar_t, 3>(), auto px_a = px.packed_accessor32<scalar_t, 3>(),
py_a = py.accessor<scalar_t, 3>(); py_a = py.packed_accessor32<scalar_t, 3>(),
for (int c = 0; c < C; c++) { p_a = p.packed_accessor32<scalar_t, 3>();
scalar_t sum_negative = 0.0, auto boundary_a = optional_boundary.value().packed_accessor32<int64_t, 2>();
sum_positive = 0.0, auto ans_a = ans.packed_accessor32<scalar_t, 1>();
scale = exp(params_a[c][0]);
for (int i = 0; i < K; i++) { for (int b = 0 b < B; b++) {
scalar_t pos_scaled_param = params_a[c][1 + K + i] * scale, int s_begin, s_end, t_begin, t_end;
neg_scaled_param = params_a[c][K - i] * scale; if (has_boundary) {
y_vals_a[c][K + i] = sum_positive - pos_scaled_param * i; s_begin = boundary_a[b][0];
sum_positive += pos_scaled_param; t_begin = boundary_a[b][1];
sum_negative -= neg_scaled_param; s_end = boundary_a[b][2];
y_vals_a[c][K - i - 1] = sum_negative + neg_scaled_param * (i + 1); t_end = boundary_a[b][3];
} else {
s_begin = 0;
s_end = S;
t_begin = 0;
t_end = T;
}
p_a[b][s_begin][t_begin] = 0.0;
for (int s = s_begin + 1; s <= s_end; ++s)
p_a[b][s][t_begin] = p_a[b][s - 1][t_begin] + px_a[b][s - 1][t_begin];
for (int t = t_begin + 1; t <= t_end; ++t)
p_a[b][s_begin][t] = p_a[b][s_begin][t - 1] + py_a[b][s_begin][t - 1];
for (int s = s_begin + 1; s <= s_end; ++s) {
scalar_t p_s_t1 = p_a[b][s][t_begin];
for (int t = t_begin + 1; t <= t_end; ++t) {
// The following statement is a small optimization of:
// p_a[b][s][t] = LogAdd(p_a[b][s - 1][t] + px_a[b][s - 1][t],
// p_a[b][s][t - 1] + py_a[b][s][t - 1]);
// .. which obtains p_a[b][s][t - 1] from a register.
p_a[b][s][t] = p_s_t1 = LogAdd(p_a[b][s - 1][t] + px_a[b][s - 1][t],
p_s_t1 + py_a[b][s][t - 1]);
}
} }
ans_a[b] = p_a[b][s_end][t_end];
} }
}));
return ans;
}
auto input_a = input.accessor<scalar_t, 3>(),
output_a = output.accessor<scalar_t, 3>();
for (int b = 0; b < B; b++) { // backward of mutual_information. Returns (px_grad, py_grad).
for (int c = 0; c < C; c++) { // p corresponds to what we computed in the forward pass.
scalar_t scale = exp(params_a[c][0]), std::vector<torch::Tensor> mutual_information_backward_cpu(
inv_scale = 1.0 / scale; torch::Tensor px,
for (int t = 0; t < T; t++) { torch::Tensor py,
// `x` is the scaled input x plus an offset so that -K maps to 0. std::optional<torch::Tensor> optional_boundary,
// Note: the discontinuities in our function are at -(K-1) ... +(K+1), torch::Tensor p,
// so in a sense -K and +K are not special, but we include those torch::Tensor ans_grad) {
// extra values as an easy way to handle the semi-infinite regions TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional");
// that are < -(K-1) and > (K-1) TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional.");
scalar_t input = input_a[b][c][t], TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional.");
x = input * inv_scale + K; TORCH_CHECK(ans_grad.dim() == 1, "ans_grad must be 3-dimensional.");
if (x < 0) x = 0;
else if (x >= N) x = N - 1; TORCH_CHECK(px.device().is_cpu() && py.device().is_cpu() && p.device().is_cpu()
// C++ rounds toward zero. && ans_grad.device() == cpu(),
int n = (int) x; "inputs must be CPU tensors");
// OK, at this point, 0 <= min < 2*K.
output_a[b][c][t] = input * params_a[c][n + 1] + y_vals_a[c][n]; auto scalar_t = px.scalar_type();
auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device());
const int B = px.size(0),
S = px.size(1),
T = px.size(2) - 1;
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T);
TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
torch::Tensor p_grad = torch::zeros({B, S + 1, T + 1}, opts);
auto long_opts = torch::TensorOptions().dtype(torch::kInt64).device(px.device());
bool has_boundary = (bool)optional_boundary;
if (!has_boundary)
optional_boundary = torch::empty({0, 0}, long_opts);
TORCH_CHECK(optional_boundary.value().device().is_cpu() &&
optional_boundary.value().dtype == torch::kInt64);
AT_DISPATCH_FLOATING_TYPES(px.scalar_type(), "mutual_information_cpu_backward_loop", ([&] {
auto px_a = px.packed_accessor32<scalar_t, 3>(),
py_a = py.packed_accessor32<scalar_t, 3>(),
p_a = p.packed_accessor32<scalar_t, 3>(),
p_grad_a = p.packed_accessor32<scalar_t, 3>();
auto ans_grad_a = ans_grad.packed_accessor32<scalar_t, 1>();
auto boundary_a = optional_boundary.value().packed_accessor32<int64_t, 2>();
for (int b = 0 b < B; b++) {
int s_begin, s_end, t_begin, t_end;
if (has_boundary) {
s_begin = boundary_a[b][0];
t_begin = boundary_a[b][1];
s_end = boundary_a[b][2];
t_end = boundary_a[b][3];
} else {
s_begin = 0;
s_end = S;
t_begin = 0;
t_end = T;
}
// Backprop for: ans_a[b] = p_a[b][s_end][t_end];
p_grad_a[b][s_end][t_end] = ans_grad_a[b];
for (int s = s_end; s > s_begin; --s) {
for (int t = t_end; t > t_begin; --t) {
// The statement we are backpropagating here is:
// p_a[b][s][t] = LogAdd(p_a[b][s - 1][t] + px_a[b][s - 1][t],
// p_a[b][s][t - 1] + py_a[b][s][t - 1]);
// .. which obtains p_a[b][s][t - 1] from a register.
scalar_t term1 = p_a[b][s - 1][t] + px_a[b][s - 1][t],
total = p_a[b][s][t],
term1_deriv = exp(term1 - total),
term2_deriv = 1.0 - term1_deriv,
grad = p_grad_a[b][s][t],
term1_grad = term1_deriv * grad,
term2_grad = term2_deriv * grad;
// We can assign to px_grad_a here rather than add, because we
// know it's currently zero.
TORCH_CHECK(px_grad_a[b][s - 1][t] == 0);
px_grad_a[b][s - 1][t] = term1_grad;
TORCH_CHECK(p_grad_a[b][s - 1][t] == 0.0); // likewise..
p_grad_a[b][s - 1][t] = term1_grad
py_grad_a[b][s][t - 1] += term2_grad;
p_grad_a[b][s][t - 1] += term2_grad;
} }
} }
}})); for (int t = t_end; t >= t_begin; --t) {
return output; // Backprop for:
// p_a[b][s_begin][t] = p_a[b][s_begin][t - 1] + py_a[b][s_begin][t - 1];
scalar_t this_p_grad = p_grad_a[b][s_begin][t];
p_grad_a[b][s_begin][t - 1] += this_p_grad;
py_grad_a[b][s_begin][t - 1] += this_p_grad;
}
for (int s = s_end; s >= s_begin; --s) {
// Backprop for:
// p_a[b][s][t_begin] = p_a[b][s - 1][t_begin] + px_a[b][s - 1][t_begin];
scalar_t this_p_grad = p_grad_a[b][s][s_begin];
p_a[b][s - 1][t_begin] += this_p_grad;
px_a[b][s - 1][t_begin] += this_p_grad;
}
// There is no backprop for:
// p_a[b][s_begin][t_begin] = 0.0;
// .. but we can use this for a check, that the grad at the beginning
// of the sequence is equal to the grad at the end of the sequence.
if (ans_grad_a[b] != 0.0) {
float grad_ratio = p_a[b][s_begin][t_begin] / ans_grad_a[b];
if (grad_ratio - 1.0 > 0.01) {
printf("Warning: mutual_information backprop: expected these numbers to be the same: %f vs. %f\n",
(float)p_a[b][s_begin][t_begin], (float)ans_grad_a[b]);
}
}
}
}));
return ans;
} }
// backward of mutual_information. Returns (input_grad, params_grad)
std::vector<torch::Tensor> mutual_information_backward_cpu(torch::Tensor input,
torch::Tensor params,
torch::Tensor output_grad) {
TORCH_CHECK(input.dim() == 3, "input must be 3-dimensional"); TORCH_CHECK(input.dim() == 3, "input must be 3-dimensional");
TORCH_CHECK(params.dim() == 2, "params must be 2-dimensional."); TORCH_CHECK(params.dim() == 2, "params must be 2-dimensional.");
TORCH_CHECK(params.size(1) >= 3 && TORCH_CHECK(params.size(1) >= 3 &&
......
...@@ -3,18 +3,27 @@ ...@@ -3,18 +3,27 @@
// forward of mutual_information. """... """ comment of `mutual_information` // forward of mutual_information. """... """ comment of `mutual_information`
// in mutual_information.py documents the behavior of this function. // in mutual_information.py documents the behavior of this function.
torch::Tensor mutual_information_cuda(torch::Tensor input, // It is the core recursion in the sequence-to-sequence mutual information
torch::Tensor params); // computation.
// returns 'ans', of dimension B (batch size).
torch::Tensor mutual_information_cuda(torch::Tensor px, // [B][S][T+1]
torch::Tensor py, // [B][S+1][T]
std::optional<torch::Tensor> boundary_info, // [B][4], int64_t.
torch::Tensor p); // [B][S+1][T+1]; an output
// backward of mutual_information; returns (grad_input, grad_params). // backward of mutual_information; returns (grad_px, grad_py)
std::vector<torch::Tensor> mutual_information_backward_cuda(torch::Tensor input, std::vector<torch::Tensor> mutual_information_backward_cuda(
torch::Tensor params, torch::Tensor px,
torch::Tensor grad_output); torch::Tensor py,
std::optional<torch::Tensor> boundary_info,
torch::Tensor p,
torch::Tensor ans_grad);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("mutual_information_cuda", &mutual_information_cuda, "Learned nonlinearity forward function (CUDA)"); m.def("mutual_information_cuda", &mutual_information_cuda, "Mutual information forward function (CUDA)");
m.def("mutual_information_backward_cuda", &mutual_information_backward_cuda, "Learned nonlinearity backward function (CUDA)"); m.def("mutual_information_backward_cuda", &mutual_information_backward_cuda, "Mutual information backward function (CUDA)");
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment