Commit 52ae49ee authored by Daniel Povey's avatar Daniel Povey
Browse files

Fix many bugs

parent 3c1ec347
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
#include <torch/extension.h> #include <torch/extension.h>
inline double Exp(double x) { inline double Exp(double x) {
return exp(x); return exp(x);
} }
...@@ -52,13 +51,15 @@ inline float LogAdd(float x, float y) { ...@@ -52,13 +51,15 @@ inline float LogAdd(float x, float y) {
// forward of mutual_information. See """... """ comment of `mutual_information` in // forward of mutual_information. See """... """ comment of `mutual_information` in
// mutual_information.py for documentation of the behavior of this function. // mutual_information.py for documentation of the behavior of this function.
// px: of shape [B, S, T+1] where
torch::Tensor mutual_information_cpu(torch::Tensor px, torch::Tensor mutual_information_cpu(torch::Tensor px,
torch::Tensor py, torch::Tensor py,
std::optional<torch::Tensor> optional_boundary, torch::Tensor boundary,
torch::Tensor p) { torch::Tensor p) {
TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional"); TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional");
TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional."); TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional.");
TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional."); TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional.");
TORCH_CHECK(boundary.dim() == 2, "boundary must be 2-dimensional.");
TORCH_CHECK(px.device().is_cpu() && py.device().is_cpu() && p.device().is_cpu(), TORCH_CHECK(px.device().is_cpu() && py.device().is_cpu() && p.device().is_cpu(),
"inputs must be CPU tensors"); "inputs must be CPU tensors");
...@@ -70,26 +71,24 @@ torch::Tensor mutual_information_cpu(torch::Tensor px, ...@@ -70,26 +71,24 @@ torch::Tensor mutual_information_cpu(torch::Tensor px,
T = px.size(2) - 1; T = px.size(2) - 1;
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T); TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T);
TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1); TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
TORCH_CHECK((boundary.size(0) == 0 && boundary.size(1) == 0) ||
(boundary.size(0) == B && boundary.size(1) == 4));
TORCH_CHECK(boundary.device().is_cpu() &&
boundary.dtype() == torch::kInt64);
torch::Tensor ans = torch::empty({B}, opts); torch::Tensor ans = torch::empty({B}, opts);
auto long_opts = torch::TensorOptions().dtype(torch::kInt64).device(px.device()); bool has_boundary = (boundary.size(0) != 0);
bool has_boundary = (bool)optional_boundary;
if (!has_boundary)
optional_boundary = torch::empty({0, 0}, long_opts);
TORCH_CHECK(optional_boundary.value().device().is_cpu() &&
optional_boundary.value().dtype == torch::kInt64);
AT_DISPATCH_FLOATING_TYPES(px.scalar_type(), "mutual_information_cpu_loop", ([&] { AT_DISPATCH_FLOATING_TYPES(px.scalar_type(), "mutual_information_cpu_loop", ([&] {
auto px_a = px.packed_accessor32<scalar_t, 3>(), auto px_a = px.packed_accessor32<scalar_t, 3>(),
py_a = py.packed_accessor32<scalar_t, 3>(), py_a = py.packed_accessor32<scalar_t, 3>(),
p_a = p.packed_accessor32<scalar_t, 3>(); p_a = p.packed_accessor32<scalar_t, 3>();
auto boundary_a = optional_boundary.value().packed_accessor32<int64_t, 2>(); auto boundary_a = boundary.packed_accessor32<int64_t, 2>();
auto ans_a = ans.packed_accessor32<scalar_t, 1>(); auto ans_a = ans.packed_accessor32<scalar_t, 1>();
for (int b = 0 b < B; b++) { for (int b = 0; b < B; b++) {
int s_begin, s_end, t_begin, t_end; int s_begin, s_end, t_begin, t_end;
if (has_boundary) { if (has_boundary) {
s_begin = boundary_a[b][0]; s_begin = boundary_a[b][0];
...@@ -130,16 +129,17 @@ torch::Tensor mutual_information_cpu(torch::Tensor px, ...@@ -130,16 +129,17 @@ torch::Tensor mutual_information_cpu(torch::Tensor px,
std::vector<torch::Tensor> mutual_information_backward_cpu( std::vector<torch::Tensor> mutual_information_backward_cpu(
torch::Tensor px, torch::Tensor px,
torch::Tensor py, torch::Tensor py,
std::optional<torch::Tensor> optional_boundary, torch::Tensor boundary,
torch::Tensor p, torch::Tensor p,
torch::Tensor ans_grad) { torch::Tensor ans_grad) {
TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional"); TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional");
TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional."); TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional.");
TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional."); TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional.");
TORCH_CHECK(boundary.dim() == 2, "boundary must be 2-dimensional.");
TORCH_CHECK(ans_grad.dim() == 1, "ans_grad must be 3-dimensional."); TORCH_CHECK(ans_grad.dim() == 1, "ans_grad must be 3-dimensional.");
TORCH_CHECK(px.device().is_cpu() && py.device().is_cpu() && p.device().is_cpu() TORCH_CHECK(px.device().is_cpu() && py.device().is_cpu() && p.device().is_cpu()
&& ans_grad.device() == cpu(), && ans_grad.device().is_cpu(),
"inputs must be CPU tensors"); "inputs must be CPU tensors");
auto scalar_t = px.scalar_type(); auto scalar_t = px.scalar_type();
...@@ -150,8 +150,12 @@ std::vector<torch::Tensor> mutual_information_backward_cpu( ...@@ -150,8 +150,12 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
T = px.size(2) - 1; T = px.size(2) - 1;
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T); TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T);
TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1); TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
TORCH_CHECK((boundary.size(0) == 0 && boundary.size(1) == 0) ||
(boundary.size(0) == B && boundary.size(1) == 4));
TORCH_CHECK(boundary.device().is_cpu() &&
boundary.dtype() == torch::kInt64);
bool has_boundary = (bool)optional_boundary; bool has_boundary = (boundary.size(0) != 0);
torch::Tensor p_grad = torch::zeros({B, S + 1, T + 1}, opts), torch::Tensor p_grad = torch::zeros({B, S + 1, T + 1}, opts),
px_grad = (has_boundary ? torch::zeros({B, S, T + 1}, opts) : px_grad = (has_boundary ? torch::zeros({B, S, T + 1}, opts) :
...@@ -159,27 +163,18 @@ std::vector<torch::Tensor> mutual_information_backward_cpu( ...@@ -159,27 +163,18 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
py_grad = (has_boundary ? torch::zeros({B, S + 1, T}, opts) : py_grad = (has_boundary ? torch::zeros({B, S + 1, T}, opts) :
torch::empty({B, S + 1, T}, opts)); torch::empty({B, S + 1, T}, opts));
auto long_opts = torch::TensorOptions().dtype(torch::kInt64).device(px.device());
if (!has_boundary)
optional_boundary = torch::empty({0, 0}, long_opts);
TORCH_CHECK(optional_boundary.value().device().is_cpu() &&
optional_boundary.value().dtype == torch::kInt64);
AT_DISPATCH_FLOATING_TYPES(px.scalar_type(), "mutual_information_cpu_backward_loop", ([&] { AT_DISPATCH_FLOATING_TYPES(px.scalar_type(), "mutual_information_cpu_backward_loop", ([&] {
auto px_a = px.packed_accessor32<scalar_t, 3>(), auto px_a = px.packed_accessor32<scalar_t, 3>(),
py_a = py.packed_accessor32<scalar_t, 3>(), // py_a = py.packed_accessor32<scalar_t, 3>(),
p_a = p.packed_accessor32<scalar_t, 3>(), p_a = p.packed_accessor32<scalar_t, 3>(),
p_grad_a = p_grad.packed_accessor32<scalar_t, 3>(), p_grad_a = p_grad.packed_accessor32<scalar_t, 3>(),
px_grad_a = px_grad.packed_accessor32<scalar_t, 3>(), px_grad_a = px_grad.packed_accessor32<scalar_t, 3>(),
py_grad_a = py_grad.packed_accessor32<scalar_t, 3>(); py_grad_a = py_grad.packed_accessor32<scalar_t, 3>();
auto ans_grad_a = ans_grad.packed_accessor32<scalar_t, 1>(); auto ans_grad_a = ans_grad.packed_accessor32<scalar_t, 1>();
auto boundary_a = boundary.packed_accessor32<int64_t, 2>();
auto boundary_a = optional_boundary.value().packed_accessor32<int64_t, 2>(); for (int b = 0; b < B; b++) {
for (int b = 0 b < B; b++) {
int s_begin, s_end, t_begin, t_end; int s_begin, s_end, t_begin, t_end;
if (has_boundary) { if (has_boundary) {
s_begin = boundary_a[b][0]; s_begin = boundary_a[b][0];
......
#include <torch/extension.h> #include <torch/extension.h>
/* /*
Forward of mutual_information. See also """... """ comment of Forward of mutual_information. See also """... """ comment of
`mutual_information` in mutual_information.py. This It is the core recursion `mutual_information` in mutual_information.py. This It is the core recursion
...@@ -33,8 +32,9 @@ ...@@ -33,8 +32,9 @@
contains, where for each batch element b, boundary[b] equals contains, where for each batch element b, boundary[b] equals
[s_begin, t_begin, s_end, t_end] [s_begin, t_begin, s_end, t_end]
which are the beginning and end (i.e. one-past-the-last) of the which are the beginning and end (i.e. one-past-the-last) of the
x and y sequences that we should process. If not set, these x and y sequences that we should process. Alternatively, may be
default to (0, 0, S, T); and they should not exceed these bounds. a tensor of shape [0][0] and type int64_t; the elements will
default to (0, 0, S, T).
ans: a tensor `ans` of shape [B], where this function will set ans: a tensor `ans` of shape [B], where this function will set
ans[b] = p[b][s_end][t_end], ans[b] = p[b][s_end][t_end],
with s_end and t_end being (S, T) if `boundary` was specified, with s_end and t_end being (S, T) if `boundary` was specified,
...@@ -48,7 +48,7 @@ ...@@ -48,7 +48,7 @@
*/ */
torch::Tensor mutual_information_cuda(torch::Tensor px, // [B][S][T+1] torch::Tensor mutual_information_cuda(torch::Tensor px, // [B][S][T+1]
torch::Tensor py, // [B][S+1][T] torch::Tensor py, // [B][S+1][T]
std::optional<torch::Tensor> boundary_info, // [B][4], int64_t. torch::Tensor boundary, // [B][4], int64_t.
torch::Tensor p); // [B][S+1][T+1]; an output torch::Tensor p); // [B][S+1][T+1]; an output
...@@ -63,7 +63,7 @@ torch::Tensor mutual_information_cuda(torch::Tensor px, // [B][S][T+1] ...@@ -63,7 +63,7 @@ torch::Tensor mutual_information_cuda(torch::Tensor px, // [B][S][T+1]
std::vector<torch::Tensor> mutual_information_backward_cuda( std::vector<torch::Tensor> mutual_information_backward_cuda(
torch::Tensor px, torch::Tensor px,
torch::Tensor py, torch::Tensor py,
std::optional<torch::Tensor> boundary_info, torch::Tensor boundary,
torch::Tensor p, torch::Tensor p,
torch::Tensor ans_grad, torch::Tensor ans_grad,
bool overwrite_ans_grad); bool overwrite_ans_grad);
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
#include <cmath> // for INFINITY #include <cmath> // for INFINITY
// returns log(exp(x) + exp(y)). // returns log(exp(x) + exp(y)).
__forceinline__ __device__ double LogAdd(double x, double y) { __forceinline__ __device__ double LogAdd(double x, double y) {
double diff; double diff;
...@@ -22,7 +21,7 @@ __forceinline__ __device__ double LogAdd(double x, double y) { ...@@ -22,7 +21,7 @@ __forceinline__ __device__ double LogAdd(double x, double y) {
} }
// returns log(exp(x) + exp(y)). // returns log(exp(x) + exp(y)).
__forceinline__ __device__ inline float LogAdd(float x, float y) { __forceinline__ __device__ float LogAdd(float x, float y) {
float diff; float diff;
if (x < y) { if (x < y) {
diff = x - y; diff = x - y;
...@@ -81,8 +80,9 @@ __forceinline__ __device__ inline float LogAdd(float x, float y) { ...@@ -81,8 +80,9 @@ __forceinline__ __device__ inline float LogAdd(float x, float y) {
contains, where for each batch element b, boundary[b] equals contains, where for each batch element b, boundary[b] equals
[s_begin, t_begin, s_end, t_end] [s_begin, t_begin, s_end, t_end]
which are the beginning and end (i.e. one-past-the-last) of the which are the beginning and end (i.e. one-past-the-last) of the
x and y sequences that we should process. If not set, these x and y sequences that we should process. Otherwise, must be
default to (0, 0, S, T); and they should not exceed these bounds. a tensor of shape [0][0] of type int64_t; the values will
default to (0, 0, S, T).
ans: a tensor `ans` of shape [B], where this function will set ans: a tensor `ans` of shape [B], where this function will set
ans[b] = p[b][s_end][t_end], ans[b] = p[b][s_end][t_end],
with s_end and t_end being (S, T) if `boundary` was specified, with s_end and t_end being (S, T) if `boundary` was specified,
...@@ -118,8 +118,8 @@ void mutual_information_kernel( ...@@ -118,8 +118,8 @@ void mutual_information_kernel(
// You can read the following expressions as simplifications of, for example, // You can read the following expressions as simplifications of, for example,
// num_s_blocks = ((S + 1) + BLOCK_SIZE - 1) / BLOCK_SIZE, // num_s_blocks = ((S + 1) + BLOCK_SIZE - 1) / BLOCK_SIZE,
// i.e. rounding-up division of (S + 1) by BLOCK_SIZE, and the same for (T + 1). // i.e. rounding-up division of (S + 1) by BLOCK_SIZE, and the same for (T + 1).
const int num_s_blocks = S / BLOCK_SIZE + 1, const int num_s_blocks = S / BLOCK_SIZE + 1;
num_t_blocks = T / BLOCK_SIZE + 1; //, num_t_blocks = T / BLOCK_SIZE + 1;
// num_blocks_this_iter is an upper bound on the number of blocks of size // num_blocks_this_iter is an upper bound on the number of blocks of size
// (BLOCK_SIZE by BLOCK_SIZE) that might be active on this iteration (`iter`). // (BLOCK_SIZE by BLOCK_SIZE) that might be active on this iteration (`iter`).
...@@ -174,7 +174,7 @@ void mutual_information_kernel( ...@@ -174,7 +174,7 @@ void mutual_information_kernel(
bool is_origin_block = (s_block_begin * t_block_begin == 0); bool is_origin_block = (s_block_begin * t_block_begin == 0);
if (boundary.size(0) != 0 && threadIdx.x < 4) if (boundary.size(0) != 0 && threadIdx.x < 4)
boundary_buf[threadDim.x] = boundary[b][threadDim.x]; boundary_buf[threadIdx.x] = boundary[b][threadIdx.x];
__syncthreads(); __syncthreads();
int s_begin = boundary_buf[0], int s_begin = boundary_buf[0],
t_begin = boundary_buf[1], t_begin = boundary_buf[1],
...@@ -384,7 +384,7 @@ void mutual_information_kernel( ...@@ -384,7 +384,7 @@ void mutual_information_kernel(
// should help us quite a bit. // should help us quite a bit.
for (int i = 0; i < block_S + block_T - 1; ++i) { for (int i = 0; i < block_S + block_T - 1; ++i) {
int s_in_block = threadIdx.x, int s_in_block = threadIdx.x,
t_in_block = i - block_s; t_in_block = i - s_in_block;
if (s_in_block < block_S && if (s_in_block < block_S &&
static_cast<unsigned int>(t_in_block) < static_cast<unsigned int>(block_T)) { static_cast<unsigned int>(t_in_block) < static_cast<unsigned int>(block_T)) {
int s = s_in_block + s_block_begin, int s = s_in_block + s_block_begin,
...@@ -489,7 +489,8 @@ void mutual_information_kernel( ...@@ -489,7 +489,8 @@ void mutual_information_kernel(
of p_grad, we need context on the top and right instead of the bottom and of p_grad, we need context on the top and right instead of the bottom and
left. So there are offsets of 1. left. So there are offsets of 1.
*/ */
template <typename scalar_t> template <typename scalar_t,
int BLOCK_SIZE>
__global__ __global__
void mutual_information_backward_kernel( void mutual_information_backward_kernel(
torch::PackedTensorAccessor32<scalar_t, 3> px, // B, S, T + 1, i.e. batch, x_seq_length, y_seq_length + 1 torch::PackedTensorAccessor32<scalar_t, 3> px, // B, S, T + 1, i.e. batch, x_seq_length, y_seq_length + 1
...@@ -751,7 +752,7 @@ void mutual_information_backward_kernel( ...@@ -751,7 +752,7 @@ void mutual_information_backward_kernel(
// mutual_information.py for documentation of the behavior of this function. // mutual_information.py for documentation of the behavior of this function.
torch::Tensor mutual_information_cuda(torch::Tensor px, torch::Tensor mutual_information_cuda(torch::Tensor px,
torch::Tensor py, torch::Tensor py,
std::optional<torch::Tensor> optional_boundary, torch::Tensor boundary,
torch::Tensor p) { torch::Tensor p) {
TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional"); TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional");
TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional."); TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional.");
...@@ -767,12 +768,16 @@ torch::Tensor mutual_information_cuda(torch::Tensor px, ...@@ -767,12 +768,16 @@ torch::Tensor mutual_information_cuda(torch::Tensor px,
T = px.size(2) - 1; T = px.size(2) - 1;
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T); TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T);
TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1); TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
TORCH_CHECK((boundary.size(0) == 0 && boundary.size(1) == 0) ||
(boundary.size(0) == B && boundary.size(1) == 4));
TORCH_CHECK(boundary.device().is_cuda() &&
boundary.dtype() == torch::kInt64);
torch::Tensor ans = torch::empty({B}, opts); torch::Tensor ans = torch::empty({B}, opts);
// num_threads and num_blocks and BLOCK_SIZE can be tuned. // num_threads and num_blocks and BLOCK_SIZE can be tuned.
// (however, num_threads may not be less than 128). // (however, num_threads may not be less than 128).
int num_threads = 128, const int num_threads = 128,
num_blocks = 256, num_blocks = 256,
BLOCK_SIZE = 32; BLOCK_SIZE = 32;
...@@ -783,21 +788,17 @@ torch::Tensor mutual_information_cuda(torch::Tensor px, ...@@ -783,21 +788,17 @@ torch::Tensor mutual_information_cuda(torch::Tensor px,
num_t_blocks = T / BLOCK_SIZE + 1, num_t_blocks = T / BLOCK_SIZE + 1,
num_iters = num_s_blocks + num_t_blocks - 1; num_iters = num_s_blocks + num_t_blocks - 1;
if ((bool)optional_boundary) AT_DISPATCH_FLOATING_TYPES(px.scalar_type(), "mutual_information_cuda_stub", ([&] {
TORCH_CHECK(optional_boundary.value().device().is_cuda(), for (int iter = 0; iter < num_iters; ++iter) {
"boundary information must be in CUDA tensor"); mutual_information_kernel<scalar_t, BLOCK_SIZE><<<num_blocks, num_threads>>>(
else px.packed_accessor32<scalar_t, 3>(),
optional_boundary = torch::empty({0, 0}, long_opts); py.packed_accessor32<scalar_t, 3>(),
p.packed_accessor32<scalar_t, 3>(),
for (int iter = 0; iter < num_iters; ++iter) { boundary.packed_accessor32<int64_t, 2>(),
mutual_information_kernel<scalar_t, BLOCK_SIZE><<<num_blocks, num_threads>>>( ans.packed_accessor32<scalar_t, 1>(),
px.packed_accessor32<scalar_t, 3>(), iter);
py.packed_accessor32<scalar_t, 3>(), }
p.packed_accessor32<scalar_t, 3>(), }));
optional_boundary.value().packed_accessor32<int64_t, 2>(),
ans.packed_accessor32<scalar_t, 1>(),
iter);
}
return ans; return ans;
} }
...@@ -807,12 +808,13 @@ torch::Tensor mutual_information_cuda(torch::Tensor px, ...@@ -807,12 +808,13 @@ torch::Tensor mutual_information_cuda(torch::Tensor px,
// If overwrite_ans_grad == true, will overwrite ans_grad with a value which // If overwrite_ans_grad == true, will overwrite ans_grad with a value which
// should be identical to the original ans_grad if the computation worked // should be identical to the original ans_grad if the computation worked
// as it should. // as it should.
torch::Tensor mutual_information_backward_cuda(torch::Tensor px, std::vector<torch::Tensor>
torch::Tensor py, mutual_information_backward_cuda(torch::Tensor px,
std::optional<torch::Tensor> optional_boundary, torch::Tensor py,
torch::Tensor p, torch::Tensor boundary,
torch::Tensor ans_grad, torch::Tensor p,
bool overwrite_ans_grad) { torch::Tensor ans_grad,
bool overwrite_ans_grad) {
TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional"); TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional");
TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional."); TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional.");
TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional."); TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional.");
...@@ -832,9 +834,13 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px, ...@@ -832,9 +834,13 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T); TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T);
TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1); TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
TORCH_CHECK(ans_grad.size(0) == b); TORCH_CHECK((boundary.size(0) == 0 && boundary.size(1) == 0) ||
(boundary.size(0) == B && boundary.size(1) == 4));
TORCH_CHECK(boundary.device().is_cuda() &&
boundary.dtype() == torch::kInt64);
TORCH_CHECK(ans_grad.size(0) == B);
bool has_boundary = (bool)optional_boundary; bool has_boundary = (boundary.size(0) != 0);
torch::Tensor p_grad = torch::empty({B, S + 1, T + 1}, opts), torch::Tensor p_grad = torch::empty({B, S + 1, T + 1}, opts),
px_grad = (has_boundary ? torch::zeros({B, S, T + 1}, opts) : px_grad = (has_boundary ? torch::zeros({B, S, T + 1}, opts) :
...@@ -855,25 +861,22 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px, ...@@ -855,25 +861,22 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
num_t_blocks = T / BLOCK_SIZE + 1, num_t_blocks = T / BLOCK_SIZE + 1,
num_iters = num_s_blocks + num_t_blocks - 1; num_iters = num_s_blocks + num_t_blocks - 1;
if (has_boundary)
TORCH_CHECK(optional_boundary.value().device().is_cuda(), AT_DISPATCH_FLOATING_TYPES(px.scalar_type(), "mutual_information_backward_stub", ([&] {
"boundary information must be in CUDA tensor"); for (int iter = num_iters - 1; iter >= 0; --iter) {
else mutual_information_backward_kernel<scalar_t, BLOCK_SIZE><<<num_blocks, num_threads>>>(
optional_boundary = torch::empty({0, 0}, long_opts); px.packed_accessor32<scalar_t, 3>(),
py.packed_accessor32<scalar_t, 3>(),
for (int iter = num_iters - 1; iter >= 0; --iter) { p.packed_accessor32<scalar_t, 3>(),
mutual_information_backward_kernel<scalar_t, BLOCK_SIZE><<<num_blocks, num_threads>>>( ans_grad.packed_accessor32<scalar_t, 1>(),
px.packed_accessor32<scalar_t, 3>(), p_grad.packed_accessor32<scalar_t, 3>(),
py.packed_accessor32<scalar_t, 3>(), px_grad.packed_accessor32<scalar_t, 3>(),
p.packed_accessor32<scalar_t, 3>(), py_grad.packed_accessor32<scalar_t, 3>(),
ans_grad.packed_accessor32<scalar_t, 1>, boundary.packed_accessor32<int64_t, 2>(),
p_grad.packed_accessor32<scalar_t, 3>(), iter,
px_grad.packed_accessor32<scalar_t, 3>(), overwrite_ans_grad);
py_grad.packed_accessor32<scalar_t, 3>(), }
optional_boundary.value().packed_accessor32<int64_t, 2>(), }));
iter,
overwrite_ans_grad);
}
std::cout << "p_grad = " << p_grad; std::cout << "p_grad = " << p_grad;
return std::vector<torch::Tensor>({px_grad, py_grad}); return std::vector<torch::Tensor>({px_grad, py_grad});
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment