Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
2b90f668
Commit
2b90f668
authored
Jul 28, 2021
by
Daniel Povey
Browse files
Pretty close to finishing all the core code, but need to check through it.
parent
77eed83f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
245 additions
and
436 deletions
+245
-436
torch_mutual_information/mutual_information_cuda_kernel.cu
torch_mutual_information/mutual_information_cuda_kernel.cu
+245
-436
No files found.
torch_mutual_information/mutual_information_cuda_kernel.cu
View file @
2b90f668
...
...
@@ -5,8 +5,6 @@
// returns log(exp(x) + exp(y)).
__forceinline__
__device__
double
LogAdd
(
double
x
,
double
y
)
{
double
diff
;
...
...
@@ -27,7 +25,6 @@ __forceinline__ __device__ double LogAdd(double x, double y) {
// returns log(exp(x) + exp(y)).
__forceinline__
__device__
inline
float
LogAdd
(
float
x
,
float
y
)
{
float
diff
;
if
(
x
<
y
)
{
diff
=
x
-
y
;
x
=
y
;
...
...
@@ -118,9 +115,12 @@ void mutual_information_kernel(
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
p
,
// B, S + 1, T + 1. This is an output.
torch
::
PackedTensorAccessor32
<
int64_t
,
2
>
boundary
,
// B, 4; or 0, 0 if boundaries are the defaults (0, 0, S, T)
torch
::
PackedTensorAccessor32
<
scalar_t
,
1
>
ans
,
// [B]
int
iter
)
{
// This kernel is sequentially called with 'iter' = 0, 1, 2 and so on, up to:
// (S+BLOCK_S_SIZE-1)/BLOCK_S_SIZE + (T+BLOCK_T_SIZE-1)/BLOCK_T_SIZE - 1
// so that each group depends on the previous group...
int
iter
)
{
// This kernel is sequentially called with 'iter' = 0, 1, 2 and so on,
// up to num_iters - 1 where
// num_iters = num_s_blocks + num_t_blocks - 1
// num_s_blocks = S / BLOCK_SIZE + 1
// num_t_blocks = T / BLOCK_SIZE + 1
// so that each group depends on the previous group...
const
int
B
=
px
.
size
(
0
),
S
=
px
.
size
(
1
),
T
=
py
.
size
(
2
);
...
...
@@ -180,35 +180,36 @@ void mutual_information_kernel(
int
s_block_begin
=
block
*
BLOCK_S_SIZE
,
t_block_begin
=
(
iter
-
block
)
*
BLOCK_T_SIZE
;
bool
is_origin_block
=
(
s_block_begin
*
t_block_begin
==
0
);
int
s_end
,
t_end
;
// s_end and t_end are the end points (last-plus-one) of the entire sequence.
if
(
threadDim
.
x
<
4
&&
boundary
.
size
(
0
)
!=
0
)
boundary_buf
[
threadDim
.
x
]
=
boundary
[
b
][
threadDim
.
x
];
__syncthreads
();
int
s_begin
=
boundary_buf
[
0
],
t_begin
=
boundary_buf
[
1
]
;
s_end
=
boundary_buf
[
2
]
;
t_end
=
boundary_buf
[
3
];
t_begin
=
boundary_buf
[
1
]
,
s_end
=
boundary_buf
[
2
]
,
t_end
=
boundary_buf
[
3
];
s_block_begin
+=
s_begin
;
t_block_begin
+=
t_begin
;
// block_S and block_T are the actual sizes of this block, no greater than
// (BLOCK_SIZE, BLOCK_SIZE) but possibly less than that if we are towards
// the end of the sequence.
int
block_S
=
min
(
BLOCK_SIZE
,
s_end
-
s_block_begin
),
block_T
=
min
(
BLOCK_SIZE
,
t_end
-
t_block_begin
);
// The last element of the output matrix p we write is (s_end, t_end),
// i.e. the one-past-the-end index is (s_end + 1, t_end + 1).
int
block_S
=
min
(
BLOCK_SIZE
,
s_end
+
1
-
s_block_begin
),
block_T
=
min
(
BLOCK_SIZE
,
t_end
+
1
-
t_block_begin
);
if
(
block_S
<=
0
||
block_T
<=
0
)
continue
;
bool
is_origin_block
=
(
s_block_begin
*
t_block_begin
==
0
);
// Load px_buf and py_buf. We exponentiate; the assumption is that they most likely
// won't overflow or underflow, but if they do overflow we'll detect it later; we'll
// also detect certain kinds of underflow.
for
(
int
i
=
threadDim
.
x
;
i
<
BLOCK_SIZE
*
BLOCK_SIZE
;
i
+=
blockDim
.
x
)
{
int
t
_in_block
=
i
%
BLOCK_SIZE
,
s
_in_block
=
i
/
BLOCK_SIZE
,
int
s
_in_block
=
i
/
BLOCK_SIZE
,
t
_in_block
=
i
%
BLOCK_SIZE
,
s
=
s_in_block
+
s_block_begin
,
t
=
t_in_block
+
t_block_begin
;
...
...
@@ -305,7 +306,7 @@ void mutual_information_kernel(
p_buf_s1_t
=
p_buf
[
s
+
1
][
0
];
}
for
(
int
i
=
1
;
i
<
2
*
BLOCK_SIZE
;
i
++
)
{
for
(
int
i
=
1
;
i
<
block_S
+
block_T
;
i
++
)
{
// i is the inner iteration, which corresponds to the (s + t) indexes of the
// elements within the block that we write. So i == 0 writes positions
// (s, t) == (0, 0); i == 1 writes (0, 1) and (1, 0); i == 2 writes
...
...
@@ -402,44 +403,6 @@ void mutual_information_kernel(
/*
Summing reduction within a one-dimensional thread block, but with a
stride of N, so that we separately sum up the values of all threads with
threadIdx.x % N == 0, with threadIdx.x % N == 1, and so on. At the end,
threads with 0 <= threadIdx.x < N contain the sums.
So this is like tiled summing reduction except that the tiles are
interspersed with each other.
Args:
N: The number we sum modulo (must be a power of 2 with
1 <= N <= blockDim.x), i.e. all threads with
threadIdx.x % N == n for some 0 <= n < N have `val` summed.
buf: Pointer to the start of a __shared__ buffer of size
blockDim.x, to be used as a temporary within this function.
val: The value to be summed
Return:
Threads where threadIdx.x < N will return the sums (over the threads with
the same value of threadIdx.x % N);
the return value in other threads is undefined.
*/
template
<
typename
scalar_t
>
__forceinline__
__device__
scalar_t
strided_reduce_sum
(
int
N
,
__volatile__
scalar_t
*
buf
,
scalar_t
val
)
{
// Each iteration halves the number of active threads
// Each thread adds its partial sum[i] to sum[lane+i]
for
(
int
i
=
blockDim
.
x
/
2
;
i
>=
N
;
i
/=
2
)
{
buf
[
threadIdx
.
x
]
=
val
;
__syncthreads
();
if
(
threadIdx
.
x
<
i
)
val
+=
buf
[
threadIdx
.
x
+
i
];
}
return
val
;
// Only threads with threadIdx.x < N will return the full sums of
// their groups.
}
/*
Backward of mutual_information.
...
...
@@ -503,7 +466,7 @@ __forceinline__ __device__ scalar_t strided_reduce_sum(int N,
p_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] +
p_grad[b][s][t + 1] * yderiv[b][s][t] (eq. 6)
px_grad[b][s][t] = p_grad[b][s
][t
+ 1] * yderiv[b][s][t] (eq. 7)
px_grad[b][s][t] = p_grad[b][s + 1]
[t]
* yderiv[b][s][t] (eq. 7)
py_grad[b][s][t] = p_grad[b][s][t + 1] * yderiv[b][s][t] (eq. 8)
(It might seem like we could just reuse px_grad and py_grad for (eq. 6), but it's
...
...
@@ -525,9 +488,11 @@ void mutual_information_backward_kernel(
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
px_grad
,
// B, S, T + 1.
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
py_grad
,
// B, S + 1, T.
torch
::
PackedTensorAccessor32
<
int64_t
,
2
>
boundary
,
// B, 4; or 0, 0 if boundaries are the defaults (0, 0, S, T)
int
iter
)
{
// This kernel is sequentially called with 'iter' = num_iters - 1, num_iters - 2, .. 0,
// where num_iters can be taken to be any sufficiently large number but will actually be:
// (S+BLOCK_S_SIZE-1)/BLOCK_S_SIZE + (T+BLOCK_T_SIZE-1)/BLOCK_T_SIZE - 1
int
iter
)
{
// This kernel is sequentially called with 'iter' = num_iters
// - 1, num_iters - 2, .. 0, where num_iters can be taken to
// be any sufficiently large number but will actually be:
// num_s_blocks + num_t_blocks - 1 where num_s_blocks = S /
// BLOCK_SIZE + 1 and num_t_blocks = T / BLOCK_SIZE + 1
const
int
B
=
px
.
size
(
0
),
S
=
px
.
size
(
1
),
T
=
py
.
size
(
2
);
...
...
@@ -543,6 +508,9 @@ void mutual_information_backward_kernel(
// but then modified to store the "xderiv" and "yderiv" values defined
// in (eq. 5) and (eq. 6) above. For out-of-range values, we'll write 0.0
// here.
// px_buf[s][t] contains px[s+s_block_begin][t+t_block_begin];
// py_buf[s][t] contains py[s+s_block_begin][t+t_block_begin].
// Unlike in the forward code, there is no offset of 1 in the indexes.
__shared__
scalar_t
px_buf
[
BLOCK_SIZE
][
BLOCK_SIZE
],
py_buf
[
BLOCK_SIZE
][
BLOCK_SIZE
];
...
...
@@ -565,278 +533,195 @@ void mutual_information_backward_kernel(
// boundary information supplied.
__shared__
int64_t
boundary_buf
[
4
];
boundary_buf
[
0
]
=
0
;
boundary_buf
[
1
]
=
0
;
boundary_buf
[
2
]
=
S
;
boundary_buf
[
3
]
=
T
;
const
int
B
=
input
.
size
(
0
),
C
=
input
.
size
(
1
),
T
=
input
.
size
(
2
),
N
=
params
.
size
(
1
)
-
1
,
K
=
N
/
2
;
// Note: N and K are powers fo 2, with K >= 1.
const
int
c
=
blockIdx
.
x
;
// c is channel index
scalar_t
*
y_vals
=
(
scalar_t
*
)
extern_buf
,
// [N], actually there are three
// spaces between here and
// `params_buf` for storing scale
// and inv_scale and l == params[c][0].
*
params_buf
=
(
scalar_t
*
)
y_vals
+
3
+
N
;
// [N]. Contains parameters (not times scale!)
// Caution: contains params[c][1] through params[c][N],
// i.e. numbering is off by 1 versus params.
// params_buf[-1] contains params[c][0] == log of scale;
// params_buf[-2] and params_buf[-3] contain scale and inv_scale.
__shared__
scalar_t
input_buf
[
THREADS_PER_BLOCK
];
// input sequence
__shared__
scalar_t
output_grad_buf
[
THREADS_PER_BLOCK
];
__shared__
char
n_buf
[
THREADS_PER_BLOCK
];
// for each input in `input_buf`,
// this stores the integer value 0
// <= n < N which determines which
// piece of the piecewise linear
// function we are in.
// Load parameters
if
(
threadIdx
.
x
<=
N
)
params_buf
[
threadIdx
.
x
-
1
]
=
params
[
c
][
threadIdx
.
x
];
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
scalar_t
scale
=
exp
(
params_buf
[
-
1
]);
params_buf
[
-
2
]
=
scale
;
params_buf
[
-
3
]
=
1.0
/
scale
;
boundary_buf
[
0
]
=
0
;
boundary_buf
[
1
]
=
0
;
boundary_buf
[
2
]
=
S
;
boundary_buf
[
3
]
=
T
;
}
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
scalar_t
scale
=
params_buf
[
-
2
],
sum_positive
=
0.0
;
for
(
int
i
=
0
;
i
<
K
;
i
++
)
{
// params_buf is indexed with an index one less than params.
scalar_t
pos_scaled_param
=
params_buf
[
K
+
i
]
*
scale
;
y_vals
[
K
+
i
]
=
sum_positive
-
pos_scaled_param
*
i
;
sum_positive
+=
pos_scaled_param
;
}
}
else
if
(
threadIdx
.
x
==
64
)
{
scalar_t
scale
=
params_buf
[
-
2
],
sum_negative
=
0.0
;
for
(
int
i
=
0
;
i
<
K
;
i
++
)
{
scalar_t
neg_scaled_param
=
params_buf
[
K
-
i
-
1
]
*
scale
;
sum_negative
-=
neg_scaled_param
;
y_vals
[
K
-
i
-
1
]
=
sum_negative
+
neg_scaled_param
*
(
i
+
1
);
}
}
__syncthreads
();
// this_param_grad and this_y_grad pertain to the 'n' value (i.e. the n'th
// linear interval) corresponding to n == threadIdx.x % N. For example, if
// threadIdx.x == 0, this thread's gradient corresponds to the left-most
// linear interval.
scalar_t
this_param_grad
=
0.0
,
this_y_vals_grad
=
0.0
;
scalar_t
inv_scale
=
params_buf
[
-
3
];
int
T_inc
=
THREADS_PER_BLOCK
/
images_per_thread_block
,
b_offset
=
threadIdx
.
x
/
T_inc
;
// offset within batch
for
(
int
b
=
blockIdx
.
y
*
images_per_thread_block
+
b_offset
;
b
<
B
;
b
+=
gridDim
.
y
*
images_per_thread_block
)
{
// The following will loop just once if images_per_thread_block > 1. If
// images_per_thread_block == 1 and T > THREADS_PER_BLOCK, we will loop
// multiple times. We want to keep all threads active so that output_grad
// will be set to zero for excess threads, and thus won't contribute to
// this_params_grad or this_y_vals_grad.
for
(
int
t_offset
=
0
;
t_offset
<
T
;
t_offset
+=
THREADS_PER_BLOCK
)
{
// The following is equivalent to:
// int t = (threadIdx.x % T_inc) + t_offset;
// given that T_inc is a power of 2 and t_offset >= THREADS_PER_BLOCK >= T_inc.
int
t
=
(
threadIdx
.
x
&
(
T_inc
-
1
))
|
t_offset
;
scalar_t
this_input
=
0.0
,
this_output_grad
;
if
(
t
<
T
)
{
this_output_grad
=
output_grad
[
b
][
c
][
t
];
this_input
=
input
[
b
][
c
][
t
];
input_buf
[
threadIdx
.
x
]
=
this_input
;
output_grad_buf
[
threadIdx
.
x
]
=
this_output_grad
;
}
scalar_t
x
=
this_input
*
inv_scale
+
K
;
if
(
x
<
0
)
x
=
0
;
else
if
(
x
>=
N
)
x
=
N
-
1
;
// The forward code did:
// output[b][c][t] = this_input * params_buf[n] + y_vals[n];
// We get the derivative for params and y_vals later.
if
(
t
<
T
)
{
int
n
=
(
int
)
x
;
// C++ rounds toward zero.
n_buf
[
threadIdx
.
x
]
=
(
char
)
n
;
input_grad
[
b
][
c
][
t
]
=
this_output_grad
*
params_buf
[
n
];
}
else
{
n_buf
[
threadIdx
.
x
]
=
255
;
}
// batch_block_iter iterates over both batch elements (index b), and block
// indexes in the range [0..num_blocks_this_iter-1]. The order here
// doesn't matter, since there are no interdependencies between these
// blocks (they are on a diagonal).
for
(
int
batch_block_iter
=
blockIdx
.
x
;
batch_block_iter
<
B
*
num_blocks_this_iter
;
batch_block_iter
+=
gridDim
.
x
)
{
int
b
=
batch_block_iter
%
B
,
block
=
batch_block_iter
/
B
;
int
s_block_begin
=
block
*
BLOCK_S_SIZE
,
t_block_begin
=
(
iter
-
block
)
*
BLOCK_T_SIZE
;
int
this_block_start
=
threadIdx
.
x
&
~
(
N
-
1
),
// == N * (threadIdx.x / N),
// since N is power of 2
this_n
=
threadIdx
.
x
&
(
N
-
1
);
// == threadIdx.x % N.
// this_n is the n value that this thread accumulates gradients for;
// it is responsible for output_grads in the block of threads
// from this_block_start to this_block_start+N-1.
// __syncthreads(); // <- not really needed.
// At this point there is an implicit within-warp
// synchronization (Note: implicit warp synchronization is not considered
// future-proof). Threads above have written to n_buf, and threads below
// will read from it; but we don't need to explicitly synchronize for now
// because the reads/writes are among threads in a group of N threads with
// (4 <= N <= 16); and 16 is less than the warp size which is 32 or 64.
// src_indexes will contain up to 16 16-bit numbers, stored starting in its
// least significant bits. It will store all the offsets within this
// block of N threads, whose chosen 'n' value equals this_n.
uint64_t
src_indexes
=
0
;
// num_src is the number of numbers in `src_indexes`. We need to store a
// separate counter because zero is a valid index and if we are to support
// N == 16 we don't have bits to spare in src_indexes to store some kind
// of marker.
int
num_src
=
0
;
// This loop always does at least N statements, but they should be
// relatively fast ones since the computation per n value is minimal and
// there is little I/O. We are figuring out the subset of our block of N
// elements, which this particular thread value is responsible for
// (because they have n == this_n), and storing them in `src_indexes` and
// `num_src`.
for
(
int
i
=
0
;
i
<
N
;
i
+=
4
)
{
uint32_t
n_block_of_4
=
*
reinterpret_cast
<
uint32_t
*>
(
n_buf
+
this_block_start
+
i
);
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
// CUDA is little endian
char
n
=
(
char
)(
n_block_of_4
>>
(
8
*
j
));
if
(
n
==
this_n
)
{
// We require that N <= 16, so 4 bits is enough to store src_idx.
src_indexes
=
(
src_indexes
<<
4
)
|
(
i
+
j
);
++
num_src
;
}
// Note: if, for out-of-range threads, we had values not in [0..N-1] in
// n_buf they won't end up mattering even though they are read here,
// because they won't equal this_n. For values 0 <= n < N originating
// in out-of-range threads, the value won't matter because the
// corresponding value in output_grad_buf will be zero.
}
}
if
(
threadDim
.
x
<
4
&&
boundary
.
size
(
0
)
!=
0
)
boundary_buf
[
threadDim
.
x
]
=
boundary
[
b
][
threadDim
.
x
];
__syncthreads
();
// While num_src could theoretically be as large as N, the hope is that no
// thread in any given warp actually loops that many times. Once all
// threads in the warp are finished looping, we can continue. It is OK
// for different warps to get out of sync here; we could be looping over a
// number of images, and the hope is that different warps will reach the
// end of the outer loop at around the same time because their variations
// in speed will average out.
for
(;
num_src
>
0
;
--
num_src
,
(
src_indexes
>>=
4
))
{
int
src_thread
=
this_block_start
|
(
src_indexes
&
0xF
);
scalar_t
src_output_grad
=
output_grad_buf
[
src_thread
],
src_input
=
input_buf
[
src_thread
];
assert
(
n_buf
[
src_thread
]
==
this_n
);
n_buf
[
src_thread
]
=
0
;
// Backprop for: output = input * params_buf[n] + y_vals[n].
// Here, n == this_n; this is how we selected these `src_idx` values.
this_param_grad
+=
src_output_grad
*
src_input
;
this_y_vals_grad
+=
src_output_grad
;
}
int
s_begin
=
boundary_buf
[
0
],
t_begin
=
boundary_buf
[
1
],
s_end
=
boundary_buf
[
2
],
t_end
=
boundary_buf
[
3
];
s_block_begin
+=
s_begin
;
t_block_begin
+=
t_begin
;
// TODO: remove the next lines
assert
(
n_buf
[
threadIdx
.
x
]
==
0
||
(
unsigned
char
)
n_buf
[
threadIdx
.
x
]
==
255
);
output_grad_buf
[
threadIdx
.
x
]
=
0.0
;
}
}
// block_S and block_T are the actual sizes of this block, no greater than
// (BLOCK_SIZE, BLOCK_SIZE) but possibly less than that if we are towards
// the end of the sequence.
// The last element of the output matrix p we write is (s_end, t_end),
// i.e. the one-past-the-end index is (s_end + 1, t_end + 1).
int
block_S
=
min
(
BLOCK_SIZE
,
s_end
+
1
-
s_block_begin
),
block_T
=
min
(
BLOCK_SIZE
,
t_end
+
1
-
t_block_begin
);
__syncthreads
();
// sync threads because we are about to re-use
// output_grad_buf for reduction, and, later, input_buf.
if
(
block_S
<=
0
||
block_T
<=
0
)
continue
;
this_param_grad
=
strided_reduce_sum
(
N
,
output_grad_buf
,
this_param_grad
);
__syncthreads
();
this_y_vals_grad
=
strided_reduce_sum
(
N
,
output_grad_buf
,
this_y_vals_grad
);
// Load px_buf and py_buf. At this point they just contain px and py
// for this block.
for
(
int
i
=
threadDim
.
x
;
i
<
BLOCK_SIZE
*
BLOCK_SIZE
;
i
+=
blockDim
.
x
)
{
int
s_in_block
=
i
/
BLOCK_SIZE
,
t_in_block
=
i
%
BLOCK_SIZE
,
s
=
s_in_block
+
s_block_begin
,
t
=
t_in_block
+
t_block_begin
;
// We let ps and py default to -infinity if they are out of range, which will
// cause xderiv and yderiv for out-of-range values to be zero, and cause
// correct behavior in edge cases (for the top and right blocks).
// The issue is that p and p_grad are of larger size than px and py.
scalar_t
this_px
=
-
INFINITY
;
if
(
s
<
s_end
&&
t
<=
t_end
)
this_px
=
px
[
b
][
s
-
1
][
t
];
px_buf
[
s_in_block
][
t_in_block
]
=
this_px
;
scalar_t
this_py
=
-
INFINITY
;
if
(
s
<=
s_end
&&
t
<
t_end
)
this_py
=
py
[
b
][
s
][
t
-
1
];
py_buf
[
s_in_block
][
t_in_block
]
=
this_py
;
}
__syncthreads
();
// sync threads because we are about to re-use
// output_grad_buf as y_vals_grad_buf.
// Re-use some buffers..
scalar_t
*
params_grad_buf
=
input_buf
+
1
,
// [N] ... but element [-1] will have deriv of scale.
*
y_vals_grad_buf
=
output_grad_buf
;
// [N]
// load p. This time we loop over the exact indexes we need. Above
// we looped to BLOCK_SIZE * BLOCK_SIZE rather than block_S and block_T
// because having power-of-2 arrangement of threads may be helpful
// for aligned reads, but here the loop is up to (BLOCK_SIZE + 1) * (BLOCK_SIZE + 1)
// which is not a power of 2, so that is not a concern here.
for
(
int
i
=
threadDim
.
x
;
i
<
(
BLOCK_SIZE
+
1
)
*
(
BLOCK_SIZE
+
1
);
i
+=
blockDim
.
x
)
{
int
s_in_block
=
i
/
(
BLOCK_SIZE
+
1
),
// 0 <= s_in_block <= block_S
t_in_block
=
i
%
(
BLOCK_SIZE
+
1
),
// 0 <= t_in_block <= block_T
s
=
s_in_block
+
s_block_begin
,
t
=
t_in_block
+
t_block_begin
;
// Setting 0.0 for out-of-bounds elements, together with setting
// -INFINITY for out-of-bounds elements of px_buf and py_buf, will
// ensure that we do the right thing in top and right edge cases,
// i.e. that no derivatives will be propagated from out-of-bounds points.
p_buf
[
s_in_block
][
t_in_block
]
=
(
s
<=
s_end
&&
t
<=
t_end
?
p
[
b
][
s
][
t
]
:
0.0
);
}
if
(
threadIdx
.
x
<
N
)
{
params_grad_buf
[
threadIdx
.
x
]
=
this_param_grad
;
y_vals_grad_buf
[
threadIdx
.
x
]
=
this_y_vals_grad
;
}
__syncthreads
();
// other threads are about to read params_grad_buf and
// y_vals_grad_buf.
// Set xderiv and yderiv; see (eq. 4) and (eq. 5).
for
(
int
i
=
threadDim
.
x
;
i
<
BLOCK_SIZE
*
BLOCK_SIZE
;
i
+=
blockDim
.
x
)
{
// We can apply this formula to the entire block even if we are processing
// a partial block; elements outside the partial block will not be used so
// their values don't matter, and elements just out
int
t
=
i
%
BLOCK_SIZE
,
s
=
i
/
BLOCK_SIZE
;
// Mathematically the following is doing:
// xderiv[b][s][t] := exp(p[b][s][t] + px[b][s][t] - p[b][s + 1][t])
// (with an offset on the s and t indexes)
px_buf
[
s
][
t
]
=
exp
(
px_buf
[
s
][
t
]
+
px_buf
[
s
][
t
]
-
p_buf
[
s
+
1
][
t
]);
// Mathematically the following is doing:
// yderiv[b][s][t] := exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1])
// (with an offset on the s and t indexes)
py_buf
[
s
][
t
]
=
exp
(
px_buf
[
s
][
t
]
+
py_buf
[
s
][
t
]
-
p_buf
[
s
][
t
+
1
]);
}
// This next block does backprop relating to `y_vals`. Comparing with the CPU
// version (call this the "reference code") is the best way to understand this
// (this code is just a modification of that). The main difference is we
// modify the indexes into params and params_grad by -1, so the index
// corresponds to the 'n' value; and element -1 of params_grad_buf will have
// the deriv of the log scale.
// Load p_grad for the top and right elements in p_buf: i.e. for elements
// p_buf[s][t] where s == block_S (exclusive-or) t == block_T. We don't
// need to load the top-right corner [block_S][block_T]; that location will
// never be accessed.
// These are the p_grad values computed by previous instances of this kernel
// If this is one of the top or right blocks, some or all of the p_grad
// values we'd be reading here will be out of range, and we use zeros.
if
(
threadIdx
.
x
<
block_S
)
{
int
s_in_block
=
threadIdx
.
x
,
t_in_block
=
block_T
,
s
=
s_in_block
+
s_block_begin
,
t
=
t_in_block
+
t_block_begin
;
p_buf
[
s_in_block
][
t_in_block
]
=
(
s
<=
s_end
&&
t
<=
t_end
?
p_grad
[
s
][
t
]
:
0.0
);
}
else
if
(
static_cast
<
unsigned
int
>
(
threadIdx
.
x
-
64
)
<
static_cast
<
unsigned
int
>
(
block_T
))
{
int
s_in_block
=
block_S
,
t_in_block
=
threadIdx
.
x
-
64
,
s
=
s_in_block
+
s_block_begin
,
t
=
t_in_block
+
t_block_begin
;
p_buf
[
s_in_block
][
t_in_block
]
=
(
s
<=
s_end
&&
t
<=
t_end
?
p_grad
[
s
][
t
]
:
0.0
);
}
scalar_t
l_grad
;
if
(
threadIdx
.
x
==
0
)
{
// Now do the backprop for the loop above where we set y_vals_a. This could
// be further optimized to replace the loop with a raking, but I doubt this
// will have a huge effect on the runtime since K will be fairly small,
// e.g. 4.
scalar_t
scale
=
params_buf
[
-
2
],
scale_grad
=
0.0
,
sum_positive_grad
=
0.0
;
for
(
int
i
=
K
-
1
;
i
>=
0
;
i
--
)
{
// Backprop for: sum_positive += pos_scaled_param;
scalar_t
pos_scaled_param_grad
=
sum_positive_grad
;
// Backprop for: y_vals[K + i] = sum_positive - pos_scaled_param * i;
scalar_t
y_grad_pos
=
y_vals_grad_buf
[
K
+
i
];
pos_scaled_param_grad
-=
i
*
y_grad_pos
;
sum_positive_grad
+=
y_grad_pos
;
// Backprop for: pos_scaled_param = params_buf[K + i] * scale,
params_grad_buf
[
K
+
i
]
+=
pos_scaled_param_grad
*
scale
;
scale_grad
+=
pos_scaled_param_grad
*
params_buf
[
K
+
i
];
// The number of inner iterations, i.e. iterations inside this
// kernel, is this_num_inner_iters. The highest iteration,
// corresponding to the highest-indexed value of p_buf that
// we need to set,
// corresponds to p_buf[block_S - 1][block_T - 1],
// and the iteration number is the sum of these indexes, i.e.
// (block_S - 1) + (block_T - 1).
bool
is_final_block
=
(
s_block_begin
+
block_S
==
s_end
+
1
&&
t_block_begin
+
block_T
==
t_end
+
1
);
int
first_iter
=
block_S
+
block_T
-
2
;
if
(
is_final_block
)
{
// The following statement, mathematically, corresponds to:
// p_grad[b][s_end][t_end] = ans_grad[b] Normally this element of p_buf
// would be set by the first iteration of the loop below, so if it's set
// this way we have to decrement first_iter to prevent it being
// overwritten.
p_buf
[
block_S
-
1
][
block_T
-
1
]
=
ans_grad
[
b
];
--
first_iter
;
}
// Backprop for: scale = exp(l), where l = params[c][0].
l_grad
=
scale
*
scale_grad
;
}
else
if
(
threadIdx
.
x
==
64
)
{
// Now do the backprop for the loop above where we set y_vals.
// Make this one threadIdx.x == 0 so it's possibly quicker to test
//
scalar_t
scale
=
params_buf
[
-
2
],
scale_grad
=
0.0
,
sum_negative_grad
=
0.0
;
for
(
int
i
=
K
-
1
;
i
>=
0
;
i
--
)
{
// Backprop for: y_vals[K - i - 1] = sum_negative + neg_scaled_param * (i + 1):
scalar_t
y_grad_neg
=
y_vals_grad_buf
[
K
-
i
-
1
];
sum_negative_grad
+=
y_grad_neg
;
scalar_t
neg_scaled_param_grad
=
y_grad_neg
*
(
i
+
1
);
// Backprop for: sum_negative -= neg_scaled_param;
neg_scaled_param_grad
-=
sum_negative_grad
;
// Backprop for: neg_scaled_param = params_buf[K - i - 1] * scale;
params_grad_buf
[
K
-
i
-
1
]
+=
neg_scaled_param_grad
*
scale
;
scale_grad
+=
neg_scaled_param_grad
*
params_buf
[
K
-
i
-
1
];
for
(
int
i
=
first_iter
;
i
>=
0
;
--
i
)
{
int
s
=
i
,
t
=
i
-
threadIdx
.
x
;
if
(
t
>=
0
)
{
// The following statement is really operating on the gradients;
// it corresponds to (eq. 6) defined above, i.e.:
// p_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] +
// p_grad[b][s][t + 1] * yderiv[b][s][t]
p_buf
[
s
][
t
]
=
(
p_buf
[
s
+
1
][
t
]
*
px_buf
[
s
][
t
]
+
p_buf
[
s
][
t
+
1
]
*
py_buf
[
s
][
t
]);
}
}
params_grad_buf
[
-
1
]
=
scale
*
scale_grad
;
}
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
params_grad_buf
[
-
1
]
+=
l_grad
;
// contribution to l grad from the "negative" branch
}
__syncthreads
();
if
(
threadIdx
.
x
<=
N
)
{
params_grad
[
blockIdx
.
y
][
c
][
threadIdx
.
x
]
=
params_grad_buf
[
threadIdx
.
x
-
1
];
// Write out p_grad, px_grad and py_grad.
for
(
int
i
=
threadDim
.
x
;
i
<
BLOCK_SIZE
*
BLOCK_SIZE
;
i
+=
blockDim
.
x
)
{
int
t_in_block
=
i
%
BLOCK_SIZE
,
s_in_block
=
i
/
BLOCK_SIZE
,
s
=
s_in_block
+
s_block_begin
,
t
=
t_in_block
+
t_block_begin
;
if
(
t
<=
t_end
&&
s
<=
s_end
)
{
p_grad
[
b
][
s
][
t
]
=
p_buf
[
s_in_block
][
t_in_block
];
if
(
s
<
s_end
)
{
// write px_grad, which is of shape [B][S][T + 1]
// From (eq. 7):
// px_grad[b][s][t] = p_grad[b][s + 1][t] * yderiv[b][s][t]
px_grad
[
b
][
s
][
t
]
=
(
p_buf
[
s_in_block
+
1
][
t_in_block
]
*
px_buf
[
s_in_block
][
t_in_block
]);
}
if
(
t
<
t_end
)
{
// write py_grad, which is of shape [B][S + 1][T]
// from (eq. 8):
// py_grad[b][s][t] = p_grad[b][s][t + 1] * yderiv[b][s][t]
py_grad
[
b
][
s
][
t
]
=
(
p_buf
[
s_in_block
][
t_in_block
+
1
]
*
py_buf
[
s_in_block
][
t_in_block
]);
}
}
}
if
(
threadIdx
.
x
==
0
&&
s_block_begin
==
s_begin
&&
t_block_end
==
t_end
)
ans_grad
[
b
]
=
p_buf
[
0
][
0
];
}
}
// forward of mutual_information. See """... """ comment of `mutual_information` in
// mutual_information.py for documentation of the behavior of this function.
torch
::
Tensor
mutual_information_cuda
(
torch
::
Tensor
px
,
...
...
@@ -861,18 +746,19 @@ torch::Tensor mutual_information_cuda(torch::Tensor px,
torch
::
Tensor
ans
=
torch
::
empty
({
B
},
opts
);
int
num_threads
=
128
,
num_blocks
=
128
;
num_blocks
=
128
,
BLOCK_SIZE
=
32
;
const
int
num_s_blocks
=
S
/
BLOCK_SIZE
+
1
,
num_t_blocks
=
T
/
BLOCK_SIZE
+
1
,
num_iters
=
std
::
max
<
int
>
(
num_s_blocks
,
num_t_blocks
)
;
num_iters
=
num_s_blocks
+
num_t_blocks
-
1
;
bool
has_boundary
=
(
bool
)
optional_boundary
;
if
(
!
has_boundary
)
optional_boundary
=
torch
::
empty
({
0
,
0
},
long_opts
);
for
(
int
iter
=
0
;
iter
<
num_iters
;
iter
++
)
{
mutual_information_kernel
<
scalar_t
,
32
><<<
num_blocks
,
num_threads
>>>
(
for
(
int
iter
=
0
;
iter
<
num_iters
;
++
iter
)
{
mutual_information_kernel
<
scalar_t
,
BLOCK_SIZE
><<<
num_blocks
,
num_threads
>>>
(
px
.
packed_accessor32
<
scalar_t
,
3
>
(),
py
.
packed_accessor32
<
scalar_t
,
3
>
(),
p
.
packed_accessor32
<
scalar_t
,
3
>
(),
...
...
@@ -880,141 +766,64 @@ torch::Tensor mutual_information_cuda(torch::Tensor px,
ans
.
packed_accessor32
<
scalar_t
,
1
>
(),
iter
);
}
int
grid_dim_y
=
1
;
// If the number of channels is quite small (<128) we can launch more thread
// groups, splitting on the batch index.
while
(
C
*
grid_dim_y
<
128
)
grid_dim_y
*=
2
;
// B_reduced is the max number of thread-groups per channel that would have
// any work to do. If grid_dim_y is more than this, we reduce it to avoid
// launching kernels with nothing to do.
int
B_reduced
=
(
B
+
images_per_thread_block
-
1
)
/
images_per_thread_block
;
if
(
grid_dim_y
>
B_reduced
)
grid_dim_y
=
B_reduced
;
int
shared_mem_numel
=
2
*
N
+
3
;
if
(
false
)
std
::
cout
<<
"C,B,T,N = "
<<
C
<<
","
<<
B
<<
","
<<
T
<<
","
<<
N
<<
", images_per_thread_block = "
<<
images_per_thread_block
<<
", grid_dim_y = "
<<
grid_dim_y
<<
"
\n
"
;
TORCH_CHECK
(
THREADS_PER_BLOCK
/
images_per_thread_block
>=
T
||
images_per_thread_block
==
1
,
"Code error"
);
TORCH_CHECK
(
N
+
1
<=
THREADS_PER_BLOCK
,
"Values of N this large are not supported."
);
dim3
gridDim
(
C
,
grid_dim_y
,
1
);
// blockDim is scalar, just THREADS_PER_BLOCK.
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"mutual_information_kernel"
,
([
&
]
{
mutual_information_kernel
<
scalar_t
><<<
gridDim
,
THREADS_PER_BLOCK
,
sizeof
(
scalar_t
)
*
shared_mem_numel
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
input
.
packed_accessor32
<
scalar_t
,
3
>
(),
params
.
packed_accessor32
<
scalar_t
,
2
>
(),
output
.
packed_accessor32
<
scalar_t
,
3
>
(),
images_per_thread_block
);
}));
return
output
;
return
ans
;
}
std
::
vector
<
torch
::
Tensor
>
mutual_information_backward_cuda
(
torch
::
Tensor
input
,
torch
::
Tensor
params
,
torch
::
Tensor
output_grad
)
{
TORCH_CHECK
(
input
.
dim
()
==
3
,
"input must be 3-dimensional"
);
TORCH_CHECK
(
params
.
dim
()
==
2
,
"params must be 2-dimensional."
);
TORCH_CHECK
(
params
.
size
(
1
)
>=
3
&&
((
params
.
size
(
1
)
-
1
)
&
(
params
.
size
(
1
)
-
2
))
==
0
,
"params.size(1) has invalid value, must be a power of 2 plus 1."
);
TORCH_CHECK
(
params
.
size
(
0
)
==
input
.
size
(
1
),
"params vs input channels mismatch"
);
TORCH_CHECK
(
output_grad
.
dim
()
==
3
&&
output_grad
.
size
(
0
)
==
input
.
size
(
0
)
&&
output_grad
.
size
(
1
)
==
input
.
size
(
1
)
&&
output_grad
.
size
(
2
)
==
input
.
size
(
2
),
"output_grad and input have mismatched dim."
);
TORCH_CHECK
(
input
.
device
().
is_cuda
(),
"Input must be a CUDA tensor"
);
TORCH_CHECK
(
output_grad
.
device
().
is_cuda
(),
"output_grad must be a CUDA tensor"
);
TORCH_CHECK
(
params
.
device
().
is_cuda
(),
"Params must be a CUDA tensor"
);
const
int
B
=
input
.
size
(
0
),
C
=
input
.
size
(
1
),
T
=
input
.
size
(
2
),
N
=
params
.
size
(
1
)
-
1
;
TORCH_CHECK
(
N
>=
4
,
"This backward code requires N >= 4"
);
TORCH_CHECK
(
N
<=
16
,
"This backward code currently requires N <= 16"
);
TORCH_CHECK
((
N
&
(
N
-
1
))
==
0
,
"N must be a power of 2"
)
auto
scalar_t
=
input
.
scalar_type
();
auto
opts
=
torch
::
TensorOptions
().
dtype
(
scalar_t
).
device
(
input
.
device
());
torch
::
Tensor
input_grad
=
torch
::
empty
({
B
,
C
,
T
},
opts
);
if
(
C
*
B
*
T
==
0
)
{
return
std
::
vector
<
torch
::
Tensor
>
({
input_grad
,
torch
::
empty
({
C
,
N
+
1
})});
}
int
images_per_thread_block
=
1
;
while
(
images_per_thread_block
*
2
*
T
<=
THREADS_PER_BLOCK
&&
images_per_thread_block
*
2
*
N
<=
THREADS_PER_BLOCK
)
images_per_thread_block
*=
2
;
int
grid_dim_y
=
1
;
// If the number of channels is quite small (<128) we can launch more thread
// groups, splitting on the batch index.
while
(
C
*
grid_dim_y
<
128
)
grid_dim_y
*=
2
;
// B_reduced is the max number of thread-groups per channel that would have
// any work to do. If grid_dim_y is more than this, we reduce it to avoid
// launching kernels with nothing to do.
int
B_reduced
=
(
B
+
images_per_thread_block
-
1
)
/
images_per_thread_block
;
if
(
grid_dim_y
>
B_reduced
)
grid_dim_y
=
B_reduced
;
// backward of mutual_information; returns (grad_px, grad_py)
torch
::
Tensor
mutual_information_backward_cuda
(
torch
::
Tensor
px
,
torch
::
Tensor
py
,
std
::
optional
<
torch
::
Tensor
>
optional_boundary
,
torch
::
Tensor
p
,
torch
::
Tensor
ans_grad
)
{
TORCH_CHECK
(
px
.
dim
()
==
3
,
"px must be 3-dimensional"
);
TORCH_CHECK
(
py
.
dim
()
==
3
,
"py must be 3-dimensional."
);
TORCH_CHECK
(
p
.
dim
()
==
3
,
"p must be 3-dimensional."
);
TORCH_CHECK
(
ans_grad
.
dim
()
==
1
,
"ans_grad must be 1-dimensional."
);
int
shared_mem_numel
=
2
*
N
+
3
;
TORCH_CHECK
(
px
.
device
().
is_cuda
()
&&
py
.
device
().
is_cuda
()
&&
p
.
device
().
is_cuda
()
&&
ans_grad
.
device
().
is_cuda
()
&&
"inputs must be CUDA tensors"
);
auto
scalar_t
=
px
.
scalar_type
();
auto
opts
=
torch
::
TensorOptions
().
dtype
(
scalar_t
).
device
(
px
.
device
());
if
(
false
)
std
::
cout
<<
"C,B,T,N = "
<<
C
<<
","
<<
B
<<
","
<<
T
<<
","
<<
N
<<
", images_per_thread_block = "
<<
images_per_thread_block
<<
", grid_dim_y = "
<<
grid_dim_y
<<
"
\n
"
;
const
int
B
=
px
.
size
(
0
),
S
=
px
.
size
(
1
),
T
=
px
.
size
(
2
)
-
1
;
TORCH_CHECK
(
THREADS_PER_BLOCK
/
images_per_thread_block
>=
T
||
images_per_thread_block
==
1
,
"Code error"
);
TORCH_CHECK
(
py
.
size
(
0
)
==
B
&&
py
.
size
(
1
)
==
S
+
1
&&
py
.
size
(
2
)
==
T
);
TORCH_CHECK
(
p
.
size
(
0
)
==
B
&&
p
.
size
(
1
)
==
S
+
1
&&
p
.
size
(
2
)
==
T
+
1
);
TORCH_CHECK
(
THREADS_PER_BLOCK
/
images_per_thread_block
>=
N
);
torch
::
Tensor
p_grad
=
torch
::
empty
({
B
,
S
+
1
,
T
+
1
},
opts
),
px_grad
=
torch
::
empty
({
B
,
S
,
T
+
1
},
opts
),
py_grad
=
torch
::
empty
({
B
,
S
+
1
,
T
},
opts
),
torch
::
Tensor
params_grad
=
torch
::
zeros
({
grid_dim_y
,
C
,
N
+
1
},
opts
);
const
int
num_threads
=
128
,
num_blocks
=
128
,
BLOCK_SIZE
=
32
;
dim3
gridDim
(
C
,
grid_dim_y
,
1
);
const
int
num_s_blocks
=
S
/
BLOCK_SIZE
+
1
,
num_t_blocks
=
T
/
BLOCK_SIZE
+
1
,
num_iters
=
num_s_blocks
+
num_t_blocks
-
1
;
// blockDim is scalar, just THREADS_PER_BLOCK.
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"mutual_information_backward_kernel"
,
([
&
]
{
mutual_information_backward_kernel
<
scalar_t
><<<
gridDim
,
THREADS_PER_BLOCK
,
sizeof
(
scalar_t
)
*
shared_mem_numel
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
input
.
packed_accessor32
<
scalar_t
,
3
>
(),
params
.
packed_accessor32
<
scalar_t
,
2
>
(),
output_grad
.
packed_accessor32
<
scalar_t
,
3
>
(),
input_grad
.
packed_accessor32
<
scalar_t
,
3
>
(),
params_grad
.
packed_accessor32
<
scalar_t
,
3
>
(),
images_per_thread_block
);
}));
bool
has_boundary
=
(
bool
)
optional_boundary
;
if
(
!
has_boundary
)
optional_boundary
=
torch
::
empty
({
0
,
0
},
long_opts
);
params_grad
=
at
::
sum
(
params_grad
,
{
0
});
return
std
::
vector
<
torch
::
Tensor
>
({
input_grad
,
params_grad
});
for
(
int
iter
=
num_iters
-
1
;
iter
>=
0
;
--
iter
)
{
mutual_information_backward_kernel
<
scalar_t
,
BLOCK_SIZE
><<<
num_blocks
,
num_threads
>>>
(
px
.
packed_accessor32
<
scalar_t
,
3
>
(),
py
.
packed_accessor32
<
scalar_t
,
3
>
(),
p
.
packed_accessor32
<
scalar_t
,
3
>
(),
ans_grad
.
packed_accessor32
<
scalar_t
,
1
>
,
p_grad
.
packed_accessor32
<
scalar_t
,
3
>
(),
px_grad
.
packed_accessor32
<
scalar_t
,
3
>
(),
py_grad
.
packed_accessor32
<
scalar_t
,
3
>
(),
optional_boundary
.
value
().
packed_accessor32
<
int64_t
,
2
>
(),
iter
);
}
return
std
::
vector
<
torch
::
Tensor
>
({
px_grad
,
py_grad
});
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment