Commit 52ae49ee authored by Daniel Povey's avatar Daniel Povey
Browse files

Fix many bugs

parent 3c1ec347
......@@ -2,7 +2,6 @@
#include <torch/extension.h>
inline double Exp(double x) {
return exp(x);
}
......@@ -52,13 +51,15 @@ inline float LogAdd(float x, float y) {
// forward of mutual_information. See """... """ comment of `mutual_information` in
// mutual_information.py for documentation of the behavior of this function.
// px: of shape [B, S, T+1] where
torch::Tensor mutual_information_cpu(torch::Tensor px,
torch::Tensor py,
std::optional<torch::Tensor> optional_boundary,
torch::Tensor boundary,
torch::Tensor p) {
TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional");
TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional.");
TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional.");
TORCH_CHECK(boundary.dim() == 2, "boundary must be 2-dimensional.");
TORCH_CHECK(px.device().is_cpu() && py.device().is_cpu() && p.device().is_cpu(),
"inputs must be CPU tensors");
......@@ -70,26 +71,24 @@ torch::Tensor mutual_information_cpu(torch::Tensor px,
T = px.size(2) - 1;
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T);
TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
TORCH_CHECK((boundary.size(0) == 0 && boundary.size(1) == 0) ||
(boundary.size(0) == B && boundary.size(1) == 4));
TORCH_CHECK(boundary.device().is_cpu() &&
boundary.dtype() == torch::kInt64);
torch::Tensor ans = torch::empty({B}, opts);
auto long_opts = torch::TensorOptions().dtype(torch::kInt64).device(px.device());
bool has_boundary = (bool)optional_boundary;
if (!has_boundary)
optional_boundary = torch::empty({0, 0}, long_opts);
bool has_boundary = (boundary.size(0) != 0);
TORCH_CHECK(optional_boundary.value().device().is_cpu() &&
optional_boundary.value().dtype == torch::kInt64);
AT_DISPATCH_FLOATING_TYPES(px.scalar_type(), "mutual_information_cpu_loop", ([&] {
auto px_a = px.packed_accessor32<scalar_t, 3>(),
py_a = py.packed_accessor32<scalar_t, 3>(),
p_a = p.packed_accessor32<scalar_t, 3>();
auto boundary_a = optional_boundary.value().packed_accessor32<int64_t, 2>();
auto boundary_a = boundary.packed_accessor32<int64_t, 2>();
auto ans_a = ans.packed_accessor32<scalar_t, 1>();
for (int b = 0 b < B; b++) {
for (int b = 0; b < B; b++) {
int s_begin, s_end, t_begin, t_end;
if (has_boundary) {
s_begin = boundary_a[b][0];
......@@ -130,16 +129,17 @@ torch::Tensor mutual_information_cpu(torch::Tensor px,
std::vector<torch::Tensor> mutual_information_backward_cpu(
torch::Tensor px,
torch::Tensor py,
std::optional<torch::Tensor> optional_boundary,
torch::Tensor boundary,
torch::Tensor p,
torch::Tensor ans_grad) {
TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional");
TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional.");
TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional.");
TORCH_CHECK(boundary.dim() == 2, "boundary must be 2-dimensional.");
TORCH_CHECK(ans_grad.dim() == 1, "ans_grad must be 3-dimensional.");
TORCH_CHECK(px.device().is_cpu() && py.device().is_cpu() && p.device().is_cpu()
&& ans_grad.device() == cpu(),
&& ans_grad.device().is_cpu(),
"inputs must be CPU tensors");
auto scalar_t = px.scalar_type();
......@@ -150,8 +150,12 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
T = px.size(2) - 1;
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T);
TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
TORCH_CHECK((boundary.size(0) == 0 && boundary.size(1) == 0) ||
(boundary.size(0) == B && boundary.size(1) == 4));
TORCH_CHECK(boundary.device().is_cpu() &&
boundary.dtype() == torch::kInt64);
bool has_boundary = (bool)optional_boundary;
bool has_boundary = (boundary.size(0) != 0);
torch::Tensor p_grad = torch::zeros({B, S + 1, T + 1}, opts),
px_grad = (has_boundary ? torch::zeros({B, S, T + 1}, opts) :
......@@ -159,27 +163,18 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
py_grad = (has_boundary ? torch::zeros({B, S + 1, T}, opts) :
torch::empty({B, S + 1, T}, opts));
auto long_opts = torch::TensorOptions().dtype(torch::kInt64).device(px.device());
if (!has_boundary)
optional_boundary = torch::empty({0, 0}, long_opts);
TORCH_CHECK(optional_boundary.value().device().is_cpu() &&
optional_boundary.value().dtype == torch::kInt64);
AT_DISPATCH_FLOATING_TYPES(px.scalar_type(), "mutual_information_cpu_backward_loop", ([&] {
auto px_a = px.packed_accessor32<scalar_t, 3>(),
py_a = py.packed_accessor32<scalar_t, 3>(),
// py_a = py.packed_accessor32<scalar_t, 3>(),
p_a = p.packed_accessor32<scalar_t, 3>(),
p_grad_a = p_grad.packed_accessor32<scalar_t, 3>(),
px_grad_a = px_grad.packed_accessor32<scalar_t, 3>(),
py_grad_a = py_grad.packed_accessor32<scalar_t, 3>();
auto ans_grad_a = ans_grad.packed_accessor32<scalar_t, 1>();
auto boundary_a = boundary.packed_accessor32<int64_t, 2>();
auto boundary_a = optional_boundary.value().packed_accessor32<int64_t, 2>();
for (int b = 0 b < B; b++) {
for (int b = 0; b < B; b++) {
int s_begin, s_end, t_begin, t_end;
if (has_boundary) {
s_begin = boundary_a[b][0];
......
#include <torch/extension.h>
/*
Forward of mutual_information. See also """... """ comment of
`mutual_information` in mutual_information.py. This It is the core recursion
......@@ -33,8 +32,9 @@
contains, where for each batch element b, boundary[b] equals
[s_begin, t_begin, s_end, t_end]
which are the beginning and end (i.e. one-past-the-last) of the
x and y sequences that we should process. If not set, these
default to (0, 0, S, T); and they should not exceed these bounds.
x and y sequences that we should process. Alternatively, may be
a tensor of shape [0][0] and type int64_t; the elements will
default to (0, 0, S, T).
ans: a tensor `ans` of shape [B], where this function will set
ans[b] = p[b][s_end][t_end],
with s_end and t_end being (S, T) if `boundary` was specified,
......@@ -48,7 +48,7 @@
*/
torch::Tensor mutual_information_cuda(torch::Tensor px, // [B][S][T+1]
torch::Tensor py, // [B][S+1][T]
std::optional<torch::Tensor> boundary_info, // [B][4], int64_t.
torch::Tensor boundary, // [B][4], int64_t.
torch::Tensor p); // [B][S+1][T+1]; an output
......@@ -63,7 +63,7 @@ torch::Tensor mutual_information_cuda(torch::Tensor px, // [B][S][T+1]
std::vector<torch::Tensor> mutual_information_backward_cuda(
torch::Tensor px,
torch::Tensor py,
std::optional<torch::Tensor> boundary_info,
torch::Tensor boundary,
torch::Tensor p,
torch::Tensor ans_grad,
bool overwrite_ans_grad);
......
......@@ -4,7 +4,6 @@
#include <cmath> // for INFINITY
// returns log(exp(x) + exp(y)).
__forceinline__ __device__ double LogAdd(double x, double y) {
double diff;
......@@ -22,7 +21,7 @@ __forceinline__ __device__ double LogAdd(double x, double y) {
}
// returns log(exp(x) + exp(y)).
__forceinline__ __device__ inline float LogAdd(float x, float y) {
__forceinline__ __device__ float LogAdd(float x, float y) {
float diff;
if (x < y) {
diff = x - y;
......@@ -81,8 +80,9 @@ __forceinline__ __device__ inline float LogAdd(float x, float y) {
contains, where for each batch element b, boundary[b] equals
[s_begin, t_begin, s_end, t_end]
which are the beginning and end (i.e. one-past-the-last) of the
x and y sequences that we should process. If not set, these
default to (0, 0, S, T); and they should not exceed these bounds.
x and y sequences that we should process. Otherwise, must be
a tensor of shape [0][0] of type int64_t; the values will
default to (0, 0, S, T).
ans: a tensor `ans` of shape [B], where this function will set
ans[b] = p[b][s_end][t_end],
with s_end and t_end being (S, T) if `boundary` was specified,
......@@ -118,8 +118,8 @@ void mutual_information_kernel(
// You can read the following expressions as simplifications of, for example,
// num_s_blocks = ((S + 1) + BLOCK_SIZE - 1) / BLOCK_SIZE,
// i.e. rounding-up division of (S + 1) by BLOCK_SIZE, and the same for (T + 1).
const int num_s_blocks = S / BLOCK_SIZE + 1,
num_t_blocks = T / BLOCK_SIZE + 1;
const int num_s_blocks = S / BLOCK_SIZE + 1;
//, num_t_blocks = T / BLOCK_SIZE + 1;
// num_blocks_this_iter is an upper bound on the number of blocks of size
// (BLOCK_SIZE by BLOCK_SIZE) that might be active on this iteration (`iter`).
......@@ -174,7 +174,7 @@ void mutual_information_kernel(
bool is_origin_block = (s_block_begin * t_block_begin == 0);
if (boundary.size(0) != 0 && threadIdx.x < 4)
boundary_buf[threadDim.x] = boundary[b][threadDim.x];
boundary_buf[threadIdx.x] = boundary[b][threadIdx.x];
__syncthreads();
int s_begin = boundary_buf[0],
t_begin = boundary_buf[1],
......@@ -384,7 +384,7 @@ void mutual_information_kernel(
// should help us quite a bit.
for (int i = 0; i < block_S + block_T - 1; ++i) {
int s_in_block = threadIdx.x,
t_in_block = i - block_s;
t_in_block = i - s_in_block;
if (s_in_block < block_S &&
static_cast<unsigned int>(t_in_block) < static_cast<unsigned int>(block_T)) {
int s = s_in_block + s_block_begin,
......@@ -489,7 +489,8 @@ void mutual_information_kernel(
of p_grad, we need context on the top and right instead of the bottom and
left. So there are offsets of 1.
*/
template <typename scalar_t>
template <typename scalar_t,
int BLOCK_SIZE>
__global__
void mutual_information_backward_kernel(
torch::PackedTensorAccessor32<scalar_t, 3> px, // B, S, T + 1, i.e. batch, x_seq_length, y_seq_length + 1
......@@ -751,7 +752,7 @@ void mutual_information_backward_kernel(
// mutual_information.py for documentation of the behavior of this function.
torch::Tensor mutual_information_cuda(torch::Tensor px,
torch::Tensor py,
std::optional<torch::Tensor> optional_boundary,
torch::Tensor boundary,
torch::Tensor p) {
TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional");
TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional.");
......@@ -767,12 +768,16 @@ torch::Tensor mutual_information_cuda(torch::Tensor px,
T = px.size(2) - 1;
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T);
TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
TORCH_CHECK((boundary.size(0) == 0 && boundary.size(1) == 0) ||
(boundary.size(0) == B && boundary.size(1) == 4));
TORCH_CHECK(boundary.device().is_cuda() &&
boundary.dtype() == torch::kInt64);
torch::Tensor ans = torch::empty({B}, opts);
// num_threads and num_blocks and BLOCK_SIZE can be tuned.
// (however, num_threads may not be less than 128).
int num_threads = 128,
const int num_threads = 128,
num_blocks = 256,
BLOCK_SIZE = 32;
......@@ -783,21 +788,17 @@ torch::Tensor mutual_information_cuda(torch::Tensor px,
num_t_blocks = T / BLOCK_SIZE + 1,
num_iters = num_s_blocks + num_t_blocks - 1;
if ((bool)optional_boundary)
TORCH_CHECK(optional_boundary.value().device().is_cuda(),
"boundary information must be in CUDA tensor");
else
optional_boundary = torch::empty({0, 0}, long_opts);
AT_DISPATCH_FLOATING_TYPES(px.scalar_type(), "mutual_information_cuda_stub", ([&] {
for (int iter = 0; iter < num_iters; ++iter) {
mutual_information_kernel<scalar_t, BLOCK_SIZE><<<num_blocks, num_threads>>>(
px.packed_accessor32<scalar_t, 3>(),
py.packed_accessor32<scalar_t, 3>(),
p.packed_accessor32<scalar_t, 3>(),
optional_boundary.value().packed_accessor32<int64_t, 2>(),
boundary.packed_accessor32<int64_t, 2>(),
ans.packed_accessor32<scalar_t, 1>(),
iter);
}
}));
return ans;
}
......@@ -807,9 +808,10 @@ torch::Tensor mutual_information_cuda(torch::Tensor px,
// If overwrite_ans_grad == true, will overwrite ans_grad with a value which
// should be identical to the original ans_grad if the computation worked
// as it should.
torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
std::vector<torch::Tensor>
mutual_information_backward_cuda(torch::Tensor px,
torch::Tensor py,
std::optional<torch::Tensor> optional_boundary,
torch::Tensor boundary,
torch::Tensor p,
torch::Tensor ans_grad,
bool overwrite_ans_grad) {
......@@ -832,9 +834,13 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T);
TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
TORCH_CHECK(ans_grad.size(0) == b);
TORCH_CHECK((boundary.size(0) == 0 && boundary.size(1) == 0) ||
(boundary.size(0) == B && boundary.size(1) == 4));
TORCH_CHECK(boundary.device().is_cuda() &&
boundary.dtype() == torch::kInt64);
TORCH_CHECK(ans_grad.size(0) == B);
bool has_boundary = (bool)optional_boundary;
bool has_boundary = (boundary.size(0) != 0);
torch::Tensor p_grad = torch::empty({B, S + 1, T + 1}, opts),
px_grad = (has_boundary ? torch::zeros({B, S, T + 1}, opts) :
......@@ -855,25 +861,22 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
num_t_blocks = T / BLOCK_SIZE + 1,
num_iters = num_s_blocks + num_t_blocks - 1;
if (has_boundary)
TORCH_CHECK(optional_boundary.value().device().is_cuda(),
"boundary information must be in CUDA tensor");
else
optional_boundary = torch::empty({0, 0}, long_opts);
AT_DISPATCH_FLOATING_TYPES(px.scalar_type(), "mutual_information_backward_stub", ([&] {
for (int iter = num_iters - 1; iter >= 0; --iter) {
mutual_information_backward_kernel<scalar_t, BLOCK_SIZE><<<num_blocks, num_threads>>>(
px.packed_accessor32<scalar_t, 3>(),
py.packed_accessor32<scalar_t, 3>(),
p.packed_accessor32<scalar_t, 3>(),
ans_grad.packed_accessor32<scalar_t, 1>,
ans_grad.packed_accessor32<scalar_t, 1>(),
p_grad.packed_accessor32<scalar_t, 3>(),
px_grad.packed_accessor32<scalar_t, 3>(),
py_grad.packed_accessor32<scalar_t, 3>(),
optional_boundary.value().packed_accessor32<int64_t, 2>(),
boundary.packed_accessor32<int64_t, 2>(),
iter,
overwrite_ans_grad);
}
}));
std::cout << "p_grad = " << p_grad;
return std::vector<torch::Tensor>({px_grad, py_grad});
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment