Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
3c1ec347
Commit
3c1ec347
authored
Jul 29, 2021
by
Daniel Povey
Browse files
Get it to a stage where it looks like it might compile
parent
8ed6deff
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
277 additions
and
331 deletions
+277
-331
torch_mutual_information/mutual_information.py
torch_mutual_information/mutual_information.py
+33
-27
torch_mutual_information/mutual_information_cpu.cpp
torch_mutual_information/mutual_information_cpu.cpp
+16
-113
torch_mutual_information/mutual_information_cuda_kernel.cu
torch_mutual_information/mutual_information_cuda_kernel.cu
+228
-191
No files found.
torch_mutual_information/mutual_information.py
View file @
3c1ec347
...
@@ -44,66 +44,72 @@ except ImportError:
...
@@ -44,66 +44,72 @@ except ImportError:
def
_mutual_information_forward_dispatcher
(
px
:
torch
.
Tensor
,
py
:
torch
.
Tensor
,
def
_mutual_information_forward_dispatcher
(
px
:
torch
.
Tensor
,
py
:
torch
.
Tensor
,
boundaries
:
torch
.
Tensor
,
q
:
torch
.
Tensor
)
->
torch
.
Tensor
:
boundaries
:
torch
.
Tensor
,
p
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
input
.
is_cuda
:
if
input
.
is_cuda
:
if
torch_mutual_information_cuda
is
None
:
if
torch_mutual_information_cuda
is
None
:
raise
EnvironmentError
(
f
'Failed to load native CUDA module'
)
raise
EnvironmentError
(
f
'Failed to load native CUDA module'
)
return
torch_mutual_information_cuda
.
mutual_information_cuda
(
return
torch_mutual_information_cuda
.
mutual_information_cuda
(
px
,
py
,
boundaries
,
q
)
px
,
py
,
boundaries
,
p
)
else
:
else
:
return
torch_mutual_information_cpu
.
mutual_information_cpu
(
return
torch_mutual_information_cpu
.
mutual_information_cpu
(
px
,
py
,
boundaries
,
q
)
px
,
py
,
boundaries
,
p
)
def
_mutual_information_backward_dispatcher
(
px
:
torch
.
Tensor
,
py
:
torch
.
Tensor
,
def
_mutual_information_backward_dispatcher
(
px
:
torch
.
Tensor
,
py
:
torch
.
Tensor
,
boundaries
:
torch
.
Tensor
,
q
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
boundaries
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
ans_grad
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
px
.
is_cuda
:
if
px
.
is_cuda
:
if
torch_mutual_information_cuda
is
None
:
if
torch_mutual_information_cuda
is
None
:
raise
EnvironmentError
(
f
'Failed to load native CUDA module'
)
raise
EnvironmentError
(
f
'Failed to load native CUDA module'
)
return
tuple
(
torch_mutual_information_cuda
.
mutual_information_backward_cuda
(
overwrite_ans_grad
=
True
px
,
py
,
boundaries
,
q
))
if
overwrite_ans_grad
:
ans_grad_copy
=
ans_grad
.
clone
()
ans
=
tuple
(
torch_mutual_information_cuda
.
mutual_information_backward_cuda
(
px
,
py
,
boundaries
,
p
,
ans_grad_copy
,
overwrite_ans_grad
))
if
overwrite_ans_grad
:
if
not
torch
.
allclose
(
ans_grad
,
ans_grad_copy
,
rtol
=
1.0e-02
):
print
(
f
"Warning: possible excsssive roundoff in mutual information backward "
"recursion: {ans_grad} vs. {ans_grad_copy}"
);
return
ans
else
:
else
:
return
tuple
(
torch_mutual_information_cpu
.
mutual_information_backward_cpu
(
return
tuple
(
torch_mutual_information_cpu
.
mutual_information_backward_cpu
(
px
,
py
,
boundaries
,
q
))
px
,
py
,
boundaries
,
p
,
ans_grad
))
class
MutualInformationRecursionFunction
(
torch
.
autograd
.
Function
):
class
MutualInformationRecursionFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
px
:
torch
.
Tensor
,
py
:
torch
.
Tensor
,
boundaries
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
ctx
,
px
:
torch
.
Tensor
,
py
:
torch
.
Tensor
,
boundaries
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
(
B
,
S
,
T
)
=
px
.
shape
(
B
,
S
,
T1
)
=
px
.
shape
T
=
T1
-
1
;
assert
py
.
shape
==
(
B
,
S
+
1
,
T
)
if
boundaries
is
not
None
:
assert
boundaries
.
shape
==
(
B
,
4
)
# p is a tensor of shape (B, S + 1, T + 1) were p[s][t] is
related to
# p is a tensor of shape (B, S + 1, T + 1) were p[s][t] is
the
# the mutual information of the pair of subsequences of x and y that are of
# the mutual information of the pair of subsequences of x and y that are of
# length s and t respectively. p[0][0] will be 0.0 and p[S][T] is
# length s and t respectively. p[0][0] will be 0.0 and p[S][T] is
# the mutual information of the entire pair of sequences, i.e. of lengths
# the mutual information of the entire pair of sequences, i.e. of lengths
# S and T respectively.
# S and T respectively.
# It is computed as follows (in C++ and CUDA):
# p[b,0,0] = 0.0
# q is a rearrangement of a tensor p which is of shape (B,S,T),
# p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
# using p[b,s,t] == q[b,s+t,t]. The reason for working with this
# p[b,s,t-1] + py[b,s,t-1])
# representation is that each row of q depends only on the previous row,
# if s > 0 or t > 0,
# so we can access the rows sequenctially and this leads to
# treating values with any -1 index as -infinity.
# better memory access patterns. We are assuming that most likely
# .. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
# T < S, which means that q should not require much more memory than p.
#
# Actually we access q beginning from 0 indexes even if `boundaries`
# has t_begin > 0 or s_begin > 0, i.e. we really access q as
# q[b, s-s_begin + t-t_begin, t-t_begin];
# note, rows of `boundaries` are [s_begin, t_begin, s_end, t_end].
p
=
torch
.
empty
(
B
,
S
+
1
,
T
+
1
,
device
=
px
.
device
,
dtype
=
px
.
dtype
)
p
=
torch
.
empty
(
B
,
S
+
1
,
T
+
1
,
device
=
px
.
device
,
dtype
=
px
.
dtype
)
ans
=
_mutual_information_forward_dispatcher
(
px
,
py
,
boundaries
,
p
)
ans
=
_mutual_information_forward_dispatcher
(
px
,
py
,
boundaries
,
p
)
if
px
.
requires_grad
or
py
.
requires_grad
:
if
px
.
requires_grad
or
py
.
requires_grad
:
ctx
.
save_for_backward
(
px
,
py
,
boundaries
,
q
)
ctx
.
save_for_backward
(
px
,
py
,
boundaries
,
p
)
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
ans_grad
:
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
None
]:
def
backward
(
ctx
,
ans_grad
:
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
None
]:
(
px
,
py
,
boundaries
,
q
)
=
ctx
.
saved_tensors
(
px
,
py
,
boundaries
,
p
)
=
ctx
.
saved_tensors
(
px_grad
,
py_grad
)
=
_mutual_information_backward_dispatcher
(
px
,
py
,
boundaries
,
q
)
(
px_grad
,
py_grad
)
=
_mutual_information_backward_dispatcher
(
px
,
py
,
boundaries
,
p
,
ans_grad
)
return
(
px_grad
,
py_grad
,
None
)
return
(
px_grad
,
py_grad
,
None
)
...
...
torch_mutual_information/mutual_information_cpu.cpp
View file @
3c1ec347
...
@@ -151,11 +151,16 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
...
@@ -151,11 +151,16 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
TORCH_CHECK
(
py
.
size
(
0
)
==
B
&&
py
.
size
(
1
)
==
S
+
1
&&
py
.
size
(
2
)
==
T
);
TORCH_CHECK
(
py
.
size
(
0
)
==
B
&&
py
.
size
(
1
)
==
S
+
1
&&
py
.
size
(
2
)
==
T
);
TORCH_CHECK
(
p
.
size
(
0
)
==
B
&&
p
.
size
(
1
)
==
S
+
1
&&
p
.
size
(
2
)
==
T
+
1
);
TORCH_CHECK
(
p
.
size
(
0
)
==
B
&&
p
.
size
(
1
)
==
S
+
1
&&
p
.
size
(
2
)
==
T
+
1
);
torch
::
Tensor
p_grad
=
torch
::
zeros
({
B
,
S
+
1
,
T
+
1
},
opts
);
bool
has_boundary
=
(
bool
)
optional_boundary
;
torch
::
Tensor
p_grad
=
torch
::
zeros
({
B
,
S
+
1
,
T
+
1
},
opts
),
px_grad
=
(
has_boundary
?
torch
::
zeros
({
B
,
S
,
T
+
1
},
opts
)
:
torch
::
empty
({
B
,
S
,
T
+
1
},
opts
)),
py_grad
=
(
has_boundary
?
torch
::
zeros
({
B
,
S
+
1
,
T
},
opts
)
:
torch
::
empty
({
B
,
S
+
1
,
T
},
opts
));
auto
long_opts
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
px
.
device
());
auto
long_opts
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
px
.
device
());
bool
has_boundary
=
(
bool
)
optional_boundary
;
if
(
!
has_boundary
)
if
(
!
has_boundary
)
optional_boundary
=
torch
::
empty
({
0
,
0
},
long_opts
);
optional_boundary
=
torch
::
empty
({
0
,
0
},
long_opts
);
...
@@ -166,7 +171,9 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
...
@@ -166,7 +171,9 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
auto
px_a
=
px
.
packed_accessor32
<
scalar_t
,
3
>
(),
auto
px_a
=
px
.
packed_accessor32
<
scalar_t
,
3
>
(),
py_a
=
py
.
packed_accessor32
<
scalar_t
,
3
>
(),
py_a
=
py
.
packed_accessor32
<
scalar_t
,
3
>
(),
p_a
=
p
.
packed_accessor32
<
scalar_t
,
3
>
(),
p_a
=
p
.
packed_accessor32
<
scalar_t
,
3
>
(),
p_grad_a
=
p
.
packed_accessor32
<
scalar_t
,
3
>
();
p_grad_a
=
p_grad
.
packed_accessor32
<
scalar_t
,
3
>
(),
px_grad_a
=
px_grad
.
packed_accessor32
<
scalar_t
,
3
>
(),
py_grad_a
=
py_grad
.
packed_accessor32
<
scalar_t
,
3
>
();
auto
ans_grad_a
=
ans_grad
.
packed_accessor32
<
scalar_t
,
1
>
();
auto
ans_grad_a
=
ans_grad
.
packed_accessor32
<
scalar_t
,
1
>
();
...
@@ -196,19 +203,17 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
...
@@ -196,19 +203,17 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
// .. which obtains p_a[b][s][t - 1] from a register.
// .. which obtains p_a[b][s][t - 1] from a register.
scalar_t
term1
=
p_a
[
b
][
s
-
1
][
t
]
+
px_a
[
b
][
s
-
1
][
t
],
scalar_t
term1
=
p_a
[
b
][
s
-
1
][
t
]
+
px_a
[
b
][
s
-
1
][
t
],
// term2 = p_a[b][s][t - 1] + py_a[b][s][t - 1], <-- not
// actually needed..
total
=
p_a
[
b
][
s
][
t
],
total
=
p_a
[
b
][
s
][
t
],
term1_deriv
=
exp
(
term1
-
total
),
term1_deriv
=
exp
(
term1
-
total
),
term2_deriv
=
1.0
-
term1_deriv
,
term2_deriv
=
1.0
-
term1_deriv
,
grad
=
p_grad_a
[
b
][
s
][
t
],
grad
=
p_grad_a
[
b
][
s
][
t
],
term1_grad
=
term1_deriv
*
grad
,
term1_grad
=
term1_deriv
*
grad
,
term2_grad
=
term2_deriv
*
grad
;
term2_grad
=
term2_deriv
*
grad
;
// We can assign to px_grad_a here rather than add, because we
// know it's currently zero.
TORCH_CHECK
(
px_grad_a
[
b
][
s
-
1
][
t
]
==
0
);
px_grad_a
[
b
][
s
-
1
][
t
]
=
term1_grad
;
px_grad_a
[
b
][
s
-
1
][
t
]
=
term1_grad
;
TORCH_CHECK
(
p_grad_a
[
b
][
s
-
1
][
t
]
==
0.0
);
// likewise..
p_grad_a
[
b
][
s
-
1
][
t
]
=
term1_grad
;
p_grad_a
[
b
][
s
-
1
][
t
]
=
term1_grad
py_grad_a
[
b
][
s
][
t
-
1
]
=
term2_grad
;
py_grad_a
[
b
][
s
][
t
-
1
]
+=
term2_grad
;
p_grad_a
[
b
][
s
][
t
-
1
]
+=
term2_grad
;
p_grad_a
[
b
][
s
][
t
-
1
]
+=
term2_grad
;
}
}
}
}
...
@@ -239,111 +244,9 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
...
@@ -239,111 +244,9 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
}
}
}
}
}));
}));
return
ans
;
}
TORCH_CHECK
(
input
.
dim
()
==
3
,
"input must be 3-dimensional"
);
std
::
cout
<<
"p_grad = "
<<
p_grad
;
TORCH_CHECK
(
params
.
dim
()
==
2
,
"params must be 2-dimensional."
);
return
std
::
vector
<
torch
::
Tensor
>
({
px_grad
,
py_grad
});
TORCH_CHECK
(
params
.
size
(
1
)
>=
3
&&
((
params
.
size
(
1
)
-
1
)
&
(
params
.
size
(
1
)
-
2
))
==
0
,
"params.size(1) has invalid value, must be a power of 2 plus 1."
);
TORCH_CHECK
(
params
.
size
(
0
)
==
input
.
size
(
1
),
"params vs input channels mismatch"
);
TORCH_CHECK
(
input
.
sizes
()
==
output_grad
.
sizes
(),
"Output-grad vs. input sizes mismatch."
);
TORCH_CHECK
(
input
.
device
().
is_cpu
(),
"Input must be a CPU tensor"
);
TORCH_CHECK
(
params
.
device
().
is_cpu
(),
"Params must be a CPU tensor"
);
TORCH_CHECK
(
output_grad
.
device
().
is_cpu
(),
"Output-grad must be a CPU tensor"
);
const
int
B
=
input
.
size
(
0
),
C
=
input
.
size
(
1
),
T
=
input
.
size
(
2
),
N
=
params
.
size
(
1
)
-
1
,
K
=
N
/
2
;
auto
scalar_t
=
input
.
scalar_type
();
auto
opts
=
torch
::
TensorOptions
().
dtype
(
scalar_t
).
device
(
input
.
device
());
torch
::
Tensor
y_vals
=
torch
::
empty
({
C
,
N
},
opts
),
y_vals_grad
=
torch
::
zeros
({
C
,
N
},
opts
),
params_grad
=
torch
::
zeros
({
C
,
N
+
1
},
opts
),
input_grad
=
torch
::
zeros
({
B
,
C
,
T
},
opts
);
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"mutual_information_backward_cpu_loop"
,
([
&
]
{
auto
params_a
=
params
.
accessor
<
scalar_t
,
2
>
(),
params_grad_a
=
params_grad
.
accessor
<
scalar_t
,
2
>
(),
y_vals_a
=
y_vals
.
accessor
<
scalar_t
,
2
>
(),
y_vals_grad_a
=
y_vals_grad
.
accessor
<
scalar_t
,
2
>
();
for
(
int
c
=
0
;
c
<
C
;
c
++
)
{
scalar_t
sum_negative
=
0.0
,
sum_positive
=
0.0
,
scale
=
exp
(
params_a
[
c
][
0
]);
for
(
int
i
=
0
;
i
<
K
;
i
++
)
{
scalar_t
pos_scaled_param
=
params_a
[
c
][
1
+
K
+
i
]
*
scale
,
neg_scaled_param
=
params_a
[
c
][
K
-
i
]
*
scale
;
y_vals_a
[
c
][
K
+
i
]
=
sum_positive
-
pos_scaled_param
*
i
;
sum_positive
+=
pos_scaled_param
;
sum_negative
-=
neg_scaled_param
;
y_vals_a
[
c
][
K
-
i
-
1
]
=
sum_negative
+
neg_scaled_param
*
(
i
+
1
);
}
}
auto
input_a
=
input
.
accessor
<
scalar_t
,
3
>
(),
output_grad_a
=
output_grad
.
accessor
<
scalar_t
,
3
>
(),
input_grad_a
=
input_grad
.
accessor
<
scalar_t
,
3
>
();
for
(
int
b
=
0
;
b
<
B
;
b
++
)
{
for
(
int
c
=
0
;
c
<
C
;
c
++
)
{
scalar_t
inv_scale
=
exp
(
-
params_a
[
c
][
0
]);
for
(
int
t
=
0
;
t
<
T
;
t
++
)
{
scalar_t
input
=
input_a
[
b
][
c
][
t
],
x
=
input
*
inv_scale
+
K
,
output_grad
=
output_grad_a
[
b
][
c
][
t
];
if
(
x
<
0
)
x
=
0
;
else
if
(
x
>=
N
)
x
=
N
-
1
;
// C++ rounds toward zero.
int
n
=
(
int
)
x
;
// OK, at this point, 0 <= n < 2*K.
// backprop for:
// output_a[b][c][t] = input * params_a[c][n + 1] + y_vals_a[c][n];
params_grad_a
[
c
][
n
+
1
]
+=
output_grad
*
input
;
y_vals_grad_a
[
c
][
n
]
+=
output_grad
;
input_grad_a
[
b
][
c
][
t
]
=
output_grad
*
params_a
[
c
][
n
+
1
];
}
}
}
// Now do the backprop for the loop above where we set y_vals_a.
for
(
int
c
=
0
;
c
<
C
;
c
++
)
{
scalar_t
scale
=
exp
(
params_a
[
c
][
0
]),
scale_grad
=
0.0
,
sum_negative_grad
=
0.0
,
sum_positive_grad
=
0.0
;
for
(
int
i
=
K
-
1
;
i
>=
0
;
i
--
)
{
// Backprop for: y_vals_a[c][K - i - 1] = sum_negative + neg_scaled_param * (i + 1):
scalar_t
y_grad_neg
=
y_vals_grad_a
[
c
][
K
-
i
-
1
];
sum_negative_grad
+=
y_grad_neg
;
scalar_t
neg_scaled_param_grad
=
y_grad_neg
*
(
i
+
1
);
// Backprop for: sum_negative -= neg_scaled_param;
neg_scaled_param_grad
-=
sum_negative_grad
;
// Backprop for: sum_positive += pos_scaled_param;
scalar_t
pos_scaled_param_grad
=
sum_positive_grad
;
// Backprop for: y_vals_a[c][K + i] = sum_positive - pos_scaled_param * i;
scalar_t
y_grad_pos
=
y_vals_grad_a
[
c
][
K
+
i
];
pos_scaled_param_grad
-=
i
*
y_grad_pos
;
sum_positive_grad
+=
y_grad_pos
;
// Backprop for: pos_scaled_param = params_a[c][1 + K + i] * scale,
// and: neg_scaled_param = params_a[c][K - i] * scale;
params_grad_a
[
c
][
1
+
K
+
i
]
+=
pos_scaled_param_grad
*
scale
;
params_grad_a
[
c
][
K
-
i
]
+=
neg_scaled_param_grad
*
scale
;
scale_grad
+=
(
pos_scaled_param_grad
*
params_a
[
c
][
1
+
K
+
i
]
+
neg_scaled_param_grad
*
params_a
[
c
][
K
-
i
]);
}
// Backprop for: scale = exp(params_a[c][0]),
params_grad_a
[
c
][
0
]
+=
scale
*
scale_grad
;
}
}));
return
std
::
vector
<
torch
::
Tensor
>
({
input_grad
,
params_grad
});
}
}
...
...
torch_mutual_information/mutual_information_cuda_kernel.cu
View file @
3c1ec347
...
@@ -73,7 +73,7 @@ __forceinline__ __device__ inline float LogAdd(float x, float y) {
...
@@ -73,7 +73,7 @@ __forceinline__ __device__ inline float LogAdd(float x, float y) {
p[b,0,0] = 0.0
p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
p[b,s,t-1] + py[b,s,t-1])
(eq. 0)
if s > 0 or t > 0,
if s > 0 or t > 0,
treating values with any -1 index as -infinity.
treating values with any -1 index as -infinity.
.. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
.. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
...
@@ -122,32 +122,33 @@ void mutual_information_kernel(
...
@@ -122,32 +122,33 @@ void mutual_information_kernel(
num_t_blocks
=
T
/
BLOCK_SIZE
+
1
;
num_t_blocks
=
T
/
BLOCK_SIZE
+
1
;
// num_blocks_this_iter is an upper bound on the number of blocks of size
// num_blocks_this_iter is an upper bound on the number of blocks of size
// (BLOCK_SIZE by BLOCK_SIZE) that might be active on this iteration
. We go
// (BLOCK_SIZE by BLOCK_SIZE) that might be active on this iteration
(`iter`).
// from the bottom left of the image so that on iter ==
0 we process only one
//
These iterations start
from the bottom left of the image so that on iter ==
// block with block-index (0, 0) then on iter == 1 we
process block-indexes
//
0 we process only one
block with block-index (0, 0) then on iter == 1 we
// (1, 0) and (0, 1); and then on iter==2 we process (2,
0), (1, 1) and (0,
//
process block-indexes
(1, 0) and (0, 1); and then on iter==2 we process (2,
// 2); and so on. We also will never have more than
`num_s_blocks` blocks
//
0), (1, 1) and (0,
2); and so on. We also will never have more than
// (We'll never have more than num_t_blocks either, but
the numbering we use
//
`num_s_blocks` blocks
(We'll never have more than num_t_blocks either, but
// corresponds to s and not t, so
if
we hit the
num_t_blocks limit, the
//
the numbering we use
corresponds to s and not t, so
when
we hit the
//
lowest-numbered blocks on s would just not be active and we'll 'continue'
//
num_t_blocks limit, the blocks with the lowest s indexes would just not be
// below).
//
active and we'll 'continue' in the loop
below).
int
num_blocks_this_iter
=
min
(
iter
+
1
,
num_s_blocks
);
int
num_blocks_this_iter
=
min
(
iter
+
1
,
num_s_blocks
);
// For the block with s_block_begin == 0 and t_block_begin == 0 (for
// For the block with s_block_begin == 0 and t_block_begin == 0 (for
// easy illustration), px_buf[s][t] will contain exp(px[s - 1][t]); or 0
// easy illustration), px_buf[s][t] will contain exp(px[s - 1][t]); or 0
// for out-of-range indexes.
// for out-of-range indexes
into px
.
// Likewise, py_buf[s][t] will contain exp(py[s][t - 1]).
// Likewise, py_buf[s][t] will contain exp(py[s][t - 1]).
__shared__
scalar_t
px_buf
[
BLOCK_SIZE
][
BLOCK_SIZE
],
__shared__
scalar_t
px_buf
[
BLOCK_SIZE
][
BLOCK_SIZE
],
py_buf
[
BLOCK_SIZE
][
BLOCK_SIZE
];
py_buf
[
BLOCK_SIZE
][
BLOCK_SIZE
];
// 1st row/col of p_buf correspond to the previous blocks, or to an edge case.
// p_buf[s][t] == exp(p[s+s_block_begin-1][t+t_block_begin-1] - normalizer).
// So, again for this origin block, p_buf[s][t] corresponds to exp(p[s - 1][t
// 1st row/col of p_buf correspond to the previously computed blocks (lower
// - 1] - normalizer); or 0 for out-of-range values.
// `iter`), or to negative indexes into p. So, for the origin block,
// p_buf[s][t] corresponds to exp(p[s - 1][t - 1] - normalizer); or 0 for
// out-of-range values.
__shared__
scalar_t
p_buf
[
BLOCK_SIZE
+
1
][
BLOCK_SIZE
+
1
];
__shared__
scalar_t
p_buf
[
BLOCK_SIZE
+
1
][
BLOCK_SIZE
+
1
];
// boundary_buf will be used to store the b'th row of `boundary` if we have
// boundary_buf will be used to store the b'th row of `boundary` if we have
// boundary information supplied.
// boundary information supplied
; or (0, 0, S, T) otherwise
.
__shared__
int64_t
boundary_buf
[
4
];
__shared__
int64_t
boundary_buf
[
4
];
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
...
@@ -157,69 +158,70 @@ void mutual_information_kernel(
...
@@ -157,69 +158,70 @@ void mutual_information_kernel(
boundary_buf
[
3
]
=
T
;
boundary_buf
[
3
]
=
T
;
}
}
// batch_block_iter iterates over both batch elements (index b), and block
// batch_block_iter iterates over batch elements (index b) and block
// indexes in the range [0..num_blocks_this_iter-1]
// indexes in the range [0..num_blocks_this_iter-1], combining both
// batch and block indexes.
for
(
int
batch_block_iter
=
blockIdx
.
x
;
for
(
int
batch_block_iter
=
blockIdx
.
x
;
batch_block_iter
<
B
*
num_blocks_this_iter
;
batch_block_iter
<
B
*
num_blocks_this_iter
;
batch_block_iter
+=
gridDim
.
x
)
{
batch_block_iter
+=
gridDim
.
x
)
{
int
b
=
batch_block_iter
%
B
,
int
block
=
batch_block_iter
/
B
,
block
=
batch_block_iter
/
B
;
b
=
batch_block_iter
%
B
;
// b is the index into the batch
int
s_block_begin
=
block
*
BLOCK_S_SIZE
,
t_block_begin
=
(
iter
-
block
)
*
BLOCK_T_SIZE
;
// Note: `block` can be no greater than `iter` because num_blocks_this_iter
// <= iter + 1, so iter - block >= 0.
int
s_block_begin
=
block
*
BLOCK_SIZE
,
t_block_begin
=
(
iter
-
block
)
*
BLOCK_SIZE
;
bool
is_origin_block
=
(
s_block_begin
*
t_block_begin
==
0
);
if
(
threadDim
.
x
<
4
&&
boundary
.
size
(
0
)
!=
0
)
if
(
boundary
.
size
(
0
)
!=
0
&&
threadIdx
.
x
<
4
)
boundary_buf
[
threadDim
.
x
]
=
boundary
[
b
][
threadDim
.
x
];
boundary_buf
[
threadDim
.
x
]
=
boundary
[
b
][
threadDim
.
x
];
__syncthreads
();
__syncthreads
();
int
s_begin
=
boundary_buf
[
0
],
int
s_begin
=
boundary_buf
[
0
],
t_begin
=
boundary_buf
[
1
],
t_begin
=
boundary_buf
[
1
],
s_end
=
boundary_buf
[
2
],
s_end
=
boundary_buf
[
2
],
t_end
=
boundary_buf
[
3
];
t_end
=
boundary_buf
[
3
];
s_block_begin
+=
s_begin
;
s_block_begin
+=
s_begin
;
t_block_begin
+=
t_begin
;
t_block_begin
+=
t_begin
;
// block_S and block_T are the actual sizes of this block
, no greater than
// block_S and block_T are the actual sizes of this block
(the block of `p`
//
(BLOCK_SIZE, BLOCK_SIZE) but possibly less than that if we are towards
//
that we will write), no greater than (BLOCK_SIZE, BLOCK_SIZE) but
// the end of the sequence.
//
possibly less than that if we are towards
the end of the sequence.
The
//
The
last element
of
the output matrix p
we
write is (s_end,
t_end),
// last element
in
the output matrix p
that we need to
write is (s_end,
// i.e. the one-past-the-end index is (s_end + 1, t_end + 1).
//
t_end),
i.e. the one-past-the-end index is (s_end + 1, t_end + 1).
int
block_S
=
min
(
BLOCK_SIZE
,
s_end
+
1
-
s_block_begin
),
int
block_S
=
min
(
BLOCK_SIZE
,
s_end
+
1
-
s_block_begin
),
block_T
=
min
(
BLOCK_SIZE
,
t_end
+
1
-
t_block_begin
);
block_T
=
min
(
BLOCK_SIZE
,
t_end
+
1
-
t_block_begin
);
if
(
block_S
<=
0
||
block_T
<=
0
)
if
(
block_S
<=
0
||
block_T
<=
0
)
continue
;
continue
;
bool
is_origin_block
=
(
s_block_begin
*
t_block_begin
==
0
);
// Load px_buf and py_buf. We exponentiate; the assumption is that they
// most likely won't overflow or underflow, but if they do overflow we'll
// Load px_buf and py_buf. We exponentiate; the assumption is that they most likely
// detect it later; we'll also detect certain kinds of underflow.
// won't overflow or underflow, but if they do overflow we'll detect it later; we'll
for
(
int
i
=
threadIdx
.
x
;
i
<
BLOCK_SIZE
*
BLOCK_SIZE
;
i
+=
blockDim
.
x
)
{
// also detect certain kinds of underflow.
for
(
int
i
=
threadDim
.
x
;
i
<
BLOCK_SIZE
*
BLOCK_SIZE
;
i
+=
blockDim
.
x
)
{
int
s_in_block
=
i
/
BLOCK_SIZE
,
int
s_in_block
=
i
/
BLOCK_SIZE
,
t_in_block
=
i
%
BLOCK_SIZE
,
t_in_block
=
i
%
BLOCK_SIZE
,
s
=
s_in_block
+
s_block_begin
,
s
=
s_in_block
+
s_block_begin
,
t
=
t_in_block
+
t_block_begin
;
t
=
t_in_block
+
t_block_begin
;
// the comparisons with S and T below just make sure we don't access
// comparing as unsigned int makes sure the index is nonnegative.
// out-of-memory regions; they do not guarantee we are in the range given
// by s_begin, s_end and so on. Note: comparing as unsigned int makes sure
// the index is nonnegative.
scalar_t
this_px
=
0.0
;
scalar_t
this_px
=
0.0
;
if
(
static_cast
<
unsigned
int
>
(
s
-
1
)
<
static_cast
<
unsigned
int
>
(
S
)
&&
if
(
static_cast
<
unsigned
int
>
(
s
-
1
)
<
static_cast
<
unsigned
int
>
(
s_end
)
&&
t
<=
T
)
t
<=
t_end
)
this_px
=
exp
(
px
[
b
][
s
-
1
][
t
]);
this_px
=
exp
(
px
[
b
][
s
-
1
][
t
]);
px_buf
[
s_in_block
][
t_in_block
]
=
this_px
;
px_buf
[
s_in_block
][
t_in_block
]
=
this_px
;
scalar_t
this_py
=
0.0
;
scalar_t
this_py
=
0.0
;
if
(
static_cast
<
unsigned
int
>
(
t
-
1
)
<
static_cast
<
unsigned
int
>
(
T
)
&&
if
(
static_cast
<
unsigned
int
>
(
t
-
1
)
<
static_cast
<
unsigned
int
>
(
t_end
)
&&
s
<=
S
)
s
<=
s_end
)
this_py
=
exp
(
py
[
b
][
s
][
t
-
1
]);
this_py
=
exp
(
py
[
b
][
s
][
t
-
1
]);
py_buf
[
s_in_block
][
t_in_block
]
=
this_py
;
py_buf
[
s_in_block
][
t_in_block
]
=
this_py
;
}
}
// Load the 1st row and column of p_buf (except element[0][0] is not needed).
// Load the 1st row and 1st column of p_buf (except element[0][0] is not
// Remember: p_buf[s][t] corresponds to exp(p[s + s_block_begin - 1][t + t_block_begin - 1] - normalizer.
// needed). This is the context from previously computed blocks of the
// image. Remember: p_buf[s][t] will correspond to exp(p[s + s_block_begin -
// 1][t + t_block_begin - 1] - normalizer.
if
(
threadIdx
.
x
<
64
)
{
// 64 == warp size. First half of threads...
if
(
threadIdx
.
x
<
64
)
{
// 64 == warp size. First half of threads...
if
(
threadIdx
.
x
<=
BLOCK_SIZE
)
{
if
(
threadIdx
.
x
<=
BLOCK_SIZE
)
{
// s_in_p_buf are simply the indexes into p_buf
// s_in_p_buf are simply the indexes into p_buf
...
@@ -227,16 +229,14 @@ void mutual_information_kernel(
...
@@ -227,16 +229,14 @@ void mutual_information_kernel(
t_in_p_buf
=
0
,
t_in_p_buf
=
0
,
s
=
s_in_p_buf
+
s_block_begin
-
1
,
s
=
s_in_p_buf
+
s_block_begin
-
1
,
t
=
t_in_p_buf
+
t_block_begin
-
1
;
t
=
t_in_p_buf
+
t_block_begin
-
1
;
// The if-statement below just guards against out-of-range memory
// accesses, it does not guarantee that we really need these values.
scalar_t
this_p
=
-
INFINITY
;
scalar_t
this_p
=
-
INFINITY
;
if
(
static_cast
<
unsigned
int
>
(
s
)
<
static_cast
<
unsigned
int
>
(
S
)
&&
if
(
static_cast
<
unsigned
int
>
(
s
)
<
=
static_cast
<
unsigned
int
>
(
s_end
)
&&
static_cast
<
unsigned
int
>
(
t
)
<
static_cast
<
unsigned
int
>
(
T
))
static_cast
<
unsigned
int
>
(
t
)
<
=
static_cast
<
unsigned
int
>
(
t_end
))
this_p
=
p
[
s
+
s_block_begin
][
s
+
t_block_begin
];
this_p
=
p
[
b
][
s
][
t
];
p_buf
[
threadIdx
.
x
][
0
]
=
this_p
;
p_buf
[
threadIdx
.
x
][
0
]
=
this_p
;
}
}
}
else
{
// Another warp handles the other leg
}
else
{
// Another warp handles the other leg
if
(
threadIdx
.
x
-
64
<=
BLOCK_SIZE
)
{
if
(
int
(
threadIdx
.
x
)
-
64
<=
BLOCK_SIZE
)
{
int
s_in_p_buf
=
0
,
int
s_in_p_buf
=
0
,
t_in_p_buf
=
threadIdx
.
x
-
64
,
t_in_p_buf
=
threadIdx
.
x
-
64
,
s
=
s_in_p_buf
+
s_block_begin
-
1
,
s
=
s_in_p_buf
+
s_block_begin
-
1
,
...
@@ -244,19 +244,19 @@ void mutual_information_kernel(
...
@@ -244,19 +244,19 @@ void mutual_information_kernel(
// The if-statement below just guards against out-of-range memory
// The if-statement below just guards against out-of-range memory
// accesses, it does not guarantee that we really need these values.
// accesses, it does not guarantee that we really need these values.
scalar_t
this_p
=
-
INFINITY
;
scalar_t
this_p
=
-
INFINITY
;
if
(
static_cast
<
unsigned
int
>
(
s
)
<
static_cast
<
unsigned
int
>
(
S
)
&&
if
(
static_cast
<
unsigned
int
>
(
s
)
<
=
static_cast
<
unsigned
int
>
(
s_end
)
&&
static_cast
<
unsigned
int
>
(
t
)
<
static_cast
<
unsigned
int
>
(
T
))
static_cast
<
unsigned
int
>
(
t
)
<
=
static_cast
<
unsigned
int
>
(
t_end
))
this_p
=
p
[
s
+
s_block_begin
][
s
+
t_block_begin
];
this_p
=
p
[
b
][
s
][
t
];
p_buf
[
threadIdx
.
x
][
0
]
=
this_p
;
p_buf
[
threadIdx
.
x
][
0
]
=
this_p
;
}
}
}
}
__syncthreads
();
__syncthreads
();
// We read p_buf in log-space; subtract 'normalizer', which
mathematically
// We read p_buf in log-space;
we now
subtract 'normalizer', which
// could be any finite number, to get i
n a reasonable range of probabilities,
//
mathematically
could be any finite number, to get i
t in a range close to
// and then exponentiate. We'll do everything in non-log space,
and late
r
//
zero,
and then exponentiate. We'll do everything in non-log space,
fo
r
// take a log before we write out the data.
//
speed, and later
take a log before we write out the data.
scalar_t
normalizer
=
(
is_origin_block
?
0.0
:
scalar_t
normalizer
=
(
is_origin_block
?
0.0
:
max
(
px_buf
[
0
][
1
],
px_buf
[
1
][
0
]));
max
(
px_buf
[
0
][
1
],
px_buf
[
1
][
0
]));
...
@@ -265,50 +265,55 @@ void mutual_information_kernel(
...
@@ -265,50 +265,55 @@ void mutual_information_kernel(
// and we'll overwrite with 1.0 if there is a panic situation due to
// and we'll overwrite with 1.0 if there is a panic situation due to
// overflow.
// overflow.
if
(
threadIdx
.
x
<=
BLOCK_SIZE
)
{
if
(
threadIdx
.
x
<=
BLOCK_SIZE
)
{
if
(
threadIdx
.
x
==
0
)
{
// p_buf[0][0] is never used for its normal purpose; we set it to zero
// p_buf[0][0]
is never used for its normal purpose; we set it to zero
.
// p_buf[0][0]
= 0.0; <-- for search purposes
.
// We'll later write an infinity there if something goes wrong, as a
// We'll later write an infinity there if something goes wrong, as a
// 'panic' indicator.
// 'panic' indicator.
p_buf
[
threadIdx
.
x
][
0
]
=
(
threadIdx
.
x
==
0
?
0.0
:
p_buf
[
threadIdx
.
x
][
0
]
=
(
threadIdx
.
x
==
0
?
0.0
:
exp
(
p_buf
[
threadIdx
.
x
][
0
]
-
normalizer
));
exp
(
p_buf
[
threadIdx
.
x
][
0
]
-
normalizer
));
}
}
else
if
(
int
(
threadIdx
.
x
)
-
64
<
BLOCK_SIZE
)
{
}
else
if
((
int
)
threadIdx
.
x
-
64
<
BLOCK_SIZE
)
{
// this happens in a different warp so can be in parallel to the code above.
p_buf
[
0
][
threadIdx
.
x
+
1
]
=
exp
(
p_buf
[
0
][
threadIdx
.
x
+
1
]
-
normalizer
);
p_buf
[
0
][
threadIdx
.
x
+
1
]
=
exp
(
p_buf
[
0
][
threadIdx
.
x
+
1
]
-
normalizer
);
}
}
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
&&
is_origin_block
)
{
// This if-statement is an optimization and modification of the loop below
// This if-statement is an optimization and modification of the loop below
// for the value i == 0, i.e. inner-iteration == 0. The modification
// for the value i == 0, i.e. inner-iteration == 0. The modification
is
//
is to use
0.0 if this is the "origin block"
with s_block_begin == 0 and
//
to set p_buf to 1.0 = exp(
0.0
)
if this is the "origin block"
,
//
t_block
_begin ==
0
. This corresponds to the
probability of the pair of
//
i.e. s == s
_begin
, t
==
t_begin
. This corresponds to the
// sequences of length (0, 0).
//
probability of the pair of
sequences of length (0, 0).
p_buf
[
1
][
1
]
=
(
is_origin_block
?
0
.0
:
p_buf
[
1
][
1
]
=
(
is_origin_block
?
1
.0
:
p_buf
[
0
][
1
]
*
px_buf
[
0
][
0
]
+
p_buf
[
0
][
1
]
*
px_buf
[
0
][
0
]
+
p_buf
[
1
][
0
]
*
py_buf
[
0
][
0
]);
p_buf
[
1
][
0
]
*
py_buf
[
0
][
0
]);
}
}
scalar_t
p_buf_s1_t
;
// This is for an optimization.
scalar_t
p_buf_s1_t
;
// This is for an optimization to avoid one
if
(
i
<
BLOCK_SIZE
)
{
// shared-memory read/write in the loop below. it
// represents p_buf[s + 1][t]; the first time we
// access this, it will be for t == 0, except for
// thread 0 when we first need it for t == 1.
if
(
threadIdx
.
x
<
BLOCK_SIZE
)
{
int
s
=
threadIdx
.
x
;
int
s
=
threadIdx
.
x
;
p_buf_s1_t
=
p_buf
[
s
+
1
][
0
];
p_buf_s1_t
=
p_buf
[
s
+
1
][
threadIdx
.
x
==
0
?
1
:
0
];
}
}
for
(
int
i
=
1
;
i
<
block_S
+
block_T
;
i
++
)
{
for
(
int
i
=
1
;
i
<
block_S
+
block_T
-
1
;
++
i
)
{
// i is the inner iteration, which corresponds to the (s + t) indexes of
the
// i is the inner iteration, which corresponds to the (s + t) indexes of
// elements within the block that we write. So i == 0 writes
positions
//
the
elements within the block that we write. So i == 0 writes
// (s, t) == (0, 0)
; i == 1 writes (0, 1) and (1, 0); i == 2 writes
//
positions
(s, t) == (0, 0)
(but we treated i == 0 as a special case
//
(0, 2), (1
, 1) and (
2
,
1
);
and so on.
//
above); i == 1 writes (0
, 1) and (
1
,
0
);
i == 2 writes (0, 2), (1, 1)
// Note: not many threads participate in this
part, only up to BLOCK_SIZE
//
and (2, 1); and so on.
Note: not many threads participate in this
// at most. Unfortunately we couldn't figure
out a very meaningful way
//
part, only up to BLOCK_SIZE
at most. Unfortunately we couldn't figure
// for more threads to do work, that looked like
it would really spead
//
out a very meaningful way
for more threads to do work, that looked like
// things up.
//
it would really spead
things up.
// So this kernel does (2 * BLOCK_SIZE) iterations, which may seem a lot,
// So this kernel does (2 * BLOCK_SIZE) iterations, which may seem a lot,
// but we do at least do the I/O in an efficient way and keep the
// but we do at least do the I/O in an efficient way and keep the
// inner loop simple and fast (e.g. no exp() or log()).
// inner loop simple and fast (e.g. no exp() or log()).
int
s
=
threadIdx
.
x
,
int
s
=
threadIdx
.
x
,
t
=
i
-
s
;
t
=
i
-
s
;
if
(
t
>=
0
)
{
if
(
static_cast
<
unsigned
int
>
(
t
)
<
static_cast
<
unsigned
int
>
(
block_T
))
{
// p_buf is indexed by s + 1 and t + 1 because it has an extra initial
// p_buf is indexed by s + 1 and t + 1 because it has an extra initial
// row and column for context from previous blocks. Taking into account
// row and column for context from previous blocks. Taking into account
// the way these buffers relate to the tensors p, px and py, and
// the way these buffers relate to the tensors p, px and py, and
...
@@ -320,7 +325,7 @@ void mutual_information_kernel(
...
@@ -320,7 +325,7 @@ void mutual_information_kernel(
//
//
// where you can see that apart from the offsets of tbb and sbb, this is
// where you can see that apart from the offsets of tbb and sbb, this is
// the same as the recursion defined for p in
// the same as the recursion defined for p in
// mutual_information.py:mutual_information_recursion().
// mutual_information.py:mutual_information_recursion()
; and (eq. 0) above
.
#if 1
#if 1
p_buf
[
s
+
1
][
t
+
1
]
=
p_buf
[
s
][
t
+
1
]
*
px_buf
[
s
][
t
]
+
p_buf
[
s
+
1
][
t
]
*
py_buf
[
s
][
t
];
p_buf
[
s
+
1
][
t
+
1
]
=
p_buf
[
s
][
t
+
1
]
*
px_buf
[
s
][
t
]
+
p_buf
[
s
+
1
][
t
]
*
py_buf
[
s
][
t
];
#else
#else
...
@@ -328,18 +333,30 @@ void mutual_information_kernel(
...
@@ -328,18 +333,30 @@ void mutual_information_kernel(
// this #if/#else) where we keep p_buf[s + 1][t] in a register to avoid
// this #if/#else) where we keep p_buf[s + 1][t] in a register to avoid
// the need for a load from shared memory.
// the need for a load from shared memory.
p_buf_s1_t
=
p_buf
[
s
][
t
+
1
]
*
px_buf
[
s
][
t
]
+
p_buf_s1_t
*
py_buf
[
s
][
t
];
p_buf_s1_t
=
p_buf
[
s
][
t
+
1
]
*
px_buf
[
s
][
t
]
+
p_buf_s1_t
*
py_buf
[
s
][
t
];
// The next time this thread reads p_buf_s1_t, t will be one greater,
// so p_buf_s1_t will contain p_buf[s + 1][t]. The first time this
// thread uses p_buf_s1_t is when t == 0, except for thread 0 where
// the 1st item accessed is for s == 0, t == 1.
p_buf
[
s
+
1
][
t
+
1
]
=
p_buf_s1_t
;
p_buf
[
s
+
1
][
t
+
1
]
=
p_buf_s1_t
;
#endif
#endif
// We don't need to do __syncthreads() in this loop because all the
// threads that are active are in the same warp. (However, in future,
// if NVidia changes some things, we might need to sync here).
}
}
__syncthreads
();
__syncthreads
();
}
}
// Write out the data.
// Write out the data to p; check that nothing has gone out of numerical
for
(
int
i
=
threadDim
.
x
;
i
<
BLOCK_SIZE
*
BLOCK_SIZE
;
i
+=
blockDim
.
x
)
{
// range, and write 'panic' flag if it has.
int
t
=
i
%
BLOCK_SIZE
,
s
=
i
/
BLOCK_SIZE
;
for
(
int
i
=
threadIdx
.
x
;
i
<
BLOCK_SIZE
*
BLOCK_SIZE
;
i
+=
blockDim
.
x
)
{
if
(
s
<
block_S
&&
t
<
block_T
)
{
int
s_in_block
=
i
/
BLOCK_SIZE
,
float
this_p
=
p_buf
[
s
+
1
][
t
+
1
];
t_in_block
=
i
%
BLOCK_SIZE
,
p
[
b
][
s
+
s_block_begin
][
t
+
t_block_begin
]
=
normalizer
+
log
(
this_p
);
s
=
s_in_block
+
s_block_begin
,
t
=
t_in_block
+
t_block_begin
;
if
(
s_in_block
<
block_S
&&
t_in_block
<
block_T
)
{
float
this_p
=
p_buf
[
s_in_block
+
1
][
t_in_block
+
1
];
p
[
b
][
s
][
t
]
=
normalizer
+
log
(
this_p
);
// If this_p is infinity, NaN or zero...
if
(
this_p
-
this_p
!=
0
||
this_p
==
0
)
if
(
this_p
-
this_p
!=
0
||
this_p
==
0
)
p_buf
[
0
][
0
]
=
1.0
;
// This is a "panic" flag.
p_buf
[
0
][
0
]
=
1.0
;
// This is a "panic" flag.
}
}
...
@@ -351,27 +368,31 @@ void mutual_information_kernel(
...
@@ -351,27 +368,31 @@ void mutual_information_kernel(
// Write `ans`, if this is the final (top-right) block in its sequence
// Write `ans`, if this is the final (top-right) block in its sequence
// Logically, the following equation corresponds to:
// Logically, the following equation corresponds to:
// ans[b] = p[b][s_end][t_end]
// ans[b] = p[b][s_end][t_end]
if
(
s_block_begin
+
S
>
s_end
&&
t_block_begin
+
T
>
t_end
)
if
(
s_block_begin
+
block_S
-
1
==
s_end
&&
ans
[
b
]
=
normalizer
+
log
(
p_buf
[
s_end
-
s_block_begin
+
1
][
t_end
-
t_block_begin
+
1
]);
t_block_begin
+
block_T
-
1
==
t_end
)
{
// you could read block_S below as block_S - 1 + 1, meaning,
// it's the last index in a block of size block_S, but the indexes into
// p_buf have a "+ 1". Likewise for block_T.
ans
[
b
]
=
normalizer
+
log
(
p_buf
[
block_S
][
block_T
]);
}
}
}
if
(
p_buf
[
0
][
0
]
!=
0.0
)
{
if
(
p_buf
[
0
][
0
]
!=
0.0
)
{
// "panic" flag set. We need to re-do the computation using log-add.
//
The
"panic" flag
is
set. We need to re-do the computation using log-add.
// This time we won't use the buffers, we'll just load and save from main
// This time we won't use the buffers, we'll just load and save from main
// memory. This code should very rarely be reached; and anyway, caching
// memory. This code should very rarely be reached; and anyway, caching
// should help us quite a bit.
// should help us quite a bit.
for
(
int
i
=
0
;
i
<
2
*
BLOCK_SIZE
;
i
++
)
{
for
(
int
i
=
0
;
i
<
block_S
+
block_T
-
1
;
++
i
)
{
int
block
_s
=
threadIdx
.
x
,
int
s_in_
block
=
threadIdx
.
x
,
block
_t
=
i
-
block_s
;
t_in_
block
=
i
-
block_s
;
if
(
s
tatic_cast
<
unsigned
int
>
(
t
)
<
static_cast
<
unsigned
int
>
(
block_
T
)
&&
if
(
s
_in_block
<
block_
S
&&
block_s
<
block_
S
)
{
static_cast
<
unsigned
int
>
(
t_in_block
)
<
static_cast
<
unsigned
int
>
(
block_
T
)
)
{
int
s
=
block
_s
+
s_block_begin
,
int
s
=
s_in_
block
+
s_block_begin
,
t
=
block
_t
+
t_block_begin
;
t
=
t_in_
block
+
t_block_begin
;
float
p_s1
=
(
s
==
0
?
-
INFINITY
:
p
[
b
][
s
-
1
][
t
]),
float
p_s1
=
(
s
==
0
?
-
INFINITY
:
p
[
b
][
s
-
1
][
t
]),
this_px
=
(
s
==
0
?
-
INFINITY
:
px
[
b
][
s
-
1
][
t
]),
p_t1
=
(
t
==
0
?
-
INFINITY
:
p
[
b
][
s
][
t
-
1
]),
p_t1
=
(
t
==
0
?
-
INFINITY
:
p
[
b
][
s
][
t
-
1
]),
this_p
x
=
px
[
b
][
s
][
t
],
this_py
=
py
[
b
][
s
][
t
]
;
this_p
y
=
(
t
==
0
?
-
INFINITY
:
py
[
b
][
s
][
t
-
1
])
;
float
this_p
=
LogAdd
(
p_s1
+
this_px
,
float
this_p
=
LogAdd
(
p_s1
+
this_px
,
p_t1
+
this_py
);
p_t1
+
this_py
);
if
(
i
==
0
&&
is_origin_block
)
if
(
i
==
0
&&
is_origin_block
)
...
@@ -382,7 +403,8 @@ void mutual_information_kernel(
...
@@ -382,7 +403,8 @@ void mutual_information_kernel(
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
// Write `ans`, if this is the final (top-right) block in its sequence.
// Write `ans`, if this is the final (top-right) block in its sequence.
// This is only reached in the 'panic situation' where we had overflow.
// This is only reached in the 'panic situation' where we had overflow.
if
(
s_block_begin
+
S
>
s_end
&&
t_block_begin
+
T
>
t_end
)
if
(
s_block_begin
+
block_S
-
1
==
s_end
&&
t_block_begin
+
block_T
-
1
==
t_end
)
ans
[
b
]
=
p
[
b
][
s_end
][
t_end
];
ans
[
b
]
=
p
[
b
][
s_end
][
t_end
];
}
}
}
}
...
@@ -402,17 +424,18 @@ void mutual_information_kernel(
...
@@ -402,17 +424,18 @@ void mutual_information_kernel(
ep[b][s][t - 1] * epy[b][s][t - 1]. (eq. 1)
ep[b][s][t - 1] * epy[b][s][t - 1]. (eq. 1)
(A)
(A)
First we consider the part
that involves recursion, i.e. the part involving only gradients of
First we consider the part
of the backprop that requires recursion or iteration,
ep. The backprop
involving
ep
only
would be
:
i.e. the part
involving only
gradients of ep. This is
:
ep_grad[b][s - 1][t] += ep_grad[b][s][t] * epx[b][s - 1][t]
ep_grad[b][s - 1][t] += ep_grad[b][s][t] * epx[b][s - 1][t]
ep_grad[b][s][t - 1] += ep_grad[b][s][t] * epy[b][s][t - 1].
ep_grad[b][s][t - 1] += ep_grad[b][s][t] * epy[b][s][t - 1].
.. and if we add 1 to the s index of the first equation above and 1 to the
.. and if we add 1 to the s index of the first equation above and 1 to the
t index of the second equation, we can see that:
t index of the second equation, we can see that:
ep_grad[b][s][t] = ep_grad[b][s + 1][t] * epx[b][s][t] +
ep_grad[b][s][t] = ep_grad[b][s + 1][t] * epx[b][s][t] +
ep_grad[b][s][t + 1] * epy[b][s][t].
ep_grad[b][s][t + 1] * epy[b][s][t].
Now, if ep = exp(p), then ep_grad == dy/dep == dy/dp dp/dep == dy/dp / (dep/dp) == dy/dp / exp(p)
Now, if ep = exp(p), and y is the loss function we are backprop'ing,
then ep_grad == dy/dep == dy/dp dp/dep == dy/dp / (dep/dp) == dy/dp / exp(p)
== dy/dp / ep. == p_grad / ep.
== dy/dp / ep. == p_grad / ep.
I.e. ep_grad = p_grad / ep.
I.e. ep_grad = p_grad / ep.
So we can write the above as:
So we can write the above as:
...
@@ -425,8 +448,8 @@ void mutual_information_kernel(
...
@@ -425,8 +448,8 @@ void mutual_information_kernel(
(B) The following is the backprop for epx and epy from (eq. 1):
(B) The following is the backprop for epx and epy from (eq. 1):
epx_grad[b][s - 1][t] +=
ep_grad[b][s][t] * ep[b][s - 1][t]
epx_grad[b][s - 1][t] += ep_grad[b][s][t] * ep[b][s - 1][t]
epy_grad[b][s][t - 1] +=
ep_grad[b][s][t] * ep[b][s][t - 1]
epy_grad[b][s][t - 1] += ep_grad[b][s][t] * ep[b][s][t - 1]
.. adding 1 to the s indexes in the 1st equation and to the t indexes in the 2nd:
.. adding 1 to the s indexes in the 1st equation and to the t indexes in the 2nd:
...
@@ -435,7 +458,7 @@ void mutual_information_kernel(
...
@@ -435,7 +458,7 @@ void mutual_information_kernel(
Using, similar to the above, ep_grad = p_grad / ep, and similarly,
Using, similar to the above, ep_grad = p_grad / ep, and similarly,
epx_grad = px_grad / epx and epy_grad = py_grad / epy, and writing exp(p) for p and so on,
epx_grad = px_grad / epx and epy_grad = py_grad / epy, and writing exp(p) for p and so on,
the above becomes
the above becomes
:
px_grad[b][s][t] / exp(px[b][s][t]) = p_grad[b][s + 1][t] / exp(p[b][s + 1][t]) * exp(p[b][s][t])
px_grad[b][s][t] / exp(px[b][s][t]) = p_grad[b][s + 1][t] / exp(p[b][s + 1][t]) * exp(p[b][s][t])
py_grad[b][s][t] / exp(py[b][s][t]) = p_grad[b][s][t + 1] / exp(p[b][s][t + 1]) * exp(p[b][s][t])
py_grad[b][s][t] / exp(py[b][s][t]) = p_grad[b][s][t + 1] / exp(p[b][s][t + 1]) * exp(p[b][s][t])
...
@@ -450,11 +473,11 @@ void mutual_information_kernel(
...
@@ -450,11 +473,11 @@ void mutual_information_kernel(
yderiv[b][s][t] := exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1]) (eq. 5)
yderiv[b][s][t] := exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1]) (eq. 5)
.. and note that these quantities are <= 1 so there is no problem doing
.. and note that these quantities are <= 1 so there is no problem doing
the exponentiation. So the recursion can be simplified as:
the exponentiation. So the recursion can be simplified
as from eqs. (2, 3a, 3b),
as:
p_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] +
p_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] +
p_grad[b][s][t + 1] * yderiv[b][s][t] (eq. 6)
p_grad[b][s][t + 1] * yderiv[b][s][t] (eq. 6)
px_grad[b][s][t] = p_grad[b][s + 1][t] *
y
deriv[b][s][t] (eq. 7)
px_grad[b][s][t] = p_grad[b][s + 1][t] *
x
deriv[b][s][t] (eq. 7)
py_grad[b][s][t] = p_grad[b][s][t + 1] * yderiv[b][s][t] (eq. 8)
py_grad[b][s][t] = p_grad[b][s][t + 1] * yderiv[b][s][t] (eq. 8)
(It might seem like we could just reuse px_grad and py_grad for (eq. 6), but it's
(It might seem like we could just reuse px_grad and py_grad for (eq. 6), but it's
...
@@ -462,8 +485,9 @@ void mutual_information_kernel(
...
@@ -462,8 +485,9 @@ void mutual_information_kernel(
write to shared memory within the loop that's the limiting factor.)
write to shared memory within the loop that's the limiting factor.)
The backward pass will be slightly different from the forward pass in terms of
The backward pass will be slightly different from the forward pass in terms of
how we store p (and p_grad), because for writing a particular block of p_grad, we
how we store and index p (and p_grad), because for writing a particular block
need context on the top and right instead of the bottom and left.
of p_grad, we need context on the top and right instead of the bottom and
left. So there are offsets of 1.
*/
*/
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
__global__
...
@@ -472,8 +496,6 @@ void mutual_information_backward_kernel(
...
@@ -472,8 +496,6 @@ void mutual_information_backward_kernel(
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
py
,
// B, S + 1, T.
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
py
,
// B, S + 1, T.
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
p
,
// B, S + 1, T + 1. Produced in forward pass.
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
p
,
// B, S + 1, T + 1. Produced in forward pass.
torch
::
PackedTensorAccessor32
<
scalar_t
,
1
>
ans_grad
,
// [B]. This is an input.
torch
::
PackedTensorAccessor32
<
scalar_t
,
1
>
ans_grad
,
// [B]. This is an input.
torch
::
PackedTensorAccessor32
<
scalar_t
,
1
>
ans_grad_compare
,
// [B]. A value will be written to here which
// should ideally equal ans_grad.
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
p_grad
,
// B, S + 1, T + 1. This is a temporary.
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
p_grad
,
// B, S + 1, T + 1. This is a temporary.
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
px_grad
,
// B, S, T + 1.
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
px_grad
,
// B, S, T + 1.
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
py_grad
,
// B, S + 1, T.
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
py_grad
,
// B, S + 1, T.
...
@@ -483,16 +505,18 @@ void mutual_information_backward_kernel(
...
@@ -483,16 +505,18 @@ void mutual_information_backward_kernel(
// be any sufficiently large number but will actually be:
// be any sufficiently large number but will actually be:
// num_s_blocks + num_t_blocks - 1 where num_s_blocks = S /
// num_s_blocks + num_t_blocks - 1 where num_s_blocks = S /
// BLOCK_SIZE + 1 and num_t_blocks = T / BLOCK_SIZE + 1
// BLOCK_SIZE + 1 and num_t_blocks = T / BLOCK_SIZE + 1
bool
overwrite_ans_grad
)
{
// If true, overwrite ans_grad with a value
bool
overwrite_ans_grad
)
{
// If overwite_ans_grad == true, this function
// which, if everything is working correctly,
// will overwrite ans_grad with a value which,
// should be identical or very close to the
// if everything is working correctly, should be
// value of ans_grad that was passed in.
// identical or very close to the value of
// ans_grad that was passed in.
const
int
B
=
px
.
size
(
0
),
const
int
B
=
px
.
size
(
0
),
S
=
px
.
size
(
1
),
S
=
px
.
size
(
1
),
T
=
py
.
size
(
2
);
T
=
py
.
size
(
2
);
// For statements that are the same as the forward pass, we are omitting some comments
// For statements that are the same as the forward pass, we are omitting some
// what we made there. We'll focus, in the comments, on differences from the forward pass.
// comments. We'll focus, in the comments, on differences from the forward
// pass.
const
int
num_s_blocks
=
S
/
BLOCK_SIZE
+
1
,
const
int
num_s_blocks
=
S
/
BLOCK_SIZE
+
1
,
num_t_blocks
=
T
/
BLOCK_SIZE
+
1
,
num_t_blocks
=
T
/
BLOCK_SIZE
+
1
,
num_blocks_this_iter
=
min
(
iter
+
1
,
num_s_blocks
);
num_blocks_this_iter
=
min
(
iter
+
1
,
num_s_blocks
);
...
@@ -502,29 +526,33 @@ void mutual_information_backward_kernel(
...
@@ -502,29 +526,33 @@ void mutual_information_backward_kernel(
// but then modified to store the "xderiv" and "yderiv" values defined
// but then modified to store the "xderiv" and "yderiv" values defined
// in (eq. 5) and (eq. 6) above. For out-of-range values, we'll write 0.0
// in (eq. 5) and (eq. 6) above. For out-of-range values, we'll write 0.0
// here.
// here.
// px_buf[s][t] contains px[s+s_block_begin][t+t_block_begin];
// Initially (before xderiv/yderiv are written):
// py_buf[s][t] contains py[s+s_block_begin][t+t_block_begin].
// px_buf[s][t] contains px[s+s_block_begin][t+t_block_begin];
// py_buf[s][t] contains py[s+s_block_begin][t+t_block_begin].
// Later (see eq. 4 and eq. 5):
// px_buf[s][t] contains exp(p[b][ss][tt] + px[b][ss][tt] - p[b][ss + 1][tt]),
// py_buf[s][t] contains exp(p[b][ss][tt] + py[b][ss][tt] - p[b][ss][tt + 1]
// where ss == s + s_block_begin, tt = t + t_block_begin.
// Unlike in the forward code, there is no offset of 1 in the indexes.
// Unlike in the forward code, there is no offset of 1 in the indexes.
__shared__
scalar_t
px_buf
[
BLOCK_SIZE
][
BLOCK_SIZE
],
__shared__
scalar_t
px_buf
[
BLOCK_SIZE
][
BLOCK_SIZE
],
py_buf
[
BLOCK_SIZE
][
BLOCK_SIZE
];
py_buf
[
BLOCK_SIZE
][
BLOCK_SIZE
];
// p_buf is initially used to store p, and then (after we are done putting
// p_buf is initially used to store p, and then (after we are done putting
// xderiv and yderiv into px_buf and py_buf) it is repurposed to store
// xderiv and yderiv into px_buf and py_buf) it is repurposed to store
// p_grad.
// p_grad.
//
//
// Unlike in the forward pass, p_buf has the same numbering as px_buf and
// Unlike in the forward pass, p_buf has the same numbering as px_buf and
// py_buf not offset by 1: e.g., for the origin block, p_buf[0][0]
refers
// py_buf
, it's
not offset by 1: e.g., for the origin block, p_buf[0][0]
// to p[0][0] and not p[-1][-1]. The p_buf block is larger by 1 than
//
refers
to p[0][0] and not p[-1][-1]. The p_buf block is larger by 1 than
// the block for px_buf and py_buf; unlike in the forward pass, we store
// the block for px_buf and py_buf; unlike in the forward pass, we store
// context on the top right, not the bottom left, i.e. the elements at
// context on the top
and
right, not the bottom
and
left, i.e. the elements at
// (one past the largest indexes in the block).
// (one past the largest indexes in the block).
//
//
// For out-of-range elements of p_buf, we'll put zero.
// For out-of-range elements of p_buf, we'll put zero.
__shared__
scalar_t
p_buf
[
BLOCK_SIZE
+
1
][
BLOCK_SIZE
+
1
];
__shared__
scalar_t
p_buf
[
BLOCK_SIZE
+
1
][
BLOCK_SIZE
+
1
];
// boundary_buf will be used to store the b'th row of `boundary` if we have
// boundary_buf will be used to store the b'th row of `boundary` if we have
// boundary information supplied.
// boundary information supplied
; or (0, 0, S, T) if not
.
__shared__
int64_t
boundary_buf
[
4
];
__shared__
int64_t
boundary_buf
[
4
];
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
...
@@ -541,13 +569,13 @@ void mutual_information_backward_kernel(
...
@@ -541,13 +569,13 @@ void mutual_information_backward_kernel(
for
(
int
batch_block_iter
=
blockIdx
.
x
;
for
(
int
batch_block_iter
=
blockIdx
.
x
;
batch_block_iter
<
B
*
num_blocks_this_iter
;
batch_block_iter
<
B
*
num_blocks_this_iter
;
batch_block_iter
+=
gridDim
.
x
)
{
batch_block_iter
+=
gridDim
.
x
)
{
int
b
=
batch_block_iter
%
B
,
int
b
lock
=
batch_block_iter
/
B
,
block
=
batch_block_iter
/
B
;
b
=
batch_block_iter
%
B
;
int
s_block_begin
=
block
*
BLOCK_
S_
SIZE
,
int
s_block_begin
=
block
*
BLOCK_SIZE
,
t_block_begin
=
(
iter
-
block
)
*
BLOCK_
T_
SIZE
;
t_block_begin
=
(
iter
-
block
)
*
BLOCK_SIZE
;
if
(
thread
Dim
.
x
<
4
&&
boundary
.
size
(
0
)
!=
0
)
if
(
thread
Idx
.
x
<
4
&&
boundary
.
size
(
0
)
!=
0
)
boundary_buf
[
thread
Dim
.
x
]
=
boundary
[
b
][
thread
Dim
.
x
];
boundary_buf
[
thread
Idx
.
x
]
=
boundary
[
b
][
thread
Idx
.
x
];
__syncthreads
();
__syncthreads
();
int
s_begin
=
boundary_buf
[
0
],
int
s_begin
=
boundary_buf
[
0
],
...
@@ -560,68 +588,69 @@ void mutual_information_backward_kernel(
...
@@ -560,68 +588,69 @@ void mutual_information_backward_kernel(
// block_S and block_T are the actual sizes of this block, no greater than
// block_S and block_T are the actual sizes of this block, no greater than
// (BLOCK_SIZE, BLOCK_SIZE) but possibly less than that if we are towards
// (BLOCK_SIZE, BLOCK_SIZE) but possibly less than that if we are towards
// the end of the sequence.
// the end of the sequence.
// The last element of the output matrix p we write is (s_end, t_end),
// The last element of the output matrix p
_grad
we write is (s_end, t_end),
// i.e. the one-past-the-end index is (s_end + 1, t_end + 1).
// i.e. the one-past-the-end index
of p_grad
is (s_end + 1, t_end + 1).
int
block_S
=
min
(
BLOCK_SIZE
,
s_end
+
1
-
s_block_begin
),
int
block_S
=
min
(
BLOCK_SIZE
,
s_end
+
1
-
s_block_begin
),
block_T
=
min
(
BLOCK_SIZE
,
t_end
+
1
-
t_block_begin
);
block_T
=
min
(
BLOCK_SIZE
,
t_end
+
1
-
t_block_begin
);
if
(
block_S
<=
0
||
block_T
<=
0
)
if
(
block_S
<=
0
||
block_T
<=
0
)
continue
;
continue
;
// Load px_buf and py_buf. At this point
they
just
contain
px and py
// Load px_buf and py_buf. At this point
we
just
set them to the
px and py
// for this block.
// for this block.
for
(
int
i
=
thread
Dim
.
x
;
i
<
BLOCK_SIZE
*
BLOCK_SIZE
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
thread
Idx
.
x
;
i
<
BLOCK_SIZE
*
BLOCK_SIZE
;
i
+=
blockDim
.
x
)
{
int
s_in_block
=
i
/
BLOCK_SIZE
,
int
s_in_block
=
i
/
BLOCK_SIZE
,
t_in_block
=
i
%
BLOCK_SIZE
,
t_in_block
=
i
%
BLOCK_SIZE
,
s
=
s_in_block
+
s_block_begin
,
s
=
s_in_block
+
s_block_begin
,
t
=
t_in_block
+
t_block_begin
;
t
=
t_in_block
+
t_block_begin
;
// We let p
s
and py default to -infinity if they are out of range, which will
// We let p
x
and py default to -infinity if they are out of range, which will
// cause xderiv and yderiv for out-of-range values to be zero, and cause
// cause xderiv and yderiv for out-of-range values to be zero, and cause
// correct behavior in edge cases (for the top and right blocks).
// correct behavior in edge cases (for the top and right blocks).
// The issue is that p and p_grad are of larger size than px and py.
// The issue is that p and p_grad are of larger size than px and py.
scalar_t
this_px
=
-
INFINITY
;
scalar_t
this_px
=
-
INFINITY
;
if
(
s
<
s_end
&&
t
<=
t_end
)
if
(
s
<
s_end
&&
t
<=
t_end
)
this_px
=
px
[
b
][
s
-
1
][
t
];
this_px
=
px
[
b
][
s
][
t
];
px_buf
[
s_in_block
][
t_in_block
]
=
this_px
;
px_buf
[
s_in_block
][
t_in_block
]
=
this_px
;
scalar_t
this_py
=
-
INFINITY
;
scalar_t
this_py
=
-
INFINITY
;
if
(
s
<=
s_end
&&
t
<
t_end
)
if
(
s
<=
s_end
&&
t
<
t_end
)
this_py
=
py
[
b
][
s
][
t
-
1
];
this_py
=
py
[
b
][
s
][
t
];
py_buf
[
s_in_block
][
t_in_block
]
=
this_py
;
py_buf
[
s_in_block
][
t_in_block
]
=
this_py
;
}
}
// load p. We could use BLOCK_SIZE + 1 here, but we use + 8 to hopefully keep
// load p. This time we loop over the exact indexes we need. Above
// reads more aligned.
// we looped to BLOCK_SIZE * BLOCK_SIZE rather than block_S and block_T
for
(
int
i
=
threadIdx
.
x
;
i
<
(
BLOCK_SIZE
+
1
)
*
(
BLOCK_SIZE
+
1
);
i
+=
blockDim
.
x
)
{
// because having power-of-2 arrangement of threads may be helpful
int
s_in_block
=
i
/
(
BLOCK_SIZE
+
1
),
// for aligned reads, but here the loop is up to (BLOCK_SIZE + 1) * (BLOCK_SIZE + 1)
t_in_block
=
i
%
(
BLOCK_SIZE
+
1
),
// which is not a power of 2, so that is not a concern here.
for
(
int
i
=
threadDim
.
x
;
i
<
(
BLOCK_SIZE
+
1
)
*
(
BLOCK_SIZE
+
1
);
i
+=
blockDim
.
x
)
{
int
s_in_block
=
i
/
(
BLOCK_SIZE
+
1
),
// 0 <= s_in_block <= block_S
t_in_block
=
i
%
(
BLOCK_SIZE
+
1
),
// 0 <= t_in_block <= block_T
s
=
s_in_block
+
s_block_begin
,
s
=
s_in_block
+
s_block_begin
,
t
=
t_in_block
+
t_block_begin
;
t
=
t_in_block
+
t_block_begin
;
// Setting 0.0 for out-of-bounds elements, together with setting
// Setting 0.0 for out-of-bounds elements, together with setting
// -INFINITY for out-of-bounds elements of px_buf and py_buf, will
// -INFINITY for out-of-bounds elements of px_buf and py_buf, will
// ensure that we do the right thing in top and right edge cases,
// ensure that we do the right thing in top and right edge cases,
// i.e. that no derivatives will be propagated from out-of-bounds points.
// i.e. that no derivatives will be propagated from out-of-bounds points
p_buf
[
s_in_block
][
t_in_block
]
=
(
s
<=
s_end
&&
t
<=
t_end
?
// because the corresponding xderiv and yderiv values will be zero.
p
[
b
][
s
][
t
]
:
0.0
);
scalar_t
this_p
=
0.0
;
if
(
s
<=
s_end
&&
t
<=
t_end
)
this_p
=
p
[
b
][
s
][
t
];
p_buf
[
s_in_block
][
t_in_block
]
=
this_p
;
}
}
// Set xderiv and yderiv; see (eq. 4) and (eq. 5).
// Set xderiv and yderiv; see (eq. 4) and (eq. 5).
for
(
int
i
=
thread
Dim
.
x
;
i
<
BLOCK_SIZE
*
BLOCK_SIZE
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
thread
Idx
.
x
;
i
<
BLOCK_SIZE
*
BLOCK_SIZE
;
i
+=
blockDim
.
x
)
{
// We can apply this formula to the entire block even if we are processing
// We can apply this formula to the entire block even if we are processing
// a partial block; elements outside the partial block will not be used so
// a partial block; we have ensured that x_buf and y_buf contain -infinity,
// their values don't matter, and elements just out
// and p contains 0, for out-of-range elements, so we'll get x_buf and y_buf
int
t
=
i
%
BLOCK_SIZE
,
s
=
i
/
BLOCK_SIZE
;
// containing 0 after applying the followin formulas.
int
s
=
i
/
BLOCK_SIZE
,
t
=
i
%
BLOCK_SIZE
;
// Mathematically the following is doing:
// Mathematically the following is doing:
// xderiv[b][s][t] := exp(p[b][s][t] + px[b][s][t] - p[b][s + 1][t])
// xderiv[b][s][t] := exp(p[b][s][t] + px[b][s][t] - p[b][s + 1][t])
// (with an offset on the s and t indexes)
// (with an offset on the s and t indexes)
px_buf
[
s
][
t
]
=
exp
(
p
x
_buf
[
s
][
t
]
+
px_buf
[
s
][
t
]
-
p_buf
[
s
+
1
][
t
]);
px_buf
[
s
][
t
]
=
exp
(
p_buf
[
s
][
t
]
+
px_buf
[
s
][
t
]
-
p_buf
[
s
+
1
][
t
]);
// Mathematically the following is doing:
// Mathematically the following is doing:
// yderiv[b][s][t] := exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1])
// yderiv[b][s][t] := exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1])
// (with an offset on the s and t indexes)
// (with an offset on the s and t indexes)
py_buf
[
s
][
t
]
=
exp
(
p
x
_buf
[
s
][
t
]
+
py_buf
[
s
][
t
]
-
p_buf
[
s
][
t
+
1
]);
py_buf
[
s
][
t
]
=
exp
(
p_buf
[
s
][
t
]
+
py_buf
[
s
][
t
]
-
p_buf
[
s
][
t
+
1
]);
}
}
// Load p_grad for the top and right elements in p_buf: i.e. for elements
// Load p_grad for the top and right elements in p_buf: i.e. for elements
...
@@ -630,7 +659,8 @@ void mutual_information_backward_kernel(
...
@@ -630,7 +659,8 @@ void mutual_information_backward_kernel(
// never be accessed.
// never be accessed.
// These are the p_grad values computed by previous instances of this kernel
// These are the p_grad values computed by previous instances of this kernel
// If this is one of the top or right blocks, some or all of the p_grad
// If this is one of the top or right blocks, some or all of the p_grad
// values we'd be reading here will be out of range, and we use zeros.
// values we'd be reading here will be out of range, and we use zeros
// to ensure no gradient gets propagated from those positions.
if
(
threadIdx
.
x
<
block_S
)
{
if
(
threadIdx
.
x
<
block_S
)
{
int
s_in_block
=
threadIdx
.
x
,
int
s_in_block
=
threadIdx
.
x
,
t_in_block
=
block_T
,
t_in_block
=
block_T
,
...
@@ -638,34 +668,33 @@ void mutual_information_backward_kernel(
...
@@ -638,34 +668,33 @@ void mutual_information_backward_kernel(
t
=
t_in_block
+
t_block_begin
;
t
=
t_in_block
+
t_block_begin
;
p_buf
[
s_in_block
][
t_in_block
]
=
(
p_buf
[
s_in_block
][
t_in_block
]
=
(
s
<=
s_end
&&
t
<=
t_end
?
p_grad
[
s
][
t
]
:
0.0
);
s
<=
s_end
&&
t
<=
t_end
?
p_grad
[
s
][
t
]
:
0.0
);
}
else
if
(
static_cast
<
unsigned
int
>
(
threadIdx
.
x
-
64
)
<
}
else
if
(
static_cast
<
unsigned
int
>
(
(
int
)
threadIdx
.
x
-
64
)
<
static_cast
<
unsigned
int
>
(
block_T
))
{
static_cast
<
unsigned
int
>
(
block_T
))
{
// casting to unsigned before the comparison tests for both negative and
// out-of-range values of (int)threadIdx.x - 64.
int
s_in_block
=
block_S
,
int
s_in_block
=
block_S
,
t_in_block
=
threadIdx
.
x
-
64
,
t_in_block
=
(
int
)
threadIdx
.
x
-
64
,
s
=
s_in_block
+
s_block_begin
,
s
=
s_in_block
+
s_block_begin
,
t
=
t_in_block
+
t_block_begin
;
t
=
t_in_block
+
t_block_begin
;
p_buf
[
s_in_block
][
t_in_block
]
=
(
p_buf
[
s_in_block
][
t_in_block
]
=
(
s
<=
s_end
&&
t
<=
t_end
?
p_grad
[
s
][
t
]
:
0.0
);
s
<=
s_end
&&
t
<=
t_end
?
p_grad
[
s
][
t
]
:
0.0
);
}
}
// The number of inner iterations, i.e. iterations inside this
// The highest-numbered value in p_buf that we need (corresponding,
// kernel, is this_num_inner_iters. The highest iteration,
// of course, to p_grad), is:
// corresponding to the highest-indexed value of p_buf that
// p_buf[block_S - 1][block_T - 1],
// we need to set,
// and the inner iteration number (i) on which we set this is the sum of
// corresponds to p_buf[block_S - 1][block_T - 1],
// these indexes, i.e. (block_S - 1) + (block_T - 1).
// and the iteration number is the sum of these indexes, i.e.
// (block_S - 1) + (block_T - 1).
bool
is_final_block
=
(
s_block_begin
+
block_S
==
s_end
+
1
&&
bool
is_final_block
=
(
s_block_begin
+
block_S
==
s_end
+
1
&&
t_block_begin
+
block_T
==
t_end
+
1
);
t_block_begin
+
block_T
==
t_end
+
1
);
int
first_iter
=
block_S
+
block_T
-
2
;
int
first_iter
=
block_S
+
block_T
-
2
;
if
(
is_final_block
)
{
if
(
is_final_block
)
{
// The following statement
, mathematically,
corresponds to:
// The following statement corresponds to:
// p_grad[b][s_end][t_end] = ans_grad[b]
Normally this element of p_buf
// p_grad[b][s_end][t_end] = ans_grad[b]
// would be set by the first iteration of
the loop below, so if it's set
//
Normally this element of p_buf
would be set by the first iteration of
// this way we have to decrement first_iter
to prevent it being
//
the loop below, so if it's set
this way we have to decrement first_iter
// overwritten.
//
to prevent it from being
overwritten.
p_buf
[
block_S
-
1
][
block_T
-
1
]
=
ans_grad
[
b
];
p_buf
[
block_S
-
1
][
block_T
-
1
]
=
ans_grad
[
b
];
--
first_iter
;
--
first_iter
;
}
}
...
@@ -675,7 +704,8 @@ void mutual_information_backward_kernel(
...
@@ -675,7 +704,8 @@ void mutual_information_backward_kernel(
t
=
i
-
threadIdx
.
x
;
t
=
i
-
threadIdx
.
x
;
if
(
t
>=
0
)
{
if
(
t
>=
0
)
{
// The following statement is really operating on the gradients;
// The following statement is really operating on the gradients;
// it corresponds to (eq. 6) defined above, i.e.:
// it corresponds, with offsets of s_block_begin and t_block_begin
// on the indexes, to (eq. 6) defined above, i.e.:
// p_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] +
// p_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] +
// p_grad[b][s][t + 1] * yderiv[b][s][t]
// p_grad[b][s][t + 1] * yderiv[b][s][t]
p_buf
[
s
][
t
]
=
(
p_buf
[
s
+
1
][
t
]
*
px_buf
[
s
][
t
]
+
p_buf
[
s
][
t
]
=
(
p_buf
[
s
+
1
][
t
]
*
px_buf
[
s
][
t
]
+
...
@@ -684,17 +714,19 @@ void mutual_information_backward_kernel(
...
@@ -684,17 +714,19 @@ void mutual_information_backward_kernel(
}
}
// Write out p_grad, px_grad and py_grad.
// Write out p_grad, px_grad and py_grad.
for
(
int
i
=
thread
Dim
.
x
;
i
<
BLOCK_SIZE
*
BLOCK_SIZE
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
thread
Idx
.
x
;
i
<
BLOCK_SIZE
*
BLOCK_SIZE
;
i
+=
blockDim
.
x
)
{
int
t
_in_block
=
i
%
BLOCK_SIZE
,
int
s
_in_block
=
i
/
BLOCK_SIZE
,
s
_in_block
=
i
/
BLOCK_SIZE
,
t
_in_block
=
i
%
BLOCK_SIZE
,
s
=
s_in_block
+
s_block_begin
,
s
=
s_in_block
+
s_block_begin
,
t
=
t_in_block
+
t_block_begin
;
t
=
t_in_block
+
t_block_begin
;
// s_end and t_end are the one-past-the-end of the (x,y) sequences, but
// the one-past-the-end element of p_grad would be (s_end + 1, t_end + 1).
if
(
t
<=
t_end
&&
s
<=
s_end
)
{
if
(
t
<=
t_end
&&
s
<=
s_end
)
{
p_grad
[
b
][
s
][
t
]
=
p_buf
[
s_in_block
][
t_in_block
];
p_grad
[
b
][
s
][
t
]
=
p_buf
[
s_in_block
][
t_in_block
];
if
(
s
<
s_end
)
{
// write px_grad, which is of shape [B][S][T + 1]
if
(
s
<
s_end
)
{
// write px_grad, which is of shape [B][S][T + 1]
// From (eq. 7):
// From (eq. 7):
// px_grad[b][s][t] = p_grad[b][s + 1][t] *
y
deriv[b][s][t]
// px_grad[b][s][t] = p_grad[b][s + 1][t] *
x
deriv[b][s][t]
px_grad
[
b
][
s
][
t
]
=
(
p_buf
[
s_in_block
+
1
][
t_in_block
]
*
px_grad
[
b
][
s
][
t
]
=
(
p_buf
[
s_in_block
+
1
][
t_in_block
]
*
px_buf
[
s_in_block
][
t_in_block
]);
px_buf
[
s_in_block
][
t_in_block
]);
}
}
...
@@ -741,7 +773,7 @@ torch::Tensor mutual_information_cuda(torch::Tensor px,
...
@@ -741,7 +773,7 @@ torch::Tensor mutual_information_cuda(torch::Tensor px,
// num_threads and num_blocks and BLOCK_SIZE can be tuned.
// num_threads and num_blocks and BLOCK_SIZE can be tuned.
// (however, num_threads may not be less than 128).
// (however, num_threads may not be less than 128).
int
num_threads
=
128
,
int
num_threads
=
128
,
num_blocks
=
128
,
num_blocks
=
256
,
BLOCK_SIZE
=
32
;
BLOCK_SIZE
=
32
;
// The blocks cover the 'p' matrix, which is of size (B, S+1, T+1),
// The blocks cover the 'p' matrix, which is of size (B, S+1, T+1),
...
@@ -802,14 +834,18 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
...
@@ -802,14 +834,18 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
TORCH_CHECK
(
p
.
size
(
0
)
==
B
&&
p
.
size
(
1
)
==
S
+
1
&&
p
.
size
(
2
)
==
T
+
1
);
TORCH_CHECK
(
p
.
size
(
0
)
==
B
&&
p
.
size
(
1
)
==
S
+
1
&&
p
.
size
(
2
)
==
T
+
1
);
TORCH_CHECK
(
ans_grad
.
size
(
0
)
==
b
);
TORCH_CHECK
(
ans_grad
.
size
(
0
)
==
b
);
bool
has_boundary
=
(
bool
)
optional_boundary
;
torch
::
Tensor
p_grad
=
torch
::
empty
({
B
,
S
+
1
,
T
+
1
},
opts
),
torch
::
Tensor
p_grad
=
torch
::
empty
({
B
,
S
+
1
,
T
+
1
},
opts
),
px_grad
=
torch
::
empty
({
B
,
S
,
T
+
1
},
opts
),
px_grad
=
(
has_boundary
?
torch
::
zeros
({
B
,
S
,
T
+
1
},
opts
)
:
py_grad
=
torch
::
empty
({
B
,
S
+
1
,
T
},
opts
),
torch
::
empty
({
B
,
S
,
T
+
1
},
opts
)),
py_grad
=
(
has_boundary
?
torch
::
zeros
({
B
,
S
,
T
+
1
},
opts
)
:
torch
::
empty
({
B
,
S
+
1
,
T
},
opts
));
// num_threads and num_blocks and BLOCK_SIZE can be tuned.
// num_threads and num_blocks and BLOCK_SIZE can be tuned.
// (however, num_threads may not be less than 128).
// (however, num_threads may not be less than 128).
const
int
num_threads
=
128
,
const
int
num_threads
=
128
,
num_blocks
=
128
,
num_blocks
=
256
,
BLOCK_SIZE
=
32
;
BLOCK_SIZE
=
32
;
// The blocks cover the 'p' matrix, which is of size (B, S+1, T+1),
// The blocks cover the 'p' matrix, which is of size (B, S+1, T+1),
...
@@ -819,7 +855,7 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
...
@@ -819,7 +855,7 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
num_t_blocks
=
T
/
BLOCK_SIZE
+
1
,
num_t_blocks
=
T
/
BLOCK_SIZE
+
1
,
num_iters
=
num_s_blocks
+
num_t_blocks
-
1
;
num_iters
=
num_s_blocks
+
num_t_blocks
-
1
;
if
(
(
bool
)
optional
_boundary
)
if
(
has
_boundary
)
TORCH_CHECK
(
optional_boundary
.
value
().
device
().
is_cuda
(),
TORCH_CHECK
(
optional_boundary
.
value
().
device
().
is_cuda
(),
"boundary information must be in CUDA tensor"
);
"boundary information must be in CUDA tensor"
);
else
else
...
@@ -838,5 +874,6 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
...
@@ -838,5 +874,6 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
iter
,
iter
,
overwrite_ans_grad
);
overwrite_ans_grad
);
}
}
std
::
cout
<<
"p_grad = "
<<
p_grad
;
return
std
::
vector
<
torch
::
Tensor
>
({
px_grad
,
py_grad
});
return
std
::
vector
<
torch
::
Tensor
>
({
px_grad
,
py_grad
});
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment