Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
52ae49ee
Commit
52ae49ee
authored
Jul 29, 2021
by
Daniel Povey
Browse files
Fix many bugs
parent
3c1ec347
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
83 additions
and
85 deletions
+83
-85
torch_mutual_information/mutual_information_cpu.cpp
torch_mutual_information/mutual_information_cpu.cpp
+22
-27
torch_mutual_information/mutual_information_cuda.cpp
torch_mutual_information/mutual_information_cuda.cpp
+5
-5
torch_mutual_information/mutual_information_cuda_kernel.cu
torch_mutual_information/mutual_information_cuda_kernel.cu
+56
-53
No files found.
torch_mutual_information/mutual_information_cpu.cpp
View file @
52ae49ee
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
#include <torch/extension.h>
#include <torch/extension.h>
inline
double
Exp
(
double
x
)
{
inline
double
Exp
(
double
x
)
{
return
exp
(
x
);
return
exp
(
x
);
}
}
...
@@ -52,13 +51,15 @@ inline float LogAdd(float x, float y) {
...
@@ -52,13 +51,15 @@ inline float LogAdd(float x, float y) {
// forward of mutual_information. See """... """ comment of `mutual_information` in
// forward of mutual_information. See """... """ comment of `mutual_information` in
// mutual_information.py for documentation of the behavior of this function.
// mutual_information.py for documentation of the behavior of this function.
// px: of shape [B, S, T+1] where
torch
::
Tensor
mutual_information_cpu
(
torch
::
Tensor
px
,
torch
::
Tensor
mutual_information_cpu
(
torch
::
Tensor
px
,
torch
::
Tensor
py
,
torch
::
Tensor
py
,
std
::
optional
<
torch
::
Tensor
>
optional_
boundary
,
torch
::
Tensor
boundary
,
torch
::
Tensor
p
)
{
torch
::
Tensor
p
)
{
TORCH_CHECK
(
px
.
dim
()
==
3
,
"px must be 3-dimensional"
);
TORCH_CHECK
(
px
.
dim
()
==
3
,
"px must be 3-dimensional"
);
TORCH_CHECK
(
py
.
dim
()
==
3
,
"py must be 3-dimensional."
);
TORCH_CHECK
(
py
.
dim
()
==
3
,
"py must be 3-dimensional."
);
TORCH_CHECK
(
p
.
dim
()
==
3
,
"p must be 3-dimensional."
);
TORCH_CHECK
(
p
.
dim
()
==
3
,
"p must be 3-dimensional."
);
TORCH_CHECK
(
boundary
.
dim
()
==
2
,
"boundary must be 2-dimensional."
);
TORCH_CHECK
(
px
.
device
().
is_cpu
()
&&
py
.
device
().
is_cpu
()
&&
p
.
device
().
is_cpu
(),
TORCH_CHECK
(
px
.
device
().
is_cpu
()
&&
py
.
device
().
is_cpu
()
&&
p
.
device
().
is_cpu
(),
"inputs must be CPU tensors"
);
"inputs must be CPU tensors"
);
...
@@ -70,26 +71,24 @@ torch::Tensor mutual_information_cpu(torch::Tensor px,
...
@@ -70,26 +71,24 @@ torch::Tensor mutual_information_cpu(torch::Tensor px,
T
=
px
.
size
(
2
)
-
1
;
T
=
px
.
size
(
2
)
-
1
;
TORCH_CHECK
(
py
.
size
(
0
)
==
B
&&
py
.
size
(
1
)
==
S
+
1
&&
py
.
size
(
2
)
==
T
);
TORCH_CHECK
(
py
.
size
(
0
)
==
B
&&
py
.
size
(
1
)
==
S
+
1
&&
py
.
size
(
2
)
==
T
);
TORCH_CHECK
(
p
.
size
(
0
)
==
B
&&
p
.
size
(
1
)
==
S
+
1
&&
p
.
size
(
2
)
==
T
+
1
);
TORCH_CHECK
(
p
.
size
(
0
)
==
B
&&
p
.
size
(
1
)
==
S
+
1
&&
p
.
size
(
2
)
==
T
+
1
);
TORCH_CHECK
((
boundary
.
size
(
0
)
==
0
&&
boundary
.
size
(
1
)
==
0
)
||
(
boundary
.
size
(
0
)
==
B
&&
boundary
.
size
(
1
)
==
4
));
TORCH_CHECK
(
boundary
.
device
().
is_cpu
()
&&
boundary
.
dtype
()
==
torch
::
kInt64
);
torch
::
Tensor
ans
=
torch
::
empty
({
B
},
opts
);
torch
::
Tensor
ans
=
torch
::
empty
({
B
},
opts
);
auto
long_opts
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
px
.
device
());
bool
has_boundary
=
(
boundary
.
size
(
0
)
!=
0
);
bool
has_boundary
=
(
bool
)
optional_boundary
;
if
(
!
has_boundary
)
optional_boundary
=
torch
::
empty
({
0
,
0
},
long_opts
);
TORCH_CHECK
(
optional_boundary
.
value
().
device
().
is_cpu
()
&&
optional_boundary
.
value
().
dtype
==
torch
::
kInt64
);
AT_DISPATCH_FLOATING_TYPES
(
px
.
scalar_type
(),
"mutual_information_cpu_loop"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
px
.
scalar_type
(),
"mutual_information_cpu_loop"
,
([
&
]
{
auto
px_a
=
px
.
packed_accessor32
<
scalar_t
,
3
>
(),
auto
px_a
=
px
.
packed_accessor32
<
scalar_t
,
3
>
(),
py_a
=
py
.
packed_accessor32
<
scalar_t
,
3
>
(),
py_a
=
py
.
packed_accessor32
<
scalar_t
,
3
>
(),
p_a
=
p
.
packed_accessor32
<
scalar_t
,
3
>
();
p_a
=
p
.
packed_accessor32
<
scalar_t
,
3
>
();
auto
boundary_a
=
optional_boundary
.
value
()
.
packed_accessor32
<
int64_t
,
2
>
();
auto
boundary_a
=
boundary
.
packed_accessor32
<
int64_t
,
2
>
();
auto
ans_a
=
ans
.
packed_accessor32
<
scalar_t
,
1
>
();
auto
ans_a
=
ans
.
packed_accessor32
<
scalar_t
,
1
>
();
for
(
int
b
=
0
b
<
B
;
b
++
)
{
for
(
int
b
=
0
;
b
<
B
;
b
++
)
{
int
s_begin
,
s_end
,
t_begin
,
t_end
;
int
s_begin
,
s_end
,
t_begin
,
t_end
;
if
(
has_boundary
)
{
if
(
has_boundary
)
{
s_begin
=
boundary_a
[
b
][
0
];
s_begin
=
boundary_a
[
b
][
0
];
...
@@ -130,16 +129,17 @@ torch::Tensor mutual_information_cpu(torch::Tensor px,
...
@@ -130,16 +129,17 @@ torch::Tensor mutual_information_cpu(torch::Tensor px,
std
::
vector
<
torch
::
Tensor
>
mutual_information_backward_cpu
(
std
::
vector
<
torch
::
Tensor
>
mutual_information_backward_cpu
(
torch
::
Tensor
px
,
torch
::
Tensor
px
,
torch
::
Tensor
py
,
torch
::
Tensor
py
,
std
::
optional
<
torch
::
Tensor
>
optional_
boundary
,
torch
::
Tensor
boundary
,
torch
::
Tensor
p
,
torch
::
Tensor
p
,
torch
::
Tensor
ans_grad
)
{
torch
::
Tensor
ans_grad
)
{
TORCH_CHECK
(
px
.
dim
()
==
3
,
"px must be 3-dimensional"
);
TORCH_CHECK
(
px
.
dim
()
==
3
,
"px must be 3-dimensional"
);
TORCH_CHECK
(
py
.
dim
()
==
3
,
"py must be 3-dimensional."
);
TORCH_CHECK
(
py
.
dim
()
==
3
,
"py must be 3-dimensional."
);
TORCH_CHECK
(
p
.
dim
()
==
3
,
"p must be 3-dimensional."
);
TORCH_CHECK
(
p
.
dim
()
==
3
,
"p must be 3-dimensional."
);
TORCH_CHECK
(
boundary
.
dim
()
==
2
,
"boundary must be 2-dimensional."
);
TORCH_CHECK
(
ans_grad
.
dim
()
==
1
,
"ans_grad must be 3-dimensional."
);
TORCH_CHECK
(
ans_grad
.
dim
()
==
1
,
"ans_grad must be 3-dimensional."
);
TORCH_CHECK
(
px
.
device
().
is_cpu
()
&&
py
.
device
().
is_cpu
()
&&
p
.
device
().
is_cpu
()
TORCH_CHECK
(
px
.
device
().
is_cpu
()
&&
py
.
device
().
is_cpu
()
&&
p
.
device
().
is_cpu
()
&&
ans_grad
.
device
()
==
cpu
(),
&&
ans_grad
.
device
()
.
is_
cpu
(),
"inputs must be CPU tensors"
);
"inputs must be CPU tensors"
);
auto
scalar_t
=
px
.
scalar_type
();
auto
scalar_t
=
px
.
scalar_type
();
...
@@ -150,8 +150,12 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
...
@@ -150,8 +150,12 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
T
=
px
.
size
(
2
)
-
1
;
T
=
px
.
size
(
2
)
-
1
;
TORCH_CHECK
(
py
.
size
(
0
)
==
B
&&
py
.
size
(
1
)
==
S
+
1
&&
py
.
size
(
2
)
==
T
);
TORCH_CHECK
(
py
.
size
(
0
)
==
B
&&
py
.
size
(
1
)
==
S
+
1
&&
py
.
size
(
2
)
==
T
);
TORCH_CHECK
(
p
.
size
(
0
)
==
B
&&
p
.
size
(
1
)
==
S
+
1
&&
p
.
size
(
2
)
==
T
+
1
);
TORCH_CHECK
(
p
.
size
(
0
)
==
B
&&
p
.
size
(
1
)
==
S
+
1
&&
p
.
size
(
2
)
==
T
+
1
);
TORCH_CHECK
((
boundary
.
size
(
0
)
==
0
&&
boundary
.
size
(
1
)
==
0
)
||
(
boundary
.
size
(
0
)
==
B
&&
boundary
.
size
(
1
)
==
4
));
TORCH_CHECK
(
boundary
.
device
().
is_cpu
()
&&
boundary
.
dtype
()
==
torch
::
kInt64
);
bool
has_boundary
=
(
bo
ol
)
optional_boundary
;
bool
has_boundary
=
(
bo
undary
.
size
(
0
)
!=
0
)
;
torch
::
Tensor
p_grad
=
torch
::
zeros
({
B
,
S
+
1
,
T
+
1
},
opts
),
torch
::
Tensor
p_grad
=
torch
::
zeros
({
B
,
S
+
1
,
T
+
1
},
opts
),
px_grad
=
(
has_boundary
?
torch
::
zeros
({
B
,
S
,
T
+
1
},
opts
)
:
px_grad
=
(
has_boundary
?
torch
::
zeros
({
B
,
S
,
T
+
1
},
opts
)
:
...
@@ -159,27 +163,18 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
...
@@ -159,27 +163,18 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
py_grad
=
(
has_boundary
?
torch
::
zeros
({
B
,
S
+
1
,
T
},
opts
)
:
py_grad
=
(
has_boundary
?
torch
::
zeros
({
B
,
S
+
1
,
T
},
opts
)
:
torch
::
empty
({
B
,
S
+
1
,
T
},
opts
));
torch
::
empty
({
B
,
S
+
1
,
T
},
opts
));
auto
long_opts
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
px
.
device
());
if
(
!
has_boundary
)
optional_boundary
=
torch
::
empty
({
0
,
0
},
long_opts
);
TORCH_CHECK
(
optional_boundary
.
value
().
device
().
is_cpu
()
&&
optional_boundary
.
value
().
dtype
==
torch
::
kInt64
);
AT_DISPATCH_FLOATING_TYPES
(
px
.
scalar_type
(),
"mutual_information_cpu_backward_loop"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
px
.
scalar_type
(),
"mutual_information_cpu_backward_loop"
,
([
&
]
{
auto
px_a
=
px
.
packed_accessor32
<
scalar_t
,
3
>
(),
auto
px_a
=
px
.
packed_accessor32
<
scalar_t
,
3
>
(),
py_a
=
py
.
packed_accessor32
<
scalar_t
,
3
>
(),
//
py_a = py.packed_accessor32<scalar_t, 3>(),
p_a
=
p
.
packed_accessor32
<
scalar_t
,
3
>
(),
p_a
=
p
.
packed_accessor32
<
scalar_t
,
3
>
(),
p_grad_a
=
p_grad
.
packed_accessor32
<
scalar_t
,
3
>
(),
p_grad_a
=
p_grad
.
packed_accessor32
<
scalar_t
,
3
>
(),
px_grad_a
=
px_grad
.
packed_accessor32
<
scalar_t
,
3
>
(),
px_grad_a
=
px_grad
.
packed_accessor32
<
scalar_t
,
3
>
(),
py_grad_a
=
py_grad
.
packed_accessor32
<
scalar_t
,
3
>
();
py_grad_a
=
py_grad
.
packed_accessor32
<
scalar_t
,
3
>
();
auto
ans_grad_a
=
ans_grad
.
packed_accessor32
<
scalar_t
,
1
>
();
auto
ans_grad_a
=
ans_grad
.
packed_accessor32
<
scalar_t
,
1
>
();
auto
boundary_a
=
boundary
.
packed_accessor32
<
int64_t
,
2
>
();
auto
boundary_a
=
optional_boundary
.
value
().
packed_accessor32
<
int64_t
,
2
>
();
for
(
int
b
=
0
;
b
<
B
;
b
++
)
{
for
(
int
b
=
0
b
<
B
;
b
++
)
{
int
s_begin
,
s_end
,
t_begin
,
t_end
;
int
s_begin
,
s_end
,
t_begin
,
t_end
;
if
(
has_boundary
)
{
if
(
has_boundary
)
{
s_begin
=
boundary_a
[
b
][
0
];
s_begin
=
boundary_a
[
b
][
0
];
...
...
torch_mutual_information/mutual_information_cuda.cpp
View file @
52ae49ee
#include <torch/extension.h>
#include <torch/extension.h>
/*
/*
Forward of mutual_information. See also """... """ comment of
Forward of mutual_information. See also """... """ comment of
`mutual_information` in mutual_information.py. This It is the core recursion
`mutual_information` in mutual_information.py. This It is the core recursion
...
@@ -33,8 +32,9 @@
...
@@ -33,8 +32,9 @@
contains, where for each batch element b, boundary[b] equals
contains, where for each batch element b, boundary[b] equals
[s_begin, t_begin, s_end, t_end]
[s_begin, t_begin, s_end, t_end]
which are the beginning and end (i.e. one-past-the-last) of the
which are the beginning and end (i.e. one-past-the-last) of the
x and y sequences that we should process. If not set, these
x and y sequences that we should process. Alternatively, may be
default to (0, 0, S, T); and they should not exceed these bounds.
a tensor of shape [0][0] and type int64_t; the elements will
default to (0, 0, S, T).
ans: a tensor `ans` of shape [B], where this function will set
ans: a tensor `ans` of shape [B], where this function will set
ans[b] = p[b][s_end][t_end],
ans[b] = p[b][s_end][t_end],
with s_end and t_end being (S, T) if `boundary` was specified,
with s_end and t_end being (S, T) if `boundary` was specified,
...
@@ -48,7 +48,7 @@
...
@@ -48,7 +48,7 @@
*/
*/
torch
::
Tensor
mutual_information_cuda
(
torch
::
Tensor
px
,
// [B][S][T+1]
torch
::
Tensor
mutual_information_cuda
(
torch
::
Tensor
px
,
// [B][S][T+1]
torch
::
Tensor
py
,
// [B][S+1][T]
torch
::
Tensor
py
,
// [B][S+1][T]
std
::
optional
<
torch
::
Tensor
>
boundary
_info
,
// [B][4], int64_t.
torch
::
Tensor
boundary
,
// [B][4], int64_t.
torch
::
Tensor
p
);
// [B][S+1][T+1]; an output
torch
::
Tensor
p
);
// [B][S+1][T+1]; an output
...
@@ -63,7 +63,7 @@ torch::Tensor mutual_information_cuda(torch::Tensor px, // [B][S][T+1]
...
@@ -63,7 +63,7 @@ torch::Tensor mutual_information_cuda(torch::Tensor px, // [B][S][T+1]
std
::
vector
<
torch
::
Tensor
>
mutual_information_backward_cuda
(
std
::
vector
<
torch
::
Tensor
>
mutual_information_backward_cuda
(
torch
::
Tensor
px
,
torch
::
Tensor
px
,
torch
::
Tensor
py
,
torch
::
Tensor
py
,
std
::
optional
<
torch
::
Tensor
>
boundary
_info
,
torch
::
Tensor
boundary
,
torch
::
Tensor
p
,
torch
::
Tensor
p
,
torch
::
Tensor
ans_grad
,
torch
::
Tensor
ans_grad
,
bool
overwrite_ans_grad
);
bool
overwrite_ans_grad
);
...
...
torch_mutual_information/mutual_information_cuda_kernel.cu
View file @
52ae49ee
...
@@ -4,7 +4,6 @@
...
@@ -4,7 +4,6 @@
#include <cmath> // for INFINITY
#include <cmath> // for INFINITY
// returns log(exp(x) + exp(y)).
// returns log(exp(x) + exp(y)).
__forceinline__
__device__
double
LogAdd
(
double
x
,
double
y
)
{
__forceinline__
__device__
double
LogAdd
(
double
x
,
double
y
)
{
double
diff
;
double
diff
;
...
@@ -22,7 +21,7 @@ __forceinline__ __device__ double LogAdd(double x, double y) {
...
@@ -22,7 +21,7 @@ __forceinline__ __device__ double LogAdd(double x, double y) {
}
}
// returns log(exp(x) + exp(y)).
// returns log(exp(x) + exp(y)).
__forceinline__
__device__
inline
float
LogAdd
(
float
x
,
float
y
)
{
__forceinline__
__device__
float
LogAdd
(
float
x
,
float
y
)
{
float
diff
;
float
diff
;
if
(
x
<
y
)
{
if
(
x
<
y
)
{
diff
=
x
-
y
;
diff
=
x
-
y
;
...
@@ -81,8 +80,9 @@ __forceinline__ __device__ inline float LogAdd(float x, float y) {
...
@@ -81,8 +80,9 @@ __forceinline__ __device__ inline float LogAdd(float x, float y) {
contains, where for each batch element b, boundary[b] equals
contains, where for each batch element b, boundary[b] equals
[s_begin, t_begin, s_end, t_end]
[s_begin, t_begin, s_end, t_end]
which are the beginning and end (i.e. one-past-the-last) of the
which are the beginning and end (i.e. one-past-the-last) of the
x and y sequences that we should process. If not set, these
x and y sequences that we should process. Otherwise, must be
default to (0, 0, S, T); and they should not exceed these bounds.
a tensor of shape [0][0] of type int64_t; the values will
default to (0, 0, S, T).
ans: a tensor `ans` of shape [B], where this function will set
ans: a tensor `ans` of shape [B], where this function will set
ans[b] = p[b][s_end][t_end],
ans[b] = p[b][s_end][t_end],
with s_end and t_end being (S, T) if `boundary` was specified,
with s_end and t_end being (S, T) if `boundary` was specified,
...
@@ -118,8 +118,8 @@ void mutual_information_kernel(
...
@@ -118,8 +118,8 @@ void mutual_information_kernel(
// You can read the following expressions as simplifications of, for example,
// You can read the following expressions as simplifications of, for example,
// num_s_blocks = ((S + 1) + BLOCK_SIZE - 1) / BLOCK_SIZE,
// num_s_blocks = ((S + 1) + BLOCK_SIZE - 1) / BLOCK_SIZE,
// i.e. rounding-up division of (S + 1) by BLOCK_SIZE, and the same for (T + 1).
// i.e. rounding-up division of (S + 1) by BLOCK_SIZE, and the same for (T + 1).
const
int
num_s_blocks
=
S
/
BLOCK_SIZE
+
1
,
const
int
num_s_blocks
=
S
/
BLOCK_SIZE
+
1
;
num_t_blocks
=
T
/
BLOCK_SIZE
+
1
;
//,
num_t_blocks = T / BLOCK_SIZE + 1;
// num_blocks_this_iter is an upper bound on the number of blocks of size
// num_blocks_this_iter is an upper bound on the number of blocks of size
// (BLOCK_SIZE by BLOCK_SIZE) that might be active on this iteration (`iter`).
// (BLOCK_SIZE by BLOCK_SIZE) that might be active on this iteration (`iter`).
...
@@ -174,7 +174,7 @@ void mutual_information_kernel(
...
@@ -174,7 +174,7 @@ void mutual_information_kernel(
bool
is_origin_block
=
(
s_block_begin
*
t_block_begin
==
0
);
bool
is_origin_block
=
(
s_block_begin
*
t_block_begin
==
0
);
if
(
boundary
.
size
(
0
)
!=
0
&&
threadIdx
.
x
<
4
)
if
(
boundary
.
size
(
0
)
!=
0
&&
threadIdx
.
x
<
4
)
boundary_buf
[
thread
Dim
.
x
]
=
boundary
[
b
][
thread
Dim
.
x
];
boundary_buf
[
thread
Idx
.
x
]
=
boundary
[
b
][
thread
Idx
.
x
];
__syncthreads
();
__syncthreads
();
int
s_begin
=
boundary_buf
[
0
],
int
s_begin
=
boundary_buf
[
0
],
t_begin
=
boundary_buf
[
1
],
t_begin
=
boundary_buf
[
1
],
...
@@ -384,7 +384,7 @@ void mutual_information_kernel(
...
@@ -384,7 +384,7 @@ void mutual_information_kernel(
// should help us quite a bit.
// should help us quite a bit.
for
(
int
i
=
0
;
i
<
block_S
+
block_T
-
1
;
++
i
)
{
for
(
int
i
=
0
;
i
<
block_S
+
block_T
-
1
;
++
i
)
{
int
s_in_block
=
threadIdx
.
x
,
int
s_in_block
=
threadIdx
.
x
,
t_in_block
=
i
-
block
_s
;
t_in_block
=
i
-
s_in_
block
;
if
(
s_in_block
<
block_S
&&
if
(
s_in_block
<
block_S
&&
static_cast
<
unsigned
int
>
(
t_in_block
)
<
static_cast
<
unsigned
int
>
(
block_T
))
{
static_cast
<
unsigned
int
>
(
t_in_block
)
<
static_cast
<
unsigned
int
>
(
block_T
))
{
int
s
=
s_in_block
+
s_block_begin
,
int
s
=
s_in_block
+
s_block_begin
,
...
@@ -489,7 +489,8 @@ void mutual_information_kernel(
...
@@ -489,7 +489,8 @@ void mutual_information_kernel(
of p_grad, we need context on the top and right instead of the bottom and
of p_grad, we need context on the top and right instead of the bottom and
left. So there are offsets of 1.
left. So there are offsets of 1.
*/
*/
template
<
typename
scalar_t
>
template
<
typename
scalar_t
,
int
BLOCK_SIZE
>
__global__
__global__
void
mutual_information_backward_kernel
(
void
mutual_information_backward_kernel
(
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
px
,
// B, S, T + 1, i.e. batch, x_seq_length, y_seq_length + 1
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
px
,
// B, S, T + 1, i.e. batch, x_seq_length, y_seq_length + 1
...
@@ -751,7 +752,7 @@ void mutual_information_backward_kernel(
...
@@ -751,7 +752,7 @@ void mutual_information_backward_kernel(
// mutual_information.py for documentation of the behavior of this function.
// mutual_information.py for documentation of the behavior of this function.
torch
::
Tensor
mutual_information_cuda
(
torch
::
Tensor
px
,
torch
::
Tensor
mutual_information_cuda
(
torch
::
Tensor
px
,
torch
::
Tensor
py
,
torch
::
Tensor
py
,
std
::
optional
<
torch
::
Tensor
>
optional_
boundary
,
torch
::
Tensor
boundary
,
torch
::
Tensor
p
)
{
torch
::
Tensor
p
)
{
TORCH_CHECK
(
px
.
dim
()
==
3
,
"px must be 3-dimensional"
);
TORCH_CHECK
(
px
.
dim
()
==
3
,
"px must be 3-dimensional"
);
TORCH_CHECK
(
py
.
dim
()
==
3
,
"py must be 3-dimensional."
);
TORCH_CHECK
(
py
.
dim
()
==
3
,
"py must be 3-dimensional."
);
...
@@ -767,12 +768,16 @@ torch::Tensor mutual_information_cuda(torch::Tensor px,
...
@@ -767,12 +768,16 @@ torch::Tensor mutual_information_cuda(torch::Tensor px,
T
=
px
.
size
(
2
)
-
1
;
T
=
px
.
size
(
2
)
-
1
;
TORCH_CHECK
(
py
.
size
(
0
)
==
B
&&
py
.
size
(
1
)
==
S
+
1
&&
py
.
size
(
2
)
==
T
);
TORCH_CHECK
(
py
.
size
(
0
)
==
B
&&
py
.
size
(
1
)
==
S
+
1
&&
py
.
size
(
2
)
==
T
);
TORCH_CHECK
(
p
.
size
(
0
)
==
B
&&
p
.
size
(
1
)
==
S
+
1
&&
p
.
size
(
2
)
==
T
+
1
);
TORCH_CHECK
(
p
.
size
(
0
)
==
B
&&
p
.
size
(
1
)
==
S
+
1
&&
p
.
size
(
2
)
==
T
+
1
);
TORCH_CHECK
((
boundary
.
size
(
0
)
==
0
&&
boundary
.
size
(
1
)
==
0
)
||
(
boundary
.
size
(
0
)
==
B
&&
boundary
.
size
(
1
)
==
4
));
TORCH_CHECK
(
boundary
.
device
().
is_cuda
()
&&
boundary
.
dtype
()
==
torch
::
kInt64
);
torch
::
Tensor
ans
=
torch
::
empty
({
B
},
opts
);
torch
::
Tensor
ans
=
torch
::
empty
({
B
},
opts
);
// num_threads and num_blocks and BLOCK_SIZE can be tuned.
// num_threads and num_blocks and BLOCK_SIZE can be tuned.
// (however, num_threads may not be less than 128).
// (however, num_threads may not be less than 128).
int
num_threads
=
128
,
const
int
num_threads
=
128
,
num_blocks
=
256
,
num_blocks
=
256
,
BLOCK_SIZE
=
32
;
BLOCK_SIZE
=
32
;
...
@@ -783,21 +788,17 @@ torch::Tensor mutual_information_cuda(torch::Tensor px,
...
@@ -783,21 +788,17 @@ torch::Tensor mutual_information_cuda(torch::Tensor px,
num_t_blocks
=
T
/
BLOCK_SIZE
+
1
,
num_t_blocks
=
T
/
BLOCK_SIZE
+
1
,
num_iters
=
num_s_blocks
+
num_t_blocks
-
1
;
num_iters
=
num_s_blocks
+
num_t_blocks
-
1
;
if
((
bool
)
optional_boundary
)
AT_DISPATCH_FLOATING_TYPES
(
px
.
scalar_type
(),
"mutual_information_cuda_stub"
,
([
&
]
{
TORCH_CHECK
(
optional_boundary
.
value
().
device
().
is_cuda
(),
for
(
int
iter
=
0
;
iter
<
num_iters
;
++
iter
)
{
"boundary information must be in CUDA tensor"
);
mutual_information_kernel
<
scalar_t
,
BLOCK_SIZE
><<<
num_blocks
,
num_threads
>>>
(
else
px
.
packed_accessor32
<
scalar_t
,
3
>
(),
optional_boundary
=
torch
::
empty
({
0
,
0
},
long_opts
);
py
.
packed_accessor32
<
scalar_t
,
3
>
(),
p
.
packed_accessor32
<
scalar_t
,
3
>
(),
for
(
int
iter
=
0
;
iter
<
num_iters
;
++
iter
)
{
boundary
.
packed_accessor32
<
int64_t
,
2
>
(),
mutual_information_kernel
<
scalar_t
,
BLOCK_SIZE
><<<
num_blocks
,
num_threads
>>>
(
ans
.
packed_accessor32
<
scalar_t
,
1
>
(),
px
.
packed_accessor32
<
scalar_t
,
3
>
(),
iter
);
py
.
packed_accessor32
<
scalar_t
,
3
>
(),
}
p
.
packed_accessor32
<
scalar_t
,
3
>
(),
}));
optional_boundary
.
value
().
packed_accessor32
<
int64_t
,
2
>
(),
ans
.
packed_accessor32
<
scalar_t
,
1
>
(),
iter
);
}
return
ans
;
return
ans
;
}
}
...
@@ -807,12 +808,13 @@ torch::Tensor mutual_information_cuda(torch::Tensor px,
...
@@ -807,12 +808,13 @@ torch::Tensor mutual_information_cuda(torch::Tensor px,
// If overwrite_ans_grad == true, will overwrite ans_grad with a value which
// If overwrite_ans_grad == true, will overwrite ans_grad with a value which
// should be identical to the original ans_grad if the computation worked
// should be identical to the original ans_grad if the computation worked
// as it should.
// as it should.
torch
::
Tensor
mutual_information_backward_cuda
(
torch
::
Tensor
px
,
std
::
vector
<
torch
::
Tensor
>
torch
::
Tensor
py
,
mutual_information_backward_cuda
(
torch
::
Tensor
px
,
std
::
optional
<
torch
::
Tensor
>
optional_boundary
,
torch
::
Tensor
py
,
torch
::
Tensor
p
,
torch
::
Tensor
boundary
,
torch
::
Tensor
ans_grad
,
torch
::
Tensor
p
,
bool
overwrite_ans_grad
)
{
torch
::
Tensor
ans_grad
,
bool
overwrite_ans_grad
)
{
TORCH_CHECK
(
px
.
dim
()
==
3
,
"px must be 3-dimensional"
);
TORCH_CHECK
(
px
.
dim
()
==
3
,
"px must be 3-dimensional"
);
TORCH_CHECK
(
py
.
dim
()
==
3
,
"py must be 3-dimensional."
);
TORCH_CHECK
(
py
.
dim
()
==
3
,
"py must be 3-dimensional."
);
TORCH_CHECK
(
p
.
dim
()
==
3
,
"p must be 3-dimensional."
);
TORCH_CHECK
(
p
.
dim
()
==
3
,
"p must be 3-dimensional."
);
...
@@ -832,9 +834,13 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
...
@@ -832,9 +834,13 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
TORCH_CHECK
(
py
.
size
(
0
)
==
B
&&
py
.
size
(
1
)
==
S
+
1
&&
py
.
size
(
2
)
==
T
);
TORCH_CHECK
(
py
.
size
(
0
)
==
B
&&
py
.
size
(
1
)
==
S
+
1
&&
py
.
size
(
2
)
==
T
);
TORCH_CHECK
(
p
.
size
(
0
)
==
B
&&
p
.
size
(
1
)
==
S
+
1
&&
p
.
size
(
2
)
==
T
+
1
);
TORCH_CHECK
(
p
.
size
(
0
)
==
B
&&
p
.
size
(
1
)
==
S
+
1
&&
p
.
size
(
2
)
==
T
+
1
);
TORCH_CHECK
(
ans_grad
.
size
(
0
)
==
b
);
TORCH_CHECK
((
boundary
.
size
(
0
)
==
0
&&
boundary
.
size
(
1
)
==
0
)
||
(
boundary
.
size
(
0
)
==
B
&&
boundary
.
size
(
1
)
==
4
));
TORCH_CHECK
(
boundary
.
device
().
is_cuda
()
&&
boundary
.
dtype
()
==
torch
::
kInt64
);
TORCH_CHECK
(
ans_grad
.
size
(
0
)
==
B
);
bool
has_boundary
=
(
bo
ol
)
optional_boundary
;
bool
has_boundary
=
(
bo
undary
.
size
(
0
)
!=
0
)
;
torch
::
Tensor
p_grad
=
torch
::
empty
({
B
,
S
+
1
,
T
+
1
},
opts
),
torch
::
Tensor
p_grad
=
torch
::
empty
({
B
,
S
+
1
,
T
+
1
},
opts
),
px_grad
=
(
has_boundary
?
torch
::
zeros
({
B
,
S
,
T
+
1
},
opts
)
:
px_grad
=
(
has_boundary
?
torch
::
zeros
({
B
,
S
,
T
+
1
},
opts
)
:
...
@@ -855,25 +861,22 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
...
@@ -855,25 +861,22 @@ torch::Tensor mutual_information_backward_cuda(torch::Tensor px,
num_t_blocks
=
T
/
BLOCK_SIZE
+
1
,
num_t_blocks
=
T
/
BLOCK_SIZE
+
1
,
num_iters
=
num_s_blocks
+
num_t_blocks
-
1
;
num_iters
=
num_s_blocks
+
num_t_blocks
-
1
;
if
(
has_boundary
)
TORCH_CHECK
(
optional_boundary
.
value
().
device
().
is_cuda
(),
AT_DISPATCH_FLOATING_TYPES
(
px
.
scalar_type
(),
"mutual_information_backward_stub"
,
([
&
]
{
"boundary information must be in CUDA tensor"
);
for
(
int
iter
=
num_iters
-
1
;
iter
>=
0
;
--
iter
)
{
else
mutual_information_backward_kernel
<
scalar_t
,
BLOCK_SIZE
><<<
num_blocks
,
num_threads
>>>
(
optional_boundary
=
torch
::
empty
({
0
,
0
},
long_opts
);
px
.
packed_accessor32
<
scalar_t
,
3
>
(),
py
.
packed_accessor32
<
scalar_t
,
3
>
(),
for
(
int
iter
=
num_iters
-
1
;
iter
>=
0
;
--
iter
)
{
p
.
packed_accessor32
<
scalar_t
,
3
>
(),
mutual_information_backward_kernel
<
scalar_t
,
BLOCK_SIZE
><<<
num_blocks
,
num_threads
>>>
(
ans_grad
.
packed_accessor32
<
scalar_t
,
1
>
(),
px
.
packed_accessor32
<
scalar_t
,
3
>
(),
p_grad
.
packed_accessor32
<
scalar_t
,
3
>
(),
py
.
packed_accessor32
<
scalar_t
,
3
>
(),
px_grad
.
packed_accessor32
<
scalar_t
,
3
>
(),
p
.
packed_accessor32
<
scalar_t
,
3
>
(),
py_grad
.
packed_accessor32
<
scalar_t
,
3
>
(),
ans_grad
.
packed_accessor32
<
scalar_t
,
1
>
,
boundary
.
packed_accessor32
<
int64_t
,
2
>
(),
p_grad
.
packed_accessor32
<
scalar_t
,
3
>
(),
iter
,
px_grad
.
packed_accessor32
<
scalar_t
,
3
>
(),
overwrite_ans_grad
);
py_grad
.
packed_accessor32
<
scalar_t
,
3
>
(),
}
optional_boundary
.
value
().
packed_accessor32
<
int64_t
,
2
>
(),
}));
iter
,
overwrite_ans_grad
);
}
std
::
cout
<<
"p_grad = "
<<
p_grad
;
std
::
cout
<<
"p_grad = "
<<
p_grad
;
return
std
::
vector
<
torch
::
Tensor
>
({
px_grad
,
py_grad
});
return
std
::
vector
<
torch
::
Tensor
>
({
px_grad
,
py_grad
});
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment