Commit e95d7864 authored by Daniel Povey's avatar Daniel Povey
Browse files

Drafts..

parent 621d5fbb
......@@ -84,16 +84,19 @@ class MutualInformationRecursionFunction(torch.autograd.Function):
# has t_begin > 0 or s_begin > 0, i.e. we really access q as
# q[b, s-s_begin + t-t_begin, t-t_begin];
# note, rows of `boundaries` are [s_begin, t_begin, s_end, t_end].
# We don't need q if we are not going to do backprop
q = (torch.empty(B, S + T, device=px.device, dtype=px.dtype)
if px.requires_grad or py.requires_grad
else None)
if px.requires_grad or py.requires_grad:
q = torch.empty(B, S, T, device=px.device, dtype=px.dtype)
else:
# We don't need to store q if we are not going to do backprop, but we
# do pass in a temporary with one real row, expanded to have "fake" rows,
# which happens to be convenient for the CPU implementation.
q = torch.empty({1, 1, T}, device=px.device, dtype=px.dtype).expand(B, S + T, T)
ans = _mutual_information_forward_dispatcher(px, py, boundaries, q)
if px.requires_grad or py.requires_grad:
ctx.save_for_backward(px, py, boundaries, w)
ctx.save_for_backward(px, py, boundaries, q)
@staticmethod
def backward(ctx, ans_grad: Tensor) -> Tuple[torch.Tensor, torch.Tensor, None]:
......@@ -109,7 +112,7 @@ def mutual_information_recursion(input, px, py, boundaries=None):
monotonic alignment between pairs of sequences is desired. The definitions of
the arguments are definitions that would be used when computing this type of
mutual information, but you can also view them as arbitrary quantities and just
look at the formula computed by this function.
make use of the formula computed by this function.
Args:
px: A torch.Tensor of some floating point type, with shape [B][S][T],
......@@ -131,6 +134,11 @@ def mutual_information_recursion(input, px, py, boundaries=None):
log(N exp f(x_s, y_{t-1}) / sum_t' exp f(x_s, y_t'))
where N is the number of terms that the sum over t' included, which
might include some or all of the other sequences as well as this one.
Note: we don't require px and py to be contiguous, but the
code assumes for optimization purposes that the T axis has
stride 1.
py: A torch.Tensor of the same dtype as px, with shape [B][S][T],
representing
py[b][s][t] = log [ p(y_t | x_{0..s-1}, y_{0..t-1}) / p(y_t) ]
......
#include <math.h> // for log1p, log1pf
#include <torch/extension.h>
// returns log(exp(x) + exp(y)).
inline double LogAdd(double x, double y) {
double diff;
if (x < y) {
diff = x - y;
x = y;
} else {
diff = y - x;
}
// diff is negative. x is now the larger one.
if (diff >= kMinLogDiffDouble) {
double res;
res = x + log1p(exp(diff));
return res;
}
return x; // return the larger one.
}
// returns log(exp(x) + exp(y)).
inline float LogAdd(float x, float y) {
float diff;
if (x < y) {
diff = x - y;
x = y;
} else {
diff = y - x;
}
// diff is negative. x is now the larger one.
if (diff >= kMinLogDiffFloat) {
float res;
res = x + log1pf(expf(diff));
return res;
}
return x; // return the larger one.
}
// forward of mutual_information. See """... """ comment of `mutual_information` in
// mutual_information.py for documentation of the behavior of this function.
torch::Tensor mutual_information_cpu(torch::Tensor input,
torch::Tensor params) {
TORCH_CHECK(input.dim() == 3, "input must be 3-dimensional");
TORCH_CHECK(params.dim() == 2, "params must be 2-dimensional.");
TORCH_CHECK(params.size(1) >= 3 &&
((params.size(1) - 1) & (params.size(1) - 2)) == 0,
"params.size(1) has invalid value, must be a power of 2 plus 1.");
TORCH_CHECK(params.size(0) == input.size(1),
"params vs input channels mismatch");
torch::Tensor mutual_information_cpu(torch::Tensor px,
torch::Tensor py,
std::optional<torch::Tensor> optional_boundary,
torch::Tensor q) {
TORCH_CHECK(input.device().is_cpu(), "Input must be a CPU tensor");
TORCH_CHECK(params.device().is_cpu(), "Params must be a CPU tensor");
TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional");
TORCH_CHECK(py.dim() == 3, "params must be 3-dimensional.");
TORCH_CHECK(q.dim() == 3, "params must be 3-dimensional.");
const int B = input.size(0),
C = input.size(1),
T = input.size(2),
N = params.size(1) - 1,
K = N / 2;
auto scalar_t = px.scalar_type();
auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device());
auto scalar_t = input.scalar_type();
auto opts = torch::TensorOptions().dtype(scalar_t).device(input.device());
torch::Tensor y_vals = torch::empty({C, N}, opts),
output = torch::empty({B, C, T}, opts);
const int B = px.size(0),
S = px.size(1),
T = px.size(2);
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "mutual_information_cpu_loop", ([&] {
auto params_a = params.accessor<scalar_t, 2>(),
y_vals_a = y_vals.accessor<scalar_t, 2>();
TORCH_CHECK(q.size(0) == B && q.size(1) == S + T && q.size(2) == T);
auto long_opts = torch::TensorOptiona().dtype(torch::kInt64);
bool has_boundary = (bool)optional_boundary;
if (!has_boundary)
optional_boundary = torch::empty({}, long_opts);
AT_DISPATCH_FLOATING_TYPES(px.scalar_type(), "mutual_information_cpu_loop", ([&] {
auto px_a = px.accessor<scalar_t, 3>(),
py_a = py.accessor<scalar_t, 3>();
for (int c = 0; c < C; c++) {
scalar_t sum_negative = 0.0,
sum_positive = 0.0,
......
......@@ -3,7 +3,6 @@
#include <cooperative_groups.h>
#define THREADS_PER_BLOCK 256
......@@ -43,9 +42,11 @@ __forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile,
/*
Forward of mutual_information. Each thread group handles a single channel (channel
c = blockIdx.x); the gridDim is (C, nb, 1) where 1 <= nb <= B (nb relates to the
image within the batch).
Forward of mutual_information. Each thread block handles blocks of (x, y) shape
equal to (BLOCK_S_SIZE, BLOCK_T_SIZE), e.g. (4, 64). Thread blocks loop over such
blocks, but they might typically loop only once. We sequentially launch groups of
threads in such a way that thread-blocks within a group do not depend on each other.
Template args:
scalar_t: the floating-point type, e.g. float, double, maybe half.
......@@ -88,17 +89,138 @@ __forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile,
*/
extern __shared__ int extern_buf[];
template <typename scalar_t>
template <typename scalar_t,
int BLOCK_S_SIZE, // e.g. BLOCK_S_SIZE == 4; power of 2
int BLOCK_T_SIZE> // e.g. BLOCK_T_SIZE == 64; power of 2.
// BLOCK_T_SIZE * 4 must equal num_threads; and must be >= 128, so BLOCK_T_SIZE >= 32 is required.
// (Note: this 4 is unrelated to BLOCK_S_SIZE but can be viewed as 1<<2,
// where 2 is the loop unrolling factor).
__global__
void mutual_information_kernel(
torch::PackedTensorAccessor32<scalar_t, 3> input, // B, C, T, i.e. batch, channels, time
torch::PackedTensorAccessor32<scalar_t, 2> params, // C, N + 1
torch::PackedTensorAccessor32<scalar_t, 3> output,
int images_per_thread_block) { // B, C, T
torch::PackedTensorAccessor32<scalar_t, 3> px, // B, S, T, i.e. batch, x_seq_length, y_seq_length
torch::PackedTensorAccessor32<scalar_t, 3> py, // B, S, T, as above
torch::PackedTensorAccessor32<scalar_t, 3> p, // B, S, T, as above. This is an output.
torch::PackedTensorAccessor32<scalar_t, 2> boundary, // B, 4; or 0, 0 if boundaries are the defaults (0, 0, S, T)
int iter) { // This kernel is sequentially called with 'iter' = 0, 1, 2 and so on, up to:
// (S+BLOCK_S_SIZE-1)/BLOCK_S_SIZE + (T+BLOCK_T_SIZE-1)/BLOCK_T_SIZE - 1
// so that each group depends on the previous group...
const int block_dimx = BLOCK_T_SIZE * 4; // known at compile time.
assert(blockDim.x == block_dimx);
const int B = px.size(0),
S = px.size(1),
T = py.size(2);
// num_s_blocks and num_t_blocks are the number of blocks we need to cover the
// array of size (S, T) with blocks of this size, in the s and t directions
// respectively.
const int num_s_blocks = (S + BLOCK_S_SIZE - 1) / BLOCK_S_SIZE,
num_t_blocks = (T + BLOCK_T_SIZE - 1) / BLOCK_T_SIZE;
// num_blocks_this_iter is an upper bound on the number of blocks that might
// be active on this iteration. We go from the bottom left of the image
// so that on iter == 0 we process only one block with block-index (0, 0)
// then on iter == 1 we process block-indexes (1, 0) and (0, 1); and then on iter==2
// we process (2, 0), (1, 1) and (0, 2); and so on. We also will never have more
// than `num_s_blocks` blocks (We'll never have more than num_t_blocks either, but
// the numbering we use corresponds to s and not t, so if we hit the num_t_blocks limit,
// the lowest-numbered blocks on s would just not be active and we'll 'continue' below).
int num_blocks_this_iter = min(iter + 1, num_s_blocks);
__shared__ scalar_t px_buf[BLOCK_S_SIZE][BLOCK_T_SIZE],
py_buf[BLOCK_S_SIZE][BLOCK_T_SIZE],
p_buf[BLOCK_S_SIZE + 1][BLOCK_T_SIZE + 1]; // 1st row/col of p_buf
// correspond to the previous
// blocks, or an edge case.
__shared__ boundary_buf[4];
// batch_block_iter iterates over both batch elements (index b), and block
// indexes
for (int batch_block_iter = blockIdx.x;
batch_block_iter < B * num_blocks_this_iter;
batch_block_iter += gridDim.x) {
int b = batch_block_iter % B,
block = batch_block_iter / B;
int s_block_begin = block * BLOCK_S_SIZE,
t_block_begin = (iter - block) * BLOCK_T_SIZE;
bool is_origin_block = (s_block_begin * t_block_begin == 0);
int s_end, t_end; // s_end and t_end are the end points (last-plus-one) of the entire sequence.
if (boundary.size(0) == 0) {
s_end = S;
t_end = T;
} else {
if (threadDim.x < 4)
boundary_buf[threadDim.x] = boundary[b][threadDim.x];
__syncthreads();
int s_begin = boundary_buf[0],
t_begin = boundary_buf[1];
s_end = boundary_buf[2];
t_end = boundary_buf[3];
s_block_begin += s_begin;
t_block_begin += t_begin;
}
// block_S and block_T are the actual sizes of this block, up to
// (BLOCK_S_SIZE, BLOCK_T_SIZE) but possibly truncated if we
// are towards the end of the sequence.
int block_S = min(BLOCK_T_SIZE, s_end - s_block_begin),
block_T = min(BLOCK_S_SIZE, t_end - t_block_begin);
if (block_S <= 0 || block_T <= 0)
continue;
// Load px_buf and py_buf. We exponentiate; the assumption is that they
// won't overflow or underflow! If they overflow we'll detect it later!
for (int i = threadDim.x; i < BLOCK_S_SIZE * BLOCK_T_SIZE; i += block_dimx) {
int t = i % BLOCK_T_SIZE, s = i / BLOCK_T_SIZE;
if (s < block_S && t < block_T) {
px_buf[s][t] = exp(px[b][s + s_block_begin][t + t_block_begin]);
py_buf[s][t] = exp(py[b][s + s_block_begin][t + t_block_begin]);
} else { // Not necessary? We'll see
px_buf[s][t] = 0.0;
py_buf[s][t] = 0.0;
}
}
// Load the 1st row and column of p_buf (except element[0][0] is not needed).
if (threadIdx.x < 64) { // 64 == warp size...
if (threadIdx.x <= BLOCK_S_SIZE) {
// this s and t are offsets relative to the block start
int s = threadIdx.x - 1,
t = -1;
if (static_cast<unsigned int>(s + s_block_begin) < static_cast<unsigned int>(block_S) &&
static_cast<unsigned int>(t + t_block_begin) < static_cast<unsigned int>(block_T))
p_buf[threadIdx.x][0] = p[s + s_block_begin][s + t_block_begin];
else
p_buf[threadIdx.x][0] = -infinity;
}
} else {
if (threadIdx.x - 64 <= BLOCK_T_SIZE) {
int i = threadIdx.x - 64,
t = i - 1,
s = -1;
if (static_cast<unsigned int>(s + s_block_begin) < static_cast<unsigned int>(block_S) &&
static_cast<unsigned int>(t + t_block_begin) < static_cast<unsigned int>(block_T))
p_buf[0][i] = p[s + s_block_begin][s + t_block_begin];
else {
p_buf[0][i] = (is_origin_block && i == 1 ? 1.0 /
-infinity;
}
}
}
const int B = input.size(0),
C = input.size(1),
T = input.size(2),
N = params.size(1) - 1,
K = N / 2; // Note: N and K are powers of 2, with K >= 1.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment