Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
77eed83f
Commit
77eed83f
authored
Jul 28, 2021
by
Daniel Povey
Browse files
Some progress, still drafting.
parent
e95d7864
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
678 additions
and
364 deletions
+678
-364
torch_mutual_information/mutual_information.py
torch_mutual_information/mutual_information.py
+29
-21
torch_mutual_information/mutual_information_cpu.cpp
torch_mutual_information/mutual_information_cpu.cpp
+173
-58
torch_mutual_information/mutual_information_cuda.cpp
torch_mutual_information/mutual_information_cuda.cpp
+17
-8
torch_mutual_information/mutual_information_cuda_kernel.cu
torch_mutual_information/mutual_information_cuda_kernel.cu
+459
-277
No files found.
torch_mutual_information/mutual_information.py
View file @
77eed83f
...
...
@@ -73,6 +73,14 @@ class MutualInformationRecursionFunction(torch.autograd.Function):
def
forward
(
ctx
,
px
:
torch
.
Tensor
,
py
:
torch
.
Tensor
,
boundaries
:
torch
.
Tensor
)
->
torch
.
Tensor
:
(
B
,
S
,
T
)
=
px
.
shape
# p is a tensor of shape (B, S + 1, T + 1) were p[s][t] is related to
# the mutual information of the pair of subsequences of x and y that are of
# length s and t respectively. p[0][0] will be 0.0 and p[S][T] is
# the mutual information of the entire pair of sequences, i.e. of lengths
# S and T respectively.
# q is a rearrangement of a tensor p which is of shape (B,S,T),
# using p[b,s,t] == q[b,s+t,t]. The reason for working with this
# representation is that each row of q depends only on the previous row,
...
...
@@ -85,15 +93,9 @@ class MutualInformationRecursionFunction(torch.autograd.Function):
# q[b, s-s_begin + t-t_begin, t-t_begin];
# note, rows of `boundaries` are [s_begin, t_begin, s_end, t_end].
if
px
.
requires_grad
or
py
.
requires_grad
:
q
=
torch
.
empty
(
B
,
S
,
T
,
device
=
px
.
device
,
dtype
=
px
.
dtype
)
else
:
# We don't need to store q if we are not going to do backprop, but we
# do pass in a temporary with one real row, expanded to have "fake" rows,
# which happens to be convenient for the CPU implementation.
q
=
torch
.
empty
({
1
,
1
,
T
},
device
=
px
.
device
,
dtype
=
px
.
dtype
).
expand
(
B
,
S
+
T
,
T
)
p
=
torch
.
empty
(
B
,
S
+
1
,
T
+
1
,
device
=
px
.
device
,
dtype
=
px
.
dtype
)
ans
=
_mutual_information_forward_dispatcher
(
px
,
py
,
boundaries
,
q
)
ans
=
_mutual_information_forward_dispatcher
(
px
,
py
,
boundaries
,
p
)
if
px
.
requires_grad
or
py
.
requires_grad
:
ctx
.
save_for_backward
(
px
,
py
,
boundaries
,
q
)
...
...
@@ -115,7 +117,7 @@ def mutual_information_recursion(input, px, py, boundaries=None):
make use of the formula computed by this function.
Args:
px: A torch.Tensor of some floating point type, with shape [B][S][T],
px: A torch.Tensor of some floating point type, with shape [B][S][T
+1
],
where B is the batch size, S is the length of the 'x' sequence
(including representations of EOS symbols but not BOS symbols), and S is the
length of the 'y' sequence (including representations of
...
...
@@ -139,13 +141,13 @@ def mutual_information_recursion(input, px, py, boundaries=None):
code assumes for optimization purposes that the T axis has
stride 1.
py: A torch.Tensor of the same dtype as px, with shape [B][S][T],
py: A torch.Tensor of the same dtype as px, with shape [B][S
+1
][T],
representing
py[b][s][t] = log [ p(y_t | x_{0..s-1}, y_{0..t-1}) / p(y_t) ]
This function does not treat x and y differently; the only difference
is that
the implementation assumes for optimization purposes that y
is likely to be the shorter sequence, i.e. that "most of the time T < S",
and it will be faster if you respect this.
is that
for optimization purposes we assume the last axis (the t axis)
has stride of 1; this is true if px and py are contiguous.
boundaries: If supplied, a torch.LongTensor of shape [B][4], where each row contains
[s_begin, t_begin, s_end, t_end]. If not supplied, the values
[0, 0, S, T] will be assumed. These are the beginning and
...
...
@@ -155,18 +157,24 @@ def mutual_information_recursion(input, px, py, boundaries=None):
Returns:
Returns a torch.Tensor of shape [B], containing the log of the mutuafl
information between the b'th pair of sequences. This is defined by
the following recursion on p[b,s,t] (where p is of shape [B,S
,T
]),
the following recursion on p[b,s,t] (where p is of shape [B,S
+1,T+1
]),
representing a mutual information between sub-sequences of lengths s and t:
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s,t], p[b,s,t-1] + py[b,s,t])
where in the case where boundaries==None: the edge cases are handled
by treating p[b,-1,-1] as 0 and all other quantities with negative
indexes as -infinity; and ans[b] would equal p[S-1,T-1]. The extension to
cases where the boundaries are specified should be obvious.
p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
(if s > 0 or t > 0)
where we handle edge cases by treating quantities with negative indexes
as -infinity. The extension to cases where the boundaries are specified
should be obvious; it just works on shorter sequences with offsets into
px and py.
"""
assert
px
.
ndim
==
3
and
px
.
shape
==
py
.
shape
and
px
.
dtype
==
py
.
dtype
assert
px
.
ndim
==
3
B
,
S
,
T1
=
px
.
shape
T
=
T1
-
1
assert
py
.
shape
==
(
B
,
S
+
1
,
T
)
assert
px
.
dtype
==
py
.
dtype
(
B
,
S
,
T
)
=
px
.
shape
if
boundaries
is
not
None
:
assert
boundaries
.
dtype
==
torch
.
LongTensor
...
...
torch_mutual_information/mutual_information_cpu.cpp
View file @
77eed83f
...
...
@@ -3,6 +3,13 @@
inline
double
Exp
(
double
x
)
{
return
exp
(
x
);
}
inline
double
Exp
(
float
x
)
{
return
expf
(
x
);
}
// returns log(exp(x) + exp(y)).
inline
double
LogAdd
(
double
x
,
double
y
)
{
double
diff
;
...
...
@@ -14,8 +21,7 @@ inline double LogAdd(double x, double y) {
diff
=
y
-
x
;
}
// diff is negative. x is now the larger one.
if
(
diff
>=
kMinLogDiffDouble
)
{
if
(
diff
>=
-
1000
)
{
double
res
;
res
=
x
+
log1p
(
exp
(
diff
));
return
res
;
...
...
@@ -35,99 +41,208 @@ inline float LogAdd(float x, float y) {
diff
=
y
-
x
;
}
// diff is negative. x is now the larger one.
if
(
diff
>=
kMinLogDiffFloat
)
{
if
(
diff
>=
-
200
)
{
float
res
;
res
=
x
+
log1pf
(
expf
(
diff
));
return
res
;
}
return
x
;
// return the larger one.
}
// forward of mutual_information. See """... """ comment of `mutual_information` in
// mutual_information.py for documentation of the behavior of this function.
torch
::
Tensor
mutual_information_cpu
(
torch
::
Tensor
px
,
torch
::
Tensor
py
,
std
::
optional
<
torch
::
Tensor
>
optional_boundary
,
torch
::
Tensor
q
)
{
torch
::
Tensor
p
)
{
TORCH_CHECK
(
px
.
dim
()
==
3
,
"px must be 3-dimensional"
);
TORCH_CHECK
(
py
.
dim
()
==
3
,
"params must be 3-dimensional."
);
TORCH_CHECK
(
q
.
dim
()
==
3
,
"params must be 3-dimensional."
);
TORCH_CHECK
(
py
.
dim
()
==
3
,
"py must be 3-dimensional."
);
TORCH_CHECK
(
p
.
dim
()
==
3
,
"p must be 3-dimensional."
);
TORCH_CHECK
(
px
.
device
().
is_cpu
()
&&
py
.
device
().
is_cpu
()
&&
p
.
device
().
is_cpu
(),
"inputs must be CPU tensors"
);
auto
scalar_t
=
px
.
scalar_type
();
auto
opts
=
torch
::
TensorOptions
().
dtype
(
scalar_t
).
device
(
px
.
device
());
const
int
B
=
px
.
size
(
0
),
S
=
px
.
size
(
1
),
T
=
px
.
size
(
2
);
TORCH_CHECK
(
q
.
size
(
0
)
==
B
&&
q
.
size
(
1
)
==
S
+
T
&&
q
.
size
(
2
)
==
T
);
T
=
px
.
size
(
2
)
-
1
;
TORCH_CHECK
(
py
.
size
(
0
)
==
B
&&
py
.
size
(
1
)
==
S
+
1
&&
py
.
size
(
2
)
==
T
);
TORCH_CHECK
(
p
.
size
(
0
)
==
B
&&
p
.
size
(
1
)
==
S
+
1
&&
p
.
size
(
2
)
==
T
+
1
);
torch
::
Tensor
ans
=
torch
::
empty
({
B
},
opts
);
auto
long_opts
=
torch
::
TensorOption
a
().
dtype
(
torch
::
kInt64
);
auto
long_opts
=
torch
::
TensorOption
s
().
dtype
(
torch
::
kInt64
)
.
device
(
px
.
device
())
;
bool
has_boundary
=
(
bool
)
optional_boundary
;
if
(
!
has_boundary
)
optional_boundary
=
torch
::
empty
({},
long_opts
);
optional_boundary
=
torch
::
empty
({
0
,
0
},
long_opts
);
TORCH_CHECK
(
optional_boundary
.
value
().
device
().
is_cpu
()
&&
optional_boundary
.
value
().
dtype
==
torch
::
kInt64
);
AT_DISPATCH_FLOATING_TYPES
(
px
.
scalar_type
(),
"mutual_information_cpu_loop"
,
([
&
]
{
auto
px_a
=
px
.
accessor
<
scalar_t
,
3
>
(),
py_a
=
py
.
accessor
<
scalar_t
,
3
>
();
for
(
int
c
=
0
;
c
<
C
;
c
++
)
{
scalar_t
sum_negative
=
0.0
,
sum_positive
=
0.0
,
scale
=
exp
(
params_a
[
c
][
0
]);
for
(
int
i
=
0
;
i
<
K
;
i
++
)
{
scalar_t
pos_scaled_param
=
params_a
[
c
][
1
+
K
+
i
]
*
scale
,
neg_scaled_param
=
params_a
[
c
][
K
-
i
]
*
scale
;
y_vals_a
[
c
][
K
+
i
]
=
sum_positive
-
pos_scaled_param
*
i
;
sum_positive
+=
pos_scaled_param
;
sum_negative
-=
neg_scaled_param
;
y_vals_a
[
c
][
K
-
i
-
1
]
=
sum_negative
+
neg_scaled_param
*
(
i
+
1
);
auto
px_a
=
px
.
packed_accessor32
<
scalar_t
,
3
>
(),
py_a
=
py
.
packed_accessor32
<
scalar_t
,
3
>
(),
p_a
=
p
.
packed_accessor32
<
scalar_t
,
3
>
();
auto
boundary_a
=
optional_boundary
.
value
().
packed_accessor32
<
int64_t
,
2
>
();
auto
ans_a
=
ans
.
packed_accessor32
<
scalar_t
,
1
>
();
for
(
int
b
=
0
b
<
B
;
b
++
)
{
int
s_begin
,
s_end
,
t_begin
,
t_end
;
if
(
has_boundary
)
{
s_begin
=
boundary_a
[
b
][
0
];
t_begin
=
boundary_a
[
b
][
1
];
s_end
=
boundary_a
[
b
][
2
];
t_end
=
boundary_a
[
b
][
3
];
}
else
{
s_begin
=
0
;
s_end
=
S
;
t_begin
=
0
;
t_end
=
T
;
}
p_a
[
b
][
s_begin
][
t_begin
]
=
0.0
;
for
(
int
s
=
s_begin
+
1
;
s
<=
s_end
;
++
s
)
p_a
[
b
][
s
][
t_begin
]
=
p_a
[
b
][
s
-
1
][
t_begin
]
+
px_a
[
b
][
s
-
1
][
t_begin
];
for
(
int
t
=
t_begin
+
1
;
t
<=
t_end
;
++
t
)
p_a
[
b
][
s_begin
][
t
]
=
p_a
[
b
][
s_begin
][
t
-
1
]
+
py_a
[
b
][
s_begin
][
t
-
1
];
for
(
int
s
=
s_begin
+
1
;
s
<=
s_end
;
++
s
)
{
scalar_t
p_s_t1
=
p_a
[
b
][
s
][
t_begin
];
for
(
int
t
=
t_begin
+
1
;
t
<=
t_end
;
++
t
)
{
// The following statement is a small optimization of:
// p_a[b][s][t] = LogAdd(p_a[b][s - 1][t] + px_a[b][s - 1][t],
// p_a[b][s][t - 1] + py_a[b][s][t - 1]);
// .. which obtains p_a[b][s][t - 1] from a register.
p_a
[
b
][
s
][
t
]
=
p_s_t1
=
LogAdd
(
p_a
[
b
][
s
-
1
][
t
]
+
px_a
[
b
][
s
-
1
][
t
],
p_s_t1
+
py_a
[
b
][
s
][
t
-
1
]);
}
}
ans_a
[
b
]
=
p_a
[
b
][
s_end
][
t_end
];
}
}));
return
ans
;
}
auto
input_a
=
input
.
accessor
<
scalar_t
,
3
>
(),
output_a
=
output
.
accessor
<
scalar_t
,
3
>
();
for
(
int
b
=
0
;
b
<
B
;
b
++
)
{
for
(
int
c
=
0
;
c
<
C
;
c
++
)
{
scalar_t
scale
=
exp
(
params_a
[
c
][
0
]),
inv_scale
=
1.0
/
scale
;
for
(
int
t
=
0
;
t
<
T
;
t
++
)
{
// `x` is the scaled input x plus an offset so that -K maps to 0.
// Note: the discontinuities in our function are at -(K-1) ... +(K+1),
// so in a sense -K and +K are not special, but we include those
// extra values as an easy way to handle the semi-infinite regions
// that are < -(K-1) and > (K-1)
scalar_t
input
=
input_a
[
b
][
c
][
t
],
x
=
input
*
inv_scale
+
K
;
if
(
x
<
0
)
x
=
0
;
else
if
(
x
>=
N
)
x
=
N
-
1
;
// C++ rounds toward zero.
int
n
=
(
int
)
x
;
// OK, at this point, 0 <= min < 2*K.
output_a
[
b
][
c
][
t
]
=
input
*
params_a
[
c
][
n
+
1
]
+
y_vals_a
[
c
][
n
];
// backward of mutual_information. Returns (px_grad, py_grad).
// p corresponds to what we computed in the forward pass.
std
::
vector
<
torch
::
Tensor
>
mutual_information_backward_cpu
(
torch
::
Tensor
px
,
torch
::
Tensor
py
,
std
::
optional
<
torch
::
Tensor
>
optional_boundary
,
torch
::
Tensor
p
,
torch
::
Tensor
ans_grad
)
{
TORCH_CHECK
(
px
.
dim
()
==
3
,
"px must be 3-dimensional"
);
TORCH_CHECK
(
py
.
dim
()
==
3
,
"py must be 3-dimensional."
);
TORCH_CHECK
(
p
.
dim
()
==
3
,
"p must be 3-dimensional."
);
TORCH_CHECK
(
ans_grad
.
dim
()
==
1
,
"ans_grad must be 3-dimensional."
);
TORCH_CHECK
(
px
.
device
().
is_cpu
()
&&
py
.
device
().
is_cpu
()
&&
p
.
device
().
is_cpu
()
&&
ans_grad
.
device
()
==
cpu
(),
"inputs must be CPU tensors"
);
auto
scalar_t
=
px
.
scalar_type
();
auto
opts
=
torch
::
TensorOptions
().
dtype
(
scalar_t
).
device
(
px
.
device
());
const
int
B
=
px
.
size
(
0
),
S
=
px
.
size
(
1
),
T
=
px
.
size
(
2
)
-
1
;
TORCH_CHECK
(
py
.
size
(
0
)
==
B
&&
py
.
size
(
1
)
==
S
+
1
&&
py
.
size
(
2
)
==
T
);
TORCH_CHECK
(
p
.
size
(
0
)
==
B
&&
p
.
size
(
1
)
==
S
+
1
&&
p
.
size
(
2
)
==
T
+
1
);
torch
::
Tensor
p_grad
=
torch
::
zeros
({
B
,
S
+
1
,
T
+
1
},
opts
);
auto
long_opts
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
px
.
device
());
bool
has_boundary
=
(
bool
)
optional_boundary
;
if
(
!
has_boundary
)
optional_boundary
=
torch
::
empty
({
0
,
0
},
long_opts
);
TORCH_CHECK
(
optional_boundary
.
value
().
device
().
is_cpu
()
&&
optional_boundary
.
value
().
dtype
==
torch
::
kInt64
);
AT_DISPATCH_FLOATING_TYPES
(
px
.
scalar_type
(),
"mutual_information_cpu_backward_loop"
,
([
&
]
{
auto
px_a
=
px
.
packed_accessor32
<
scalar_t
,
3
>
(),
py_a
=
py
.
packed_accessor32
<
scalar_t
,
3
>
(),
p_a
=
p
.
packed_accessor32
<
scalar_t
,
3
>
(),
p_grad_a
=
p
.
packed_accessor32
<
scalar_t
,
3
>
();
auto
ans_grad_a
=
ans_grad
.
packed_accessor32
<
scalar_t
,
1
>
();
auto
boundary_a
=
optional_boundary
.
value
().
packed_accessor32
<
int64_t
,
2
>
();
for
(
int
b
=
0
b
<
B
;
b
++
)
{
int
s_begin
,
s_end
,
t_begin
,
t_end
;
if
(
has_boundary
)
{
s_begin
=
boundary_a
[
b
][
0
];
t_begin
=
boundary_a
[
b
][
1
];
s_end
=
boundary_a
[
b
][
2
];
t_end
=
boundary_a
[
b
][
3
];
}
else
{
s_begin
=
0
;
s_end
=
S
;
t_begin
=
0
;
t_end
=
T
;
}
// Backprop for: ans_a[b] = p_a[b][s_end][t_end];
p_grad_a
[
b
][
s_end
][
t_end
]
=
ans_grad_a
[
b
];
for
(
int
s
=
s_end
;
s
>
s_begin
;
--
s
)
{
for
(
int
t
=
t_end
;
t
>
t_begin
;
--
t
)
{
// The statement we are backpropagating here is:
// p_a[b][s][t] = LogAdd(p_a[b][s - 1][t] + px_a[b][s - 1][t],
// p_a[b][s][t - 1] + py_a[b][s][t - 1]);
// .. which obtains p_a[b][s][t - 1] from a register.
scalar_t
term1
=
p_a
[
b
][
s
-
1
][
t
]
+
px_a
[
b
][
s
-
1
][
t
],
total
=
p_a
[
b
][
s
][
t
],
term1_deriv
=
exp
(
term1
-
total
),
term2_deriv
=
1.0
-
term1_deriv
,
grad
=
p_grad_a
[
b
][
s
][
t
],
term1_grad
=
term1_deriv
*
grad
,
term2_grad
=
term2_deriv
*
grad
;
// We can assign to px_grad_a here rather than add, because we
// know it's currently zero.
TORCH_CHECK
(
px_grad_a
[
b
][
s
-
1
][
t
]
==
0
);
px_grad_a
[
b
][
s
-
1
][
t
]
=
term1_grad
;
TORCH_CHECK
(
p_grad_a
[
b
][
s
-
1
][
t
]
==
0.0
);
// likewise..
p_grad_a
[
b
][
s
-
1
][
t
]
=
term1_grad
py_grad_a
[
b
][
s
][
t
-
1
]
+=
term2_grad
;
p_grad_a
[
b
][
s
][
t
-
1
]
+=
term2_grad
;
}
}
}}));
return
output
;
for
(
int
t
=
t_end
;
t
>=
t_begin
;
--
t
)
{
// Backprop for:
// p_a[b][s_begin][t] = p_a[b][s_begin][t - 1] + py_a[b][s_begin][t - 1];
scalar_t
this_p_grad
=
p_grad_a
[
b
][
s_begin
][
t
];
p_grad_a
[
b
][
s_begin
][
t
-
1
]
+=
this_p_grad
;
py_grad_a
[
b
][
s_begin
][
t
-
1
]
+=
this_p_grad
;
}
for
(
int
s
=
s_end
;
s
>=
s_begin
;
--
s
)
{
// Backprop for:
// p_a[b][s][t_begin] = p_a[b][s - 1][t_begin] + px_a[b][s - 1][t_begin];
scalar_t
this_p_grad
=
p_grad_a
[
b
][
s
][
s_begin
];
p_a
[
b
][
s
-
1
][
t_begin
]
+=
this_p_grad
;
px_a
[
b
][
s
-
1
][
t_begin
]
+=
this_p_grad
;
}
// There is no backprop for:
// p_a[b][s_begin][t_begin] = 0.0;
// .. but we can use this for a check, that the grad at the beginning
// of the sequence is equal to the grad at the end of the sequence.
if
(
ans_grad_a
[
b
]
!=
0.0
)
{
float
grad_ratio
=
p_a
[
b
][
s_begin
][
t_begin
]
/
ans_grad_a
[
b
];
if
(
grad_ratio
-
1.0
>
0.01
)
{
printf
(
"Warning: mutual_information backprop: expected these numbers to be the same: %f vs. %f
\n
"
,
(
float
)
p_a
[
b
][
s_begin
][
t_begin
],
(
float
)
ans_grad_a
[
b
]);
}
}
}
}));
return
ans
;
}
// backward of mutual_information. Returns (input_grad, params_grad)
std
::
vector
<
torch
::
Tensor
>
mutual_information_backward_cpu
(
torch
::
Tensor
input
,
torch
::
Tensor
params
,
torch
::
Tensor
output_grad
)
{
TORCH_CHECK
(
input
.
dim
()
==
3
,
"input must be 3-dimensional"
);
TORCH_CHECK
(
params
.
dim
()
==
2
,
"params must be 2-dimensional."
);
TORCH_CHECK
(
params
.
size
(
1
)
>=
3
&&
...
...
torch_mutual_information/mutual_information_cuda.cpp
View file @
77eed83f
...
...
@@ -3,18 +3,27 @@
// forward of mutual_information. """... """ comment of `mutual_information`
// in mutual_information.py documents the behavior of this function.
torch
::
Tensor
mutual_information_cuda
(
torch
::
Tensor
input
,
torch
::
Tensor
params
);
// It is the core recursion in the sequence-to-sequence mutual information
// computation.
// returns 'ans', of dimension B (batch size).
torch
::
Tensor
mutual_information_cuda
(
torch
::
Tensor
px
,
// [B][S][T+1]
torch
::
Tensor
py
,
// [B][S+1][T]
std
::
optional
<
torch
::
Tensor
>
boundary_info
,
// [B][4], int64_t.
torch
::
Tensor
p
);
// [B][S+1][T+1]; an output
// backward of mutual_information; returns (grad_input, grad_params).
std
::
vector
<
torch
::
Tensor
>
mutual_information_backward_cuda
(
torch
::
Tensor
input
,
torch
::
Tensor
params
,
torch
::
Tensor
grad_output
);
// backward of mutual_information; returns (grad_px, grad_py)
std
::
vector
<
torch
::
Tensor
>
mutual_information_backward_cuda
(
torch
::
Tensor
px
,
torch
::
Tensor
py
,
std
::
optional
<
torch
::
Tensor
>
boundary_info
,
torch
::
Tensor
p
,
torch
::
Tensor
ans_grad
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"mutual_information_cuda"
,
&
mutual_information_cuda
,
"
Learned nonlinearity
forward function (CUDA)"
);
m
.
def
(
"mutual_information_backward_cuda"
,
&
mutual_information_backward_cuda
,
"
Learned nonlinearity
backward function (CUDA)"
);
m
.
def
(
"mutual_information_cuda"
,
&
mutual_information_cuda
,
"
Mutual information
forward function (CUDA)"
);
m
.
def
(
"mutual_information_backward_cuda"
,
&
mutual_information_backward_cuda
,
"
Mutual information
backward function (CUDA)"
);
}
torch_mutual_information/mutual_information_cuda_kernel.cu
View file @
77eed83f
#include <torch/extension.h>
#include <c10/cuda/CUDAStream.h> // for getCurrentCUDAStream()
#include <cooperative_groups.h>
#include <cmath> // for INFINITY
/*
Tiled summing reduction within a warp. Requires that the thread-block
be 1-dimensional, i.e. blockDim.y == blockDim.z == 1. Does not use
__syncthreads, so it is safe to call in a subset of threads.
TODO: we can in principle do this without a buffer, using __shfl_down()
(see here https://sodocumentation.net/cuda/topic/6566/parallel-reduction--e-g--how-to-sum-an-array-)
if CC >= 3.0.
// returns log(exp(x) + exp(y)).
__forceinline__
__device__
double
LogAdd
(
double
x
,
double
y
)
{
double
diff
;
Args:
threads_per_tile: Must be a power of 2 in the interval [1,32]. Summation is
within blocks of threads of this size.
buf: Pointer to the start of a __shared__ buffer of size
blockDim.x, to be used as a temporary within this function.
val: The value to be summed
Return:
Threads where threadIdx.x % threads_per_tile == 0 will return the sum:
\sum_{i=0}^{threads_per_tile-1} [val in thread threadIdx.x + i]
The return value in other threads is undefined.
*/
template
<
typename
scalar_t
>
__forceinline__
__device__
scalar_t
tiled_warp_reduce_sum
(
int
threads_per_tile
,
__volatile__
scalar_t
*
buf
,
scalar_t
val
)
{
// Each iteration halves the number of active threads
// Each thread adds its partial sum[i] to sum[lane+i]
for
(
int
i
=
threads_per_tile
/
2
;
i
>
0
;
i
/=
2
)
{
buf
[
threadIdx
.
x
]
=
val
;
if
(
threadIdx
.
x
%
threads_per_tile
<
i
)
val
+=
buf
[
threadIdx
.
x
+
i
];
if
(
x
<
y
)
{
diff
=
x
-
y
;
x
=
y
;
}
else
{
diff
=
y
-
x
;
}
// diff is negative. x is now the larger one.
if
(
diff
-
diff
!=
0
)
return
x
;
// x and y are probably -inf. Return the larger one.
else
return
x
+
log1p
(
exp
(
diff
));
}
// returns log(exp(x) + exp(y)).
__forceinline__
__device__
inline
float
LogAdd
(
float
x
,
float
y
)
{
float
diff
;
if
(
x
<
y
)
{
diff
=
x
-
y
;
x
=
y
;
}
else
{
diff
=
y
-
x
;
}
return
val
;
// Only threads with threadIdx.x % threads_per_tile == 0 will
// return the full sums of their tiles.
// diff is negative. x is now the larger one.
if
(
diff
-
diff
!=
0
)
return
x
;
// x and y are probably -inf. Return the larger one.
else
return
x
+
log1p
(
exp
(
diff
));
}
/*
Forward of mutual_information. Each thread block handles blocks of (x, y) shape
equal to (BLOCK_S_SIZE, BLOCK_T_SIZE), e.g. (4, 64). Thread blocks loop over such
blocks, but they might typically loop only once. We sequentially launch groups of
threads in such a way that thread-blocks within a group do not depend on each other.
equal to (BLOCK_SIZE, BLOCK_SIZE), e.g. (32, 32). Thread blocks loop over such
blocks, but they might loop only once if there is not that much data to process.
We sequentially launch groups of threads in such a way that thread-blocks
within a group do not depend on each other.
Template args:
scalar_t: the floating-point type, e.g. float, double, maybe half.
Args:
input: input image, shape (B, C, T) where B is batch size, C is
px: log-odds ratio of generating next x in the sequence, i.e.
xy[b][s][t] is the log-odds probability of generating x_t of
the b'th image given subsequences of length (s, t). (See
mutual_information.py for more info). Shape [B][S][T + 1]
py: log-odds ratio of generating next y in the sequence.
Shape [B][S + 1][T]
p: matrix of mutual information of sub-sequences, that this
function writes to. Shape [B][S + 1][T + 1]. This function
computes the following recursion:
p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
(if s > 0 or t > 0)
boundary: If set, a tensor of shape [B][4] of type int64_t, which
contains, for each batch element, [s_begin, t_begin, s_end, t_end]
which are the beginning and end (one-past-the-last) of the
x and y sequences that we should process. If not set, these
default to (0, 0, S, T), and they should not exceed these bounds
or be empty (i.e. s_begin <= t_begin or s_end <= t_end).
nput: input image, shape (B, C, T) where B is batch size, C is
the number of channels and T is the time axis. (For more-than-1d
convolution setups, T would really be more than 1 axis, reshaped).
params: of shape (C, N+1) where N is the number of linear regions in the
...
...
@@ -75,68 +103,74 @@ __forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile,
This kernel is allocated with `extern_buf` containing enough memory
to store 2*N + 3 values of type scalar_t.
The blockDim must equal (THREADS_PER_BLOCK, 1, 1)
The requirements on the grid dimension are:
gridDim.x == num-channels C (required)
1 <= gridDim.y <= B, where B is the number of blocks
gridDim.z == 1
When we invoke this kernel, we'll invoke it as:
mutual_information_kernel<<<gridDim, blockDim, bytesShared, stream>>>
where bytesShared is the number of bytes needed in `extern_buf`:
bytesShared = sizeof(shared_t) * (2N + 3)
We also require N + 1 <= THREADS_PER_BLOCK.
The block-dim and grid-dim must both be 1-dimensional, and the block-dim must
be at least 128.
*/
extern
__shared__
int
extern_buf
[];
template
<
typename
scalar_t
,
int
BLOCK_S_SIZE
,
// e.g. BLOCK_S_SIZE == 4; power of 2
int
BLOCK_T_SIZE
>
// e.g. BLOCK_T_SIZE == 64; power of 2.
// BLOCK_T_SIZE * 4 must equal num_threads; and must be >= 128, so BLOCK_T_SIZE >= 32 is required.
// (Note: this 4 is unrelated to BLOCK_S_SIZE but can be viewed as 1<<2,
// where 2 is the loop unrolling factor).
int
BLOCK_SIZE
>
// e.g. BLOCK_SIZE == 16 or 32. Note: we require the
// num-threads be at least 128.
__global__
void
mutual_information_kernel
(
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
px
,
// B, S, T, i.e. batch, x_seq_length, y_seq_length
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
py
,
// B, S, T, as above
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
p
,
// B, S, T, as above. This is an output.
torch
::
PackedTensorAccessor32
<
scalar_t
,
2
>
boundary
,
// B, 4; or 0, 0 if boundaries are the defaults (0, 0, S, T)
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
px
,
// B, S, T + 1, i.e. batch, x_seq_length, y_seq_length + 1
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
py
,
// B, S + 1, T.
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
p
,
// B, S + 1, T + 1. This is an output.
torch
::
PackedTensorAccessor32
<
int64_t
,
2
>
boundary
,
// B, 4; or 0, 0 if boundaries are the defaults (0, 0, S, T)
torch
::
PackedTensorAccessor32
<
scalar_t
,
1
>
ans
,
// [B]
int
iter
)
{
// This kernel is sequentially called with 'iter' = 0, 1, 2 and so on, up to:
// (S+BLOCK_S_SIZE-1)/BLOCK_S_SIZE + (T+BLOCK_T_SIZE-1)/BLOCK_T_SIZE - 1
// so that each group depends on the previous group...
const
int
block_dimx
=
BLOCK_T_SIZE
*
4
;
// known at compile time.
assert
(
blockDim
.
x
==
block_dimx
);
const
int
B
=
px
.
size
(
0
),
S
=
px
.
size
(
1
),
T
=
py
.
size
(
2
);
// num_s_blocks and num_t_blocks are the number of blocks we need to cover the
// array of size (S, T) with blocks of this size, in the s and t directions
// respectively.
const
int
num_s_blocks
=
(
S
+
BLOCK_S_SIZE
-
1
)
/
BLOCK_S_SIZE
,
num_t_blocks
=
(
T
+
BLOCK_T_SIZE
-
1
)
/
BLOCK_T_SIZE
;
// num_blocks_this_iter is an upper bound on the number of blocks that might
// be active on this iteration. We go from the bottom left of the image
// so that on iter == 0 we process only one block with block-index (0, 0)
// then on iter == 1 we process block-indexes (1, 0) and (0, 1); and then on iter==2
// we process (2, 0), (1, 1) and (0, 2); and so on. We also will never have more
// than `num_s_blocks` blocks (We'll never have more than num_t_blocks either, but
// the numbering we use corresponds to s and not t, so if we hit the num_t_blocks limit,
// the lowest-numbered blocks on s would just not be active and we'll 'continue' below).
// You can read the following expressions as simplifications of, for example,
// num_s_blocks = ((S + 1) + BLOCK_SIZE - 1) / BLOCK_SIZE,
// i.e. rounding-up division of (S + 1) by BLOCK_SIZE, and the same for (T + 1).
const
int
num_s_blocks
=
S
/
BLOCK_SIZE
+
1
,
num_t_blocks
=
T
/
BLOCK_SIZE
+
1
;
// num_blocks_this_iter is an upper bound on the number of blocks of size
// (BLOCK_SIZE by BLOCK_SIZE) that might be active on this iteration. We go
// from the bottom left of the image so that on iter == 0 we process only one
// block with block-index (0, 0) then on iter == 1 we process block-indexes
// (1, 0) and (0, 1); and then on iter==2 we process (2, 0), (1, 1) and (0,
// 2); and so on. We also will never have more than `num_s_blocks` blocks
// (We'll never have more than num_t_blocks either, but the numbering we use
// corresponds to s and not t, so if we hit the num_t_blocks limit, the
// lowest-numbered blocks on s would just not be active and we'll 'continue'
// below).
int
num_blocks_this_iter
=
min
(
iter
+
1
,
num_s_blocks
);
__shared__
scalar_t
px_buf
[
BLOCK_S_SIZE
][
BLOCK_T_SIZE
],
py_buf
[
BLOCK_S_SIZE
][
BLOCK_T_SIZE
],
p_buf
[
BLOCK_S_SIZE
+
1
][
BLOCK_T_SIZE
+
1
];
// 1st row/col of p_buf
// correspond to the previous
// blocks, or an edge case.
// For the block with s_block_begin == 0 and t_block_begin == 0 (for
// easy illustration), px_buf[s][t] will contain exp(px[s - 1][t]); or 0
// for out-of-range indexes.
// Likewise, py_buf[s][t] will contain exp(py[s][t - 1]).
__shared__
scalar_t
px_buf
[
BLOCK_SIZE
][
BLOCK_SIZE
],
py_buf
[
BLOCK_SIZE
][
BLOCK_SIZE
];
// 1st row/col of p_buf correspond to the previous blocks, or to an edge case.
// So, again for this origin block, p_buf[s][t] corresponds to exp(p[s - 1][t
// - 1] - normalizer); or 0 for out-of-range values.
__shared__
scalar_t
p_buf
[
BLOCK_SIZE
+
1
][
BLOCK_SIZE
+
1
];
// boundary_buf will be used to store the b'th row of `boundary` if we have
// boundary information supplied.
__shared__
int64_t
boundary_buf
[
4
];
__shared__
boundary_buf
[
4
];
if
(
threadIdx
.
x
==
0
)
{
boundary_buf
[
0
]
=
0
;
boundary_buf
[
1
]
=
0
;
boundary_buf
[
2
]
=
S
;
boundary_buf
[
3
]
=
T
;
}
// batch_block_iter iterates over both batch elements (index b), and block
// indexes
// indexes
in the range [0..num_blocks_this_iter-1]
for
(
int
batch_block_iter
=
blockIdx
.
x
;
batch_block_iter
<
B
*
num_blocks_this_iter
;
batch_block_iter
+=
gridDim
.
x
)
{
...
...
@@ -149,146 +183,219 @@ void mutual_information_kernel(
bool
is_origin_block
=
(
s_block_begin
*
t_block_begin
==
0
);
int
s_end
,
t_end
;
// s_end and t_end are the end points (last-plus-one) of the entire sequence.
if
(
boundary
.
size
(
0
)
==
0
)
{
s_end
=
S
;
t_end
=
T
;
}
else
{
if
(
threadDim
.
x
<
4
)
boundary_buf
[
threadDim
.
x
]
=
boundary
[
b
][
threadDim
.
x
];
__syncthreads
();
int
s_begin
=
boundary_buf
[
0
],
t_begin
=
boundary_buf
[
1
];
s_end
=
boundary_buf
[
2
];
t_end
=
boundary_buf
[
3
];
s_block_begin
+=
s_begin
;
t_block_begin
+=
t_begin
;
}
// block_S and block_T are the actual sizes of this block, up to
// (BLOCK_S_SIZE, BLOCK_T_SIZE) but possibly truncated if we
// are towards the end of the sequence.
int
block_S
=
min
(
BLOCK_T_SIZE
,
s_end
-
s_block_begin
),
block_T
=
min
(
BLOCK_S_SIZE
,
t_end
-
t_block_begin
);
if
(
threadDim
.
x
<
4
&&
boundary
.
size
(
0
)
!=
0
)
boundary_buf
[
threadDim
.
x
]
=
boundary
[
b
][
threadDim
.
x
];
__syncthreads
();
int
s_begin
=
boundary_buf
[
0
],
t_begin
=
boundary_buf
[
1
];
s_end
=
boundary_buf
[
2
];
t_end
=
boundary_buf
[
3
];
s_block_begin
+=
s_begin
;
t_block_begin
+=
t_begin
;
// block_S and block_T are the actual sizes of this block, no greater than
// (BLOCK_SIZE, BLOCK_SIZE) but possibly less than that if we are towards
// the end of the sequence.
int
block_S
=
min
(
BLOCK_SIZE
,
s_end
-
s_block_begin
),
block_T
=
min
(
BLOCK_SIZE
,
t_end
-
t_block_begin
);
if
(
block_S
<=
0
||
block_T
<=
0
)
continue
;
// Load px_buf and py_buf. We exponentiate; the assumption is that they
// won't overflow or underflow! If they overflow we'll detect it later!
for
(
int
i
=
threadDim
.
x
;
i
<
BLOCK_S_SIZE
*
BLOCK_T_SIZE
;
i
+=
block_dimx
)
{
int
t
=
i
%
BLOCK_T_SIZE
,
s
=
i
/
BLOCK_T_SIZE
;
if
(
s
<
block_S
&&
t
<
block_T
)
{
px_buf
[
s
][
t
]
=
exp
(
px
[
b
][
s
+
s_block_begin
][
t
+
t_block_begin
]);
py_buf
[
s
][
t
]
=
exp
(
py
[
b
][
s
+
s_block_begin
][
t
+
t_block_begin
]);
}
else
{
// Not necessary? We'll see
px_buf
[
s
][
t
]
=
0.0
;
py_buf
[
s
][
t
]
=
0.0
;
}
// Load px_buf and py_buf. We exponentiate; the assumption is that they most likely
// won't overflow or underflow, but if they do overflow we'll detect it later; we'll
// also detect certain kinds of underflow.
for
(
int
i
=
threadDim
.
x
;
i
<
BLOCK_SIZE
*
BLOCK_SIZE
;
i
+=
blockDim
.
x
)
{
int
t_in_block
=
i
%
BLOCK_SIZE
,
s_in_block
=
i
/
BLOCK_SIZE
,
s
=
s_in_block
+
s_block_begin
,
t
=
t_in_block
+
t_block_begin
;
// the comparisons with S and T below just make sure we don't access
// out-of-memory regions; they do not guarantee we are in the range given
// by s_begin, s_end and so on. Note: comparing as unsigned int makes sure
// the index is nonnegative.
scalar_t
this_px
=
0.0
;
if
(
static_cast
<
unsigned
int
>
(
s
-
1
)
<
static_cast
<
unsigned
int
>
(
S
)
&&
t
<=
T
)
this_px
=
exp
(
px
[
b
][
s
-
1
][
t
]);
px_buf
[
s_in_block
][
t_in_block
]
=
this_px
;
scalar_t
this_py
=
0.0
;
if
(
static_cast
<
unsigned
int
>
(
t
-
1
)
<
static_cast
<
unsigned
int
>
(
T
)
&&
s
<=
S
)
this_py
=
exp
(
py
[
b
][
s
][
t
-
1
]);
py_buf
[
s_in_block
][
t_in_block
]
=
this_py
;
}
// Load the 1st row and column of p_buf (except element[0][0] is not needed).
if
(
threadIdx
.
x
<
64
)
{
// 64 == warp size...
if
(
threadIdx
.
x
<=
BLOCK_S_SIZE
)
{
// this s and t are offsets relative to the block start
int
s
=
threadIdx
.
x
-
1
,
t
=
-
1
;
if
(
static_cast
<
unsigned
int
>
(
s
+
s_block_begin
)
<
static_cast
<
unsigned
int
>
(
block_S
)
&&
static_cast
<
unsigned
int
>
(
t
+
t_block_begin
)
<
static_cast
<
unsigned
int
>
(
block_T
))
p_buf
[
threadIdx
.
x
][
0
]
=
p
[
s
+
s_block_begin
][
s
+
t_block_begin
];
else
p_buf
[
threadIdx
.
x
][
0
]
=
-
infinity
;
// Remember: p_buf[s][t] corresponds to exp(p[s + s_block_begin - 1][t + t_block_begin - 1] - normalizer.
if
(
threadIdx
.
x
<
64
)
{
// 64 == warp size. First half of threads...
if
(
threadIdx
.
x
<=
BLOCK_SIZE
)
{
// s_in_p_buf are simply the indexes into p_buf
int
s_in_p_buf
=
threadIdx
.
x
,
t_in_p_buf
=
0
,
s
=
s_in_p_buf
+
s_block_begin
-
1
,
t
=
t_in_p_buf
+
t_block_begin
-
1
;
// The if-statement below just guards against out-of-range memory
// accesses, it does not guarantee that we really need these values.
scalar_t
this_p
=
-
INFINITY
;
if
(
static_cast
<
unsigned
int
>
(
s
)
<
static_cast
<
unsigned
int
>
(
S
)
&&
static_cast
<
unsigned
int
>
(
t
)
<
static_cast
<
unsigned
int
>
(
T
))
this_p
=
p
[
s
+
s_block_begin
][
s
+
t_block_begin
];
p_buf
[
threadIdx
.
x
][
0
]
=
this_p
;
}
}
else
{
if
(
threadIdx
.
x
-
64
<=
BLOCK_T_SIZE
)
{
int
i
=
threadIdx
.
x
-
64
,
t
=
i
-
1
,
s
=
-
1
;
if
(
static_cast
<
unsigned
int
>
(
s
+
s_block_begin
)
<
static_cast
<
unsigned
int
>
(
block_S
)
&&
static_cast
<
unsigned
int
>
(
t
+
t_block_begin
)
<
static_cast
<
unsigned
int
>
(
block_T
))
p_buf
[
0
][
i
]
=
p
[
s
+
s_block_begin
][
s
+
t_block_begin
];
else
{
p_buf
[
0
][
i
]
=
(
is_origin_block
&&
i
==
1
?
1.0
/
-
infinity
;
}
}
else
{
// Another warp handles the other leg
if
(
threadIdx
.
x
-
64
<=
BLOCK_SIZE
)
{
int
s_in_p_buf
=
0
,
t_in_p_buf
=
threadIdx
.
x
-
64
,
s
=
s_in_p_buf
+
s_block_begin
-
1
,
t
=
t_in_p_buf
+
t_block_begin
-
1
;
// The if-statement below just guards against out-of-range memory
// accesses, it does not guarantee that we really need these values.
scalar_t
this_p
=
-
INFINITY
;
if
(
static_cast
<
unsigned
int
>
(
s
)
<
static_cast
<
unsigned
int
>
(
S
)
&&
static_cast
<
unsigned
int
>
(
t
)
<
static_cast
<
unsigned
int
>
(
T
))
this_p
=
p
[
s
+
s_block_begin
][
s
+
t_block_begin
];
p_buf
[
threadIdx
.
x
][
0
]
=
this_p
;
}
}
__syncthreads
();
// We read p_buf in log-space; subtract 'normalizer', which mathematically
// could be any finite number, to get in a reasonable range of probabilities,
// and then exponentiate. We'll do everything in non-log space, and later
// take a log before we write out the data.
scalar_t
normalizer
=
(
is_origin_block
?
0.0
:
max
(
px_buf
[
0
][
1
],
px_buf
[
1
][
0
]));
// Normalize and exponentiate the edge elements of p_buf, i.e. the elements
// where at one index is 0. The [0][0] element is special; we write 0.0,
// and we'll overwrite with 1.0 if there is a panic situation due to
// overflow.
if
(
threadIdx
.
x
<=
BLOCK_SIZE
)
{
if
(
threadIdx
.
x
==
0
)
{
// p_buf[0][0] is never used for its normal purpose; we set it to zero.
// We'll later write an infinity there if something goes wrong, as a
// 'panic' indicator.
p_buf
[
threadIdx
.
x
][
0
]
=
(
threadIdx
.
x
==
0
?
0.0
:
exp
(
p_buf
[
threadIdx
.
x
][
0
]
-
normalizer
));
}
}
else
if
((
int
)
threadIdx
.
x
-
64
<
BLOCK_SIZE
)
{
p_buf
[
0
][
threadIdx
.
x
+
1
]
=
exp
(
p_buf
[
0
][
threadIdx
.
x
+
1
]
-
normalizer
);
}
}
if
(
threadIdx
.
x
==
0
)
{
// This if-statement is an optimization and modification of the loop below
// for the value i == 0, i.e. inner-iteration == 0. The modification
// is to use 0.0 if this is the "origin block" with s_block_begin == 0 and
// t_block_begin == 0. This corresponds to the probability of the pair of
// sequences of length (0, 0).
p_buf
[
1
][
1
]
=
(
is_origin_block
?
0.0
:
p_buf
[
0
][
1
]
*
px_buf
[
0
][
0
]
+
p_buf
[
1
][
0
]
*
py_buf
[
0
][
0
]);
}
N
=
params
.
size
(
1
)
-
1
,
K
=
N
/
2
;
// Note: N and K are powers of 2, with K >= 1.
scalar_t
p_buf_s1_t
;
// This is for an optimization.
if
(
i
<
BLOCK_SIZE
)
{
int
s
=
threadIdx
.
x
;
p_buf_s1_t
=
p_buf
[
s
+
1
][
0
];
}
const
int
c
=
blockIdx
.
x
;
// c is channel index
for
(
int
i
=
1
;
i
<
2
*
BLOCK_SIZE
;
i
++
)
{
// i is the inner iteration, which corresponds to the (s + t) indexes of the
// elements within the block that we write. So i == 0 writes positions
// (s, t) == (0, 0); i == 1 writes (0, 1) and (1, 0); i == 2 writes
// (0, 2), (1, 1) and (2, 1); and so on.
// Note: not many threads participate in this part, only up to BLOCK_SIZE
// at most. Unfortunately we couldn't figure out a very meaningful way
// for more threads to do work, that looked like it would really spead
// things up.
// So this kernel does (2 * BLOCK_SIZE) iterations, which may seem a lot,
// but we do at least do the I/O in an efficient way and keep the
// inner loop simple and fast (e.g. no exp() or log()).
int
s
=
threadIdx
.
x
,
t
=
i
-
s
;
if
(
t
>=
0
)
{
// p_buf is indexed by s + 1 and t + 1 because it has an extra initial
// row and column for context from previous blocks. Taking into account
// the way these buffers relate to the tensors p, px and py, and
// ignoring `normalizer`, code below can be interpreted as follows,
// writing sbb for s_block_begin and tbb for t_block_begin:
//
// p[b][s+sbb][t+tbb] = LogAdd(p[b][s+sbb-1][t+tbb] + px[s+sbb-1][t+tbb],
// p[b][s+sbb][t+tbb-1] + py[s+sbb][t+tbb-1]
//
// where you can see that apart from the offsets of tbb and sbb, this is
// the same as the recursion defined for p in
// mutual_information.py:mutual_information_recursion().
#if 1
p_buf
[
s
+
1
][
t
+
1
]
=
p_buf
[
s
][
t
+
1
]
*
px_buf
[
s
][
t
]
+
p_buf
[
s
+
1
][
t
]
*
py_buf
[
s
][
t
];
#else
// This is an optimization of the statement above (the other half of
// this #if/#else) where we keep p_buf[s + 1][t] in a register to avoid
// the need for a load from shared memory.
p_buf_s1_t
=
p_buf
[
s
][
t
+
1
]
*
px_buf
[
s
][
t
]
+
p_buf_s1_t
*
py_buf
[
s
][
t
];
p_buf
[
s
+
1
][
t
+
1
]
=
p_buf_s1_t
;
#endif
}
__syncthreads
();
}
scalar_t
*
y_vals
=
(
scalar_t
*
)
extern_buf
,
// [N], actually there are 3
// spaces between here and
// `params_buf` for storing scale
// and inv_scale and l == params[c][0].
*
params_buf
=
(
scalar_t
*
)
y_vals
+
3
+
N
;
// [N]. params_buf[n] ontains params[c][n-1].
// params_buf[-1] contains params[c][0] == log of scale;
// params_buf[-2] contains scale, params_buf[-3]
// contains inv_scale.
// Load parameters
if
(
threadIdx
.
x
<=
N
)
params_buf
[
threadIdx
.
x
-
1
]
=
params
[
c
][
threadIdx
.
x
];
__syncthreads
();
// Write out the data.
for
(
int
i
=
threadDim
.
x
;
i
<
BLOCK_SIZE
*
BLOCK_SIZE
;
i
+=
blockDim
.
x
)
{
int
t
=
i
%
BLOCK_SIZE
,
s
=
i
/
BLOCK_SIZE
;
if
(
s
<
block_S
&&
t
<
block_T
)
{
float
this_p
=
p_buf
[
s
+
1
][
t
+
1
];
p
[
b
][
s
+
s_block_begin
][
t
+
t_block_begin
]
=
normalizer
+
log
(
this_p
);
if
(
this_p
-
this_p
!=
0
||
this_p
==
0
)
p_buf
[
0
][
0
]
=
1.0
;
// This is a "panic" flag.
}
}
if
(
threadIdx
.
x
==
0
)
{
scalar_t
scale
=
exp
(
params_buf
[
-
1
]);
params_buf
[
-
2
]
=
scale
;
params_buf
[
-
3
]
=
1.0
/
scale
;
}
__syncthreads
();
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
scalar_t
scale
=
params_buf
[
-
2
],
sum_positive
=
0.0
;
for
(
int
i
=
0
;
i
<
K
;
i
++
)
{
// params_buf is indexed with an index one less than params.
scalar_t
pos_scaled_param
=
params_buf
[
K
+
i
]
*
scale
;
y_vals
[
K
+
i
]
=
sum_positive
-
pos_scaled_param
*
i
;
sum_positive
+=
pos_scaled_param
;
}
}
else
if
(
threadIdx
.
x
==
64
)
{
scalar_t
scale
=
params_buf
[
-
2
],
sum_negative
=
0.0
;
for
(
int
i
=
0
;
i
<
K
;
i
++
)
{
scalar_t
neg_scaled_param
=
params_buf
[
K
-
1
-
i
]
*
scale
;
sum_negative
-=
neg_scaled_param
;
y_vals
[
K
-
i
-
1
]
=
sum_negative
+
neg_scaled_param
*
(
i
+
1
);
if
(
threadIdx
.
x
==
0
)
{
// Write `ans`, if this is the final (top-right) block in its sequence
// Logically, the following equation corresponds to:
// ans[b] = p[b][s_end][t_end]
if
(
s_block_begin
+
S
>
s_end
&&
t_block_begin
+
T
>
t_end
)
ans
[
b
]
=
normalizer
+
log
(
p_buf
[
s_end
-
s_block_begin
+
1
][
t_end
-
t_block_begin
+
1
]);
}
}
__syncthreads
();
scalar_t
inv_scale
=
params_buf
[
-
3
];
int
T_inc
=
THREADS_PER_BLOCK
/
images_per_thread_block
,
b_offset
=
threadIdx
.
x
/
T_inc
,
// offset within batch
t_start
=
threadIdx
.
x
%
T_inc
;
for
(
int
b
=
blockIdx
.
y
*
images_per_thread_block
+
b_offset
;
b
<
B
;
b
+=
gridDim
.
y
*
images_per_thread_block
)
{
// We do "t += THREADS_PER_BLOCK" instead of t += (THREADS_PER_BLOCK /
// images_per_thread_block) as a small optimization because the only case we
// really need to loop is when images_per_thread_block == 1:a we only let
// images_per_thread_block > 1 if T * images_per_thread_block <=
// THREADS_PER_BLOCK.
for
(
int
t
=
t_start
;
t
<
T
;
t
+=
THREADS_PER_BLOCK
)
{
scalar_t
this_input
=
input
[
b
][
c
][
t
],
x
=
this_input
*
inv_scale
+
K
;
if
(
x
<
0
)
x
=
0
;
else
if
(
x
>=
N
)
x
=
N
-
1
;
// C++ rounds toward zero.
int
n
=
(
int
)
x
;
// OK, at this point, 0 <= min < N. Versus the CPU code, we removed the
// factor of 'scale' because params_buf already has that factor.
output
[
b
][
c
][
t
]
=
this_input
*
params_buf
[
n
]
+
y_vals
[
n
];
if
(
p_buf
[
0
][
0
]
!=
0.0
)
{
// "panic" flag set. We need to re-do the computation using log-add.
// This time we won't use the buffers, we'll just load and save from main
// memory. This code should very rarely be reached; and anyway, caching
// should help us quite a bit.
for
(
int
i
=
0
;
i
<
2
*
BLOCK_SIZE
;
i
++
)
{
int
block_s
=
threadIdx
.
x
,
block_t
=
i
-
block_s
;
if
(
static_cast
<
unsigned
int
>
(
t
)
<
static_cast
<
unsigned
int
>
(
block_T
)
&&
block_s
<
block_S
)
{
int
s
=
block_s
+
s_block_begin
,
t
=
block_t
+
t_block_begin
;
float
p_s1
=
(
s
==
0
?
-
INFINITY
:
p
[
b
][
s
-
1
][
t
]),
p_t1
=
(
t
==
0
?
-
INFINITY
:
p
[
b
][
s
][
t
-
1
]),
this_px
=
px
[
b
][
s
][
t
],
this_py
=
py
[
b
][
s
][
t
];
float
this_p
=
LogAdd
(
p_s1
+
this_px
,
p_t1
+
this_py
);
if
(
i
==
0
&&
is_origin_block
)
this_p
=
0.0
;
p
[
b
][
s
][
t
]
=
this_p
;
}
}
if
(
threadIdx
.
x
==
0
)
{
// Write `ans`, if this is the final (top-right) block in its sequence.
// This is only reached in the 'panic situation' where we had overflow.
if
(
s_block_begin
+
S
>
s_end
&&
t_block_begin
+
T
>
t_end
)
ans
[
b
]
=
p
[
b
][
s_end
][
t_end
];
}
}
}
}
...
...
@@ -334,74 +441,135 @@ __forceinline__ __device__ scalar_t strided_reduce_sum(int N,
}
/*
Backward of mutual_information. Each thread group handles a single channel (channel
c = blockIdx.x); the gridDim is (C, nb, 1) where 1 <= nb <= B (nb relates to the
image within the batch).
Backward of mutual_information.
Template args:
scalar_t: the floating-point type, e.g. float, double, maybe half.
If we were to write the forward pass in non-log space, it would be (ignoring
edge cases), as follows... we'll prefix all the variable names with e, e.g. ep,
to clarify that it's the exp of the actual argument p:
Args:
input: input image, shape (B, C, T) where B is batch size, C is
the number of channels and T is the time axis. (For more-than-1d
convolution setups, T would really be more than 1 axis, reshaped).
params: of shape (C, N+1) where N is the number of linear regions in the
piecewise linear function; params[c][0] is l which is
a log scale parameter that dictates how far apart
the discontinuities in the piecewise linear function are,
and params[c][n+1] for 0 <= n < N are the derivatives
of the linear parts of the piecewise linear function.
The discontinuities of the function are at:
exp(l) * [ -(N/2 - 1), -(N/2 - 2), ... (N/2 - 1) ]
output: The transformed input, shape (B , C, T)
images_per_thread_block: The number of images processed by each thread
block. The calling code must guarantee that this is a power
of 2, and that EITHER:
(THREADS_PER_BLOCK / images_per_thread_block >= T AND
THREADS_PER_BLOCK / images_per_thread_block >= N),
OR
images_per_thread_block == 1
.. this is used for a small optimization.
ep[b][s][t] = ep[b][s - 1][t] * epx[b][s - 1][t] +
ep[b][s][t - 1] * epy[b][s][t - 1]. (eq. 1)
ALSO,
(A)
First we consider the part that involves recursion, i.e. the part involving only gradients of
ep. The backprop involving ep only would be:
ep_grad[b][s - 1][t] += ep_grad[b][s][t] * epx[b][s - 1][t]
ep_grad[b][s][t - 1] += ep_grad[b][s][t] * epy[b][s][t - 1].
.. and if we add 1 to the s index of the first equation above and 1 to the
t index of the second equation, we can see that:
This kernel is allocated with `extern_buf` containing enough memory
to store 2*N + 3 values of type scalar_t.
ep_grad[b][s][t] = ep_grad[b][s + 1][t] * epx[b][s][t] +
ep_grad[b][s][t + 1] * epy[b][s][t].
Now, if ep = exp(p), then ep_grad == dy/dep == dy/dp dp/dep == dy/dp / (dep/dp) == dy/dp / exp(p)
== dy/dp / ep. == p_grad / ep.
I.e. ep_grad = p_grad / ep.
So we can write the above as:
p_grad[b][s][t] / ep[b][s][t] = p_grad[b][s + 1][t] / ep[b][s + 1][t] * epx[b][s][t] +
p_grad[b][s][t + 1] / ep[b][s][t + 1] * epy[b][s][t].
Or, rearranging:
p_grad[b][s][t] = p_grad[b][s + 1][t] * exp(p[b][s][t] + px[b][s][t] - p[b][s + 1][t]) +
p_grad[b][s][t + 1] * exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1]). (eq. 2)
(B) The following is the backprop for epx and epy from (eq. 1):
epx_grad[b][s - 1][t] += ep_grad[b][s][t] * ep[b][s - 1][t]
epy_grad[b][s][t - 1] += ep_grad[b][s][t] * ep[b][s][t - 1]
.. adding 1 to the s indexes in the 1st equation and to the t indexes in the 2nd:
epx_grad[b][s][t] = ep_grad[b][s + 1][t] * ep[b][s][t]
epy_grad[b][s][t] = ep_grad[b][s][t + 1] * ep[b][s][t]
The blockDim must equal (THREADS_PER_BLOCK, 1, 1)
The requirements on the grid dimension are:
gridDim.x == num-channels C (required)
1 <= gridDim.y <= B, where B is the number of blocks
gridDim.z == 1
When we invoke this kernel, we'll invoke it as:
mutual_information_backward_kernel<<<gridDim, blockDim, bytesShared, stream>>>
where bytesShared is the number of bytes needed in `extern_buf`:
bytesShared = sizeof(shared_t) * (2N + 3)
We also require that N <= THREADS_PER_BLOCK (for best performance,
N should be quite small, like no larger than 8 or so).
We also require 4 <= N <= 16 for this code!
And we require that
N <= (THREADS_PER_BLOCK / images_per_thread_block)
(both sides will be powers of 2).. this ensures that blocks of threads
summing the N values are always within the same image, which helps
avoid a problem where some loops over 'b' would be done earlier
than others, and we'd end up counting certain pixels twice as their
output_grad would stay nonzero.
Using, similar to the above, ep_grad = p_grad / ep, and similarly,
epx_grad = px_grad / epx and epy_grad = py_grad / epy, and writing exp(p) for p and so on,
the above becomes
px_grad[b][s][t] / exp(px[b][s][t]) = p_grad[b][s + 1][t] / exp(p[b][s + 1][t] * exp(p[b][s][t])
py_grad[b][s][t] / exp(py[b][s][t]) = p_grad[b][s][t + 1] / exp(p[b][s][t + 1] * exp(p[b][s][t])
Rearranging:
px_grad[b][s][t] = p_grad[b][s + 1][t] * exp(p[b][s][t] + px[b][s][t] - p[b][s + 1][t]) (eq. 3a)
py_grad[b][s][t] = p_grad[b][s][t + 1] * exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1]) (eq. 3b)
Defining terms that are common to (eq. 2) and (eqs. 3a,3b), write:
xderiv[b][s][t] := exp(p[b][s][t] + px[b][s][t] - p[b][s + 1][t]) (eq. 4)
yderiv[b][s][t] := exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1]) (eq. 5)
.. and note that these quantities are <= 1 so there is no problem doing
the exponentiation. So the recursion can be simplified as:
p_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] +
p_grad[b][s][t + 1] * yderiv[b][s][t] (eq. 6)
px_grad[b][s][t] = p_grad[b][s][t + 1] * yderiv[b][s][t] (eq. 7)
py_grad[b][s][t] = p_grad[b][s][t + 1] * yderiv[b][s][t] (eq. 8)
(It might seem like we could just reuse px_grad and py_grad for (eq. 6), but it's
not clear to me that this is the best strategy since that would require an extra
write to shared memory within the loop that's the limiting factor.)
The backward pass will be slightly different from the forward pass in terms of
how we store p (and p_grad), because for writing a particular block of p_grad, we
need context on the top and right instead of the bottom and left.
*/
template
<
typename
scalar_t
>
__global__
void
mutual_information_backward_kernel
(
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
input
,
// B, C, T, i.e. batch, channels, time
torch
::
PackedTensorAccessor32
<
scalar_t
,
2
>
params
,
// C, N + 1
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
output_grad
,
// B, C, T
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
input_grad
,
// B, C, T
// params_grad is of dim (gridDim.y, C, N + 1), we'll sum over dim 0.
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
params_grad
,
int
images_per_thread_block
)
{
// B, C, T
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
px
,
// B, S, T + 1, i.e. batch, x_seq_length, y_seq_length + 1
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
py
,
// B, S + 1, T.
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
p
,
// B, S + 1, T + 1. Produced in forward pass.
torch
::
PackedTensorAccessor32
<
scalar_t
,
1
>
ans_grad
,
// [B]. This is an input.
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
p_grad
,
// B, S + 1, T + 1. This is a temporary.
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
px_grad
,
// B, S, T + 1.
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
py_grad
,
// B, S + 1, T.
torch
::
PackedTensorAccessor32
<
int64_t
,
2
>
boundary
,
// B, 4; or 0, 0 if boundaries are the defaults (0, 0, S, T)
int
iter
)
{
// This kernel is sequentially called with 'iter' = num_iters - 1, num_iters - 2, .. 0,
// where num_iters can be taken to be any sufficiently large number but will actually be:
// (S+BLOCK_S_SIZE-1)/BLOCK_S_SIZE + (T+BLOCK_T_SIZE-1)/BLOCK_T_SIZE - 1
const
int
B
=
px
.
size
(
0
),
S
=
px
.
size
(
1
),
T
=
py
.
size
(
2
);
// For statements that are the same as the forward pass, we are omitting some comments
// what we made there. We'll focus, in the comments, on differences from the forward pass.
const
int
num_s_blocks
=
S
/
BLOCK_SIZE
+
1
,
num_t_blocks
=
T
/
BLOCK_SIZE
+
1
,
num_blocks_this_iter
=
min
(
iter
+
1
,
num_s_blocks
);
// px_buf and py_buf are used temporarily to store the px and py values,
// but then modified to store the "xderiv" and "yderiv" values defined
// in (eq. 5) and (eq. 6) above. For out-of-range values, we'll write 0.0
// here.
__shared__
scalar_t
px_buf
[
BLOCK_SIZE
][
BLOCK_SIZE
],
py_buf
[
BLOCK_SIZE
][
BLOCK_SIZE
];
// p_buf is initially used to store p, and then (after we are done putting
// xderiv and yderiv into px_buf and py_buf) it is repurposed to store
// p_grad.
//
// Unlike in the forward pass, p_buf has the same numbering as px_buf and
// py_buf not offset by 1: e.g., for the origin block, p_buf[0][0] refers
// to p[0][0] and not p[-1][-1]. The p_buf block is larger by 1 than
// the block for px_buf and py_buf; unlike in the forward pass, we store
// context on the top right, not the bottom left, i.e. the elements at
// (one past the largest indexes in the block).
//
// For out-of-range elements of p_buf, we'll put zero.
__shared__
scalar_t
p_buf
[
BLOCK_SIZE
+
1
][
BLOCK_SIZE
+
1
];
// boundary_buf will be used to store the b'th row of `boundary` if we have
// boundary information supplied.
__shared__
int64_t
boundary_buf
[
4
];
boundary_buf
[
0
]
=
0
;
boundary_buf
[
1
]
=
0
;
boundary_buf
[
2
]
=
S
;
boundary_buf
[
3
]
=
T
;
const
int
B
=
input
.
size
(
0
),
C
=
input
.
size
(
1
),
...
...
@@ -669,38 +837,52 @@ void mutual_information_backward_kernel(
// forward of mutual_information. See """... """ comment of `mutual_information` in
// mutual_information.py for documentation of the behavior of this function.
torch
::
Tensor
mutual_information_cuda
(
torch
::
Tensor
px
,
torch
::
Tensor
py
,
std
::
optional
<
torch
::
Tensor
>
optional_boundary
,
torch
::
Tensor
p
)
{
TORCH_CHECK
(
px
.
dim
()
==
3
,
"px must be 3-dimensional"
);
TORCH_CHECK
(
py
.
dim
()
==
3
,
"py must be 3-dimensional."
);
TORCH_CHECK
(
p
.
dim
()
==
3
,
"p must be 3-dimensional."
);
TORCH_CHECK
(
px
.
device
().
is_cuda
()
&&
py
.
device
().
is_cuda
()
&&
p
.
device
().
is_cuda
(),
"inputs must be CUDA tensors"
);
torch
::
Tensor
mutual_information_cuda
(
torch
::
Tensor
input
,
torch
::
Tensor
params
)
{
TORCH_CHECK
(
input
.
dim
()
==
3
,
"input must be 3-dimensional"
);
TORCH_CHECK
(
params
.
dim
()
==
2
,
"params must be 2-dimensional."
);
TORCH_CHECK
(
params
.
size
(
1
)
>=
3
&&
((
params
.
size
(
1
)
-
1
)
&
(
params
.
size
(
1
)
-
2
))
==
0
,
"params.size(1) has invalid value, must be a power of 2 plus 1."
);
TORCH_CHECK
(
params
.
size
(
0
)
==
input
.
size
(
1
),
"params vs input channels mismatch"
);
TORCH_CHECK
(
input
.
device
().
is_cuda
(),
"Input must be a CUDA tensor"
);
TORCH_CHECK
(
params
.
device
().
is_cuda
(),
"Params must be a CUDA tensor"
);
const
int
B
=
input
.
size
(
0
),
C
=
input
.
size
(
1
),
T
=
input
.
size
(
2
),
N
=
params
.
size
(
1
)
-
1
;
auto
scalar_t
=
px
.
scalar_type
();
auto
opts
=
torch
::
TensorOptions
().
dtype
(
scalar_t
).
device
(
px
.
device
());
auto
scalar_t
=
input
.
scalar_type
();
auto
opts
=
torch
::
TensorOptions
().
dtype
(
scalar_t
).
device
(
input
.
device
());
const
int
B
=
px
.
size
(
0
),
S
=
px
.
size
(
1
),
T
=
px
.
size
(
2
)
-
1
;
TORCH_CHECK
(
py
.
size
(
0
)
==
B
&&
py
.
size
(
1
)
==
S
+
1
&&
py
.
size
(
2
)
==
T
);
TORCH_CHECK
(
p
.
size
(
0
)
==
B
&&
p
.
size
(
1
)
==
S
+
1
&&
p
.
size
(
2
)
==
T
+
1
);
torch
::
Tensor
ans
=
torch
::
empty
({
B
},
opts
);
int
num_threads
=
128
,
num_blocks
=
128
;
const
int
num_s_blocks
=
S
/
BLOCK_SIZE
+
1
,
num_t_blocks
=
T
/
BLOCK_SIZE
+
1
,
num_iters
=
std
::
max
<
int
>
(
num_s_blocks
,
num_t_blocks
);
bool
has_boundary
=
(
bool
)
optional_boundary
;
if
(
!
has_boundary
)
optional_boundary
=
torch
::
empty
({
0
,
0
},
long_opts
);
for
(
int
iter
=
0
;
iter
<
num_iters
;
iter
++
)
{
mutual_information_kernel
<
scalar_t
,
32
><<<
num_blocks
,
num_threads
>>>
(
px
.
packed_accessor32
<
scalar_t
,
3
>
(),
py
.
packed_accessor32
<
scalar_t
,
3
>
(),
p
.
packed_accessor32
<
scalar_t
,
3
>
(),
optional_boundary
.
value
().
packed_accessor32
<
int64_t
,
2
>
(),
ans
.
packed_accessor32
<
scalar_t
,
1
>
(),
iter
);
}
torch
::
Tensor
output
=
torch
::
empty
({
B
,
C
,
T
},
opts
);
if
(
C
*
B
*
T
==
0
)
return
output
;
int
images_per_thread_block
=
1
;
while
(
images_per_thread_block
*
2
*
T
<=
THREADS_PER_BLOCK
)
images_per_thread_block
*=
2
;
int
grid_dim_y
=
1
;
// If the number of channels is quite small (<128) we can launch more thread
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment