Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
77eed83f
Commit
77eed83f
authored
Jul 28, 2021
by
Daniel Povey
Browse files
Some progress, still drafting.
parent
e95d7864
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
678 additions
and
364 deletions
+678
-364
torch_mutual_information/mutual_information.py
torch_mutual_information/mutual_information.py
+29
-21
torch_mutual_information/mutual_information_cpu.cpp
torch_mutual_information/mutual_information_cpu.cpp
+173
-58
torch_mutual_information/mutual_information_cuda.cpp
torch_mutual_information/mutual_information_cuda.cpp
+17
-8
torch_mutual_information/mutual_information_cuda_kernel.cu
torch_mutual_information/mutual_information_cuda_kernel.cu
+459
-277
No files found.
torch_mutual_information/mutual_information.py
View file @
77eed83f
...
@@ -73,6 +73,14 @@ class MutualInformationRecursionFunction(torch.autograd.Function):
...
@@ -73,6 +73,14 @@ class MutualInformationRecursionFunction(torch.autograd.Function):
def
forward
(
ctx
,
px
:
torch
.
Tensor
,
py
:
torch
.
Tensor
,
boundaries
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
ctx
,
px
:
torch
.
Tensor
,
py
:
torch
.
Tensor
,
boundaries
:
torch
.
Tensor
)
->
torch
.
Tensor
:
(
B
,
S
,
T
)
=
px
.
shape
(
B
,
S
,
T
)
=
px
.
shape
# p is a tensor of shape (B, S + 1, T + 1) were p[s][t] is related to
# the mutual information of the pair of subsequences of x and y that are of
# length s and t respectively. p[0][0] will be 0.0 and p[S][T] is
# the mutual information of the entire pair of sequences, i.e. of lengths
# S and T respectively.
# q is a rearrangement of a tensor p which is of shape (B,S,T),
# q is a rearrangement of a tensor p which is of shape (B,S,T),
# using p[b,s,t] == q[b,s+t,t]. The reason for working with this
# using p[b,s,t] == q[b,s+t,t]. The reason for working with this
# representation is that each row of q depends only on the previous row,
# representation is that each row of q depends only on the previous row,
...
@@ -85,15 +93,9 @@ class MutualInformationRecursionFunction(torch.autograd.Function):
...
@@ -85,15 +93,9 @@ class MutualInformationRecursionFunction(torch.autograd.Function):
# q[b, s-s_begin + t-t_begin, t-t_begin];
# q[b, s-s_begin + t-t_begin, t-t_begin];
# note, rows of `boundaries` are [s_begin, t_begin, s_end, t_end].
# note, rows of `boundaries` are [s_begin, t_begin, s_end, t_end].
if
px
.
requires_grad
or
py
.
requires_grad
:
p
=
torch
.
empty
(
B
,
S
+
1
,
T
+
1
,
device
=
px
.
device
,
dtype
=
px
.
dtype
)
q
=
torch
.
empty
(
B
,
S
,
T
,
device
=
px
.
device
,
dtype
=
px
.
dtype
)
else
:
# We don't need to store q if we are not going to do backprop, but we
# do pass in a temporary with one real row, expanded to have "fake" rows,
# which happens to be convenient for the CPU implementation.
q
=
torch
.
empty
({
1
,
1
,
T
},
device
=
px
.
device
,
dtype
=
px
.
dtype
).
expand
(
B
,
S
+
T
,
T
)
ans
=
_mutual_information_forward_dispatcher
(
px
,
py
,
boundaries
,
q
)
ans
=
_mutual_information_forward_dispatcher
(
px
,
py
,
boundaries
,
p
)
if
px
.
requires_grad
or
py
.
requires_grad
:
if
px
.
requires_grad
or
py
.
requires_grad
:
ctx
.
save_for_backward
(
px
,
py
,
boundaries
,
q
)
ctx
.
save_for_backward
(
px
,
py
,
boundaries
,
q
)
...
@@ -115,7 +117,7 @@ def mutual_information_recursion(input, px, py, boundaries=None):
...
@@ -115,7 +117,7 @@ def mutual_information_recursion(input, px, py, boundaries=None):
make use of the formula computed by this function.
make use of the formula computed by this function.
Args:
Args:
px: A torch.Tensor of some floating point type, with shape [B][S][T],
px: A torch.Tensor of some floating point type, with shape [B][S][T
+1
],
where B is the batch size, S is the length of the 'x' sequence
where B is the batch size, S is the length of the 'x' sequence
(including representations of EOS symbols but not BOS symbols), and S is the
(including representations of EOS symbols but not BOS symbols), and S is the
length of the 'y' sequence (including representations of
length of the 'y' sequence (including representations of
...
@@ -139,13 +141,13 @@ def mutual_information_recursion(input, px, py, boundaries=None):
...
@@ -139,13 +141,13 @@ def mutual_information_recursion(input, px, py, boundaries=None):
code assumes for optimization purposes that the T axis has
code assumes for optimization purposes that the T axis has
stride 1.
stride 1.
py: A torch.Tensor of the same dtype as px, with shape [B][S][T],
py: A torch.Tensor of the same dtype as px, with shape [B][S
+1
][T],
representing
representing
py[b][s][t] = log [ p(y_t | x_{0..s-1}, y_{0..t-1}) / p(y_t) ]
py[b][s][t] = log [ p(y_t | x_{0..s-1}, y_{0..t-1}) / p(y_t) ]
This function does not treat x and y differently; the only difference
This function does not treat x and y differently; the only difference
is that
the implementation assumes for optimization purposes that y
is that
for optimization purposes we assume the last axis (the t axis)
is likely to be the shorter sequence, i.e. that "most of the time T < S",
has stride of 1; this is true if px and py are contiguous.
and it will be faster if you respect this.
boundaries: If supplied, a torch.LongTensor of shape [B][4], where each row contains
boundaries: If supplied, a torch.LongTensor of shape [B][4], where each row contains
[s_begin, t_begin, s_end, t_end]. If not supplied, the values
[s_begin, t_begin, s_end, t_end]. If not supplied, the values
[0, 0, S, T] will be assumed. These are the beginning and
[0, 0, S, T] will be assumed. These are the beginning and
...
@@ -155,18 +157,24 @@ def mutual_information_recursion(input, px, py, boundaries=None):
...
@@ -155,18 +157,24 @@ def mutual_information_recursion(input, px, py, boundaries=None):
Returns:
Returns:
Returns a torch.Tensor of shape [B], containing the log of the mutuafl
Returns a torch.Tensor of shape [B], containing the log of the mutuafl
information between the b'th pair of sequences. This is defined by
information between the b'th pair of sequences. This is defined by
the following recursion on p[b,s,t] (where p is of shape [B,S
,T
]),
the following recursion on p[b,s,t] (where p is of shape [B,S
+1,T+1
]),
representing a mutual information between sub-sequences of lengths s and t:
representing a mutual information between sub-sequences of lengths s and t:
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s,t], p[b,s,t-1] + py[b,s,t])
p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
where in the case where boundaries==None: the edge cases are handled
p[b,s,t-1] + py[b,s,t-1])
by treating p[b,-1,-1] as 0 and all other quantities with negative
(if s > 0 or t > 0)
indexes as -infinity; and ans[b] would equal p[S-1,T-1]. The extension to
cases where the boundaries are specified should be obvious.
where we handle edge cases by treating quantities with negative indexes
as -infinity. The extension to cases where the boundaries are specified
should be obvious; it just works on shorter sequences with offsets into
px and py.
"""
"""
assert
px
.
ndim
==
3
and
px
.
shape
==
py
.
shape
and
px
.
dtype
==
py
.
dtype
assert
px
.
ndim
==
3
B
,
S
,
T1
=
px
.
shape
T
=
T1
-
1
assert
py
.
shape
==
(
B
,
S
+
1
,
T
)
assert
px
.
dtype
==
py
.
dtype
(
B
,
S
,
T
)
=
px
.
shape
(
B
,
S
,
T
)
=
px
.
shape
if
boundaries
is
not
None
:
if
boundaries
is
not
None
:
assert
boundaries
.
dtype
==
torch
.
LongTensor
assert
boundaries
.
dtype
==
torch
.
LongTensor
...
...
torch_mutual_information/mutual_information_cpu.cpp
View file @
77eed83f
...
@@ -3,6 +3,13 @@
...
@@ -3,6 +3,13 @@
inline
double
Exp
(
double
x
)
{
return
exp
(
x
);
}
inline
double
Exp
(
float
x
)
{
return
expf
(
x
);
}
// returns log(exp(x) + exp(y)).
// returns log(exp(x) + exp(y)).
inline
double
LogAdd
(
double
x
,
double
y
)
{
inline
double
LogAdd
(
double
x
,
double
y
)
{
double
diff
;
double
diff
;
...
@@ -14,8 +21,7 @@ inline double LogAdd(double x, double y) {
...
@@ -14,8 +21,7 @@ inline double LogAdd(double x, double y) {
diff
=
y
-
x
;
diff
=
y
-
x
;
}
}
// diff is negative. x is now the larger one.
// diff is negative. x is now the larger one.
if
(
diff
>=
-
1000
)
{
if
(
diff
>=
kMinLogDiffDouble
)
{
double
res
;
double
res
;
res
=
x
+
log1p
(
exp
(
diff
));
res
=
x
+
log1p
(
exp
(
diff
));
return
res
;
return
res
;
...
@@ -35,99 +41,208 @@ inline float LogAdd(float x, float y) {
...
@@ -35,99 +41,208 @@ inline float LogAdd(float x, float y) {
diff
=
y
-
x
;
diff
=
y
-
x
;
}
}
// diff is negative. x is now the larger one.
// diff is negative. x is now the larger one.
if
(
diff
>=
-
200
)
{
if
(
diff
>=
kMinLogDiffFloat
)
{
float
res
;
float
res
;
res
=
x
+
log1pf
(
expf
(
diff
));
res
=
x
+
log1pf
(
expf
(
diff
));
return
res
;
return
res
;
}
}
return
x
;
// return the larger one.
return
x
;
// return the larger one.
}
}
// forward of mutual_information. See """... """ comment of `mutual_information` in
// forward of mutual_information. See """... """ comment of `mutual_information` in
// mutual_information.py for documentation of the behavior of this function.
// mutual_information.py for documentation of the behavior of this function.
torch
::
Tensor
mutual_information_cpu
(
torch
::
Tensor
px
,
torch
::
Tensor
mutual_information_cpu
(
torch
::
Tensor
px
,
torch
::
Tensor
py
,
torch
::
Tensor
py
,
std
::
optional
<
torch
::
Tensor
>
optional_boundary
,
std
::
optional
<
torch
::
Tensor
>
optional_boundary
,
torch
::
Tensor
q
)
{
torch
::
Tensor
p
)
{
TORCH_CHECK
(
px
.
dim
()
==
3
,
"px must be 3-dimensional"
);
TORCH_CHECK
(
px
.
dim
()
==
3
,
"px must be 3-dimensional"
);
TORCH_CHECK
(
py
.
dim
()
==
3
,
"params must be 3-dimensional."
);
TORCH_CHECK
(
py
.
dim
()
==
3
,
"py must be 3-dimensional."
);
TORCH_CHECK
(
q
.
dim
()
==
3
,
"params must be 3-dimensional."
);
TORCH_CHECK
(
p
.
dim
()
==
3
,
"p must be 3-dimensional."
);
TORCH_CHECK
(
px
.
device
().
is_cpu
()
&&
py
.
device
().
is_cpu
()
&&
p
.
device
().
is_cpu
(),
"inputs must be CPU tensors"
);
auto
scalar_t
=
px
.
scalar_type
();
auto
scalar_t
=
px
.
scalar_type
();
auto
opts
=
torch
::
TensorOptions
().
dtype
(
scalar_t
).
device
(
px
.
device
());
auto
opts
=
torch
::
TensorOptions
().
dtype
(
scalar_t
).
device
(
px
.
device
());
const
int
B
=
px
.
size
(
0
),
const
int
B
=
px
.
size
(
0
),
S
=
px
.
size
(
1
),
S
=
px
.
size
(
1
),
T
=
px
.
size
(
2
);
T
=
px
.
size
(
2
)
-
1
;
TORCH_CHECK
(
py
.
size
(
0
)
==
B
&&
py
.
size
(
1
)
==
S
+
1
&&
py
.
size
(
2
)
==
T
);
TORCH_CHECK
(
q
.
size
(
0
)
==
B
&&
q
.
size
(
1
)
==
S
+
T
&&
q
.
size
(
2
)
==
T
);
TORCH_CHECK
(
p
.
size
(
0
)
==
B
&&
p
.
size
(
1
)
==
S
+
1
&&
p
.
size
(
2
)
==
T
+
1
);
torch
::
Tensor
ans
=
torch
::
empty
({
B
},
opts
);
auto
long_opts
=
torch
::
TensorOption
a
().
dtype
(
torch
::
kInt64
);
auto
long_opts
=
torch
::
TensorOption
s
().
dtype
(
torch
::
kInt64
)
.
device
(
px
.
device
())
;
bool
has_boundary
=
(
bool
)
optional_boundary
;
bool
has_boundary
=
(
bool
)
optional_boundary
;
if
(
!
has_boundary
)
if
(
!
has_boundary
)
optional_boundary
=
torch
::
empty
({},
long_opts
);
optional_boundary
=
torch
::
empty
({
0
,
0
},
long_opts
);
TORCH_CHECK
(
optional_boundary
.
value
().
device
().
is_cpu
()
&&
optional_boundary
.
value
().
dtype
==
torch
::
kInt64
);
AT_DISPATCH_FLOATING_TYPES
(
px
.
scalar_type
(),
"mutual_information_cpu_loop"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
px
.
scalar_type
(),
"mutual_information_cpu_loop"
,
([
&
]
{
auto
px_a
=
px
.
accessor
<
scalar_t
,
3
>
(),
auto
px_a
=
px
.
packed_accessor32
<
scalar_t
,
3
>
(),
py_a
=
py
.
accessor
<
scalar_t
,
3
>
();
py_a
=
py
.
packed_accessor32
<
scalar_t
,
3
>
(),
for
(
int
c
=
0
;
c
<
C
;
c
++
)
{
p_a
=
p
.
packed_accessor32
<
scalar_t
,
3
>
();
scalar_t
sum_negative
=
0.0
,
auto
boundary_a
=
optional_boundary
.
value
().
packed_accessor32
<
int64_t
,
2
>
();
sum_positive
=
0.0
,
auto
ans_a
=
ans
.
packed_accessor32
<
scalar_t
,
1
>
();
scale
=
exp
(
params_a
[
c
][
0
]);
for
(
int
i
=
0
;
i
<
K
;
i
++
)
{
for
(
int
b
=
0
b
<
B
;
b
++
)
{
scalar_t
pos_scaled_param
=
params_a
[
c
][
1
+
K
+
i
]
*
scale
,
int
s_begin
,
s_end
,
t_begin
,
t_end
;
neg_scaled_param
=
params_a
[
c
][
K
-
i
]
*
scale
;
if
(
has_boundary
)
{
y_vals_a
[
c
][
K
+
i
]
=
sum_positive
-
pos_scaled_param
*
i
;
s_begin
=
boundary_a
[
b
][
0
];
sum_positive
+=
pos_scaled_param
;
t_begin
=
boundary_a
[
b
][
1
];
sum_negative
-=
neg_scaled_param
;
s_end
=
boundary_a
[
b
][
2
];
y_vals_a
[
c
][
K
-
i
-
1
]
=
sum_negative
+
neg_scaled_param
*
(
i
+
1
);
t_end
=
boundary_a
[
b
][
3
];
}
else
{
s_begin
=
0
;
s_end
=
S
;
t_begin
=
0
;
t_end
=
T
;
}
p_a
[
b
][
s_begin
][
t_begin
]
=
0.0
;
for
(
int
s
=
s_begin
+
1
;
s
<=
s_end
;
++
s
)
p_a
[
b
][
s
][
t_begin
]
=
p_a
[
b
][
s
-
1
][
t_begin
]
+
px_a
[
b
][
s
-
1
][
t_begin
];
for
(
int
t
=
t_begin
+
1
;
t
<=
t_end
;
++
t
)
p_a
[
b
][
s_begin
][
t
]
=
p_a
[
b
][
s_begin
][
t
-
1
]
+
py_a
[
b
][
s_begin
][
t
-
1
];
for
(
int
s
=
s_begin
+
1
;
s
<=
s_end
;
++
s
)
{
scalar_t
p_s_t1
=
p_a
[
b
][
s
][
t_begin
];
for
(
int
t
=
t_begin
+
1
;
t
<=
t_end
;
++
t
)
{
// The following statement is a small optimization of:
// p_a[b][s][t] = LogAdd(p_a[b][s - 1][t] + px_a[b][s - 1][t],
// p_a[b][s][t - 1] + py_a[b][s][t - 1]);
// .. which obtains p_a[b][s][t - 1] from a register.
p_a
[
b
][
s
][
t
]
=
p_s_t1
=
LogAdd
(
p_a
[
b
][
s
-
1
][
t
]
+
px_a
[
b
][
s
-
1
][
t
],
p_s_t1
+
py_a
[
b
][
s
][
t
-
1
]);
}
}
}
ans_a
[
b
]
=
p_a
[
b
][
s_end
][
t_end
];
}
}
}));
return
ans
;
}
auto
input_a
=
input
.
accessor
<
scalar_t
,
3
>
(),
output_a
=
output
.
accessor
<
scalar_t
,
3
>
();
for
(
int
b
=
0
;
b
<
B
;
b
++
)
{
// backward of mutual_information. Returns (px_grad, py_grad).
for
(
int
c
=
0
;
c
<
C
;
c
++
)
{
// p corresponds to what we computed in the forward pass.
scalar_t
scale
=
exp
(
params_a
[
c
][
0
]),
std
::
vector
<
torch
::
Tensor
>
mutual_information_backward_cpu
(
inv_scale
=
1.0
/
scale
;
torch
::
Tensor
px
,
for
(
int
t
=
0
;
t
<
T
;
t
++
)
{
torch
::
Tensor
py
,
// `x` is the scaled input x plus an offset so that -K maps to 0.
std
::
optional
<
torch
::
Tensor
>
optional_boundary
,
// Note: the discontinuities in our function are at -(K-1) ... +(K+1),
torch
::
Tensor
p
,
// so in a sense -K and +K are not special, but we include those
torch
::
Tensor
ans_grad
)
{
// extra values as an easy way to handle the semi-infinite regions
TORCH_CHECK
(
px
.
dim
()
==
3
,
"px must be 3-dimensional"
);
// that are < -(K-1) and > (K-1)
TORCH_CHECK
(
py
.
dim
()
==
3
,
"py must be 3-dimensional."
);
scalar_t
input
=
input_a
[
b
][
c
][
t
],
TORCH_CHECK
(
p
.
dim
()
==
3
,
"p must be 3-dimensional."
);
x
=
input
*
inv_scale
+
K
;
TORCH_CHECK
(
ans_grad
.
dim
()
==
1
,
"ans_grad must be 3-dimensional."
);
if
(
x
<
0
)
x
=
0
;
else
if
(
x
>=
N
)
x
=
N
-
1
;
TORCH_CHECK
(
px
.
device
().
is_cpu
()
&&
py
.
device
().
is_cpu
()
&&
p
.
device
().
is_cpu
()
// C++ rounds toward zero.
&&
ans_grad
.
device
()
==
cpu
(),
int
n
=
(
int
)
x
;
"inputs must be CPU tensors"
);
// OK, at this point, 0 <= min < 2*K.
output_a
[
b
][
c
][
t
]
=
input
*
params_a
[
c
][
n
+
1
]
+
y_vals_a
[
c
][
n
];
auto
scalar_t
=
px
.
scalar_type
();
auto
opts
=
torch
::
TensorOptions
().
dtype
(
scalar_t
).
device
(
px
.
device
());
const
int
B
=
px
.
size
(
0
),
S
=
px
.
size
(
1
),
T
=
px
.
size
(
2
)
-
1
;
TORCH_CHECK
(
py
.
size
(
0
)
==
B
&&
py
.
size
(
1
)
==
S
+
1
&&
py
.
size
(
2
)
==
T
);
TORCH_CHECK
(
p
.
size
(
0
)
==
B
&&
p
.
size
(
1
)
==
S
+
1
&&
p
.
size
(
2
)
==
T
+
1
);
torch
::
Tensor
p_grad
=
torch
::
zeros
({
B
,
S
+
1
,
T
+
1
},
opts
);
auto
long_opts
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
px
.
device
());
bool
has_boundary
=
(
bool
)
optional_boundary
;
if
(
!
has_boundary
)
optional_boundary
=
torch
::
empty
({
0
,
0
},
long_opts
);
TORCH_CHECK
(
optional_boundary
.
value
().
device
().
is_cpu
()
&&
optional_boundary
.
value
().
dtype
==
torch
::
kInt64
);
AT_DISPATCH_FLOATING_TYPES
(
px
.
scalar_type
(),
"mutual_information_cpu_backward_loop"
,
([
&
]
{
auto
px_a
=
px
.
packed_accessor32
<
scalar_t
,
3
>
(),
py_a
=
py
.
packed_accessor32
<
scalar_t
,
3
>
(),
p_a
=
p
.
packed_accessor32
<
scalar_t
,
3
>
(),
p_grad_a
=
p
.
packed_accessor32
<
scalar_t
,
3
>
();
auto
ans_grad_a
=
ans_grad
.
packed_accessor32
<
scalar_t
,
1
>
();
auto
boundary_a
=
optional_boundary
.
value
().
packed_accessor32
<
int64_t
,
2
>
();
for
(
int
b
=
0
b
<
B
;
b
++
)
{
int
s_begin
,
s_end
,
t_begin
,
t_end
;
if
(
has_boundary
)
{
s_begin
=
boundary_a
[
b
][
0
];
t_begin
=
boundary_a
[
b
][
1
];
s_end
=
boundary_a
[
b
][
2
];
t_end
=
boundary_a
[
b
][
3
];
}
else
{
s_begin
=
0
;
s_end
=
S
;
t_begin
=
0
;
t_end
=
T
;
}
// Backprop for: ans_a[b] = p_a[b][s_end][t_end];
p_grad_a
[
b
][
s_end
][
t_end
]
=
ans_grad_a
[
b
];
for
(
int
s
=
s_end
;
s
>
s_begin
;
--
s
)
{
for
(
int
t
=
t_end
;
t
>
t_begin
;
--
t
)
{
// The statement we are backpropagating here is:
// p_a[b][s][t] = LogAdd(p_a[b][s - 1][t] + px_a[b][s - 1][t],
// p_a[b][s][t - 1] + py_a[b][s][t - 1]);
// .. which obtains p_a[b][s][t - 1] from a register.
scalar_t
term1
=
p_a
[
b
][
s
-
1
][
t
]
+
px_a
[
b
][
s
-
1
][
t
],
total
=
p_a
[
b
][
s
][
t
],
term1_deriv
=
exp
(
term1
-
total
),
term2_deriv
=
1.0
-
term1_deriv
,
grad
=
p_grad_a
[
b
][
s
][
t
],
term1_grad
=
term1_deriv
*
grad
,
term2_grad
=
term2_deriv
*
grad
;
// We can assign to px_grad_a here rather than add, because we
// know it's currently zero.
TORCH_CHECK
(
px_grad_a
[
b
][
s
-
1
][
t
]
==
0
);
px_grad_a
[
b
][
s
-
1
][
t
]
=
term1_grad
;
TORCH_CHECK
(
p_grad_a
[
b
][
s
-
1
][
t
]
==
0.0
);
// likewise..
p_grad_a
[
b
][
s
-
1
][
t
]
=
term1_grad
py_grad_a
[
b
][
s
][
t
-
1
]
+=
term2_grad
;
p_grad_a
[
b
][
s
][
t
-
1
]
+=
term2_grad
;
}
}
}
}
}}));
for
(
int
t
=
t_end
;
t
>=
t_begin
;
--
t
)
{
return
output
;
// Backprop for:
// p_a[b][s_begin][t] = p_a[b][s_begin][t - 1] + py_a[b][s_begin][t - 1];
scalar_t
this_p_grad
=
p_grad_a
[
b
][
s_begin
][
t
];
p_grad_a
[
b
][
s_begin
][
t
-
1
]
+=
this_p_grad
;
py_grad_a
[
b
][
s_begin
][
t
-
1
]
+=
this_p_grad
;
}
for
(
int
s
=
s_end
;
s
>=
s_begin
;
--
s
)
{
// Backprop for:
// p_a[b][s][t_begin] = p_a[b][s - 1][t_begin] + px_a[b][s - 1][t_begin];
scalar_t
this_p_grad
=
p_grad_a
[
b
][
s
][
s_begin
];
p_a
[
b
][
s
-
1
][
t_begin
]
+=
this_p_grad
;
px_a
[
b
][
s
-
1
][
t_begin
]
+=
this_p_grad
;
}
// There is no backprop for:
// p_a[b][s_begin][t_begin] = 0.0;
// .. but we can use this for a check, that the grad at the beginning
// of the sequence is equal to the grad at the end of the sequence.
if
(
ans_grad_a
[
b
]
!=
0.0
)
{
float
grad_ratio
=
p_a
[
b
][
s_begin
][
t_begin
]
/
ans_grad_a
[
b
];
if
(
grad_ratio
-
1.0
>
0.01
)
{
printf
(
"Warning: mutual_information backprop: expected these numbers to be the same: %f vs. %f
\n
"
,
(
float
)
p_a
[
b
][
s_begin
][
t_begin
],
(
float
)
ans_grad_a
[
b
]);
}
}
}
}));
return
ans
;
}
}
// backward of mutual_information. Returns (input_grad, params_grad)
std
::
vector
<
torch
::
Tensor
>
mutual_information_backward_cpu
(
torch
::
Tensor
input
,
torch
::
Tensor
params
,
torch
::
Tensor
output_grad
)
{
TORCH_CHECK
(
input
.
dim
()
==
3
,
"input must be 3-dimensional"
);
TORCH_CHECK
(
input
.
dim
()
==
3
,
"input must be 3-dimensional"
);
TORCH_CHECK
(
params
.
dim
()
==
2
,
"params must be 2-dimensional."
);
TORCH_CHECK
(
params
.
dim
()
==
2
,
"params must be 2-dimensional."
);
TORCH_CHECK
(
params
.
size
(
1
)
>=
3
&&
TORCH_CHECK
(
params
.
size
(
1
)
>=
3
&&
...
...
torch_mutual_information/mutual_information_cuda.cpp
View file @
77eed83f
...
@@ -3,18 +3,27 @@
...
@@ -3,18 +3,27 @@
// forward of mutual_information. """... """ comment of `mutual_information`
// forward of mutual_information. """... """ comment of `mutual_information`
// in mutual_information.py documents the behavior of this function.
// in mutual_information.py documents the behavior of this function.
torch
::
Tensor
mutual_information_cuda
(
torch
::
Tensor
input
,
// It is the core recursion in the sequence-to-sequence mutual information
torch
::
Tensor
params
);
// computation.
// returns 'ans', of dimension B (batch size).
torch
::
Tensor
mutual_information_cuda
(
torch
::
Tensor
px
,
// [B][S][T+1]
torch
::
Tensor
py
,
// [B][S+1][T]
std
::
optional
<
torch
::
Tensor
>
boundary_info
,
// [B][4], int64_t.
torch
::
Tensor
p
);
// [B][S+1][T+1]; an output
// backward of mutual_information; returns (grad_input, grad_params).
// backward of mutual_information; returns (grad_px, grad_py)
std
::
vector
<
torch
::
Tensor
>
mutual_information_backward_cuda
(
torch
::
Tensor
input
,
std
::
vector
<
torch
::
Tensor
>
mutual_information_backward_cuda
(
torch
::
Tensor
params
,
torch
::
Tensor
px
,
torch
::
Tensor
grad_output
);
torch
::
Tensor
py
,
std
::
optional
<
torch
::
Tensor
>
boundary_info
,
torch
::
Tensor
p
,
torch
::
Tensor
ans_grad
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"mutual_information_cuda"
,
&
mutual_information_cuda
,
"
Learned nonlinearity
forward function (CUDA)"
);
m
.
def
(
"mutual_information_cuda"
,
&
mutual_information_cuda
,
"
Mutual information
forward function (CUDA)"
);
m
.
def
(
"mutual_information_backward_cuda"
,
&
mutual_information_backward_cuda
,
"
Learned nonlinearity
backward function (CUDA)"
);
m
.
def
(
"mutual_information_backward_cuda"
,
&
mutual_information_backward_cuda
,
"
Mutual information
backward function (CUDA)"
);
}
}
torch_mutual_information/mutual_information_cuda_kernel.cu
View file @
77eed83f
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment