Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
3c1ec347
Commit
3c1ec347
authored
Jul 29, 2021
by
Daniel Povey
Browse files
Get it to a stage where it looks like it might compile
parent
8ed6deff
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
277 additions
and
331 deletions
+277
-331
torch_mutual_information/mutual_information.py
torch_mutual_information/mutual_information.py
+33
-27
torch_mutual_information/mutual_information_cpu.cpp
torch_mutual_information/mutual_information_cpu.cpp
+16
-113
torch_mutual_information/mutual_information_cuda_kernel.cu
torch_mutual_information/mutual_information_cuda_kernel.cu
+228
-191
No files found.
torch_mutual_information/mutual_information.py
View file @
3c1ec347
...
...
@@ -44,66 +44,72 @@ except ImportError:
def
_mutual_information_forward_dispatcher
(
px
:
torch
.
Tensor
,
py
:
torch
.
Tensor
,
boundaries
:
torch
.
Tensor
,
q
:
torch
.
Tensor
)
->
torch
.
Tensor
:
boundaries
:
torch
.
Tensor
,
p
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
input
.
is_cuda
:
if
torch_mutual_information_cuda
is
None
:
raise
EnvironmentError
(
f
'Failed to load native CUDA module'
)
return
torch_mutual_information_cuda
.
mutual_information_cuda
(
px
,
py
,
boundaries
,
q
)
px
,
py
,
boundaries
,
p
)
else
:
return
torch_mutual_information_cpu
.
mutual_information_cpu
(
px
,
py
,
boundaries
,
q
)
px
,
py
,
boundaries
,
p
)
def
_mutual_information_backward_dispatcher
(
px
:
torch
.
Tensor
,
py
:
torch
.
Tensor
,
boundaries
:
torch
.
Tensor
,
q
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
boundaries
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
ans_grad
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
px
.
is_cuda
:
if
torch_mutual_information_cuda
is
None
:
raise
EnvironmentError
(
f
'Failed to load native CUDA module'
)
return
tuple
(
torch_mutual_information_cuda
.
mutual_information_backward_cuda
(
px
,
py
,
boundaries
,
q
))
overwrite_ans_grad
=
True
if
overwrite_ans_grad
:
ans_grad_copy
=
ans_grad
.
clone
()
ans
=
tuple
(
torch_mutual_information_cuda
.
mutual_information_backward_cuda
(
px
,
py
,
boundaries
,
p
,
ans_grad_copy
,
overwrite_ans_grad
))
if
overwrite_ans_grad
:
if
not
torch
.
allclose
(
ans_grad
,
ans_grad_copy
,
rtol
=
1.0e-02
):
print
(
f
"Warning: possible excsssive roundoff in mutual information backward "
"recursion: {ans_grad} vs. {ans_grad_copy}"
);
return
ans
else
:
return
tuple
(
torch_mutual_information_cpu
.
mutual_information_backward_cpu
(
px
,
py
,
boundaries
,
q
))
px
,
py
,
boundaries
,
p
,
ans_grad
))
class
MutualInformationRecursionFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
px
:
torch
.
Tensor
,
py
:
torch
.
Tensor
,
boundaries
:
torch
.
Tensor
)
->
torch
.
Tensor
:
(
B
,
S
,
T
)
=
px
.
shape
def
forward
(
ctx
,
px
:
torch
.
Tensor
,
py
:
torch
.
Tensor
,
boundaries
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
(
B
,
S
,
T1
)
=
px
.
shape
T
=
T1
-
1
;
assert
py
.
shape
==
(
B
,
S
+
1
,
T
)
if
boundaries
is
not
None
:
assert
boundaries
.
shape
==
(
B
,
4
)
# p is a tensor of shape (B, S + 1, T + 1) were p[s][t] is
related to
# p is a tensor of shape (B, S + 1, T + 1) were p[s][t] is
the
# the mutual information of the pair of subsequences of x and y that are of
# length s and t respectively. p[0][0] will be 0.0 and p[S][T] is
# the mutual information of the entire pair of sequences, i.e. of lengths
# S and T respectively.
# q is a rearrangement of a tensor p which is of shape (B,S,T),
# using p[b,s,t] == q[b,s+t,t]. The reason for working with this
# representation is that each row of q depends only on the previous row,
# so we can access the rows sequenctially and this leads to
# better memory access patterns. We are assuming that most likely
# T < S, which means that q should not require much more memory than p.
#
# Actually we access q beginning from 0 indexes even if `boundaries`
# has t_begin > 0 or s_begin > 0, i.e. we really access q as
# q[b, s-s_begin + t-t_begin, t-t_begin];
# note, rows of `boundaries` are [s_begin, t_begin, s_end, t_end].
# It is computed as follows (in C++ and CUDA):
# p[b,0,0] = 0.0
# p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
# p[b,s,t-1] + py[b,s,t-1])
# if s > 0 or t > 0,
# treating values with any -1 index as -infinity.
# .. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
p
=
torch
.
empty
(
B
,
S
+
1
,
T
+
1
,
device
=
px
.
device
,
dtype
=
px
.
dtype
)
ans
=
_mutual_information_forward_dispatcher
(
px
,
py
,
boundaries
,
p
)
if
px
.
requires_grad
or
py
.
requires_grad
:
ctx
.
save_for_backward
(
px
,
py
,
boundaries
,
q
)
ctx
.
save_for_backward
(
px
,
py
,
boundaries
,
p
)
@
staticmethod
def
backward
(
ctx
,
ans_grad
:
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
None
]:
(
px
,
py
,
boundaries
,
q
)
=
ctx
.
saved_tensors
(
px_grad
,
py_grad
)
=
_mutual_information_backward_dispatcher
(
px
,
py
,
boundaries
,
q
)
(
px
,
py
,
boundaries
,
p
)
=
ctx
.
saved_tensors
(
px_grad
,
py_grad
)
=
_mutual_information_backward_dispatcher
(
px
,
py
,
boundaries
,
p
,
ans_grad
)
return
(
px_grad
,
py_grad
,
None
)
...
...
torch_mutual_information/mutual_information_cpu.cpp
View file @
3c1ec347
...
...
@@ -151,11 +151,16 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
TORCH_CHECK
(
py
.
size
(
0
)
==
B
&&
py
.
size
(
1
)
==
S
+
1
&&
py
.
size
(
2
)
==
T
);
TORCH_CHECK
(
p
.
size
(
0
)
==
B
&&
p
.
size
(
1
)
==
S
+
1
&&
p
.
size
(
2
)
==
T
+
1
);
torch
::
Tensor
p_grad
=
torch
::
zeros
({
B
,
S
+
1
,
T
+
1
},
opts
);
bool
has_boundary
=
(
bool
)
optional_boundary
;
torch
::
Tensor
p_grad
=
torch
::
zeros
({
B
,
S
+
1
,
T
+
1
},
opts
),
px_grad
=
(
has_boundary
?
torch
::
zeros
({
B
,
S
,
T
+
1
},
opts
)
:
torch
::
empty
({
B
,
S
,
T
+
1
},
opts
)),
py_grad
=
(
has_boundary
?
torch
::
zeros
({
B
,
S
+
1
,
T
},
opts
)
:
torch
::
empty
({
B
,
S
+
1
,
T
},
opts
));
auto
long_opts
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
px
.
device
());
bool
has_boundary
=
(
bool
)
optional_boundary
;
if
(
!
has_boundary
)
optional_boundary
=
torch
::
empty
({
0
,
0
},
long_opts
);
...
...
@@ -166,7 +171,9 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
auto
px_a
=
px
.
packed_accessor32
<
scalar_t
,
3
>
(),
py_a
=
py
.
packed_accessor32
<
scalar_t
,
3
>
(),
p_a
=
p
.
packed_accessor32
<
scalar_t
,
3
>
(),
p_grad_a
=
p
.
packed_accessor32
<
scalar_t
,
3
>
();
p_grad_a
=
p_grad
.
packed_accessor32
<
scalar_t
,
3
>
(),
px_grad_a
=
px_grad
.
packed_accessor32
<
scalar_t
,
3
>
(),
py_grad_a
=
py_grad
.
packed_accessor32
<
scalar_t
,
3
>
();
auto
ans_grad_a
=
ans_grad
.
packed_accessor32
<
scalar_t
,
1
>
();
...
...
@@ -196,19 +203,17 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
// .. which obtains p_a[b][s][t - 1] from a register.
scalar_t
term1
=
p_a
[
b
][
s
-
1
][
t
]
+
px_a
[
b
][
s
-
1
][
t
],
// term2 = p_a[b][s][t - 1] + py_a[b][s][t - 1], <-- not
// actually needed..
total
=
p_a
[
b
][
s
][
t
],
term1_deriv
=
exp
(
term1
-
total
),
term2_deriv
=
1.0
-
term1_deriv
,
grad
=
p_grad_a
[
b
][
s
][
t
],
term1_grad
=
term1_deriv
*
grad
,
term2_grad
=
term2_deriv
*
grad
;
// We can assign to px_grad_a here rather than add, because we
// know it's currently zero.
TORCH_CHECK
(
px_grad_a
[
b
][
s
-
1
][
t
]
==
0
);
px_grad_a
[
b
][
s
-
1
][
t
]
=
term1_grad
;
TORCH_CHECK
(
p_grad_a
[
b
][
s
-
1
][
t
]
==
0.0
);
// likewise..
p_grad_a
[
b
][
s
-
1
][
t
]
=
term1_grad
py_grad_a
[
b
][
s
][
t
-
1
]
+=
term2_grad
;
p_grad_a
[
b
][
s
-
1
][
t
]
=
term1_grad
;
py_grad_a
[
b
][
s
][
t
-
1
]
=
term2_grad
;
p_grad_a
[
b
][
s
][
t
-
1
]
+=
term2_grad
;
}
}
...
...
@@ -239,111 +244,9 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
}
}
}));
return
ans
;
}
TORCH_CHECK
(
input
.
dim
()
==
3
,
"input must be 3-dimensional"
);
TORCH_CHECK
(
params
.
dim
()
==
2
,
"params must be 2-dimensional."
);
TORCH_CHECK
(
params
.
size
(
1
)
>=
3
&&
((
params
.
size
(
1
)
-
1
)
&
(
params
.
size
(
1
)
-
2
))
==
0
,
"params.size(1) has invalid value, must be a power of 2 plus 1."
);
TORCH_CHECK
(
params
.
size
(
0
)
==
input
.
size
(
1
),
"params vs input channels mismatch"
);
TORCH_CHECK
(
input
.
sizes
()
==
output_grad
.
sizes
(),
"Output-grad vs. input sizes mismatch."
);
TORCH_CHECK
(
input
.
device
().
is_cpu
(),
"Input must be a CPU tensor"
);
TORCH_CHECK
(
params
.
device
().
is_cpu
(),
"Params must be a CPU tensor"
);
TORCH_CHECK
(
output_grad
.
device
().
is_cpu
(),
"Output-grad must be a CPU tensor"
);
const
int
B
=
input
.
size
(
0
),
C
=
input
.
size
(
1
),
T
=
input
.
size
(
2
),
N
=
params
.
size
(
1
)
-
1
,
K
=
N
/
2
;
auto
scalar_t
=
input
.
scalar_type
();
auto
opts
=
torch
::
TensorOptions
().
dtype
(
scalar_t
).
device
(
input
.
device
());
torch
::
Tensor
y_vals
=
torch
::
empty
({
C
,
N
},
opts
),
y_vals_grad
=
torch
::
zeros
({
C
,
N
},
opts
),
params_grad
=
torch
::
zeros
({
C
,
N
+
1
},
opts
),
input_grad
=
torch
::
zeros
({
B
,
C
,
T
},
opts
);
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"mutual_information_backward_cpu_loop"
,
([
&
]
{
auto
params_a
=
params
.
accessor
<
scalar_t
,
2
>
(),
params_grad_a
=
params_grad
.
accessor
<
scalar_t
,
2
>
(),
y_vals_a
=
y_vals
.
accessor
<
scalar_t
,
2
>
(),
y_vals_grad_a
=
y_vals_grad
.
accessor
<
scalar_t
,
2
>
();
for
(
int
c
=
0
;
c
<
C
;
c
++
)
{
scalar_t
sum_negative
=
0.0
,
sum_positive
=
0.0
,
scale
=
exp
(
params_a
[
c
][
0
]);
for
(
int
i
=
0
;
i
<
K
;
i
++
)
{
scalar_t
pos_scaled_param
=
params_a
[
c
][
1
+
K
+
i
]
*
scale
,
neg_scaled_param
=
params_a
[
c
][
K
-
i
]
*
scale
;
y_vals_a
[
c
][
K
+
i
]
=
sum_positive
-
pos_scaled_param
*
i
;
sum_positive
+=
pos_scaled_param
;
sum_negative
-=
neg_scaled_param
;
y_vals_a
[
c
][
K
-
i
-
1
]
=
sum_negative
+
neg_scaled_param
*
(
i
+
1
);
}
}
auto
input_a
=
input
.
accessor
<
scalar_t
,
3
>
(),
output_grad_a
=
output_grad
.
accessor
<
scalar_t
,
3
>
(),
input_grad_a
=
input_grad
.
accessor
<
scalar_t
,
3
>
();
for
(
int
b
=
0
;
b
<
B
;
b
++
)
{
for
(
int
c
=
0
;
c
<
C
;
c
++
)
{
scalar_t
inv_scale
=
exp
(
-
params_a
[
c
][
0
]);
for
(
int
t
=
0
;
t
<
T
;
t
++
)
{
scalar_t
input
=
input_a
[
b
][
c
][
t
],
x
=
input
*
inv_scale
+
K
,
output_grad
=
output_grad_a
[
b
][
c
][
t
];
if
(
x
<
0
)
x
=
0
;
else
if
(
x
>=
N
)
x
=
N
-
1
;
// C++ rounds toward zero.
int
n
=
(
int
)
x
;
// OK, at this point, 0 <= n < 2*K.
// backprop for:
// output_a[b][c][t] = input * params_a[c][n + 1] + y_vals_a[c][n];
params_grad_a
[
c
][
n
+
1
]
+=
output_grad
*
input
;
y_vals_grad_a
[
c
][
n
]
+=
output_grad
;
input_grad_a
[
b
][
c
][
t
]
=
output_grad
*
params_a
[
c
][
n
+
1
];
}
}
}
// Now do the backprop for the loop above where we set y_vals_a.
for
(
int
c
=
0
;
c
<
C
;
c
++
)
{
scalar_t
scale
=
exp
(
params_a
[
c
][
0
]),
scale_grad
=
0.0
,
sum_negative_grad
=
0.0
,
sum_positive_grad
=
0.0
;
for
(
int
i
=
K
-
1
;
i
>=
0
;
i
--
)
{
// Backprop for: y_vals_a[c][K - i - 1] = sum_negative + neg_scaled_param * (i + 1):
scalar_t
y_grad_neg
=
y_vals_grad_a
[
c
][
K
-
i
-
1
];
sum_negative_grad
+=
y_grad_neg
;
scalar_t
neg_scaled_param_grad
=
y_grad_neg
*
(
i
+
1
);
// Backprop for: sum_negative -= neg_scaled_param;
neg_scaled_param_grad
-=
sum_negative_grad
;
// Backprop for: sum_positive += pos_scaled_param;
scalar_t
pos_scaled_param_grad
=
sum_positive_grad
;
// Backprop for: y_vals_a[c][K + i] = sum_positive - pos_scaled_param * i;
scalar_t
y_grad_pos
=
y_vals_grad_a
[
c
][
K
+
i
];
pos_scaled_param_grad
-=
i
*
y_grad_pos
;
sum_positive_grad
+=
y_grad_pos
;
// Backprop for: pos_scaled_param = params_a[c][1 + K + i] * scale,
// and: neg_scaled_param = params_a[c][K - i] * scale;
params_grad_a
[
c
][
1
+
K
+
i
]
+=
pos_scaled_param_grad
*
scale
;
params_grad_a
[
c
][
K
-
i
]
+=
neg_scaled_param_grad
*
scale
;
scale_grad
+=
(
pos_scaled_param_grad
*
params_a
[
c
][
1
+
K
+
i
]
+
neg_scaled_param_grad
*
params_a
[
c
][
K
-
i
]);
}
// Backprop for: scale = exp(params_a[c][0]),
params_grad_a
[
c
][
0
]
+=
scale
*
scale_grad
;
}
}));
return
std
::
vector
<
torch
::
Tensor
>
({
input_grad
,
params_grad
});
std
::
cout
<<
"p_grad = "
<<
p_grad
;
return
std
::
vector
<
torch
::
Tensor
>
({
px_grad
,
py_grad
});
}
...
...
torch_mutual_information/mutual_information_cuda_kernel.cu
View file @
3c1ec347
...
...
@@ -73,7 +73,7 @@ __forceinline__ __device__ inline float LogAdd(float x, float y) {
p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
p[b,s,t-1] + py[b,s,t-1])
(eq. 0)
if s > 0 or t > 0,
treating values with any -1 index as -infinity.
.. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
...
...
@@ -122,32 +122,33 @@ void mutual_information_kernel(
num_t_blocks
=
T
/
BLOCK_SIZE
+
1
;
// num_blocks_this_iter is an upper bound on the number of blocks of size
// (BLOCK_SIZE by BLOCK_SIZE) that might be active on this iteration
. We go
// from the bottom left of the image so that on iter ==
0 we process only one
// block with block-index (0, 0) then on iter == 1 we
process block-indexes
// (1, 0) and (0, 1); and then on iter==2 we process (2,
0), (1, 1) and (0,
// 2); and so on. We also will never have more than
`num_s_blocks` blocks
// (We'll never have more than num_t_blocks either, but
the numbering we use
// corresponds to s and not t, so
if
we hit the
num_t_blocks limit, the
//
lowest-numbered blocks on s would just not be active and we'll 'continue'
// below).
// (BLOCK_SIZE by BLOCK_SIZE) that might be active on this iteration
(`iter`).
//
These iterations start
from the bottom left of the image so that on iter ==
//
0 we process only one
block with block-index (0, 0) then on iter == 1 we
//
process block-indexes
(1, 0) and (0, 1); and then on iter==2 we process (2,
//
0), (1, 1) and (0,
2); and so on. We also will never have more than
//
`num_s_blocks` blocks
(We'll never have more than num_t_blocks either, but
//
the numbering we use
corresponds to s and not t, so
when
we hit the
//
num_t_blocks limit, the blocks with the lowest s indexes would just not be
//
active and we'll 'continue' in the loop
below).
int
num_blocks_this_iter
=
min
(
iter
+
1
,
num_s_blocks
);
// For the block with s_block_begin == 0 and t_block_begin == 0 (for
// easy illustration), px_buf[s][t] will contain exp(px[s - 1][t]); or 0
// for out-of-range indexes.
// for out-of-range indexes
into px
.
// Likewise, py_buf[s][t] will contain exp(py[s][t - 1]).
__shared__
scalar_t
px_buf
[
BLOCK_SIZE
][
BLOCK_SIZE
],
py_buf
[
BLOCK_SIZE
][
BLOCK_SIZE
];
// 1st row/col of p_buf correspond to the previous blocks, or to an edge case.
// So, again for this origin block, p_buf[s][t] corresponds to exp(p[s - 1][t
// - 1] - normalizer); or 0 for out-of-range values.
// p_buf[s][t] == exp(p[s+s_block_begin-1][t+t_block_begin-1] - normalizer).
// 1st row/col of p_buf correspond to the previously computed blocks (lower
// `iter`), or to negative indexes into p. So, for the origin block,
// p_buf[s][t] corresponds to exp(p[s - 1][t - 1] - normalizer); or 0 for
// out-of-range values.
__shared__
scalar_t
p_buf
[
BLOCK_SIZE
+
1
][
BLOCK_SIZE
+
1
];
// boundary_buf will be used to store the b'th row of `boundary` if we have
// boundary information supplied.
// boundary information supplied
; or (0, 0, S, T) otherwise
.
__shared__
int64_t
boundary_buf
[
4
];
if
(
threadIdx
.
x
==
0
)
{
...
...
@@ -157,69 +158,70 @@ void mutual_information_kernel(
boundary_buf
[
3
]
=
T
;
}
// batch_block_iter iterates over both batch elements (index b), and block
// indexes in the range [0..num_blocks_this_iter-1]
// batch_block_iter iterates over batch elements (index b) and block
// indexes in the range [0..num_blocks_this_iter-1], combining both
// batch and block indexes.
for
(
int
batch_block_iter
=
blockIdx
.
x
;
batch_block_iter
<
B
*
num_blocks_this_iter
;
batch_block_iter
+=
gridDim
.
x
)
{
int
b
=
batch_block_iter
%
B
,
block
=
batch_block_iter
/
B
;
int
s_block_begin
=
block
*
BLOCK_S_SIZE
,
t_block_begin
=
(
iter
-
block
)
*
BLOCK_T_SIZE
;
int
block
=
batch_block_iter
/
B
,
b
=
batch_block_iter
%
B
;
// b is the index into the batch
// Note: `block` can be no greater than `iter` because num_blocks_this_iter
// <= iter + 1, so iter - block >= 0.
int
s_block_begin
=
block
*
BLOCK_SIZE
,
t_block_begin
=
(
iter
-
block
)
*
BLOCK_SIZE
;
bool
is_origin_block
=
(
s_block_begin
*
t_block_begin
==
0
);
if
(
threadDim
.
x
<
4
&&
boundary
.
size
(
0
)
!=
0
)
if
(
boundary
.
size
(
0
)
!=
0
&&
threadIdx
.
x
<
4
)
boundary_buf
[
threadDim
.
x
]
=
boundary
[
b
][
threadDim
.
x
];
__syncthreads
();
int
s_begin
=
boundary_buf
[
0
],
t_begin
=
boundary_buf
[
1
],
s_end
=
boundary_buf
[
2
],
t_end
=
boundary_buf
[
3
];
s_block_begin
+=
s_begin
;
t_block_begin
+=
t_begin
;
// block_S and block_T are the actual sizes of this block
, no greater than
//
(BLOCK_SIZE, BLOCK_SIZE) but possibly less than that if we are towards
// the end of the sequence.
//
The
last element
of
the output matrix p
we
write is (s_end,
t_end),
// i.e. the one-past-the-end index is (s_end + 1, t_end + 1).
// block_S and block_T are the actual sizes of this block
(the block of `p`
//
that we will write), no greater than (BLOCK_SIZE, BLOCK_SIZE) but
//
possibly less than that if we are towards
the end of the sequence.
The
// last element
in
the output matrix p
that we need to
write is (s_end,
//
t_end),
i.e. the one-past-the-end index is (s_end + 1, t_end + 1).
int
block_S
=
min
(
BLOCK_SIZE
,
s_end
+
1
-
s_block_begin
),
block_T
=
min
(
BLOCK_SIZE
,
t_end
+
1
-
t_block_begin
);
if
(
block_S
<=
0
||
block_T
<=
0
)
continue
;
bool
is_origin_block
=
(
s_block_begin
*
t_block_begin
==
0
);
// Load px_buf and py_buf. We exponentiate; the assumption is that they most likely
// won't overflow or underflow, but if they do overflow we'll detect it later; we'll
// also detect certain kinds of underflow.
for
(
int
i
=
threadDim
.
x
;
i
<
BLOCK_SIZE
*
BLOCK_SIZE
;
i
+=
blockDim
.
x
)
{
// Load px_buf and py_buf. We exponentiate; the assumption is that they
// most likely won't overflow or underflow, but if they do overflow we'll
// detect it later; we'll also detect certain kinds of underflow.
for
(
int
i
=
threadIdx
.
x
;
i
<
BLOCK_SIZE
*
BLOCK_SIZE
;
i
+=
blockDim
.
x
)
{
int
s_in_block
=
i
/
BLOCK_SIZE
,
t_in_block
=
i
%
BLOCK_SIZE
,
s
=
s_in_block
+
s_block_begin
,
t
=
t_in_block
+
t_block_begin
;
// the comparisons with S and T below just make sure we don't access
// out-of-memory regions; they do not guarantee we are in the range given
// by s_begin, s_end and so on. Note: comparing as unsigned int makes sure
// the index is nonnegative.
// comparing as unsigned int makes sure the index is nonnegative.
scalar_t
this_px
=
0.0
;
if
(
static_cast
<
unsigned
int
>
(
s
-
1
)
<
static_cast
<
unsigned
int
>
(
S
)
&&
t
<=
T
)
if
(
static_cast
<
unsigned
int
>
(
s
-
1
)
<
static_cast
<
unsigned
int
>
(
s_end
)
&&
t
<=
t_end
)
this_px
=
exp
(
px
[
b
][
s
-
1
][
t
]);
px_buf
[
s_in_block
][
t_in_block
]
=
this_px
;
scalar_t
this_py
=
0.0
;
if
(
static_cast
<
unsigned
int
>
(
t
-
1
)
<
static_cast
<
unsigned
int
>
(
T
)
&&
s
<=
S
)
if
(
static_cast
<
unsigned
int
>
(
t
-
1
)
<
static_cast
<
unsigned
int
>
(
t_end
)
&&
s
<=
s_end
)
this_py
=
exp
(
py
[
b
][
s
][
t
-
1
]);
py_buf
[
s_in_block
][
t_in_block
]
=
this_py
;
}
// Load the 1st row and column of p_buf (except element[0][0] is not needed).
// Remember: p_buf[s][t] corresponds to exp(p[s + s_block_begin - 1][t + t_block_begin - 1] - normalizer.
// Load the 1st row and 1st column of p_buf (except element[0][0] is not
// needed). This is the context from previously computed blocks of the
// image. Remember: p_buf[s][t] will correspond to exp(p[s + s_block_begin -
// 1][t + t_block_begin - 1] - normalizer.
if
(
threadIdx
.
x
<
64
)
{
// 64 == warp size. First half of threads...
if
(
threadIdx
.
x
<=
BLOCK_SIZE
)
{
// s_in_p_buf are simply the indexes into p_buf
...
...
@@ -227,16 +229,14 @@ void mutual_information_kernel(
t_in_p_buf
=
0
,
s
=
s_in_p_buf
+
s_block_begin
-
1
,
t
=
t_in_p_buf
+
t_block_begin
-
1
;
// The if-statement below just guards against out-of-range memory
// accesses, it does not guarantee that we really need these values.
scalar_t
this_p
=
-
INFINITY
;
if
(
static_cast
<
unsigned
int
>
(
s
)
<
static_cast
<
unsigned
int
>
(
S
)
&&
static_cast
<
unsigned
int
>
(
t
)
<
static_cast
<
unsigned
int
>
(
T
))
this_p
=
p
[
s
+
s_block_begin
][
s
+
t_block_begin
];
if
(
static_cast
<
unsigned
int
>
(
s
)
<
=
static_cast
<
unsigned
int
>
(
s_end
)
&&
static_cast
<
unsigned
int
>
(
t
)
<
=
static_cast
<
unsigned
int
>
(
t_end
))
this_p
=
p
[
b
][
s
][
t
];
p_buf
[
threadIdx
.
x
][
0
]
=
this_p
;
}
}
else
{
// Another warp handles the other leg
if
(
threadIdx
.
x
-
64
<=
BLOCK_SIZE
)
{
if
(
int
(
threadIdx
.
x
)
-
64
<=
BLOCK_SIZE
)
{
int
s_in_p_buf
=
0
,
t_in_p_buf
=
threadIdx
.
x
-
64
,
s
=
s_in_p_buf
+
s_block_begin
-
1
,
...
...
@@ -244,19 +244,19 @@ void mutual_information_kernel(
// The if-statement below just guards against out-of-range memory
// accesses, it does not guarantee that we really need these values.
scalar_t
this_p
=
-
INFINITY
;
if
(
static_cast
<
unsigned
int
>
(
s
)
<
static_cast
<
unsigned
int
>
(
S
)
&&
static_cast
<
unsigned
int
>
(
t
)
<
static_cast
<
unsigned
int
>
(
T
))
this_p
=
p
[
s
+
s_block_begin
][
s
+
t_block_begin
];
if
(
static_cast
<
unsigned
int
>
(
s
)
<
=
static_cast
<
unsigned
int
>
(
s_end
)
&&
static_cast
<
unsigned
int
>
(
t
)
<
=
static_cast
<
unsigned
int
>
(
t_end
))
this_p
=
p
[
b
][
s
][
t
];
p_buf
[
threadIdx
.
x
][
0
]
=
this_p
;
}
}
__syncthreads
();
// We read p_buf in log-space; subtract 'normalizer', which
mathematically
// could be any finite number, to get i
n a reasonable range of probabilities,
// and then exponentiate. We'll do everything in non-log space,
and late
r
// take a log before we write out the data.
// We read p_buf in log-space;
we now
subtract 'normalizer', which
//
mathematically
could be any finite number, to get i
t in a range close to
//
zero,
and then exponentiate. We'll do everything in non-log space,
fo
r
//
speed, and later
take a log before we write out the data.
scalar_t
normalizer
=
(
is_origin_block
?
0.0
:
max
(
px_buf
[
0
][
1
],
px_buf
[
1
][
0
]));
...
...
@@ -265,50 +265,55 @@ void mutual_information_kernel(
// and we'll overwrite with 1.0 if there is a panic situation due to
// overflow.
if
(
threadIdx
.
x
<=
BLOCK_SIZE
)
{
if
(
threadIdx
.
x
==
0
)
{
// p_buf[0][0]
is never used for its normal purpose; we set it to zero
.
// We'll later write an infinity there if something goes wrong, as a
// 'panic' indicator.
p_buf
[
threadIdx
.
x
][
0
]
=
(
threadIdx
.
x
==
0
?
0.0
:
exp
(
p_buf
[
threadIdx
.
x
][
0
]
-
normalizer
));
}
}
else
if
((
int
)
threadIdx
.
x
-
64
<
BLOCK_SIZE
)
{
// p_buf[0][0] is never used for its normal purpose; we set it to zero
// p_buf[0][0]
= 0.0; <-- for search purposes
.
// We'll later write an infinity there if something goes wrong, as a
// 'panic' indicator.
p_buf
[
threadIdx
.
x
][
0
]
=
(
threadIdx
.
x
==
0
?
0.0
:
exp
(
p_buf
[
threadIdx
.
x
][
0
]
-
normalizer
));
}
else
if
(
int
(
threadIdx
.
x
)
-
64
<
BLOCK_SIZE
)
{
// this happens in a different warp so can be in parallel to the code above.
p_buf
[
0
][
threadIdx
.
x
+
1
]
=
exp
(
p_buf
[
0
][
threadIdx
.
x
+
1
]
-
normalizer
);
}
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
&&
is_origin_block
)
{
// This if-statement is an optimization and modification of the loop below
// for the value i == 0, i.e. inner-iteration == 0. The modification
//
is to use
0.0 if this is the "origin block"
with s_block_begin == 0 and
//
t_block
_begin ==
0
. This corresponds to the
probability of the pair of
// sequences of length (0, 0).
p_buf
[
1
][
1
]
=
(
is_origin_block
?
0
.0
:
// for the value i == 0, i.e. inner-iteration == 0. The modification
is
//
to set p_buf to 1.0 = exp(
0.0
)
if this is the "origin block"
,
//
i.e. s == s
_begin
, t
==
t_begin
. This corresponds to the
//
probability of the pair of
sequences of length (0, 0).
p_buf
[
1
][
1
]
=
(
is_origin_block
?
1
.0
:
p_buf
[
0
][
1
]
*
px_buf
[
0
][
0
]
+
p_buf
[
1
][
0
]
*
py_buf
[
0
][
0
]);
}
scalar_t
p_buf_s1_t
;
// This is for an optimization.
if
(
i
<
BLOCK_SIZE
)
{
scalar_t
p_buf_s1_t
;
// This is for an optimization to avoid one
// shared-memory read/write in the loop below. it
// represents p_buf[s + 1][t]; the first time we
// access this, it will be for t == 0, except for
// thread 0 when we first need it for t == 1.
if
(
threadIdx
.
x
<
BLOCK_SIZE
)
{
int
s
=
threadIdx
.
x
;
p_buf_s1_t
=
p_buf
[
s
+
1
][
0
];
p_buf_s1_t
=
p_buf
[
s
+
1
][
threadIdx
.
x
==
0
?
1
:
0
];
}
for
(
int
i
=
1
;
i
<
block_S
+
block_T
;
i
++
)
{
// i is the inner iteration, which corresponds to the (s + t) indexes of
the
// elements within the block that we write. So i == 0 writes
positions
// (s, t) == (0, 0)
; i == 1 writes (0, 1) and (1, 0); i == 2 writes
//
(0, 2), (1
, 1) and (
2
,
1
);
and so on.
// Note: not many threads participate in this
part, only up to BLOCK_SIZE
// at most. Unfortunately we couldn't figure
out a very meaningful way
// for more threads to do work, that looked like
it would really spead
// things up.
for
(
int
i
=
1
;
i
<
block_S
+
block_T
-
1
;
++
i
)
{
// i is the inner iteration, which corresponds to the (s + t) indexes of
//
the
elements within the block that we write. So i == 0 writes
//
positions
(s, t) == (0, 0)
(but we treated i == 0 as a special case
//
above); i == 1 writes (0
, 1) and (
1
,
0
);
i == 2 writes (0, 2), (1, 1)
//
and (2, 1); and so on.
Note: not many threads participate in this
//
part, only up to BLOCK_SIZE
at most. Unfortunately we couldn't figure
//
out a very meaningful way
for more threads to do work, that looked like
//
it would really spead
things up.
// So this kernel does (2 * BLOCK_SIZE) iterations, which may seem a lot,
// but we do at least do the I/O in an efficient way and keep the
// inner loop simple and fast (e.g. no exp() or log()).
int
s
=
threadIdx
.
x
,
t
=
i
-
s
;
if
(
t
>=
0
)
{
if
(
static_cast
<
unsigned
int
>
(
t
)
<
static_cast
<
unsigned
int
>
(
block_T
))
{
// p_buf is indexed by s + 1 and t + 1 because it has an extra initial
// row and column for context from previous blocks. Taking into account
// the way these buffers relate to the tensors p, px and py, and
...
...
@@ -320,7 +325,7 @@ void mutual_information_kernel(
//
// where you can see that apart from the offsets of tbb and sbb, this is
// the same as the recursion defined for p in
// mutual_information.py:mutual_information_recursion().
// mutual_information.py:mutual_information_recursion()
; and (eq. 0) above
.
#if 1
p_buf
[
s
+
1
][
t
+
1
]
=
p_buf
[
s
][
t
+
1
]
*
px_buf
[
s
][
t
]
+
p_buf
[
s
+
1
][
t
]
*
py_buf
[
s
][
t
];
#else
...
...
@@ -328,18 +333,30 @@ void mutual_information_kernel(
// this #if/#else) where we keep p_buf[s + 1][t] in a register to avoid
// the need for a load from shared memory.
p_buf_s1_t
=
p_buf
[
s
][
t
+
1
]
*
px_buf
[
s
][
t
]
+
p_buf_s1_t
*
py_buf
[
s
][
t
];
// The next time this thread reads p_buf_s1_t, t will be one greater,
// so p_buf_s1_t will contain p_buf[s + 1][t]. The first time this
// thread uses p_buf_s1_t is when t == 0, except for thread 0 where
// the 1st item accessed is for s == 0, t == 1.
p_buf
[
s
+
1
][
t
+
1
]
=
p_buf_s1_t
;
#endif
// We don't need to do __syncthreads() in this loop because all the
// threads that are active are in the same warp. (However, in future,
// if NVidia changes some things, we might need to sync here).
}
__syncthreads
();
}
// Write out the data.
for
(
int
i
=
threadDim
.
x
;
i
<
BLOCK_SIZE
*
BLOCK_SIZE
;
i
+=
blockDim
.
x
)
{
int
t
=
i
%
BLOCK_SIZE
,
s
=
i
/
BLOCK_SIZE
;
if
(
s
<
block_S
&&
t
<
block_T
)
{
float
this_p
=
p_buf
[
s
+
1
][
t
+
1
];
p
[
b
][
s
+
s_block_begin
][
t
+
t_block_begin
]
=
normalizer
+
log
(
this_p
);
// Write out the data to p; check that nothing has gone out of numerical
// range, and write 'panic' flag if it has.
for
(
int
i
=
threadIdx
.
x
;
i
<
BLOCK_SIZE
*
BLOCK_SIZE
;
i
+=
blockDim
.
x
)
{
int
s_in_block
=
i
/
BLOCK_SIZE
,
t_in_block
=
i
%
BLOCK_SIZE
,
s
=
s_in_block
+
s_block_begin
,
t
=
t_in_block
+
t_block_begin
;
if
(
s_in_block
<
block_S
&&
t_in_block
<
block_T
)
{
float
this_p
=
p_buf
[
s_in_block
+
1
][
t_in_block
+
1
];
p
[
b
][
s
][
t
]
=
normalizer
+
log
(
this_p
);
// If this_p is infinity, NaN or zero...
if
(
this_p
-
this_p
!=
0
||
this_p
==
0
)
p_buf
[
0
][
0
]
=
1.0
;
// This is a "panic" flag.
}
...
...
@@ -351,27 +368,31 @@ void mutual_information_kernel(
// Write `ans`, if this is the final (top-right) block in its sequence
// Logically, the following equation corresponds to:
// ans[b] = p[b][s_end][t_end]
if
(
s_block_begin
+
S
>
s_end
&&
t_block_begin
+
T
>
t_end
)
ans
[
b
]
=
normalizer
+
log
(
p_buf
[
s_end
-
s_block_begin
+
1
][
t_end
-
t_block_begin
+
1
]);
if
(
s_block_begin
+
block_S
-
1
==
s_end
&&
t_block_begin
+
block_T
-
1
==
t_end
)
{
// you could read block_S below as block_S - 1 + 1, meaning,
// it's the last index in a block of size block_S, but the indexes into
// p_buf have a "+ 1". Likewise for block_T.
ans
[
b
]
=
normalizer
+
log
(
p_buf
[
block_S
][
block_T
]);
}
}
if
(
p_buf
[
0
][
0
]
!=
0.0
)
{
// "panic" flag set. We need to re-do the computation using log-add.
//
The
"panic" flag
is
set. We need to re-do the computation using log-add.
// This time we won't use the buffers, we'll just load and save from main
// memory. This code should very rarely be reached; and anyway, caching
// should help us quite a bit.
for
(
int
i
=
0
;
i
<
2
*
BLOCK_SIZE
;
i
++
)
{
int
block
_s
=
threadIdx
.
x
,
block
_t
=
i
-
block_s
;
if
(
s
tatic_cast
<
unsigned
int
>
(
t
)
<
static_cast
<
unsigned
int
>
(
block_
T
)
&&
block_s
<
block_
S
)
{
int
s
=
block
_s
+
s_block_begin
,
t
=
block
_t
+
t_block_begin
;
for
(
int
i
=
0
;
i
<
block_S
+
block_T
-
1
;
++
i
)
{
int
s_in_
block
=
threadIdx
.
x
,
t_in_
block
=
i
-
block_s
;
if
(
s
_in_block
<
block_
S
&&
static_cast
<
unsigned
int
>
(
t_in_block
)
<
static_cast
<
unsigned
int
>
(
block_
T
)
)
{
int
s
=
s_in_
block
+
s_block_begin
,
t
=
t_in_
block
+
t_block_begin
;
float
p_s1
=
(
s
==
0
?
-
INFINITY
:
p
[
b
][
s
-
1
][
t
]),
this_px
=
(
s
==
0
?
-
INFINITY
:
px
[
b
][
s
-
1
][
t
]),
p_t1
=
(
t
==
0
?
-
INFINITY
:
p
[
b
][
s
][
t
-
1
]),
this_p
x
=
px
[
b
][
s
][
t
],
this_py
=
py
[
b
][
s
][
t
]
;
this_p
y
=
(
t
==
0
?
-
INFINITY
:
py
[
b
][
s
][
t
-
1
])
;
float
this_p
=
LogAdd
(
p_s1
+
this_px
,
p_t1
+
this_py
);
if
(
i
==
0
&&
is_origin_block
)
...
...
@@ -382,7 +403,8 @@ void mutual_information_kernel(
if
(
threadIdx
.
x
==
0
)
{
// Write `ans`, if this is the final (top-right) block in its sequence.
// This is only reached in the 'panic situation' where we had overflow.
if
(
s_block_begin
+
S
>
s_end
&&
t_block_begin
+
T
>
t_end
)
if
(
s_block_begin
+
block_S
-
1
==
s_end
&&
t_block_begin
+
block_T
-
1
==
t_end
)
ans
[
b
]
=
p
[
b
][
s_end
][
t_end
];
}
}
...
...
@@ -402,17 +424,18 @@ void mutual_information_kernel(
ep[b][s][t - 1] * epy[b][s][t - 1]. (eq. 1)
(A)
First we consider the part
that involves recursion, i.e. the part involving only gradients of
ep. The backprop
involving
ep
only
would be
:
ep_grad[b][s - 1][t] += ep_grad[b][s][t] * epx[b][s - 1][t]
ep_grad[b][s][t - 1] += ep_grad[b][s][t] * epy[b][s][t - 1].
First we consider the part
of the backprop that requires recursion or iteration,
i.e. the part
involving only
gradients of ep. This is
:
ep_grad[b][s - 1][t] += ep_grad[b][s][t] * epx[b][s - 1][t]
ep_grad[b][s][t - 1] += ep_grad[b][s][t] * epy[b][s][t - 1].
.. and if we add 1 to the s index of the first equation above and 1 to the
t index of the second equation, we can see that:
ep_grad[b][s][t] = ep_grad[b][s + 1][t] * epx[b][s][t] +
ep_grad[b][s][t + 1] * epy[b][s][t].
Now, if ep = exp(p), then ep_grad == dy/dep == dy/dp dp/dep == dy/dp / (dep/dp) == dy/dp / exp(p)
Now, if ep = exp(p), and y is the loss function we are backprop'ing,
then ep_grad == dy/dep == dy/dp dp/dep == dy/dp / (dep/dp) == dy/dp / exp(p)
== dy/dp / ep. == p_grad / ep.
I.e. ep_grad = p_grad / ep.
So we can write the above as:
...
...
@@ -425,8 +448,8 @@ void mutual_information_kernel(
(B) The following is the backprop for epx and epy from (eq. 1):
epx_grad[b][s - 1][t] +=
ep_grad[b][s][t] * ep[b][s - 1][t]
epy_grad[b][s][t - 1] +=
ep_grad[b][s][t] * ep[b][s][t - 1]
epx_grad[b][s - 1][t] += ep_grad[b][s][t] * ep[b][s - 1][t]
epy_grad[b][s][t - 1] += ep_grad[b][s][t] * ep[b][s][t - 1]
.. adding 1 to the s indexes in the 1st equation and to the t indexes in the 2nd:
...
...
@@ -435,7 +458,7 @@ void mutual_information_kernel(
Using, similar to the above, ep_grad = p_grad / ep, and similarly,
epx_grad = px_grad / epx and epy_grad = py_grad / epy, and writing exp(p) for p and so on,
the above becomes
the above becomes
:
px_grad[b][s][t] / exp(px[b][s][t]) = p_grad[b][s + 1][t] / exp(p[b][s + 1][t]) * exp(p[b][s][t])
py_grad[b][s][t] / exp(py[b][s][t]) = p_grad[b][s][t + 1] / exp(p[b][s][t + 1]) * exp(p[b][s][t])
...
...
@@ -450,11 +473,11 @@ void mutual_information_kernel(
yderiv[b][s][t] := exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1]) (eq. 5)
.. and note that these quantities are <= 1 so there is no problem doing
the exponentiation. So the recursion can be simplified as:
the exponentiation. So the recursion can be simplified
as from eqs. (2, 3a, 3b),
as:
p_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] +
p_grad[b][s][t + 1] * yderiv[b][s][t] (eq. 6)
px_grad[b][s][t] = p_grad[b][s + 1][t] *
y
deriv[b][s][t] (eq. 7)
px_grad[b][s][t] = p_grad[b][s + 1][t] *
x
deriv[b][s][t] (eq. 7)
py_grad[b][s][t] = p_grad[b][s][t + 1] * yderiv[b][s][t] (eq. 8)
(It might seem like we could just reuse px_grad and py_grad for (eq. 6), but it's
...
...
@@ -462,8 +485,9 @@ void mutual_information_kernel(
write to shared memory within the loop that's the limiting factor.)
The backward pass will be slightly different from the forward pass in terms of
how we store p (and p_grad), because for writing a particular block of p_grad, we
need context on the top and right instead of the bottom and left.
how we store and index p (and p_grad), because for writing a particular block
of p_grad, we need context on the top and right instead of the bottom and
left. So there are offsets of 1.
*/
template
<
typename
scalar_t
>
__global__
...
...
@@ -472,8 +496,6 @@ void mutual_information_backward_kernel(
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
py
,
// B, S + 1, T.
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
p
,
// B, S + 1, T + 1. Produced in forward pass.
torch
::
PackedTensorAccessor32
<
scalar_t
,
1
>
ans_grad
,
// [B]. This is an input.
torch
::
PackedTensorAccessor32
<
scalar_t
,
1
>
ans_grad_compare
,
// [B]. A value will be written to here which
// should ideally equal ans_grad.
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
p_grad
,
// B, S + 1, T + 1. This is a temporary.
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
px_grad
,
// B, S, T + 1.
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
py_grad
,
// B, S + 1, T.
...
...
@@ -483,16 +505,18 @@ void mutual_information_backward_kernel(
// be any sufficiently large number but will actually be:
// num_s_blocks + num_t_blocks - 1 where num_s_blocks = S /
// BLOCK_SIZE + 1 and num_t_blocks = T / BLOCK_SIZE + 1
bool
overwrite_ans_grad
)
{
// If true, overwrite ans_grad with a value
// which, if everything is working correctly,
// should be identical or very close to the
// value of ans_grad that was passed in.
bool
overwrite_ans_grad
)
{
// If overwite_ans_grad == true, this function
// will overwrite ans_grad with a value which,
// if everything is working correctly, should be
// identical or very close to the value of
// ans_grad that was passed in.
const
int
B
=
px
.
size
(
0
),
S
=
px
.
size
(
1
),
T
=
py
.
size
(
2
);
// For statements that are the same as the forward pass, we are omitting some comments
// what we made there. We'll focus, in the comments, on differences from the forward pass.
// For statements that are the same as the forward pass, we are omitting some
// comments. We'll focus, in the comments, on differences from the forward
// pass.
const
int
num_s_blocks
=
S
/
BLOCK_SIZE
+
1
,
num_t_blocks
=
T
/
BLOCK_SIZE
+
1
,
num_blocks_this_iter
=
min
(
iter
+
1
,
num_s_blocks
);
...
...
@@ -502,29 +526,33 @@ void mutual_information_backward_kernel(
// but then modified to store the "xderiv" and "yderiv" values defined
// in (eq. 5) and (eq. 6) above. For out-of-range values, we'll write 0.0
// here.
// px_buf[s][t] contains px[s+s_block_begin][t+t_block_begin];
// py_buf[s][t] contains py[s+s_block_begin][t+t_block_begin].
// Initially (before xderiv/yderiv are written):
// px_buf[s][t] contains px[s+s_block_begin][t+t_block_begin];
// py_buf[s][t] contains py[s+s_block_begin][t+t_block_begin].
// Later (see eq. 4 and eq. 5):
// px_buf[s][t] contains exp(p[b][ss][tt] + px[b][ss][tt] - p[b][ss + 1][tt]),
// py_buf[s][t] contains exp(p[b][ss][tt] + py[b][ss][tt] - p[b][ss][tt + 1]
// where ss == s + s_block_begin, tt = t + t_block_begin.
// Unlike in the forward code, there is no offset of 1 in the indexes.
__shared__
scalar_t
px_buf
[
BLOCK_SIZE
][
BLOCK_SIZE
],
py_buf
[
BLOCK_SIZE
][
BLOCK_SIZE
];
// p_buf is initially used to store p, and then (after we are done putting
// xderiv and yderiv into px_buf and py_buf) it is repurposed to store
// p_grad.
//
// Unlike in the forward pass, p_buf has the same numbering as px_buf and
// py_buf not offset by 1: e.g., for the origin block, p_buf[0][0]
refers
// to p[0][0] and not p[-1][-1]. The p_buf block is larger by 1 than
// py_buf
, it's
not offset by 1: e.g., for the origin block, p_buf[0][0]
//
refers
to p[0][0] and not p[-1][-1]. The p_buf block is larger by 1 than
// the block for px_buf and py_buf; unlike in the forward pass, we store
// context on the top right, not the bottom left, i.e. the elements at
// context on the top
and
right, not the bottom
and
left, i.e. the elements at
// (one past the largest indexes in the block).
//
// For out-of-range elements of p_buf, we'll put zero.
__shared__
scalar_t
p_buf
[
BLOCK_SIZE
+
1
][
BLOCK_SIZE
+
1
];
// boundary_buf will be used to store the b'th row of `boundary` if we have
// boundary information supplied.
// boundary information supplied
; or (0, 0, S, T) if not
.
__shared__
int64_t
boundary_buf
[
4
];
if
(
threadIdx
.
x
==
0
)
{
...
...
@@ -541,13 +569,13 @@ void mutual_information_backward_kernel(
for
(
int
batch_block_iter
=
blockIdx
.
x
;
batch_block_iter
<
B
*
num_blocks_this_iter
;
batch_block_iter
+=
gridDim
.
x
)
{
int
b
=
batch_block_iter
%
B
,
block
=
batch_block_iter
/
B
;
int
s_block_begin
=
block
*
BLOCK_
S_
SIZE
,
t_block_begin
=
(
iter
-
block
)
*
BLOCK_
T_
SIZE
;
int
b
lock
=
batch_block_iter
/
B
,
b
=
batch_block_iter
%
B
;
int
s_block_begin
=
block
*
BLOCK_SIZE
,
t_block_begin
=
(
iter
-
block
)
*
BLOCK_SIZE
;
if
(
thread
Dim
.
x
<
4
&&
boundary
.
size
(
0
)
!=
0
)
boundary_buf
[
thread
Dim
.
x
]
=
boundary
[
b
][
thread
Dim
.
x
];
if
(
thread
Idx
.
x
<
4
&&
boundary
.
size
(
0
)
!=
0
)
boundary_buf
[
thread
Idx
.
x
]
=
boundary
[
b
][
thread
Idx
.
x
];
__syncthreads
();
int
s_begin
=
boundary_buf
[
0
],
...
...
@@ -560,68 +588,69 @@ void mutual_information_backward_kernel(
// block_S and block_T are the actual sizes of this block, no greater than
// (BLOCK_SIZE, BLOCK_SIZE) but possibly less than that if we are towards
// the end of the sequence.
// The last element of the output matrix p we write is (s_end, t_end),
// i.e. the one-past-the-end index is (s_end + 1, t_end + 1).
// The last element of the output matrix p
_grad
we write is (s_end, t_end),
// i.e. the one-past-the-end index
of p_grad
is (s_end + 1, t_end + 1).
int
block_S
=
min
(
BLOCK_SIZE
,
s_end
+
1
-
s_block_begin
),
block_T
=
min
(
BLOCK_SIZE
,
t_end
+
1
-
t_block_begin
);
block_T
=
min
(
BLOCK_SIZE
,
t_end
+
1
-
t_block_begin
);
if
(
block_S
<=
0
||
block_T
<=
0
)
continue
;
// Load px_buf and py_buf. At this point
they
just
contain
px and py
// Load px_buf and py_buf. At this point
we
just
set them to the
px and py
// for this block.
for
(
int
i
=
thread
Dim
.
x
;
i
<
BLOCK_SIZE
*
BLOCK_SIZE
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
thread
Idx
.
x
;
i
<
BLOCK_SIZE
*
BLOCK_SIZE
;
i
+=
blockDim
.
x
)
{
int
s_in_block
=
i
/
BLOCK_SIZE
,
t_in_block
=
i
%
BLOCK_SIZE
,
s
=
s_in_block
+
s_block_begin
,
t
=
t_in_block
+
t_block_begin
;
// We let p
s
and py default to -infinity if they are out of range, which will
// We let p
x
and py default to -infinity if they are out of range, which will
// cause xderiv and yderiv for out-of-range values to be zero, and cause
// correct behavior in edge cases (for the top and right blocks).
// The issue is that p and p_grad are of larger size than px and py.
scalar_t
this_px
=
-
INFINITY
;
if
(
s
<
s_end
&&
t
<=
t_end
)
this_px
=
px
[
b
][
s
-
1
][
t
];
this_px
=
px
[
b
][
s
][
t
];
px_buf
[
s_in_block
][
t_in_block
]
=
this_px
;
scalar_t
this_py
=
-
INFINITY
;
if
(
s
<=
s_end
&&
t
<
t_end
)
this_py
=
py
[
b
][
s
][
t
-
1
];
this_py
=
py
[
b
][
s
][
t
];
py_buf
[
s_in_block
][
t_in_block
]
=
this_py
;
}
// load p. This time we loop over the exact indexes we need. Above
// we looped to BLOCK_SIZE * BLOCK_SIZE rather than block_S and block_T
// because having power-of-2 arrangement of threads may be helpful
// for aligned reads, but here the loop is up to (BLOCK_SIZE + 1) * (BLOCK_SIZE + 1)
// which is not a power of 2, so that is not a concern here.
for
(
int
i
=
threadDim
.
x
;
i
<
(
BLOCK_SIZE
+
1
)
*
(
BLOCK_SIZE
+
1
);
i
+=
blockDim
.
x
)
{
int
s_in_block
=
i
/
(
BLOCK_SIZE
+
1
),
// 0 <= s_in_block <= block_S
t_in_block
=
i
%
(
BLOCK_SIZE
+
1
),
// 0 <= t_in_block <= block_T
// load p. We could use BLOCK_SIZE + 1 here, but we use + 8 to hopefully keep
// reads more aligned.
for
(
int
i
=
threadIdx
.
x
;
i
<
(
BLOCK_SIZE
+
1
)
*
(
BLOCK_SIZE
+
1
);
i
+=
blockDim
.
x
)
{
int
s_in_block
=
i
/
(
BLOCK_SIZE
+
1
),
t_in_block
=
i
%
(
BLOCK_SIZE
+
1
),
s
=
s_in_block
+
s_block_begin
,
t
=
t_in_block
+
t_block_begin
;
// Setting 0.0 for out-of-bounds elements, together with setting
// -INFINITY for out-of-bounds elements of px_buf and py_buf, will
// ensure that we do the right thing in top and right edge cases,
// i.e. that no derivatives will be propagated from out-of-bounds points.
p_buf
[
s_in_block
][
t_in_block
]
=
(
s
<=
s_end
&&
t
<=
t_end
?
p
[
b
][
s
][
t
]
:
0.0
);
// i.e. that no derivatives will be propagated from out-of-bounds points
// because the corresponding xderiv and yderiv values will be zero.
scalar_t
this_p
=
0.0
;
if
(
s
<=
s_end
&&
t
<=
t_end
)
this_p
=
p
[
b
][
s
][
t
];
p_buf
[
s_in_block
][
t_in_block
]
=
this_p
;
}
// Set xderiv and yderiv; see (eq. 4) and (eq. 5).
for
(
int
i
=
thread
Dim
.
x
;
i
<
BLOCK_SIZE
*
BLOCK_SIZE
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
thread
Idx
.
x
;
i
<
BLOCK_SIZE
*
BLOCK_SIZE
;
i
+=
blockDim
.
x
)
{
// We can apply this formula to the entire block even if we are processing
// a partial block; elements outside the partial block will not be used so
// their values don't matter, and elements just out
int
t
=
i
%
BLOCK_SIZE
,
s
=
i
/
BLOCK_SIZE
;
// a partial block; we have ensured that x_buf and y_buf contain -infinity,
// and p contains 0, for out-of-range elements, so we'll get x_buf and y_buf
// containing 0 after applying the followin formulas.
int
s
=
i
/
BLOCK_SIZE
,
t
=
i
%
BLOCK_SIZE
;
// Mathematically the following is doing:
// xderiv[b][s][t] := exp(p[b][s][t] + px[b][s][t] - p[b][s + 1][t])
// (with an offset on the s and t indexes)
px_buf
[
s
][
t
]
=
exp
(
p
x
_buf
[
s
][
t
]
+
px_buf
[
s
][
t
]
-
p_buf
[
s
+
1
][
t
]);
px_buf
[
s
][
t
]
=
exp
(
p_buf
[
s
][
t
]
+
px_buf
[
s
][
t
]
-
p_buf
[
s
+
1
][
t
]);
// Mathematically the following is doing:
// yderiv[b][s][t] := exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1])
// (with an offset on the s and t indexes)
py_buf
[
s
][
t
]
=
exp
(
p
x
_buf
[
s
][
t
]
+
py_buf
[
s
][
t
]
-
p_buf
[
s
][
t
+
1
]);
py_buf
[
s
][
t
]
=
exp
(
p_buf
[
s
][
t
]
+
py_buf
[
s
][
t
]
-
p_buf
[
s
][
t
+
1
]);
}
// Load p_grad for the top and right elements in p_buf: i.e. for elements
...
...
@@ -630,7 +659,8 @@ void mutual_information_backward_kernel(
// never be accessed.
// These are the p_grad values computed by previous instances of this kernel
// If this is one of the top or right blocks, some or all of the p_grad
// values we'd be reading here will be out of range, and we use zeros.
// values we'd be reading here will be out of range, and we use zeros
// to ensure no gradient gets propagated from those positions.
if
(
threadIdx
.
x
<
block_S
)
{
int
s_in_block
=
threadIdx
.
x
,
t_in_block
=
block_T
,
...
...
@@ -638,34 +668,33 @@ void mutual_information_backward_kernel(
t
=
t_in_block
+
t_block_begin
;
p_buf
[
s_in_block
][
t_in_block
]
=
(
s
<=
s_end
&&
t
<=
t_end
?
p_grad
[
s
][
t
]
:
0.0
);
}
else
if
(
static_cast
<
unsigned
int
>
(
threadIdx
.
x
-
64
)
<
}
else
if
(
static_cast
<
unsigned
int
>
(
(
int
)
threadIdx
.
x
-
64
)
<
static_cast
<
unsigned
int
>
(
block_T
))
{
// casting to unsigned before the comparison tests for both negative and
// out-of-range values of (int)threadIdx.x - 64.
int
s_in_block
=
block_S
,
t_in_block
=
threadIdx
.
x
-
64
,
t_in_block
=
(
int
)
threadIdx
.
x
-
64
,
s
=
s_in_block
+
s_block_begin
,
t
=
t_in_block
+
t_block_begin
;
p_buf
[
s_in_block
][
t_in_block
]
=
(
s
<=
s_end
&&
t
<=
t_end
?
p_grad
[
s
][
t
]
:
0.0
);
}
// The number of inner iterations, i.e. iterations inside this
// kernel, is this_num_inner_iters. The highest iteration,
// corresponding to the highest-indexed value of p_buf that
// we need to set,
// corresponds to p_buf[block_S - 1][block_T - 1],
// and the iteration number is the sum of these indexes, i.e.
// (block_S - 1) + (block_T - 1).
// The highest-numbered value in p_buf that we need (corresponding,
// of course, to p_grad), is:
// p_buf[block_S - 1][block_T - 1],
// and the inner iteration number (i) on which we set this is the sum of
// these indexes, i.e. (block_S - 1) + (block_T - 1).
bool
is_final_block
=
(
s_block_begin
+
block_S
==
s_end
+
1
&&
t_block_begin
+
block_T
==
t_end
+
1
);
int
first_iter
=
block_S
+
block_T
-
2
;
if
(
is_final_block
)
{
// The following statement
, mathematically,
corresponds to:
// p_grad[b][s_end][t_end] = ans_grad[b]
Normally this element of p_buf
// would be set by the first iteration of
the loop below, so if it's set
// this way we have to decrement first_iter
to prevent it being
// overwritten.
// The following statement corresponds to:
// p_grad[b][s_end][t_end] = ans_grad[b]
//
Normally this element of p_buf
would be set by the first iteration of
//
the loop below, so if it's set
this way we have to decrement first_iter
//
to prevent it from being
overwritten.
p_buf
[
block_S
-
1
][
block_T
-
1
]
=
ans_grad
[
b
];
--
first_iter
;
}
...
...
@@ -675,7 +704,8 @@ void mutual_information_backward_kernel(
t
=
i
-
threadIdx
.
x
;
if
(
t
>=
0
)
{
// The following statement is really operating on the gradients;
// it corresponds to (eq. 6) defined above, i.e.:
// it corresponds, with offsets of s_block_begin and t_block_begin
// on the indexes, to (eq. 6) defined above, i.e.:
// p_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] +
// p_grad[b][s][t + 1] * yderiv[b][s][t]
p_buf
[
s
][
t
]
=
(
p_buf
[
s
+
1
][
t
]
*
px_buf
[
s
][
t
]
+
...
...
@@ -684,17 +714,19 @@ void mutual_information_backward_kernel(
}
// Write out p_grad, px_grad and py_grad.
for
(
int
i
=
thread
Dim
.
x
;
i
<
BLOCK_SIZE
*
BLOCK_SIZE
;
i
+=
blockDim
.
x
)
{
int
t
_in_block
=
i
%
BLOCK_SIZE
,
s
_in_block
=
i
/
BLOCK_SIZE
,
for
(
int
i
=
thread
Idx
.
x
;
i
<
BLOCK_SIZE
*
BLOCK_SIZE
;
i
+=
blockDim
.
x
)
{
int
s
_in_block
=
i
/
BLOCK_SIZE
,
t
_in_block
=
i
%
BLOCK_SIZE
,
s
=
s_in_block
+
s_block_begin
,
t
=
t_in_block
+
t_block_begin
;
// s_end and t_end are the one-past-the-end of the (x,y) sequences, but
// the one-past-the-end element of p_grad would be (s_end + 1, t_end + 1).
if
(
t
<=
t_end
&&
s
<=
s_end
)
{
p_grad
[
b
][
s
][
t
]
=
p_buf
[
s_in_block
][
t_in_block
];
if
(
s
<
s_end
)
{
// write px_grad, which is of shape [B][S][T + 1]
// From (eq. 7):
// px_grad[b][s][t] = p_grad[b][s + 1][t] *
y
deriv[b][s][t]
// px_grad[b][s][t] = p_grad[b][s + 1][t] *
x
deriv[b][s][t]
px_grad
[
b
][
s
][
t
]
=
(
p_buf
[
s_in_block
+
1
][
t_in_block
]
*
px_buf
[
s_in_block
][
t_in_block
]);
}
...
...
@@ -741,7 +773,7 @@ torch::Tensor mutual_information_cuda(torch::Tensor px,
// num_threads and num_blocks and BLOCK_SIZE can be tuned.
// (however, num_threads may not be less than 128).
int
num_threads
=
128
,
num_blocks
=
128
,
num_blocks
=
256
,
BLOCK_SIZE
=
32
;
// The blocks cover the 'p' matrix, which is of size (B, S+1, T+1),
...
...
@@ -802,14 +834,18 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
TORCH_CHECK
(
p
.
size
(
0
)
==
B
&&
p
.
size
(
1
)
==
S
+
1
&&
p
.
size
(
2
)
==
T
+
1
);
TORCH_CHECK
(
ans_grad
.
size
(
0
)
==
b
);
bool
has_boundary
=
(
bool
)
optional_boundary
;
torch
::
Tensor
p_grad
=
torch
::
empty
({
B
,
S
+
1
,
T
+
1
},
opts
),
px_grad
=
torch
::
empty
({
B
,
S
,
T
+
1
},
opts
),
py_grad
=
torch
::
empty
({
B
,
S
+
1
,
T
},
opts
),
px_grad
=
(
has_boundary
?
torch
::
zeros
({
B
,
S
,
T
+
1
},
opts
)
:
torch
::
empty
({
B
,
S
,
T
+
1
},
opts
)),
py_grad
=
(
has_boundary
?
torch
::
zeros
({
B
,
S
,
T
+
1
},
opts
)
:
torch
::
empty
({
B
,
S
+
1
,
T
},
opts
));
// num_threads and num_blocks and BLOCK_SIZE can be tuned.
// (however, num_threads may not be less than 128).
const
int
num_threads
=
128
,
num_blocks
=
128
,
num_blocks
=
256
,
BLOCK_SIZE
=
32
;
// The blocks cover the 'p' matrix, which is of size (B, S+1, T+1),
...
...
@@ -819,7 +855,7 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
num_t_blocks
=
T
/
BLOCK_SIZE
+
1
,
num_iters
=
num_s_blocks
+
num_t_blocks
-
1
;
if
(
(
bool
)
optional
_boundary
)
if
(
has
_boundary
)
TORCH_CHECK
(
optional_boundary
.
value
().
device
().
is_cuda
(),
"boundary information must be in CUDA tensor"
);
else
...
...
@@ -838,5 +874,6 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
iter
,
overwrite_ans_grad
);
}
std
::
cout
<<
"p_grad = "
<<
p_grad
;
return
std
::
vector
<
torch
::
Tensor
>
({
px_grad
,
py_grad
});
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment