Commit 9f929ab3 authored by Daniel Povey's avatar Daniel Povey
Browse files

Some cleanup

parent 2b90f668
......@@ -2,6 +2,6 @@ include requirements.txt
include pyproject.toml
include LICENSE*
recursive-include torch_mutual_information *
recursive-include doc/img *
precursive-include doc/img *
recursive-include tests *
global-exclude *.pyc
\ No newline at end of file
......@@ -42,12 +42,12 @@ __forceinline__ __device__ inline float LogAdd(float x, float y) {
/*
Forward of mutual_information. Each thread block handles blocks of (x, y) shape
equal to (BLOCK_SIZE, BLOCK_SIZE), e.g. (32, 32). Thread blocks loop over such
blocks, but they might loop only once if there is not that much data to process.
We sequentially launch groups of threads in such a way that thread-blocks
within a group do not depend on each other.
Forward of mutual_information. Each thread block computes blocks of the 'p'
array of (s, t) shape equal to (BLOCK_SIZE, BLOCK_SIZE), e.g. (32, 32).
Thread blocks loop over such blocks, but they might loop only once if there is
not that much data to process. We sequentially launch thread groups in
such a way that thread-blocks within a group do not depend on each other
(see the "iter" parameter).
Template args:
scalar_t: the floating-point type, e.g. float, double, maybe half.
......@@ -59,9 +59,10 @@ __forceinline__ __device__ inline float LogAdd(float x, float y) {
mutual_information.py for more info). Shape [B][S][T + 1]
py: log-odds ratio of generating next y in the sequence.
Shape [B][S + 1][T]
p: matrix of mutual information of sub-sequences, that this
function writes to. Shape [B][S + 1][T + 1]. This function
computes the following recursion:
p: This function writes to p[s][t] the mutual information between
sub-sequences of x and y of length s and t respectively.
Its shape is [B][S + 1][T + 1]. This function implements
the following recursion:
p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
......@@ -745,6 +746,8 @@ torch::Tensor mutual_information_cuda(torch::Tensor px,
torch::Tensor ans = torch::empty({B}, opts);
// num_threads and num_blocks and BLOCK_SIZE can be tuned.
// (however, num_threads may not be less than 128).
int num_threads = 128,
num_blocks = 128,
BLOCK_SIZE = 32;
......@@ -753,8 +756,10 @@ torch::Tensor mutual_information_cuda(torch::Tensor px,
num_t_blocks = T / BLOCK_SIZE + 1,
num_iters = num_s_blocks + num_t_blocks - 1;
bool has_boundary = (bool)optional_boundary;
if (!has_boundary)
if ((bool)optional_boundary)
TORCH_CHECK(optional_boundary.value().device().is_cuda(),
"boundary information must be in CUDA tensor");
else
optional_boundary = torch::empty({0, 0}, long_opts);
for (int iter = 0; iter < num_iters; ++iter) {
......@@ -796,11 +801,14 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T);
TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
TORCH_CHECK(ans_grad.size(0) == b);
torch::Tensor p_grad = torch::empty({B, S + 1, T + 1}, opts),
px_grad = torch::empty({B, S, T + 1}, opts),
py_grad = torch::empty({B, S + 1, T}, opts),
// num_threads and num_blocks and BLOCK_SIZE can be tuned.
// (however, num_threads may not be less than 128).
const int num_threads = 128,
num_blocks = 128,
BLOCK_SIZE = 32;
......@@ -809,8 +817,10 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
num_t_blocks = T / BLOCK_SIZE + 1,
num_iters = num_s_blocks + num_t_blocks - 1;
bool has_boundary = (bool)optional_boundary;
if (!has_boundary)
if ((bool)optional_boundary)
TORCH_CHECK(optional_boundary.value().device().is_cuda(),
"boundary information must be in CUDA tensor");
else
optional_boundary = torch::empty({0, 0}, long_opts);
for (int iter = num_iters - 1; iter >= 0; --iter) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment