Commit 9f929ab3 authored by Daniel Povey's avatar Daniel Povey
Browse files

Some cleanup

parent 2b90f668
...@@ -2,6 +2,6 @@ include requirements.txt ...@@ -2,6 +2,6 @@ include requirements.txt
include pyproject.toml include pyproject.toml
include LICENSE* include LICENSE*
recursive-include torch_mutual_information * recursive-include torch_mutual_information *
recursive-include doc/img * precursive-include doc/img *
recursive-include tests * recursive-include tests *
global-exclude *.pyc global-exclude *.pyc
\ No newline at end of file
...@@ -42,12 +42,12 @@ __forceinline__ __device__ inline float LogAdd(float x, float y) { ...@@ -42,12 +42,12 @@ __forceinline__ __device__ inline float LogAdd(float x, float y) {
/* /*
Forward of mutual_information. Each thread block handles blocks of (x, y) shape Forward of mutual_information. Each thread block computes blocks of the 'p'
equal to (BLOCK_SIZE, BLOCK_SIZE), e.g. (32, 32). Thread blocks loop over such array of (s, t) shape equal to (BLOCK_SIZE, BLOCK_SIZE), e.g. (32, 32).
blocks, but they might loop only once if there is not that much data to process. Thread blocks loop over such blocks, but they might loop only once if there is
We sequentially launch groups of threads in such a way that thread-blocks not that much data to process. We sequentially launch thread groups in
within a group do not depend on each other. such a way that thread-blocks within a group do not depend on each other
(see the "iter" parameter).
Template args: Template args:
scalar_t: the floating-point type, e.g. float, double, maybe half. scalar_t: the floating-point type, e.g. float, double, maybe half.
...@@ -59,9 +59,10 @@ __forceinline__ __device__ inline float LogAdd(float x, float y) { ...@@ -59,9 +59,10 @@ __forceinline__ __device__ inline float LogAdd(float x, float y) {
mutual_information.py for more info). Shape [B][S][T + 1] mutual_information.py for more info). Shape [B][S][T + 1]
py: log-odds ratio of generating next y in the sequence. py: log-odds ratio of generating next y in the sequence.
Shape [B][S + 1][T] Shape [B][S + 1][T]
p: matrix of mutual information of sub-sequences, that this p: This function writes to p[s][t] the mutual information between
function writes to. Shape [B][S + 1][T + 1]. This function sub-sequences of x and y of length s and t respectively.
computes the following recursion: Its shape is [B][S + 1][T + 1]. This function implements
the following recursion:
p[b,0,0] = 0.0 p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
...@@ -745,6 +746,8 @@ torch::Tensor mutual_information_cuda(torch::Tensor px, ...@@ -745,6 +746,8 @@ torch::Tensor mutual_information_cuda(torch::Tensor px,
torch::Tensor ans = torch::empty({B}, opts); torch::Tensor ans = torch::empty({B}, opts);
// num_threads and num_blocks and BLOCK_SIZE can be tuned.
// (however, num_threads may not be less than 128).
int num_threads = 128, int num_threads = 128,
num_blocks = 128, num_blocks = 128,
BLOCK_SIZE = 32; BLOCK_SIZE = 32;
...@@ -753,8 +756,10 @@ torch::Tensor mutual_information_cuda(torch::Tensor px, ...@@ -753,8 +756,10 @@ torch::Tensor mutual_information_cuda(torch::Tensor px,
num_t_blocks = T / BLOCK_SIZE + 1, num_t_blocks = T / BLOCK_SIZE + 1,
num_iters = num_s_blocks + num_t_blocks - 1; num_iters = num_s_blocks + num_t_blocks - 1;
bool has_boundary = (bool)optional_boundary; if ((bool)optional_boundary)
if (!has_boundary) TORCH_CHECK(optional_boundary.value().device().is_cuda(),
"boundary information must be in CUDA tensor");
else
optional_boundary = torch::empty({0, 0}, long_opts); optional_boundary = torch::empty({0, 0}, long_opts);
for (int iter = 0; iter < num_iters; ++iter) { for (int iter = 0; iter < num_iters; ++iter) {
...@@ -796,11 +801,14 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px, ...@@ -796,11 +801,14 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T); TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T);
TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1); TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
TORCH_CHECK(ans_grad.size(0) == b);
torch::Tensor p_grad = torch::empty({B, S + 1, T + 1}, opts), torch::Tensor p_grad = torch::empty({B, S + 1, T + 1}, opts),
px_grad = torch::empty({B, S, T + 1}, opts), px_grad = torch::empty({B, S, T + 1}, opts),
py_grad = torch::empty({B, S + 1, T}, opts), py_grad = torch::empty({B, S + 1, T}, opts),
// num_threads and num_blocks and BLOCK_SIZE can be tuned.
// (however, num_threads may not be less than 128).
const int num_threads = 128, const int num_threads = 128,
num_blocks = 128, num_blocks = 128,
BLOCK_SIZE = 32; BLOCK_SIZE = 32;
...@@ -809,8 +817,10 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px, ...@@ -809,8 +817,10 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
num_t_blocks = T / BLOCK_SIZE + 1, num_t_blocks = T / BLOCK_SIZE + 1,
num_iters = num_s_blocks + num_t_blocks - 1; num_iters = num_s_blocks + num_t_blocks - 1;
bool has_boundary = (bool)optional_boundary; if ((bool)optional_boundary)
if (!has_boundary) TORCH_CHECK(optional_boundary.value().device().is_cuda(),
"boundary information must be in CUDA tensor");
else
optional_boundary = torch::empty({0, 0}, long_opts); optional_boundary = torch::empty({0, 0}, long_opts);
for (int iter = num_iters - 1; iter >= 0; --iter) { for (int iter = num_iters - 1; iter >= 0; --iter) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment