Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
77eed83f
Commit
77eed83f
authored
Jul 28, 2021
by
Daniel Povey
Browse files
Some progress, still drafting.
parent
e95d7864
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
678 additions
and
364 deletions
+678
-364
torch_mutual_information/mutual_information.py
torch_mutual_information/mutual_information.py
+29
-21
torch_mutual_information/mutual_information_cpu.cpp
torch_mutual_information/mutual_information_cpu.cpp
+173
-58
torch_mutual_information/mutual_information_cuda.cpp
torch_mutual_information/mutual_information_cuda.cpp
+17
-8
torch_mutual_information/mutual_information_cuda_kernel.cu
torch_mutual_information/mutual_information_cuda_kernel.cu
+459
-277
No files found.
torch_mutual_information/mutual_information.py
View file @
77eed83f
...
@@ -73,6 +73,14 @@ class MutualInformationRecursionFunction(torch.autograd.Function):
...
@@ -73,6 +73,14 @@ class MutualInformationRecursionFunction(torch.autograd.Function):
def
forward
(
ctx
,
px
:
torch
.
Tensor
,
py
:
torch
.
Tensor
,
boundaries
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
ctx
,
px
:
torch
.
Tensor
,
py
:
torch
.
Tensor
,
boundaries
:
torch
.
Tensor
)
->
torch
.
Tensor
:
(
B
,
S
,
T
)
=
px
.
shape
(
B
,
S
,
T
)
=
px
.
shape
# p is a tensor of shape (B, S + 1, T + 1) were p[s][t] is related to
# the mutual information of the pair of subsequences of x and y that are of
# length s and t respectively. p[0][0] will be 0.0 and p[S][T] is
# the mutual information of the entire pair of sequences, i.e. of lengths
# S and T respectively.
# q is a rearrangement of a tensor p which is of shape (B,S,T),
# q is a rearrangement of a tensor p which is of shape (B,S,T),
# using p[b,s,t] == q[b,s+t,t]. The reason for working with this
# using p[b,s,t] == q[b,s+t,t]. The reason for working with this
# representation is that each row of q depends only on the previous row,
# representation is that each row of q depends only on the previous row,
...
@@ -85,15 +93,9 @@ class MutualInformationRecursionFunction(torch.autograd.Function):
...
@@ -85,15 +93,9 @@ class MutualInformationRecursionFunction(torch.autograd.Function):
# q[b, s-s_begin + t-t_begin, t-t_begin];
# q[b, s-s_begin + t-t_begin, t-t_begin];
# note, rows of `boundaries` are [s_begin, t_begin, s_end, t_end].
# note, rows of `boundaries` are [s_begin, t_begin, s_end, t_end].
if
px
.
requires_grad
or
py
.
requires_grad
:
p
=
torch
.
empty
(
B
,
S
+
1
,
T
+
1
,
device
=
px
.
device
,
dtype
=
px
.
dtype
)
q
=
torch
.
empty
(
B
,
S
,
T
,
device
=
px
.
device
,
dtype
=
px
.
dtype
)
else
:
# We don't need to store q if we are not going to do backprop, but we
# do pass in a temporary with one real row, expanded to have "fake" rows,
# which happens to be convenient for the CPU implementation.
q
=
torch
.
empty
({
1
,
1
,
T
},
device
=
px
.
device
,
dtype
=
px
.
dtype
).
expand
(
B
,
S
+
T
,
T
)
ans
=
_mutual_information_forward_dispatcher
(
px
,
py
,
boundaries
,
q
)
ans
=
_mutual_information_forward_dispatcher
(
px
,
py
,
boundaries
,
p
)
if
px
.
requires_grad
or
py
.
requires_grad
:
if
px
.
requires_grad
or
py
.
requires_grad
:
ctx
.
save_for_backward
(
px
,
py
,
boundaries
,
q
)
ctx
.
save_for_backward
(
px
,
py
,
boundaries
,
q
)
...
@@ -115,7 +117,7 @@ def mutual_information_recursion(input, px, py, boundaries=None):
...
@@ -115,7 +117,7 @@ def mutual_information_recursion(input, px, py, boundaries=None):
make use of the formula computed by this function.
make use of the formula computed by this function.
Args:
Args:
px: A torch.Tensor of some floating point type, with shape [B][S][T],
px: A torch.Tensor of some floating point type, with shape [B][S][T
+1
],
where B is the batch size, S is the length of the 'x' sequence
where B is the batch size, S is the length of the 'x' sequence
(including representations of EOS symbols but not BOS symbols), and S is the
(including representations of EOS symbols but not BOS symbols), and S is the
length of the 'y' sequence (including representations of
length of the 'y' sequence (including representations of
...
@@ -139,13 +141,13 @@ def mutual_information_recursion(input, px, py, boundaries=None):
...
@@ -139,13 +141,13 @@ def mutual_information_recursion(input, px, py, boundaries=None):
code assumes for optimization purposes that the T axis has
code assumes for optimization purposes that the T axis has
stride 1.
stride 1.
py: A torch.Tensor of the same dtype as px, with shape [B][S][T],
py: A torch.Tensor of the same dtype as px, with shape [B][S
+1
][T],
representing
representing
py[b][s][t] = log [ p(y_t | x_{0..s-1}, y_{0..t-1}) / p(y_t) ]
py[b][s][t] = log [ p(y_t | x_{0..s-1}, y_{0..t-1}) / p(y_t) ]
This function does not treat x and y differently; the only difference
This function does not treat x and y differently; the only difference
is that
the implementation assumes for optimization purposes that y
is that
for optimization purposes we assume the last axis (the t axis)
is likely to be the shorter sequence, i.e. that "most of the time T < S",
has stride of 1; this is true if px and py are contiguous.
and it will be faster if you respect this.
boundaries: If supplied, a torch.LongTensor of shape [B][4], where each row contains
boundaries: If supplied, a torch.LongTensor of shape [B][4], where each row contains
[s_begin, t_begin, s_end, t_end]. If not supplied, the values
[s_begin, t_begin, s_end, t_end]. If not supplied, the values
[0, 0, S, T] will be assumed. These are the beginning and
[0, 0, S, T] will be assumed. These are the beginning and
...
@@ -155,18 +157,24 @@ def mutual_information_recursion(input, px, py, boundaries=None):
...
@@ -155,18 +157,24 @@ def mutual_information_recursion(input, px, py, boundaries=None):
Returns:
Returns:
Returns a torch.Tensor of shape [B], containing the log of the mutuafl
Returns a torch.Tensor of shape [B], containing the log of the mutuafl
information between the b'th pair of sequences. This is defined by
information between the b'th pair of sequences. This is defined by
the following recursion on p[b,s,t] (where p is of shape [B,S
,T
]),
the following recursion on p[b,s,t] (where p is of shape [B,S
+1,T+1
]),
representing a mutual information between sub-sequences of lengths s and t:
representing a mutual information between sub-sequences of lengths s and t:
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s,t], p[b,s,t-1] + py[b,s,t])
p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
where in the case where boundaries==None: the edge cases are handled
p[b,s,t-1] + py[b,s,t-1])
by treating p[b,-1,-1] as 0 and all other quantities with negative
(if s > 0 or t > 0)
indexes as -infinity; and ans[b] would equal p[S-1,T-1]. The extension to
cases where the boundaries are specified should be obvious.
where we handle edge cases by treating quantities with negative indexes
as -infinity. The extension to cases where the boundaries are specified
should be obvious; it just works on shorter sequences with offsets into
px and py.
"""
"""
assert
px
.
ndim
==
3
and
px
.
shape
==
py
.
shape
and
px
.
dtype
==
py
.
dtype
assert
px
.
ndim
==
3
B
,
S
,
T1
=
px
.
shape
T
=
T1
-
1
assert
py
.
shape
==
(
B
,
S
+
1
,
T
)
assert
px
.
dtype
==
py
.
dtype
(
B
,
S
,
T
)
=
px
.
shape
(
B
,
S
,
T
)
=
px
.
shape
if
boundaries
is
not
None
:
if
boundaries
is
not
None
:
assert
boundaries
.
dtype
==
torch
.
LongTensor
assert
boundaries
.
dtype
==
torch
.
LongTensor
...
...
torch_mutual_information/mutual_information_cpu.cpp
View file @
77eed83f
...
@@ -3,6 +3,13 @@
...
@@ -3,6 +3,13 @@
inline
double
Exp
(
double
x
)
{
return
exp
(
x
);
}
inline
double
Exp
(
float
x
)
{
return
expf
(
x
);
}
// returns log(exp(x) + exp(y)).
// returns log(exp(x) + exp(y)).
inline
double
LogAdd
(
double
x
,
double
y
)
{
inline
double
LogAdd
(
double
x
,
double
y
)
{
double
diff
;
double
diff
;
...
@@ -14,8 +21,7 @@ inline double LogAdd(double x, double y) {
...
@@ -14,8 +21,7 @@ inline double LogAdd(double x, double y) {
diff
=
y
-
x
;
diff
=
y
-
x
;
}
}
// diff is negative. x is now the larger one.
// diff is negative. x is now the larger one.
if
(
diff
>=
-
1000
)
{
if
(
diff
>=
kMinLogDiffDouble
)
{
double
res
;
double
res
;
res
=
x
+
log1p
(
exp
(
diff
));
res
=
x
+
log1p
(
exp
(
diff
));
return
res
;
return
res
;
...
@@ -35,99 +41,208 @@ inline float LogAdd(float x, float y) {
...
@@ -35,99 +41,208 @@ inline float LogAdd(float x, float y) {
diff
=
y
-
x
;
diff
=
y
-
x
;
}
}
// diff is negative. x is now the larger one.
// diff is negative. x is now the larger one.
if
(
diff
>=
-
200
)
{
if
(
diff
>=
kMinLogDiffFloat
)
{
float
res
;
float
res
;
res
=
x
+
log1pf
(
expf
(
diff
));
res
=
x
+
log1pf
(
expf
(
diff
));
return
res
;
return
res
;
}
}
return
x
;
// return the larger one.
return
x
;
// return the larger one.
}
}
// forward of mutual_information. See """... """ comment of `mutual_information` in
// forward of mutual_information. See """... """ comment of `mutual_information` in
// mutual_information.py for documentation of the behavior of this function.
// mutual_information.py for documentation of the behavior of this function.
torch
::
Tensor
mutual_information_cpu
(
torch
::
Tensor
px
,
torch
::
Tensor
mutual_information_cpu
(
torch
::
Tensor
px
,
torch
::
Tensor
py
,
torch
::
Tensor
py
,
std
::
optional
<
torch
::
Tensor
>
optional_boundary
,
std
::
optional
<
torch
::
Tensor
>
optional_boundary
,
torch
::
Tensor
q
)
{
torch
::
Tensor
p
)
{
TORCH_CHECK
(
px
.
dim
()
==
3
,
"px must be 3-dimensional"
);
TORCH_CHECK
(
px
.
dim
()
==
3
,
"px must be 3-dimensional"
);
TORCH_CHECK
(
py
.
dim
()
==
3
,
"params must be 3-dimensional."
);
TORCH_CHECK
(
py
.
dim
()
==
3
,
"py must be 3-dimensional."
);
TORCH_CHECK
(
q
.
dim
()
==
3
,
"params must be 3-dimensional."
);
TORCH_CHECK
(
p
.
dim
()
==
3
,
"p must be 3-dimensional."
);
TORCH_CHECK
(
px
.
device
().
is_cpu
()
&&
py
.
device
().
is_cpu
()
&&
p
.
device
().
is_cpu
(),
"inputs must be CPU tensors"
);
auto
scalar_t
=
px
.
scalar_type
();
auto
scalar_t
=
px
.
scalar_type
();
auto
opts
=
torch
::
TensorOptions
().
dtype
(
scalar_t
).
device
(
px
.
device
());
auto
opts
=
torch
::
TensorOptions
().
dtype
(
scalar_t
).
device
(
px
.
device
());
const
int
B
=
px
.
size
(
0
),
const
int
B
=
px
.
size
(
0
),
S
=
px
.
size
(
1
),
S
=
px
.
size
(
1
),
T
=
px
.
size
(
2
);
T
=
px
.
size
(
2
)
-
1
;
TORCH_CHECK
(
py
.
size
(
0
)
==
B
&&
py
.
size
(
1
)
==
S
+
1
&&
py
.
size
(
2
)
==
T
);
TORCH_CHECK
(
q
.
size
(
0
)
==
B
&&
q
.
size
(
1
)
==
S
+
T
&&
q
.
size
(
2
)
==
T
);
TORCH_CHECK
(
p
.
size
(
0
)
==
B
&&
p
.
size
(
1
)
==
S
+
1
&&
p
.
size
(
2
)
==
T
+
1
);
torch
::
Tensor
ans
=
torch
::
empty
({
B
},
opts
);
auto
long_opts
=
torch
::
TensorOption
a
().
dtype
(
torch
::
kInt64
);
auto
long_opts
=
torch
::
TensorOption
s
().
dtype
(
torch
::
kInt64
)
.
device
(
px
.
device
())
;
bool
has_boundary
=
(
bool
)
optional_boundary
;
bool
has_boundary
=
(
bool
)
optional_boundary
;
if
(
!
has_boundary
)
if
(
!
has_boundary
)
optional_boundary
=
torch
::
empty
({},
long_opts
);
optional_boundary
=
torch
::
empty
({
0
,
0
},
long_opts
);
TORCH_CHECK
(
optional_boundary
.
value
().
device
().
is_cpu
()
&&
optional_boundary
.
value
().
dtype
==
torch
::
kInt64
);
AT_DISPATCH_FLOATING_TYPES
(
px
.
scalar_type
(),
"mutual_information_cpu_loop"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
px
.
scalar_type
(),
"mutual_information_cpu_loop"
,
([
&
]
{
auto
px_a
=
px
.
accessor
<
scalar_t
,
3
>
(),
auto
px_a
=
px
.
packed_accessor32
<
scalar_t
,
3
>
(),
py_a
=
py
.
accessor
<
scalar_t
,
3
>
();
py_a
=
py
.
packed_accessor32
<
scalar_t
,
3
>
(),
for
(
int
c
=
0
;
c
<
C
;
c
++
)
{
p_a
=
p
.
packed_accessor32
<
scalar_t
,
3
>
();
scalar_t
sum_negative
=
0.0
,
auto
boundary_a
=
optional_boundary
.
value
().
packed_accessor32
<
int64_t
,
2
>
();
sum_positive
=
0.0
,
auto
ans_a
=
ans
.
packed_accessor32
<
scalar_t
,
1
>
();
scale
=
exp
(
params_a
[
c
][
0
]);
for
(
int
i
=
0
;
i
<
K
;
i
++
)
{
for
(
int
b
=
0
b
<
B
;
b
++
)
{
scalar_t
pos_scaled_param
=
params_a
[
c
][
1
+
K
+
i
]
*
scale
,
int
s_begin
,
s_end
,
t_begin
,
t_end
;
neg_scaled_param
=
params_a
[
c
][
K
-
i
]
*
scale
;
if
(
has_boundary
)
{
y_vals_a
[
c
][
K
+
i
]
=
sum_positive
-
pos_scaled_param
*
i
;
s_begin
=
boundary_a
[
b
][
0
];
sum_positive
+=
pos_scaled_param
;
t_begin
=
boundary_a
[
b
][
1
];
sum_negative
-=
neg_scaled_param
;
s_end
=
boundary_a
[
b
][
2
];
y_vals_a
[
c
][
K
-
i
-
1
]
=
sum_negative
+
neg_scaled_param
*
(
i
+
1
);
t_end
=
boundary_a
[
b
][
3
];
}
else
{
s_begin
=
0
;
s_end
=
S
;
t_begin
=
0
;
t_end
=
T
;
}
p_a
[
b
][
s_begin
][
t_begin
]
=
0.0
;
for
(
int
s
=
s_begin
+
1
;
s
<=
s_end
;
++
s
)
p_a
[
b
][
s
][
t_begin
]
=
p_a
[
b
][
s
-
1
][
t_begin
]
+
px_a
[
b
][
s
-
1
][
t_begin
];
for
(
int
t
=
t_begin
+
1
;
t
<=
t_end
;
++
t
)
p_a
[
b
][
s_begin
][
t
]
=
p_a
[
b
][
s_begin
][
t
-
1
]
+
py_a
[
b
][
s_begin
][
t
-
1
];
for
(
int
s
=
s_begin
+
1
;
s
<=
s_end
;
++
s
)
{
scalar_t
p_s_t1
=
p_a
[
b
][
s
][
t_begin
];
for
(
int
t
=
t_begin
+
1
;
t
<=
t_end
;
++
t
)
{
// The following statement is a small optimization of:
// p_a[b][s][t] = LogAdd(p_a[b][s - 1][t] + px_a[b][s - 1][t],
// p_a[b][s][t - 1] + py_a[b][s][t - 1]);
// .. which obtains p_a[b][s][t - 1] from a register.
p_a
[
b
][
s
][
t
]
=
p_s_t1
=
LogAdd
(
p_a
[
b
][
s
-
1
][
t
]
+
px_a
[
b
][
s
-
1
][
t
],
p_s_t1
+
py_a
[
b
][
s
][
t
-
1
]);
}
}
}
ans_a
[
b
]
=
p_a
[
b
][
s_end
][
t_end
];
}
}
}));
return
ans
;
}
auto
input_a
=
input
.
accessor
<
scalar_t
,
3
>
(),
output_a
=
output
.
accessor
<
scalar_t
,
3
>
();
for
(
int
b
=
0
;
b
<
B
;
b
++
)
{
// backward of mutual_information. Returns (px_grad, py_grad).
for
(
int
c
=
0
;
c
<
C
;
c
++
)
{
// p corresponds to what we computed in the forward pass.
scalar_t
scale
=
exp
(
params_a
[
c
][
0
]),
std
::
vector
<
torch
::
Tensor
>
mutual_information_backward_cpu
(
inv_scale
=
1.0
/
scale
;
torch
::
Tensor
px
,
for
(
int
t
=
0
;
t
<
T
;
t
++
)
{
torch
::
Tensor
py
,
// `x` is the scaled input x plus an offset so that -K maps to 0.
std
::
optional
<
torch
::
Tensor
>
optional_boundary
,
// Note: the discontinuities in our function are at -(K-1) ... +(K+1),
torch
::
Tensor
p
,
// so in a sense -K and +K are not special, but we include those
torch
::
Tensor
ans_grad
)
{
// extra values as an easy way to handle the semi-infinite regions
TORCH_CHECK
(
px
.
dim
()
==
3
,
"px must be 3-dimensional"
);
// that are < -(K-1) and > (K-1)
TORCH_CHECK
(
py
.
dim
()
==
3
,
"py must be 3-dimensional."
);
scalar_t
input
=
input_a
[
b
][
c
][
t
],
TORCH_CHECK
(
p
.
dim
()
==
3
,
"p must be 3-dimensional."
);
x
=
input
*
inv_scale
+
K
;
TORCH_CHECK
(
ans_grad
.
dim
()
==
1
,
"ans_grad must be 3-dimensional."
);
if
(
x
<
0
)
x
=
0
;
else
if
(
x
>=
N
)
x
=
N
-
1
;
TORCH_CHECK
(
px
.
device
().
is_cpu
()
&&
py
.
device
().
is_cpu
()
&&
p
.
device
().
is_cpu
()
// C++ rounds toward zero.
&&
ans_grad
.
device
()
==
cpu
(),
int
n
=
(
int
)
x
;
"inputs must be CPU tensors"
);
// OK, at this point, 0 <= min < 2*K.
output_a
[
b
][
c
][
t
]
=
input
*
params_a
[
c
][
n
+
1
]
+
y_vals_a
[
c
][
n
];
auto
scalar_t
=
px
.
scalar_type
();
auto
opts
=
torch
::
TensorOptions
().
dtype
(
scalar_t
).
device
(
px
.
device
());
const
int
B
=
px
.
size
(
0
),
S
=
px
.
size
(
1
),
T
=
px
.
size
(
2
)
-
1
;
TORCH_CHECK
(
py
.
size
(
0
)
==
B
&&
py
.
size
(
1
)
==
S
+
1
&&
py
.
size
(
2
)
==
T
);
TORCH_CHECK
(
p
.
size
(
0
)
==
B
&&
p
.
size
(
1
)
==
S
+
1
&&
p
.
size
(
2
)
==
T
+
1
);
torch
::
Tensor
p_grad
=
torch
::
zeros
({
B
,
S
+
1
,
T
+
1
},
opts
);
auto
long_opts
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
px
.
device
());
bool
has_boundary
=
(
bool
)
optional_boundary
;
if
(
!
has_boundary
)
optional_boundary
=
torch
::
empty
({
0
,
0
},
long_opts
);
TORCH_CHECK
(
optional_boundary
.
value
().
device
().
is_cpu
()
&&
optional_boundary
.
value
().
dtype
==
torch
::
kInt64
);
AT_DISPATCH_FLOATING_TYPES
(
px
.
scalar_type
(),
"mutual_information_cpu_backward_loop"
,
([
&
]
{
auto
px_a
=
px
.
packed_accessor32
<
scalar_t
,
3
>
(),
py_a
=
py
.
packed_accessor32
<
scalar_t
,
3
>
(),
p_a
=
p
.
packed_accessor32
<
scalar_t
,
3
>
(),
p_grad_a
=
p
.
packed_accessor32
<
scalar_t
,
3
>
();
auto
ans_grad_a
=
ans_grad
.
packed_accessor32
<
scalar_t
,
1
>
();
auto
boundary_a
=
optional_boundary
.
value
().
packed_accessor32
<
int64_t
,
2
>
();
for
(
int
b
=
0
b
<
B
;
b
++
)
{
int
s_begin
,
s_end
,
t_begin
,
t_end
;
if
(
has_boundary
)
{
s_begin
=
boundary_a
[
b
][
0
];
t_begin
=
boundary_a
[
b
][
1
];
s_end
=
boundary_a
[
b
][
2
];
t_end
=
boundary_a
[
b
][
3
];
}
else
{
s_begin
=
0
;
s_end
=
S
;
t_begin
=
0
;
t_end
=
T
;
}
// Backprop for: ans_a[b] = p_a[b][s_end][t_end];
p_grad_a
[
b
][
s_end
][
t_end
]
=
ans_grad_a
[
b
];
for
(
int
s
=
s_end
;
s
>
s_begin
;
--
s
)
{
for
(
int
t
=
t_end
;
t
>
t_begin
;
--
t
)
{
// The statement we are backpropagating here is:
// p_a[b][s][t] = LogAdd(p_a[b][s - 1][t] + px_a[b][s - 1][t],
// p_a[b][s][t - 1] + py_a[b][s][t - 1]);
// .. which obtains p_a[b][s][t - 1] from a register.
scalar_t
term1
=
p_a
[
b
][
s
-
1
][
t
]
+
px_a
[
b
][
s
-
1
][
t
],
total
=
p_a
[
b
][
s
][
t
],
term1_deriv
=
exp
(
term1
-
total
),
term2_deriv
=
1.0
-
term1_deriv
,
grad
=
p_grad_a
[
b
][
s
][
t
],
term1_grad
=
term1_deriv
*
grad
,
term2_grad
=
term2_deriv
*
grad
;
// We can assign to px_grad_a here rather than add, because we
// know it's currently zero.
TORCH_CHECK
(
px_grad_a
[
b
][
s
-
1
][
t
]
==
0
);
px_grad_a
[
b
][
s
-
1
][
t
]
=
term1_grad
;
TORCH_CHECK
(
p_grad_a
[
b
][
s
-
1
][
t
]
==
0.0
);
// likewise..
p_grad_a
[
b
][
s
-
1
][
t
]
=
term1_grad
py_grad_a
[
b
][
s
][
t
-
1
]
+=
term2_grad
;
p_grad_a
[
b
][
s
][
t
-
1
]
+=
term2_grad
;
}
}
}
}
}}));
for
(
int
t
=
t_end
;
t
>=
t_begin
;
--
t
)
{
return
output
;
// Backprop for:
// p_a[b][s_begin][t] = p_a[b][s_begin][t - 1] + py_a[b][s_begin][t - 1];
scalar_t
this_p_grad
=
p_grad_a
[
b
][
s_begin
][
t
];
p_grad_a
[
b
][
s_begin
][
t
-
1
]
+=
this_p_grad
;
py_grad_a
[
b
][
s_begin
][
t
-
1
]
+=
this_p_grad
;
}
for
(
int
s
=
s_end
;
s
>=
s_begin
;
--
s
)
{
// Backprop for:
// p_a[b][s][t_begin] = p_a[b][s - 1][t_begin] + px_a[b][s - 1][t_begin];
scalar_t
this_p_grad
=
p_grad_a
[
b
][
s
][
s_begin
];
p_a
[
b
][
s
-
1
][
t_begin
]
+=
this_p_grad
;
px_a
[
b
][
s
-
1
][
t_begin
]
+=
this_p_grad
;
}
// There is no backprop for:
// p_a[b][s_begin][t_begin] = 0.0;
// .. but we can use this for a check, that the grad at the beginning
// of the sequence is equal to the grad at the end of the sequence.
if
(
ans_grad_a
[
b
]
!=
0.0
)
{
float
grad_ratio
=
p_a
[
b
][
s_begin
][
t_begin
]
/
ans_grad_a
[
b
];
if
(
grad_ratio
-
1.0
>
0.01
)
{
printf
(
"Warning: mutual_information backprop: expected these numbers to be the same: %f vs. %f
\n
"
,
(
float
)
p_a
[
b
][
s_begin
][
t_begin
],
(
float
)
ans_grad_a
[
b
]);
}
}
}
}));
return
ans
;
}
}
// backward of mutual_information. Returns (input_grad, params_grad)
std
::
vector
<
torch
::
Tensor
>
mutual_information_backward_cpu
(
torch
::
Tensor
input
,
torch
::
Tensor
params
,
torch
::
Tensor
output_grad
)
{
TORCH_CHECK
(
input
.
dim
()
==
3
,
"input must be 3-dimensional"
);
TORCH_CHECK
(
input
.
dim
()
==
3
,
"input must be 3-dimensional"
);
TORCH_CHECK
(
params
.
dim
()
==
2
,
"params must be 2-dimensional."
);
TORCH_CHECK
(
params
.
dim
()
==
2
,
"params must be 2-dimensional."
);
TORCH_CHECK
(
params
.
size
(
1
)
>=
3
&&
TORCH_CHECK
(
params
.
size
(
1
)
>=
3
&&
...
...
torch_mutual_information/mutual_information_cuda.cpp
View file @
77eed83f
...
@@ -3,18 +3,27 @@
...
@@ -3,18 +3,27 @@
// forward of mutual_information. """... """ comment of `mutual_information`
// forward of mutual_information. """... """ comment of `mutual_information`
// in mutual_information.py documents the behavior of this function.
// in mutual_information.py documents the behavior of this function.
torch
::
Tensor
mutual_information_cuda
(
torch
::
Tensor
input
,
// It is the core recursion in the sequence-to-sequence mutual information
torch
::
Tensor
params
);
// computation.
// returns 'ans', of dimension B (batch size).
torch
::
Tensor
mutual_information_cuda
(
torch
::
Tensor
px
,
// [B][S][T+1]
torch
::
Tensor
py
,
// [B][S+1][T]
std
::
optional
<
torch
::
Tensor
>
boundary_info
,
// [B][4], int64_t.
torch
::
Tensor
p
);
// [B][S+1][T+1]; an output
// backward of mutual_information; returns (grad_input, grad_params).
// backward of mutual_information; returns (grad_px, grad_py)
std
::
vector
<
torch
::
Tensor
>
mutual_information_backward_cuda
(
torch
::
Tensor
input
,
std
::
vector
<
torch
::
Tensor
>
mutual_information_backward_cuda
(
torch
::
Tensor
params
,
torch
::
Tensor
px
,
torch
::
Tensor
grad_output
);
torch
::
Tensor
py
,
std
::
optional
<
torch
::
Tensor
>
boundary_info
,
torch
::
Tensor
p
,
torch
::
Tensor
ans_grad
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"mutual_information_cuda"
,
&
mutual_information_cuda
,
"
Learned nonlinearity
forward function (CUDA)"
);
m
.
def
(
"mutual_information_cuda"
,
&
mutual_information_cuda
,
"
Mutual information
forward function (CUDA)"
);
m
.
def
(
"mutual_information_backward_cuda"
,
&
mutual_information_backward_cuda
,
"
Learned nonlinearity
backward function (CUDA)"
);
m
.
def
(
"mutual_information_backward_cuda"
,
&
mutual_information_backward_cuda
,
"
Mutual information
backward function (CUDA)"
);
}
}
torch_mutual_information/mutual_information_cuda_kernel.cu
View file @
77eed83f
#include <torch/extension.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAStream.h> // for getCurrentCUDAStream()
#include <c10/cuda/CUDAStream.h> // for getCurrentCUDAStream()
#include <cooperative_groups.h>
#include <cooperative_groups.h>
#include <cmath> // for INFINITY
/*
// returns log(exp(x) + exp(y)).
Tiled summing reduction within a warp. Requires that the thread-block
__forceinline__
__device__
double
LogAdd
(
double
x
,
double
y
)
{
be 1-dimensional, i.e. blockDim.y == blockDim.z == 1. Does not use
double
diff
;
__syncthreads, so it is safe to call in a subset of threads.
TODO: we can in principle do this without a buffer, using __shfl_down()
(see here https://sodocumentation.net/cuda/topic/6566/parallel-reduction--e-g--how-to-sum-an-array-)
if CC >= 3.0.
Args:
if
(
x
<
y
)
{
threads_per_tile: Must be a power of 2 in the interval [1,32]. Summation is
diff
=
x
-
y
;
within blocks of threads of this size.
x
=
y
;
buf: Pointer to the start of a __shared__ buffer of size
}
else
{
blockDim.x, to be used as a temporary within this function.
diff
=
y
-
x
;
val: The value to be summed
}
Return:
// diff is negative. x is now the larger one.
Threads where threadIdx.x % threads_per_tile == 0 will return the sum:
if
(
diff
-
diff
!=
0
)
\sum_{i=0}^{threads_per_tile-1} [val in thread threadIdx.x + i]
return
x
;
// x and y are probably -inf. Return the larger one.
The return value in other threads is undefined.
else
*/
return
x
+
log1p
(
exp
(
diff
));
template
<
typename
scalar_t
>
}
__forceinline__
__device__
scalar_t
tiled_warp_reduce_sum
(
int
threads_per_tile
,
__volatile__
scalar_t
*
buf
,
// returns log(exp(x) + exp(y)).
scalar_t
val
)
{
__forceinline__
__device__
inline
float
LogAdd
(
float
x
,
float
y
)
{
// Each iteration halves the number of active threads
float
diff
;
// Each thread adds its partial sum[i] to sum[lane+i]
for
(
int
i
=
threads_per_tile
/
2
;
i
>
0
;
i
/=
2
)
{
if
(
x
<
y
)
{
buf
[
threadIdx
.
x
]
=
val
;
diff
=
x
-
y
;
if
(
threadIdx
.
x
%
threads_per_tile
<
i
)
x
=
y
;
val
+=
buf
[
threadIdx
.
x
+
i
];
}
else
{
diff
=
y
-
x
;
}
}
return
val
;
// Only threads with threadIdx.x % threads_per_tile == 0 will
// diff is negative. x is now the larger one.
// return the full sums of their tiles.
if
(
diff
-
diff
!=
0
)
return
x
;
// x and y are probably -inf. Return the larger one.
else
return
x
+
log1p
(
exp
(
diff
));
}
}
/*
/*
Forward of mutual_information. Each thread block handles blocks of (x, y) shape
Forward of mutual_information. Each thread block handles blocks of (x, y) shape
equal to (BLOCK_S_SIZE, BLOCK_T_SIZE), e.g. (4, 64). Thread blocks loop over such
equal to (BLOCK_SIZE, BLOCK_SIZE), e.g. (32, 32). Thread blocks loop over such
blocks, but they might typically loop only once. We sequentially launch groups of
blocks, but they might loop only once if there is not that much data to process.
threads in such a way that thread-blocks within a group do not depend on each other.
We sequentially launch groups of threads in such a way that thread-blocks
within a group do not depend on each other.
Template args:
Template args:
scalar_t: the floating-point type, e.g. float, double, maybe half.
scalar_t: the floating-point type, e.g. float, double, maybe half.
Args:
Args:
input: input image, shape (B, C, T) where B is batch size, C is
px: log-odds ratio of generating next x in the sequence, i.e.
xy[b][s][t] is the log-odds probability of generating x_t of
the b'th image given subsequences of length (s, t). (See
mutual_information.py for more info). Shape [B][S][T + 1]
py: log-odds ratio of generating next y in the sequence.
Shape [B][S + 1][T]
p: matrix of mutual information of sub-sequences, that this
function writes to. Shape [B][S + 1][T + 1]. This function
computes the following recursion:
p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
(if s > 0 or t > 0)
boundary: If set, a tensor of shape [B][4] of type int64_t, which
contains, for each batch element, [s_begin, t_begin, s_end, t_end]
which are the beginning and end (one-past-the-last) of the
x and y sequences that we should process. If not set, these
default to (0, 0, S, T), and they should not exceed these bounds
or be empty (i.e. s_begin <= t_begin or s_end <= t_end).
nput: input image, shape (B, C, T) where B is batch size, C is
the number of channels and T is the time axis. (For more-than-1d
the number of channels and T is the time axis. (For more-than-1d
convolution setups, T would really be more than 1 axis, reshaped).
convolution setups, T would really be more than 1 axis, reshaped).
params: of shape (C, N+1) where N is the number of linear regions in the
params: of shape (C, N+1) where N is the number of linear regions in the
...
@@ -75,68 +103,74 @@ __forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile,
...
@@ -75,68 +103,74 @@ __forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile,
This kernel is allocated with `extern_buf` containing enough memory
This kernel is allocated with `extern_buf` containing enough memory
to store 2*N + 3 values of type scalar_t.
to store 2*N + 3 values of type scalar_t.
The blockDim must equal (THREADS_PER_BLOCK, 1, 1)
The block-dim and grid-dim must both be 1-dimensional, and the block-dim must
be at least 128.
The requirements on the grid dimension are:
gridDim.x == num-channels C (required)
1 <= gridDim.y <= B, where B is the number of blocks
gridDim.z == 1
When we invoke this kernel, we'll invoke it as:
mutual_information_kernel<<<gridDim, blockDim, bytesShared, stream>>>
where bytesShared is the number of bytes needed in `extern_buf`:
bytesShared = sizeof(shared_t) * (2N + 3)
We also require N + 1 <= THREADS_PER_BLOCK.
*/
*/
extern
__shared__
int
extern_buf
[];
template
<
typename
scalar_t
,
template
<
typename
scalar_t
,
int
BLOCK_S_SIZE
,
// e.g. BLOCK_S_SIZE == 4; power of 2
int
BLOCK_SIZE
>
// e.g. BLOCK_SIZE == 16 or 32. Note: we require the
int
BLOCK_T_SIZE
>
// e.g. BLOCK_T_SIZE == 64; power of 2.
// num-threads be at least 128.
// BLOCK_T_SIZE * 4 must equal num_threads; and must be >= 128, so BLOCK_T_SIZE >= 32 is required.
// (Note: this 4 is unrelated to BLOCK_S_SIZE but can be viewed as 1<<2,
// where 2 is the loop unrolling factor).
__global__
__global__
void
mutual_information_kernel
(
void
mutual_information_kernel
(
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
px
,
// B, S, T, i.e. batch, x_seq_length, y_seq_length
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
px
,
// B, S, T + 1, i.e. batch, x_seq_length, y_seq_length + 1
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
py
,
// B, S, T, as above
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
py
,
// B, S + 1, T.
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
p
,
// B, S, T, as above. This is an output.
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
p
,
// B, S + 1, T + 1. This is an output.
torch
::
PackedTensorAccessor32
<
scalar_t
,
2
>
boundary
,
// B, 4; or 0, 0 if boundaries are the defaults (0, 0, S, T)
torch
::
PackedTensorAccessor32
<
int64_t
,
2
>
boundary
,
// B, 4; or 0, 0 if boundaries are the defaults (0, 0, S, T)
torch
::
PackedTensorAccessor32
<
scalar_t
,
1
>
ans
,
// [B]
int
iter
)
{
// This kernel is sequentially called with 'iter' = 0, 1, 2 and so on, up to:
int
iter
)
{
// This kernel is sequentially called with 'iter' = 0, 1, 2 and so on, up to:
// (S+BLOCK_S_SIZE-1)/BLOCK_S_SIZE + (T+BLOCK_T_SIZE-1)/BLOCK_T_SIZE - 1
// (S+BLOCK_S_SIZE-1)/BLOCK_S_SIZE + (T+BLOCK_T_SIZE-1)/BLOCK_T_SIZE - 1
// so that each group depends on the previous group...
// so that each group depends on the previous group...
const
int
block_dimx
=
BLOCK_T_SIZE
*
4
;
// known at compile time.
assert
(
blockDim
.
x
==
block_dimx
);
const
int
B
=
px
.
size
(
0
),
const
int
B
=
px
.
size
(
0
),
S
=
px
.
size
(
1
),
S
=
px
.
size
(
1
),
T
=
py
.
size
(
2
);
T
=
py
.
size
(
2
);
// num_s_blocks and num_t_blocks are the number of blocks we need to cover the
// num_s_blocks and num_t_blocks are the number of blocks we need to cover the
// array of size (S, T) with blocks of this size, in the s and t directions
// array of size (S, T) with blocks of this size, in the s and t directions
// respectively.
// respectively.
const
int
num_s_blocks
=
(
S
+
BLOCK_S_SIZE
-
1
)
/
BLOCK_S_SIZE
,
// You can read the following expressions as simplifications of, for example,
num_t_blocks
=
(
T
+
BLOCK_T_SIZE
-
1
)
/
BLOCK_T_SIZE
;
// num_s_blocks = ((S + 1) + BLOCK_SIZE - 1) / BLOCK_SIZE,
// num_blocks_this_iter is an upper bound on the number of blocks that might
// i.e. rounding-up division of (S + 1) by BLOCK_SIZE, and the same for (T + 1).
// be active on this iteration. We go from the bottom left of the image
const
int
num_s_blocks
=
S
/
BLOCK_SIZE
+
1
,
// so that on iter == 0 we process only one block with block-index (0, 0)
num_t_blocks
=
T
/
BLOCK_SIZE
+
1
;
// then on iter == 1 we process block-indexes (1, 0) and (0, 1); and then on iter==2
// we process (2, 0), (1, 1) and (0, 2); and so on. We also will never have more
// num_blocks_this_iter is an upper bound on the number of blocks of size
// than `num_s_blocks` blocks (We'll never have more than num_t_blocks either, but
// (BLOCK_SIZE by BLOCK_SIZE) that might be active on this iteration. We go
// the numbering we use corresponds to s and not t, so if we hit the num_t_blocks limit,
// from the bottom left of the image so that on iter == 0 we process only one
// the lowest-numbered blocks on s would just not be active and we'll 'continue' below).
// block with block-index (0, 0) then on iter == 1 we process block-indexes
// (1, 0) and (0, 1); and then on iter==2 we process (2, 0), (1, 1) and (0,
// 2); and so on. We also will never have more than `num_s_blocks` blocks
// (We'll never have more than num_t_blocks either, but the numbering we use
// corresponds to s and not t, so if we hit the num_t_blocks limit, the
// lowest-numbered blocks on s would just not be active and we'll 'continue'
// below).
int
num_blocks_this_iter
=
min
(
iter
+
1
,
num_s_blocks
);
int
num_blocks_this_iter
=
min
(
iter
+
1
,
num_s_blocks
);
__shared__
scalar_t
px_buf
[
BLOCK_S_SIZE
][
BLOCK_T_SIZE
],
// For the block with s_block_begin == 0 and t_block_begin == 0 (for
py_buf
[
BLOCK_S_SIZE
][
BLOCK_T_SIZE
],
// easy illustration), px_buf[s][t] will contain exp(px[s - 1][t]); or 0
p_buf
[
BLOCK_S_SIZE
+
1
][
BLOCK_T_SIZE
+
1
];
// 1st row/col of p_buf
// for out-of-range indexes.
// correspond to the previous
// Likewise, py_buf[s][t] will contain exp(py[s][t - 1]).
// blocks, or an edge case.
__shared__
scalar_t
px_buf
[
BLOCK_SIZE
][
BLOCK_SIZE
],
py_buf
[
BLOCK_SIZE
][
BLOCK_SIZE
];
// 1st row/col of p_buf correspond to the previous blocks, or to an edge case.
// So, again for this origin block, p_buf[s][t] corresponds to exp(p[s - 1][t
// - 1] - normalizer); or 0 for out-of-range values.
__shared__
scalar_t
p_buf
[
BLOCK_SIZE
+
1
][
BLOCK_SIZE
+
1
];
// boundary_buf will be used to store the b'th row of `boundary` if we have
// boundary information supplied.
__shared__
int64_t
boundary_buf
[
4
];
__shared__
boundary_buf
[
4
];
if
(
threadIdx
.
x
==
0
)
{
boundary_buf
[
0
]
=
0
;
boundary_buf
[
1
]
=
0
;
boundary_buf
[
2
]
=
S
;
boundary_buf
[
3
]
=
T
;
}
// batch_block_iter iterates over both batch elements (index b), and block
// batch_block_iter iterates over both batch elements (index b), and block
// indexes
// indexes
in the range [0..num_blocks_this_iter-1]
for
(
int
batch_block_iter
=
blockIdx
.
x
;
for
(
int
batch_block_iter
=
blockIdx
.
x
;
batch_block_iter
<
B
*
num_blocks_this_iter
;
batch_block_iter
<
B
*
num_blocks_this_iter
;
batch_block_iter
+=
gridDim
.
x
)
{
batch_block_iter
+=
gridDim
.
x
)
{
...
@@ -149,146 +183,219 @@ void mutual_information_kernel(
...
@@ -149,146 +183,219 @@ void mutual_information_kernel(
bool
is_origin_block
=
(
s_block_begin
*
t_block_begin
==
0
);
bool
is_origin_block
=
(
s_block_begin
*
t_block_begin
==
0
);
int
s_end
,
t_end
;
// s_end and t_end are the end points (last-plus-one) of the entire sequence.
int
s_end
,
t_end
;
// s_end and t_end are the end points (last-plus-one) of the entire sequence.
if
(
boundary
.
size
(
0
)
==
0
)
{
if
(
threadDim
.
x
<
4
&&
boundary
.
size
(
0
)
!=
0
)
s_end
=
S
;
boundary_buf
[
threadDim
.
x
]
=
boundary
[
b
][
threadDim
.
x
];
t_end
=
T
;
__syncthreads
();
}
else
{
int
s_begin
=
boundary_buf
[
0
],
if
(
threadDim
.
x
<
4
)
t_begin
=
boundary_buf
[
1
];
boundary_buf
[
threadDim
.
x
]
=
boundary
[
b
][
threadDim
.
x
];
s_end
=
boundary_buf
[
2
];
__syncthreads
();
t_end
=
boundary_buf
[
3
];
int
s_begin
=
boundary_buf
[
0
],
s_block_begin
+=
s_begin
;
t_begin
=
boundary_buf
[
1
];
t_block_begin
+=
t_begin
;
s_end
=
boundary_buf
[
2
];
t_end
=
boundary_buf
[
3
];
// block_S and block_T are the actual sizes of this block, no greater than
s_block_begin
+=
s_begin
;
// (BLOCK_SIZE, BLOCK_SIZE) but possibly less than that if we are towards
t_block_begin
+=
t_begin
;
// the end of the sequence.
}
int
block_S
=
min
(
BLOCK_SIZE
,
s_end
-
s_block_begin
),
block_T
=
min
(
BLOCK_SIZE
,
t_end
-
t_block_begin
);
// block_S and block_T are the actual sizes of this block, up to
// (BLOCK_S_SIZE, BLOCK_T_SIZE) but possibly truncated if we
// are towards the end of the sequence.
int
block_S
=
min
(
BLOCK_T_SIZE
,
s_end
-
s_block_begin
),
block_T
=
min
(
BLOCK_S_SIZE
,
t_end
-
t_block_begin
);
if
(
block_S
<=
0
||
block_T
<=
0
)
if
(
block_S
<=
0
||
block_T
<=
0
)
continue
;
continue
;
// Load px_buf and py_buf. We exponentiate; the assumption is that they
// Load px_buf and py_buf. We exponentiate; the assumption is that they most likely
// won't overflow or underflow! If they overflow we'll detect it later!
// won't overflow or underflow, but if they do overflow we'll detect it later; we'll
// also detect certain kinds of underflow.
for
(
int
i
=
threadDim
.
x
;
i
<
BLOCK_S_SIZE
*
BLOCK_T_SIZE
;
i
+=
block_dimx
)
{
for
(
int
i
=
threadDim
.
x
;
i
<
BLOCK_SIZE
*
BLOCK_SIZE
;
i
+=
blockDim
.
x
)
{
int
t
=
i
%
BLOCK_T_SIZE
,
s
=
i
/
BLOCK_T_SIZE
;
int
t_in_block
=
i
%
BLOCK_SIZE
,
if
(
s
<
block_S
&&
t
<
block_T
)
{
s_in_block
=
i
/
BLOCK_SIZE
,
px_buf
[
s
][
t
]
=
exp
(
px
[
b
][
s
+
s_block_begin
][
t
+
t_block_begin
]);
s
=
s_in_block
+
s_block_begin
,
py_buf
[
s
][
t
]
=
exp
(
py
[
b
][
s
+
s_block_begin
][
t
+
t_block_begin
]);
t
=
t_in_block
+
t_block_begin
;
}
else
{
// Not necessary? We'll see
px_buf
[
s
][
t
]
=
0.0
;
// the comparisons with S and T below just make sure we don't access
py_buf
[
s
][
t
]
=
0.0
;
// out-of-memory regions; they do not guarantee we are in the range given
}
// by s_begin, s_end and so on. Note: comparing as unsigned int makes sure
// the index is nonnegative.
scalar_t
this_px
=
0.0
;
if
(
static_cast
<
unsigned
int
>
(
s
-
1
)
<
static_cast
<
unsigned
int
>
(
S
)
&&
t
<=
T
)
this_px
=
exp
(
px
[
b
][
s
-
1
][
t
]);
px_buf
[
s_in_block
][
t_in_block
]
=
this_px
;
scalar_t
this_py
=
0.0
;
if
(
static_cast
<
unsigned
int
>
(
t
-
1
)
<
static_cast
<
unsigned
int
>
(
T
)
&&
s
<=
S
)
this_py
=
exp
(
py
[
b
][
s
][
t
-
1
]);
py_buf
[
s_in_block
][
t_in_block
]
=
this_py
;
}
}
// Load the 1st row and column of p_buf (except element[0][0] is not needed).
// Load the 1st row and column of p_buf (except element[0][0] is not needed).
if
(
threadIdx
.
x
<
64
)
{
// 64 == warp size...
// Remember: p_buf[s][t] corresponds to exp(p[s + s_block_begin - 1][t + t_block_begin - 1] - normalizer.
if
(
threadIdx
.
x
<=
BLOCK_S_SIZE
)
{
if
(
threadIdx
.
x
<
64
)
{
// 64 == warp size. First half of threads...
// this s and t are offsets relative to the block start
if
(
threadIdx
.
x
<=
BLOCK_SIZE
)
{
int
s
=
threadIdx
.
x
-
1
,
// s_in_p_buf are simply the indexes into p_buf
t
=
-
1
;
int
s_in_p_buf
=
threadIdx
.
x
,
if
(
static_cast
<
unsigned
int
>
(
s
+
s_block_begin
)
<
static_cast
<
unsigned
int
>
(
block_S
)
&&
t_in_p_buf
=
0
,
static_cast
<
unsigned
int
>
(
t
+
t_block_begin
)
<
static_cast
<
unsigned
int
>
(
block_T
))
s
=
s_in_p_buf
+
s_block_begin
-
1
,
p_buf
[
threadIdx
.
x
][
0
]
=
p
[
s
+
s_block_begin
][
s
+
t_block_begin
];
t
=
t_in_p_buf
+
t_block_begin
-
1
;
else
// The if-statement below just guards against out-of-range memory
p_buf
[
threadIdx
.
x
][
0
]
=
-
infinity
;
// accesses, it does not guarantee that we really need these values.
scalar_t
this_p
=
-
INFINITY
;
if
(
static_cast
<
unsigned
int
>
(
s
)
<
static_cast
<
unsigned
int
>
(
S
)
&&
static_cast
<
unsigned
int
>
(
t
)
<
static_cast
<
unsigned
int
>
(
T
))
this_p
=
p
[
s
+
s_block_begin
][
s
+
t_block_begin
];
p_buf
[
threadIdx
.
x
][
0
]
=
this_p
;
}
}
}
else
{
}
else
{
// Another warp handles the other leg
if
(
threadIdx
.
x
-
64
<=
BLOCK_T_SIZE
)
{
if
(
threadIdx
.
x
-
64
<=
BLOCK_SIZE
)
{
int
i
=
threadIdx
.
x
-
64
,
int
s_in_p_buf
=
0
,
t
=
i
-
1
,
t_in_p_buf
=
threadIdx
.
x
-
64
,
s
=
-
1
;
s
=
s_in_p_buf
+
s_block_begin
-
1
,
if
(
static_cast
<
unsigned
int
>
(
s
+
s_block_begin
)
<
static_cast
<
unsigned
int
>
(
block_S
)
&&
t
=
t_in_p_buf
+
t_block_begin
-
1
;
static_cast
<
unsigned
int
>
(
t
+
t_block_begin
)
<
static_cast
<
unsigned
int
>
(
block_T
))
// The if-statement below just guards against out-of-range memory
p_buf
[
0
][
i
]
=
p
[
s
+
s_block_begin
][
s
+
t_block_begin
];
// accesses, it does not guarantee that we really need these values.
else
{
scalar_t
this_p
=
-
INFINITY
;
p_buf
[
0
][
i
]
=
(
is_origin_block
&&
i
==
1
?
1.0
/
if
(
static_cast
<
unsigned
int
>
(
s
)
<
static_cast
<
unsigned
int
>
(
S
)
&&
-
infinity
;
static_cast
<
unsigned
int
>
(
t
)
<
static_cast
<
unsigned
int
>
(
T
))
}
this_p
=
p
[
s
+
s_block_begin
][
s
+
t_block_begin
];
p_buf
[
threadIdx
.
x
][
0
]
=
this_p
;
}
}
}
__syncthreads
();
// We read p_buf in log-space; subtract 'normalizer', which mathematically
// could be any finite number, to get in a reasonable range of probabilities,
// and then exponentiate. We'll do everything in non-log space, and later
// take a log before we write out the data.
scalar_t
normalizer
=
(
is_origin_block
?
0.0
:
max
(
px_buf
[
0
][
1
],
px_buf
[
1
][
0
]));
// Normalize and exponentiate the edge elements of p_buf, i.e. the elements
// where at one index is 0. The [0][0] element is special; we write 0.0,
// and we'll overwrite with 1.0 if there is a panic situation due to
// overflow.
if
(
threadIdx
.
x
<=
BLOCK_SIZE
)
{
if
(
threadIdx
.
x
==
0
)
{
// p_buf[0][0] is never used for its normal purpose; we set it to zero.
// We'll later write an infinity there if something goes wrong, as a
// 'panic' indicator.
p_buf
[
threadIdx
.
x
][
0
]
=
(
threadIdx
.
x
==
0
?
0.0
:
exp
(
p_buf
[
threadIdx
.
x
][
0
]
-
normalizer
));
}
}
else
if
((
int
)
threadIdx
.
x
-
64
<
BLOCK_SIZE
)
{
p_buf
[
0
][
threadIdx
.
x
+
1
]
=
exp
(
p_buf
[
0
][
threadIdx
.
x
+
1
]
-
normalizer
);
}
}
if
(
threadIdx
.
x
==
0
)
{
// This if-statement is an optimization and modification of the loop below
// for the value i == 0, i.e. inner-iteration == 0. The modification
// is to use 0.0 if this is the "origin block" with s_block_begin == 0 and
// t_block_begin == 0. This corresponds to the probability of the pair of
// sequences of length (0, 0).
p_buf
[
1
][
1
]
=
(
is_origin_block
?
0.0
:
p_buf
[
0
][
1
]
*
px_buf
[
0
][
0
]
+
p_buf
[
1
][
0
]
*
py_buf
[
0
][
0
]);
}
N
=
params
.
size
(
1
)
-
1
,
scalar_t
p_buf_s1_t
;
// This is for an optimization.
K
=
N
/
2
;
// Note: N and K are powers of 2, with K >= 1.
if
(
i
<
BLOCK_SIZE
)
{
int
s
=
threadIdx
.
x
;
p_buf_s1_t
=
p_buf
[
s
+
1
][
0
];
}
const
int
c
=
blockIdx
.
x
;
// c is channel index
for
(
int
i
=
1
;
i
<
2
*
BLOCK_SIZE
;
i
++
)
{
// i is the inner iteration, which corresponds to the (s + t) indexes of the
// elements within the block that we write. So i == 0 writes positions
// (s, t) == (0, 0); i == 1 writes (0, 1) and (1, 0); i == 2 writes
// (0, 2), (1, 1) and (2, 1); and so on.
// Note: not many threads participate in this part, only up to BLOCK_SIZE
// at most. Unfortunately we couldn't figure out a very meaningful way
// for more threads to do work, that looked like it would really spead
// things up.
// So this kernel does (2 * BLOCK_SIZE) iterations, which may seem a lot,
// but we do at least do the I/O in an efficient way and keep the
// inner loop simple and fast (e.g. no exp() or log()).
int
s
=
threadIdx
.
x
,
t
=
i
-
s
;
if
(
t
>=
0
)
{
// p_buf is indexed by s + 1 and t + 1 because it has an extra initial
// row and column for context from previous blocks. Taking into account
// the way these buffers relate to the tensors p, px and py, and
// ignoring `normalizer`, code below can be interpreted as follows,
// writing sbb for s_block_begin and tbb for t_block_begin:
//
// p[b][s+sbb][t+tbb] = LogAdd(p[b][s+sbb-1][t+tbb] + px[s+sbb-1][t+tbb],
// p[b][s+sbb][t+tbb-1] + py[s+sbb][t+tbb-1]
//
// where you can see that apart from the offsets of tbb and sbb, this is
// the same as the recursion defined for p in
// mutual_information.py:mutual_information_recursion().
#if 1
p_buf
[
s
+
1
][
t
+
1
]
=
p_buf
[
s
][
t
+
1
]
*
px_buf
[
s
][
t
]
+
p_buf
[
s
+
1
][
t
]
*
py_buf
[
s
][
t
];
#else
// This is an optimization of the statement above (the other half of
// this #if/#else) where we keep p_buf[s + 1][t] in a register to avoid
// the need for a load from shared memory.
p_buf_s1_t
=
p_buf
[
s
][
t
+
1
]
*
px_buf
[
s
][
t
]
+
p_buf_s1_t
*
py_buf
[
s
][
t
];
p_buf
[
s
+
1
][
t
+
1
]
=
p_buf_s1_t
;
#endif
}
__syncthreads
();
}
scalar_t
*
y_vals
=
(
scalar_t
*
)
extern_buf
,
// [N], actually there are 3
// Write out the data.
// spaces between here and
for
(
int
i
=
threadDim
.
x
;
i
<
BLOCK_SIZE
*
BLOCK_SIZE
;
i
+=
blockDim
.
x
)
{
// `params_buf` for storing scale
int
t
=
i
%
BLOCK_SIZE
,
s
=
i
/
BLOCK_SIZE
;
// and inv_scale and l == params[c][0].
if
(
s
<
block_S
&&
t
<
block_T
)
{
*
params_buf
=
(
scalar_t
*
)
y_vals
+
3
+
N
;
// [N]. params_buf[n] ontains params[c][n-1].
float
this_p
=
p_buf
[
s
+
1
][
t
+
1
];
// params_buf[-1] contains params[c][0] == log of scale;
p
[
b
][
s
+
s_block_begin
][
t
+
t_block_begin
]
=
normalizer
+
log
(
this_p
);
// params_buf[-2] contains scale, params_buf[-3]
if
(
this_p
-
this_p
!=
0
||
this_p
==
0
)
// contains inv_scale.
p_buf
[
0
][
0
]
=
1.0
;
// This is a "panic" flag.
// Load parameters
}
if
(
threadIdx
.
x
<=
N
)
}
params_buf
[
threadIdx
.
x
-
1
]
=
params
[
c
][
threadIdx
.
x
];
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
__syncthreads
();
scalar_t
scale
=
exp
(
params_buf
[
-
1
]);
params_buf
[
-
2
]
=
scale
;
params_buf
[
-
3
]
=
1.0
/
scale
;
}
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
scalar_t
scale
=
params_buf
[
-
2
],
// Write `ans`, if this is the final (top-right) block in its sequence
sum_positive
=
0.0
;
// Logically, the following equation corresponds to:
for
(
int
i
=
0
;
i
<
K
;
i
++
)
{
// ans[b] = p[b][s_end][t_end]
// params_buf is indexed with an index one less than params.
if
(
s_block_begin
+
S
>
s_end
&&
t_block_begin
+
T
>
t_end
)
scalar_t
pos_scaled_param
=
params_buf
[
K
+
i
]
*
scale
;
ans
[
b
]
=
normalizer
+
log
(
p_buf
[
s_end
-
s_block_begin
+
1
][
t_end
-
t_block_begin
+
1
]);
y_vals
[
K
+
i
]
=
sum_positive
-
pos_scaled_param
*
i
;
sum_positive
+=
pos_scaled_param
;
}
}
else
if
(
threadIdx
.
x
==
64
)
{
scalar_t
scale
=
params_buf
[
-
2
],
sum_negative
=
0.0
;
for
(
int
i
=
0
;
i
<
K
;
i
++
)
{
scalar_t
neg_scaled_param
=
params_buf
[
K
-
1
-
i
]
*
scale
;
sum_negative
-=
neg_scaled_param
;
y_vals
[
K
-
i
-
1
]
=
sum_negative
+
neg_scaled_param
*
(
i
+
1
);
}
}
}
__syncthreads
();
scalar_t
inv_scale
=
params_buf
[
-
3
];
int
T_inc
=
THREADS_PER_BLOCK
/
images_per_thread_block
,
b_offset
=
threadIdx
.
x
/
T_inc
,
// offset within batch
t_start
=
threadIdx
.
x
%
T_inc
;
for
(
int
b
=
blockIdx
.
y
*
images_per_thread_block
+
b_offset
;
b
<
B
;
if
(
p_buf
[
0
][
0
]
!=
0.0
)
{
b
+=
gridDim
.
y
*
images_per_thread_block
)
{
// "panic" flag set. We need to re-do the computation using log-add.
// We do "t += THREADS_PER_BLOCK" instead of t += (THREADS_PER_BLOCK /
// This time we won't use the buffers, we'll just load and save from main
// images_per_thread_block) as a small optimization because the only case we
// memory. This code should very rarely be reached; and anyway, caching
// really need to loop is when images_per_thread_block == 1:a we only let
// should help us quite a bit.
// images_per_thread_block > 1 if T * images_per_thread_block <=
for
(
int
i
=
0
;
i
<
2
*
BLOCK_SIZE
;
i
++
)
{
// THREADS_PER_BLOCK.
int
block_s
=
threadIdx
.
x
,
for
(
int
t
=
t_start
;
t
<
T
;
t
+=
THREADS_PER_BLOCK
)
{
block_t
=
i
-
block_s
;
scalar_t
this_input
=
input
[
b
][
c
][
t
],
if
(
static_cast
<
unsigned
int
>
(
t
)
<
static_cast
<
unsigned
int
>
(
block_T
)
&&
x
=
this_input
*
inv_scale
+
K
;
block_s
<
block_S
)
{
if
(
x
<
0
)
x
=
0
;
int
s
=
block_s
+
s_block_begin
,
else
if
(
x
>=
N
)
x
=
N
-
1
;
t
=
block_t
+
t_block_begin
;
// C++ rounds toward zero.
float
p_s1
=
(
s
==
0
?
-
INFINITY
:
p
[
b
][
s
-
1
][
t
]),
int
n
=
(
int
)
x
;
p_t1
=
(
t
==
0
?
-
INFINITY
:
p
[
b
][
s
][
t
-
1
]),
// OK, at this point, 0 <= min < N. Versus the CPU code, we removed the
this_px
=
px
[
b
][
s
][
t
],
this_py
=
py
[
b
][
s
][
t
];
// factor of 'scale' because params_buf already has that factor.
float
this_p
=
LogAdd
(
p_s1
+
this_px
,
output
[
b
][
c
][
t
]
=
this_input
*
params_buf
[
n
]
+
y_vals
[
n
];
p_t1
+
this_py
);
if
(
i
==
0
&&
is_origin_block
)
this_p
=
0.0
;
p
[
b
][
s
][
t
]
=
this_p
;
}
}
if
(
threadIdx
.
x
==
0
)
{
// Write `ans`, if this is the final (top-right) block in its sequence.
// This is only reached in the 'panic situation' where we had overflow.
if
(
s_block_begin
+
S
>
s_end
&&
t_block_begin
+
T
>
t_end
)
ans
[
b
]
=
p
[
b
][
s_end
][
t_end
];
}
}
}
}
}
}
}
...
@@ -334,74 +441,135 @@ __forceinline__ __device__ scalar_t strided_reduce_sum(int N,
...
@@ -334,74 +441,135 @@ __forceinline__ __device__ scalar_t strided_reduce_sum(int N,
}
}
/*
/*
Backward of mutual_information. Each thread group handles a single channel (channel
Backward of mutual_information.
c = blockIdx.x); the gridDim is (C, nb, 1) where 1 <= nb <= B (nb relates to the
image within the batch).
Template args:
If we were to write the forward pass in non-log space, it would be (ignoring
scalar_t: the floating-point type, e.g. float, double, maybe half.
edge cases), as follows... we'll prefix all the variable names with e, e.g. ep,
to clarify that it's the exp of the actual argument p:
Args:
ep[b][s][t] = ep[b][s - 1][t] * epx[b][s - 1][t] +
input: input image, shape (B, C, T) where B is batch size, C is
ep[b][s][t - 1] * epy[b][s][t - 1]. (eq. 1)
the number of channels and T is the time axis. (For more-than-1d
convolution setups, T would really be more than 1 axis, reshaped).
params: of shape (C, N+1) where N is the number of linear regions in the
piecewise linear function; params[c][0] is l which is
a log scale parameter that dictates how far apart
the discontinuities in the piecewise linear function are,
and params[c][n+1] for 0 <= n < N are the derivatives
of the linear parts of the piecewise linear function.
The discontinuities of the function are at:
exp(l) * [ -(N/2 - 1), -(N/2 - 2), ... (N/2 - 1) ]
output: The transformed input, shape (B , C, T)
images_per_thread_block: The number of images processed by each thread
block. The calling code must guarantee that this is a power
of 2, and that EITHER:
(THREADS_PER_BLOCK / images_per_thread_block >= T AND
THREADS_PER_BLOCK / images_per_thread_block >= N),
OR
images_per_thread_block == 1
.. this is used for a small optimization.
ALSO,
(A)
First we consider the part that involves recursion, i.e. the part involving only gradients of
ep. The backprop involving ep only would be:
ep_grad[b][s - 1][t] += ep_grad[b][s][t] * epx[b][s - 1][t]
ep_grad[b][s][t - 1] += ep_grad[b][s][t] * epy[b][s][t - 1].
.. and if we add 1 to the s index of the first equation above and 1 to the
t index of the second equation, we can see that:
This kernel is allocated with `extern_buf` containing enough memory
ep_grad[b][s][t] = ep_grad[b][s + 1][t] * epx[b][s][t] +
to store 2*N + 3 values of type scalar_t.
ep_grad[b][s][t + 1] * epy[b][s][t].
Now, if ep = exp(p), then ep_grad == dy/dep == dy/dp dp/dep == dy/dp / (dep/dp) == dy/dp / exp(p)
== dy/dp / ep. == p_grad / ep.
I.e. ep_grad = p_grad / ep.
So we can write the above as:
p_grad[b][s][t] / ep[b][s][t] = p_grad[b][s + 1][t] / ep[b][s + 1][t] * epx[b][s][t] +
p_grad[b][s][t + 1] / ep[b][s][t + 1] * epy[b][s][t].
Or, rearranging:
p_grad[b][s][t] = p_grad[b][s + 1][t] * exp(p[b][s][t] + px[b][s][t] - p[b][s + 1][t]) +
p_grad[b][s][t + 1] * exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1]). (eq. 2)
(B) The following is the backprop for epx and epy from (eq. 1):
epx_grad[b][s - 1][t] += ep_grad[b][s][t] * ep[b][s - 1][t]
epy_grad[b][s][t - 1] += ep_grad[b][s][t] * ep[b][s][t - 1]
.. adding 1 to the s indexes in the 1st equation and to the t indexes in the 2nd:
epx_grad[b][s][t] = ep_grad[b][s + 1][t] * ep[b][s][t]
epy_grad[b][s][t] = ep_grad[b][s][t + 1] * ep[b][s][t]
The blockDim must equal (THREADS_PER_BLOCK, 1, 1)
Using, similar to the above, ep_grad = p_grad / ep, and similarly,
epx_grad = px_grad / epx and epy_grad = py_grad / epy, and writing exp(p) for p and so on,
The requirements on the grid dimension are:
the above becomes
gridDim.x == num-channels C (required)
1 <= gridDim.y <= B, where B is the number of blocks
gridDim.z == 1
When we invoke this kernel, we'll invoke it as:
mutual_information_backward_kernel<<<gridDim, blockDim, bytesShared, stream>>>
where bytesShared is the number of bytes needed in `extern_buf`:
bytesShared = sizeof(shared_t) * (2N + 3)
We also require that N <= THREADS_PER_BLOCK (for best performance,
N should be quite small, like no larger than 8 or so).
We also require 4 <= N <= 16 for this code!
And we require that
N <= (THREADS_PER_BLOCK / images_per_thread_block)
(both sides will be powers of 2).. this ensures that blocks of threads
summing the N values are always within the same image, which helps
avoid a problem where some loops over 'b' would be done earlier
than others, and we'd end up counting certain pixels twice as their
output_grad would stay nonzero.
px_grad[b][s][t] / exp(px[b][s][t]) = p_grad[b][s + 1][t] / exp(p[b][s + 1][t] * exp(p[b][s][t])
py_grad[b][s][t] / exp(py[b][s][t]) = p_grad[b][s][t + 1] / exp(p[b][s][t + 1] * exp(p[b][s][t])
Rearranging:
px_grad[b][s][t] = p_grad[b][s + 1][t] * exp(p[b][s][t] + px[b][s][t] - p[b][s + 1][t]) (eq. 3a)
py_grad[b][s][t] = p_grad[b][s][t + 1] * exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1]) (eq. 3b)
Defining terms that are common to (eq. 2) and (eqs. 3a,3b), write:
xderiv[b][s][t] := exp(p[b][s][t] + px[b][s][t] - p[b][s + 1][t]) (eq. 4)
yderiv[b][s][t] := exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1]) (eq. 5)
.. and note that these quantities are <= 1 so there is no problem doing
the exponentiation. So the recursion can be simplified as:
p_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] +
p_grad[b][s][t + 1] * yderiv[b][s][t] (eq. 6)
px_grad[b][s][t] = p_grad[b][s][t + 1] * yderiv[b][s][t] (eq. 7)
py_grad[b][s][t] = p_grad[b][s][t + 1] * yderiv[b][s][t] (eq. 8)
(It might seem like we could just reuse px_grad and py_grad for (eq. 6), but it's
not clear to me that this is the best strategy since that would require an extra
write to shared memory within the loop that's the limiting factor.)
The backward pass will be slightly different from the forward pass in terms of
how we store p (and p_grad), because for writing a particular block of p_grad, we
need context on the top and right instead of the bottom and left.
*/
*/
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
__global__
void
mutual_information_backward_kernel
(
void
mutual_information_backward_kernel
(
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
input
,
// B, C, T, i.e. batch, channels, time
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
px
,
// B, S, T + 1, i.e. batch, x_seq_length, y_seq_length + 1
torch
::
PackedTensorAccessor32
<
scalar_t
,
2
>
params
,
// C, N + 1
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
py
,
// B, S + 1, T.
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
output_grad
,
// B, C, T
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
p
,
// B, S + 1, T + 1. Produced in forward pass.
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
input_grad
,
// B, C, T
torch
::
PackedTensorAccessor32
<
scalar_t
,
1
>
ans_grad
,
// [B]. This is an input.
// params_grad is of dim (gridDim.y, C, N + 1), we'll sum over dim 0.
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
p_grad
,
// B, S + 1, T + 1. This is a temporary.
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
params_grad
,
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
px_grad
,
// B, S, T + 1.
int
images_per_thread_block
)
{
// B, C, T
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
py_grad
,
// B, S + 1, T.
torch
::
PackedTensorAccessor32
<
int64_t
,
2
>
boundary
,
// B, 4; or 0, 0 if boundaries are the defaults (0, 0, S, T)
int
iter
)
{
// This kernel is sequentially called with 'iter' = num_iters - 1, num_iters - 2, .. 0,
// where num_iters can be taken to be any sufficiently large number but will actually be:
// (S+BLOCK_S_SIZE-1)/BLOCK_S_SIZE + (T+BLOCK_T_SIZE-1)/BLOCK_T_SIZE - 1
const
int
B
=
px
.
size
(
0
),
S
=
px
.
size
(
1
),
T
=
py
.
size
(
2
);
// For statements that are the same as the forward pass, we are omitting some comments
// what we made there. We'll focus, in the comments, on differences from the forward pass.
const
int
num_s_blocks
=
S
/
BLOCK_SIZE
+
1
,
num_t_blocks
=
T
/
BLOCK_SIZE
+
1
,
num_blocks_this_iter
=
min
(
iter
+
1
,
num_s_blocks
);
// px_buf and py_buf are used temporarily to store the px and py values,
// but then modified to store the "xderiv" and "yderiv" values defined
// in (eq. 5) and (eq. 6) above. For out-of-range values, we'll write 0.0
// here.
__shared__
scalar_t
px_buf
[
BLOCK_SIZE
][
BLOCK_SIZE
],
py_buf
[
BLOCK_SIZE
][
BLOCK_SIZE
];
// p_buf is initially used to store p, and then (after we are done putting
// xderiv and yderiv into px_buf and py_buf) it is repurposed to store
// p_grad.
//
// Unlike in the forward pass, p_buf has the same numbering as px_buf and
// py_buf not offset by 1: e.g., for the origin block, p_buf[0][0] refers
// to p[0][0] and not p[-1][-1]. The p_buf block is larger by 1 than
// the block for px_buf and py_buf; unlike in the forward pass, we store
// context on the top right, not the bottom left, i.e. the elements at
// (one past the largest indexes in the block).
//
// For out-of-range elements of p_buf, we'll put zero.
__shared__
scalar_t
p_buf
[
BLOCK_SIZE
+
1
][
BLOCK_SIZE
+
1
];
// boundary_buf will be used to store the b'th row of `boundary` if we have
// boundary information supplied.
__shared__
int64_t
boundary_buf
[
4
];
boundary_buf
[
0
]
=
0
;
boundary_buf
[
1
]
=
0
;
boundary_buf
[
2
]
=
S
;
boundary_buf
[
3
]
=
T
;
const
int
B
=
input
.
size
(
0
),
const
int
B
=
input
.
size
(
0
),
C
=
input
.
size
(
1
),
C
=
input
.
size
(
1
),
...
@@ -669,38 +837,52 @@ void mutual_information_backward_kernel(
...
@@ -669,38 +837,52 @@ void mutual_information_backward_kernel(
// forward of mutual_information. See """... """ comment of `mutual_information` in
// mutual_information.py for documentation of the behavior of this function.
torch
::
Tensor
mutual_information_cuda
(
torch
::
Tensor
px
,
torch
::
Tensor
py
,
std
::
optional
<
torch
::
Tensor
>
optional_boundary
,
torch
::
Tensor
p
)
{
TORCH_CHECK
(
px
.
dim
()
==
3
,
"px must be 3-dimensional"
);
TORCH_CHECK
(
py
.
dim
()
==
3
,
"py must be 3-dimensional."
);
TORCH_CHECK
(
p
.
dim
()
==
3
,
"p must be 3-dimensional."
);
TORCH_CHECK
(
px
.
device
().
is_cuda
()
&&
py
.
device
().
is_cuda
()
&&
p
.
device
().
is_cuda
(),
"inputs must be CUDA tensors"
);
torch
::
Tensor
mutual_information_cuda
(
torch
::
Tensor
input
,
auto
scalar_t
=
px
.
scalar_type
();
torch
::
Tensor
params
)
{
auto
opts
=
torch
::
TensorOptions
().
dtype
(
scalar_t
).
device
(
px
.
device
());
TORCH_CHECK
(
input
.
dim
()
==
3
,
"input must be 3-dimensional"
);
TORCH_CHECK
(
params
.
dim
()
==
2
,
"params must be 2-dimensional."
);
TORCH_CHECK
(
params
.
size
(
1
)
>=
3
&&
((
params
.
size
(
1
)
-
1
)
&
(
params
.
size
(
1
)
-
2
))
==
0
,
"params.size(1) has invalid value, must be a power of 2 plus 1."
);
TORCH_CHECK
(
params
.
size
(
0
)
==
input
.
size
(
1
),
"params vs input channels mismatch"
);
TORCH_CHECK
(
input
.
device
().
is_cuda
(),
"Input must be a CUDA tensor"
);
TORCH_CHECK
(
params
.
device
().
is_cuda
(),
"Params must be a CUDA tensor"
);
const
int
B
=
input
.
size
(
0
),
C
=
input
.
size
(
1
),
T
=
input
.
size
(
2
),
N
=
params
.
size
(
1
)
-
1
;
auto
scalar_t
=
input
.
scalar_type
();
const
int
B
=
px
.
size
(
0
),
auto
opts
=
torch
::
TensorOptions
().
dtype
(
scalar_t
).
device
(
input
.
device
());
S
=
px
.
size
(
1
),
T
=
px
.
size
(
2
)
-
1
;
TORCH_CHECK
(
py
.
size
(
0
)
==
B
&&
py
.
size
(
1
)
==
S
+
1
&&
py
.
size
(
2
)
==
T
);
TORCH_CHECK
(
p
.
size
(
0
)
==
B
&&
p
.
size
(
1
)
==
S
+
1
&&
p
.
size
(
2
)
==
T
+
1
);
torch
::
Tensor
ans
=
torch
::
empty
({
B
},
opts
);
int
num_threads
=
128
,
num_blocks
=
128
;
const
int
num_s_blocks
=
S
/
BLOCK_SIZE
+
1
,
num_t_blocks
=
T
/
BLOCK_SIZE
+
1
,
num_iters
=
std
::
max
<
int
>
(
num_s_blocks
,
num_t_blocks
);
bool
has_boundary
=
(
bool
)
optional_boundary
;
if
(
!
has_boundary
)
optional_boundary
=
torch
::
empty
({
0
,
0
},
long_opts
);
for
(
int
iter
=
0
;
iter
<
num_iters
;
iter
++
)
{
mutual_information_kernel
<
scalar_t
,
32
><<<
num_blocks
,
num_threads
>>>
(
px
.
packed_accessor32
<
scalar_t
,
3
>
(),
py
.
packed_accessor32
<
scalar_t
,
3
>
(),
p
.
packed_accessor32
<
scalar_t
,
3
>
(),
optional_boundary
.
value
().
packed_accessor32
<
int64_t
,
2
>
(),
ans
.
packed_accessor32
<
scalar_t
,
1
>
(),
iter
);
}
torch
::
Tensor
output
=
torch
::
empty
({
B
,
C
,
T
},
opts
);
if
(
C
*
B
*
T
==
0
)
return
output
;
int
images_per_thread_block
=
1
;
while
(
images_per_thread_block
*
2
*
T
<=
THREADS_PER_BLOCK
)
images_per_thread_block
*=
2
;
int
grid_dim_y
=
1
;
int
grid_dim_y
=
1
;
// If the number of channels is quite small (<128) we can launch more thread
// If the number of channels is quite small (<128) we can launch more thread
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment