Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
77eed83f
Commit
77eed83f
authored
Jul 28, 2021
by
Daniel Povey
Browse files
Some progress, still drafting.
parent
e95d7864
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
678 additions
and
364 deletions
+678
-364
torch_mutual_information/mutual_information.py
torch_mutual_information/mutual_information.py
+29
-21
torch_mutual_information/mutual_information_cpu.cpp
torch_mutual_information/mutual_information_cpu.cpp
+173
-58
torch_mutual_information/mutual_information_cuda.cpp
torch_mutual_information/mutual_information_cuda.cpp
+17
-8
torch_mutual_information/mutual_information_cuda_kernel.cu
torch_mutual_information/mutual_information_cuda_kernel.cu
+459
-277
No files found.
torch_mutual_information/mutual_information.py
View file @
77eed83f
...
...
@@ -73,6 +73,14 @@ class MutualInformationRecursionFunction(torch.autograd.Function):
def
forward
(
ctx
,
px
:
torch
.
Tensor
,
py
:
torch
.
Tensor
,
boundaries
:
torch
.
Tensor
)
->
torch
.
Tensor
:
(
B
,
S
,
T
)
=
px
.
shape
# p is a tensor of shape (B, S + 1, T + 1) were p[s][t] is related to
# the mutual information of the pair of subsequences of x and y that are of
# length s and t respectively. p[0][0] will be 0.0 and p[S][T] is
# the mutual information of the entire pair of sequences, i.e. of lengths
# S and T respectively.
# q is a rearrangement of a tensor p which is of shape (B,S,T),
# using p[b,s,t] == q[b,s+t,t]. The reason for working with this
# representation is that each row of q depends only on the previous row,
...
...
@@ -85,15 +93,9 @@ class MutualInformationRecursionFunction(torch.autograd.Function):
# q[b, s-s_begin + t-t_begin, t-t_begin];
# note, rows of `boundaries` are [s_begin, t_begin, s_end, t_end].
if
px
.
requires_grad
or
py
.
requires_grad
:
q
=
torch
.
empty
(
B
,
S
,
T
,
device
=
px
.
device
,
dtype
=
px
.
dtype
)
else
:
# We don't need to store q if we are not going to do backprop, but we
# do pass in a temporary with one real row, expanded to have "fake" rows,
# which happens to be convenient for the CPU implementation.
q
=
torch
.
empty
({
1
,
1
,
T
},
device
=
px
.
device
,
dtype
=
px
.
dtype
).
expand
(
B
,
S
+
T
,
T
)
p
=
torch
.
empty
(
B
,
S
+
1
,
T
+
1
,
device
=
px
.
device
,
dtype
=
px
.
dtype
)
ans
=
_mutual_information_forward_dispatcher
(
px
,
py
,
boundaries
,
q
)
ans
=
_mutual_information_forward_dispatcher
(
px
,
py
,
boundaries
,
p
)
if
px
.
requires_grad
or
py
.
requires_grad
:
ctx
.
save_for_backward
(
px
,
py
,
boundaries
,
q
)
...
...
@@ -115,7 +117,7 @@ def mutual_information_recursion(input, px, py, boundaries=None):
make use of the formula computed by this function.
Args:
px: A torch.Tensor of some floating point type, with shape [B][S][T],
px: A torch.Tensor of some floating point type, with shape [B][S][T
+1
],
where B is the batch size, S is the length of the 'x' sequence
(including representations of EOS symbols but not BOS symbols), and S is the
length of the 'y' sequence (including representations of
...
...
@@ -139,13 +141,13 @@ def mutual_information_recursion(input, px, py, boundaries=None):
code assumes for optimization purposes that the T axis has
stride 1.
py: A torch.Tensor of the same dtype as px, with shape [B][S][T],
py: A torch.Tensor of the same dtype as px, with shape [B][S
+1
][T],
representing
py[b][s][t] = log [ p(y_t | x_{0..s-1}, y_{0..t-1}) / p(y_t) ]
This function does not treat x and y differently; the only difference
is that
the implementation assumes for optimization purposes that y
is likely to be the shorter sequence, i.e. that "most of the time T < S",
and it will be faster if you respect this.
is that
for optimization purposes we assume the last axis (the t axis)
has stride of 1; this is true if px and py are contiguous.
boundaries: If supplied, a torch.LongTensor of shape [B][4], where each row contains
[s_begin, t_begin, s_end, t_end]. If not supplied, the values
[0, 0, S, T] will be assumed. These are the beginning and
...
...
@@ -155,18 +157,24 @@ def mutual_information_recursion(input, px, py, boundaries=None):
Returns:
Returns a torch.Tensor of shape [B], containing the log of the mutuafl
information between the b'th pair of sequences. This is defined by
the following recursion on p[b,s,t] (where p is of shape [B,S
,T
]),
the following recursion on p[b,s,t] (where p is of shape [B,S
+1,T+1
]),
representing a mutual information between sub-sequences of lengths s and t:
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s,t], p[b,s,t-1] + py[b,s,t])
where in the case where boundaries==None: the edge cases are handled
by treating p[b,-1,-1] as 0 and all other quantities with negative
indexes as -infinity; and ans[b] would equal p[S-1,T-1]. The extension to
cases where the boundaries are specified should be obvious.
p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
(if s > 0 or t > 0)
where we handle edge cases by treating quantities with negative indexes
as -infinity. The extension to cases where the boundaries are specified
should be obvious; it just works on shorter sequences with offsets into
px and py.
"""
assert
px
.
ndim
==
3
and
px
.
shape
==
py
.
shape
and
px
.
dtype
==
py
.
dtype
assert
px
.
ndim
==
3
B
,
S
,
T1
=
px
.
shape
T
=
T1
-
1
assert
py
.
shape
==
(
B
,
S
+
1
,
T
)
assert
px
.
dtype
==
py
.
dtype
(
B
,
S
,
T
)
=
px
.
shape
if
boundaries
is
not
None
:
assert
boundaries
.
dtype
==
torch
.
LongTensor
...
...
torch_mutual_information/mutual_information_cpu.cpp
View file @
77eed83f
...
...
@@ -3,6 +3,13 @@
inline
double
Exp
(
double
x
)
{
return
exp
(
x
);
}
inline
double
Exp
(
float
x
)
{
return
expf
(
x
);
}
// returns log(exp(x) + exp(y)).
inline
double
LogAdd
(
double
x
,
double
y
)
{
double
diff
;
...
...
@@ -14,8 +21,7 @@ inline double LogAdd(double x, double y) {
diff
=
y
-
x
;
}
// diff is negative. x is now the larger one.
if
(
diff
>=
kMinLogDiffDouble
)
{
if
(
diff
>=
-
1000
)
{
double
res
;
res
=
x
+
log1p
(
exp
(
diff
));
return
res
;
...
...
@@ -35,99 +41,208 @@ inline float LogAdd(float x, float y) {
diff
=
y
-
x
;
}
// diff is negative. x is now the larger one.
if
(
diff
>=
kMinLogDiffFloat
)
{
if
(
diff
>=
-
200
)
{
float
res
;
res
=
x
+
log1pf
(
expf
(
diff
));
return
res
;
}
return
x
;
// return the larger one.
}
// forward of mutual_information. See """... """ comment of `mutual_information` in
// mutual_information.py for documentation of the behavior of this function.
torch
::
Tensor
mutual_information_cpu
(
torch
::
Tensor
px
,
torch
::
Tensor
py
,
std
::
optional
<
torch
::
Tensor
>
optional_boundary
,
torch
::
Tensor
q
)
{
torch
::
Tensor
p
)
{
TORCH_CHECK
(
px
.
dim
()
==
3
,
"px must be 3-dimensional"
);
TORCH_CHECK
(
py
.
dim
()
==
3
,
"params must be 3-dimensional."
);
TORCH_CHECK
(
q
.
dim
()
==
3
,
"params must be 3-dimensional."
);
TORCH_CHECK
(
py
.
dim
()
==
3
,
"py must be 3-dimensional."
);
TORCH_CHECK
(
p
.
dim
()
==
3
,
"p must be 3-dimensional."
);
TORCH_CHECK
(
px
.
device
().
is_cpu
()
&&
py
.
device
().
is_cpu
()
&&
p
.
device
().
is_cpu
(),
"inputs must be CPU tensors"
);
auto
scalar_t
=
px
.
scalar_type
();
auto
opts
=
torch
::
TensorOptions
().
dtype
(
scalar_t
).
device
(
px
.
device
());
const
int
B
=
px
.
size
(
0
),
S
=
px
.
size
(
1
),
T
=
px
.
size
(
2
);
TORCH_CHECK
(
q
.
size
(
0
)
==
B
&&
q
.
size
(
1
)
==
S
+
T
&&
q
.
size
(
2
)
==
T
);
T
=
px
.
size
(
2
)
-
1
;
TORCH_CHECK
(
py
.
size
(
0
)
==
B
&&
py
.
size
(
1
)
==
S
+
1
&&
py
.
size
(
2
)
==
T
);
TORCH_CHECK
(
p
.
size
(
0
)
==
B
&&
p
.
size
(
1
)
==
S
+
1
&&
p
.
size
(
2
)
==
T
+
1
);
torch
::
Tensor
ans
=
torch
::
empty
({
B
},
opts
);
auto
long_opts
=
torch
::
TensorOption
a
().
dtype
(
torch
::
kInt64
);
auto
long_opts
=
torch
::
TensorOption
s
().
dtype
(
torch
::
kInt64
)
.
device
(
px
.
device
())
;
bool
has_boundary
=
(
bool
)
optional_boundary
;
if
(
!
has_boundary
)
optional_boundary
=
torch
::
empty
({},
long_opts
);
optional_boundary
=
torch
::
empty
({
0
,
0
},
long_opts
);
TORCH_CHECK
(
optional_boundary
.
value
().
device
().
is_cpu
()
&&
optional_boundary
.
value
().
dtype
==
torch
::
kInt64
);
AT_DISPATCH_FLOATING_TYPES
(
px
.
scalar_type
(),
"mutual_information_cpu_loop"
,
([
&
]
{
auto
px_a
=
px
.
accessor
<
scalar_t
,
3
>
(),
py_a
=
py
.
accessor
<
scalar_t
,
3
>
();
for
(
int
c
=
0
;
c
<
C
;
c
++
)
{
scalar_t
sum_negative
=
0.0
,
sum_positive
=
0.0
,
scale
=
exp
(
params_a
[
c
][
0
]);
for
(
int
i
=
0
;
i
<
K
;
i
++
)
{
scalar_t
pos_scaled_param
=
params_a
[
c
][
1
+
K
+
i
]
*
scale
,
neg_scaled_param
=
params_a
[
c
][
K
-
i
]
*
scale
;
y_vals_a
[
c
][
K
+
i
]
=
sum_positive
-
pos_scaled_param
*
i
;
sum_positive
+=
pos_scaled_param
;
sum_negative
-=
neg_scaled_param
;
y_vals_a
[
c
][
K
-
i
-
1
]
=
sum_negative
+
neg_scaled_param
*
(
i
+
1
);
auto
px_a
=
px
.
packed_accessor32
<
scalar_t
,
3
>
(),
py_a
=
py
.
packed_accessor32
<
scalar_t
,
3
>
(),
p_a
=
p
.
packed_accessor32
<
scalar_t
,
3
>
();
auto
boundary_a
=
optional_boundary
.
value
().
packed_accessor32
<
int64_t
,
2
>
();
auto
ans_a
=
ans
.
packed_accessor32
<
scalar_t
,
1
>
();
for
(
int
b
=
0
b
<
B
;
b
++
)
{
int
s_begin
,
s_end
,
t_begin
,
t_end
;
if
(
has_boundary
)
{
s_begin
=
boundary_a
[
b
][
0
];
t_begin
=
boundary_a
[
b
][
1
];
s_end
=
boundary_a
[
b
][
2
];
t_end
=
boundary_a
[
b
][
3
];
}
else
{
s_begin
=
0
;
s_end
=
S
;
t_begin
=
0
;
t_end
=
T
;
}
p_a
[
b
][
s_begin
][
t_begin
]
=
0.0
;
for
(
int
s
=
s_begin
+
1
;
s
<=
s_end
;
++
s
)
p_a
[
b
][
s
][
t_begin
]
=
p_a
[
b
][
s
-
1
][
t_begin
]
+
px_a
[
b
][
s
-
1
][
t_begin
];
for
(
int
t
=
t_begin
+
1
;
t
<=
t_end
;
++
t
)
p_a
[
b
][
s_begin
][
t
]
=
p_a
[
b
][
s_begin
][
t
-
1
]
+
py_a
[
b
][
s_begin
][
t
-
1
];
for
(
int
s
=
s_begin
+
1
;
s
<=
s_end
;
++
s
)
{
scalar_t
p_s_t1
=
p_a
[
b
][
s
][
t_begin
];
for
(
int
t
=
t_begin
+
1
;
t
<=
t_end
;
++
t
)
{
// The following statement is a small optimization of:
// p_a[b][s][t] = LogAdd(p_a[b][s - 1][t] + px_a[b][s - 1][t],
// p_a[b][s][t - 1] + py_a[b][s][t - 1]);
// .. which obtains p_a[b][s][t - 1] from a register.
p_a
[
b
][
s
][
t
]
=
p_s_t1
=
LogAdd
(
p_a
[
b
][
s
-
1
][
t
]
+
px_a
[
b
][
s
-
1
][
t
],
p_s_t1
+
py_a
[
b
][
s
][
t
-
1
]);
}
}
ans_a
[
b
]
=
p_a
[
b
][
s_end
][
t_end
];
}
}));
return
ans
;
}
auto
input_a
=
input
.
accessor
<
scalar_t
,
3
>
(),
output_a
=
output
.
accessor
<
scalar_t
,
3
>
();
for
(
int
b
=
0
;
b
<
B
;
b
++
)
{
for
(
int
c
=
0
;
c
<
C
;
c
++
)
{
scalar_t
scale
=
exp
(
params_a
[
c
][
0
]),
inv_scale
=
1.0
/
scale
;
for
(
int
t
=
0
;
t
<
T
;
t
++
)
{
// `x` is the scaled input x plus an offset so that -K maps to 0.
// Note: the discontinuities in our function are at -(K-1) ... +(K+1),
// so in a sense -K and +K are not special, but we include those
// extra values as an easy way to handle the semi-infinite regions
// that are < -(K-1) and > (K-1)
scalar_t
input
=
input_a
[
b
][
c
][
t
],
x
=
input
*
inv_scale
+
K
;
if
(
x
<
0
)
x
=
0
;
else
if
(
x
>=
N
)
x
=
N
-
1
;
// C++ rounds toward zero.
int
n
=
(
int
)
x
;
// OK, at this point, 0 <= min < 2*K.
output_a
[
b
][
c
][
t
]
=
input
*
params_a
[
c
][
n
+
1
]
+
y_vals_a
[
c
][
n
];
// backward of mutual_information. Returns (px_grad, py_grad).
// p corresponds to what we computed in the forward pass.
std
::
vector
<
torch
::
Tensor
>
mutual_information_backward_cpu
(
torch
::
Tensor
px
,
torch
::
Tensor
py
,
std
::
optional
<
torch
::
Tensor
>
optional_boundary
,
torch
::
Tensor
p
,
torch
::
Tensor
ans_grad
)
{
TORCH_CHECK
(
px
.
dim
()
==
3
,
"px must be 3-dimensional"
);
TORCH_CHECK
(
py
.
dim
()
==
3
,
"py must be 3-dimensional."
);
TORCH_CHECK
(
p
.
dim
()
==
3
,
"p must be 3-dimensional."
);
TORCH_CHECK
(
ans_grad
.
dim
()
==
1
,
"ans_grad must be 3-dimensional."
);
TORCH_CHECK
(
px
.
device
().
is_cpu
()
&&
py
.
device
().
is_cpu
()
&&
p
.
device
().
is_cpu
()
&&
ans_grad
.
device
()
==
cpu
(),
"inputs must be CPU tensors"
);
auto
scalar_t
=
px
.
scalar_type
();
auto
opts
=
torch
::
TensorOptions
().
dtype
(
scalar_t
).
device
(
px
.
device
());
const
int
B
=
px
.
size
(
0
),
S
=
px
.
size
(
1
),
T
=
px
.
size
(
2
)
-
1
;
TORCH_CHECK
(
py
.
size
(
0
)
==
B
&&
py
.
size
(
1
)
==
S
+
1
&&
py
.
size
(
2
)
==
T
);
TORCH_CHECK
(
p
.
size
(
0
)
==
B
&&
p
.
size
(
1
)
==
S
+
1
&&
p
.
size
(
2
)
==
T
+
1
);
torch
::
Tensor
p_grad
=
torch
::
zeros
({
B
,
S
+
1
,
T
+
1
},
opts
);
auto
long_opts
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
px
.
device
());
bool
has_boundary
=
(
bool
)
optional_boundary
;
if
(
!
has_boundary
)
optional_boundary
=
torch
::
empty
({
0
,
0
},
long_opts
);
TORCH_CHECK
(
optional_boundary
.
value
().
device
().
is_cpu
()
&&
optional_boundary
.
value
().
dtype
==
torch
::
kInt64
);
AT_DISPATCH_FLOATING_TYPES
(
px
.
scalar_type
(),
"mutual_information_cpu_backward_loop"
,
([
&
]
{
auto
px_a
=
px
.
packed_accessor32
<
scalar_t
,
3
>
(),
py_a
=
py
.
packed_accessor32
<
scalar_t
,
3
>
(),
p_a
=
p
.
packed_accessor32
<
scalar_t
,
3
>
(),
p_grad_a
=
p
.
packed_accessor32
<
scalar_t
,
3
>
();
auto
ans_grad_a
=
ans_grad
.
packed_accessor32
<
scalar_t
,
1
>
();
auto
boundary_a
=
optional_boundary
.
value
().
packed_accessor32
<
int64_t
,
2
>
();
for
(
int
b
=
0
b
<
B
;
b
++
)
{
int
s_begin
,
s_end
,
t_begin
,
t_end
;
if
(
has_boundary
)
{
s_begin
=
boundary_a
[
b
][
0
];
t_begin
=
boundary_a
[
b
][
1
];
s_end
=
boundary_a
[
b
][
2
];
t_end
=
boundary_a
[
b
][
3
];
}
else
{
s_begin
=
0
;
s_end
=
S
;
t_begin
=
0
;
t_end
=
T
;
}
// Backprop for: ans_a[b] = p_a[b][s_end][t_end];
p_grad_a
[
b
][
s_end
][
t_end
]
=
ans_grad_a
[
b
];
for
(
int
s
=
s_end
;
s
>
s_begin
;
--
s
)
{
for
(
int
t
=
t_end
;
t
>
t_begin
;
--
t
)
{
// The statement we are backpropagating here is:
// p_a[b][s][t] = LogAdd(p_a[b][s - 1][t] + px_a[b][s - 1][t],
// p_a[b][s][t - 1] + py_a[b][s][t - 1]);
// .. which obtains p_a[b][s][t - 1] from a register.
scalar_t
term1
=
p_a
[
b
][
s
-
1
][
t
]
+
px_a
[
b
][
s
-
1
][
t
],
total
=
p_a
[
b
][
s
][
t
],
term1_deriv
=
exp
(
term1
-
total
),
term2_deriv
=
1.0
-
term1_deriv
,
grad
=
p_grad_a
[
b
][
s
][
t
],
term1_grad
=
term1_deriv
*
grad
,
term2_grad
=
term2_deriv
*
grad
;
// We can assign to px_grad_a here rather than add, because we
// know it's currently zero.
TORCH_CHECK
(
px_grad_a
[
b
][
s
-
1
][
t
]
==
0
);
px_grad_a
[
b
][
s
-
1
][
t
]
=
term1_grad
;
TORCH_CHECK
(
p_grad_a
[
b
][
s
-
1
][
t
]
==
0.0
);
// likewise..
p_grad_a
[
b
][
s
-
1
][
t
]
=
term1_grad
py_grad_a
[
b
][
s
][
t
-
1
]
+=
term2_grad
;
p_grad_a
[
b
][
s
][
t
-
1
]
+=
term2_grad
;
}
}
}}));
return
output
;
for
(
int
t
=
t_end
;
t
>=
t_begin
;
--
t
)
{
// Backprop for:
// p_a[b][s_begin][t] = p_a[b][s_begin][t - 1] + py_a[b][s_begin][t - 1];
scalar_t
this_p_grad
=
p_grad_a
[
b
][
s_begin
][
t
];
p_grad_a
[
b
][
s_begin
][
t
-
1
]
+=
this_p_grad
;
py_grad_a
[
b
][
s_begin
][
t
-
1
]
+=
this_p_grad
;
}
for
(
int
s
=
s_end
;
s
>=
s_begin
;
--
s
)
{
// Backprop for:
// p_a[b][s][t_begin] = p_a[b][s - 1][t_begin] + px_a[b][s - 1][t_begin];
scalar_t
this_p_grad
=
p_grad_a
[
b
][
s
][
s_begin
];
p_a
[
b
][
s
-
1
][
t_begin
]
+=
this_p_grad
;
px_a
[
b
][
s
-
1
][
t_begin
]
+=
this_p_grad
;
}
// There is no backprop for:
// p_a[b][s_begin][t_begin] = 0.0;
// .. but we can use this for a check, that the grad at the beginning
// of the sequence is equal to the grad at the end of the sequence.
if
(
ans_grad_a
[
b
]
!=
0.0
)
{
float
grad_ratio
=
p_a
[
b
][
s_begin
][
t_begin
]
/
ans_grad_a
[
b
];
if
(
grad_ratio
-
1.0
>
0.01
)
{
printf
(
"Warning: mutual_information backprop: expected these numbers to be the same: %f vs. %f
\n
"
,
(
float
)
p_a
[
b
][
s_begin
][
t_begin
],
(
float
)
ans_grad_a
[
b
]);
}
}
}
}));
return
ans
;
}
// backward of mutual_information. Returns (input_grad, params_grad)
std
::
vector
<
torch
::
Tensor
>
mutual_information_backward_cpu
(
torch
::
Tensor
input
,
torch
::
Tensor
params
,
torch
::
Tensor
output_grad
)
{
TORCH_CHECK
(
input
.
dim
()
==
3
,
"input must be 3-dimensional"
);
TORCH_CHECK
(
params
.
dim
()
==
2
,
"params must be 2-dimensional."
);
TORCH_CHECK
(
params
.
size
(
1
)
>=
3
&&
...
...
torch_mutual_information/mutual_information_cuda.cpp
View file @
77eed83f
...
...
@@ -3,18 +3,27 @@
// forward of mutual_information. """... """ comment of `mutual_information`
// in mutual_information.py documents the behavior of this function.
torch
::
Tensor
mutual_information_cuda
(
torch
::
Tensor
input
,
torch
::
Tensor
params
);
// It is the core recursion in the sequence-to-sequence mutual information
// computation.
// returns 'ans', of dimension B (batch size).
torch
::
Tensor
mutual_information_cuda
(
torch
::
Tensor
px
,
// [B][S][T+1]
torch
::
Tensor
py
,
// [B][S+1][T]
std
::
optional
<
torch
::
Tensor
>
boundary_info
,
// [B][4], int64_t.
torch
::
Tensor
p
);
// [B][S+1][T+1]; an output
// backward of mutual_information; returns (grad_input, grad_params).
std
::
vector
<
torch
::
Tensor
>
mutual_information_backward_cuda
(
torch
::
Tensor
input
,
torch
::
Tensor
params
,
torch
::
Tensor
grad_output
);
// backward of mutual_information; returns (grad_px, grad_py)
std
::
vector
<
torch
::
Tensor
>
mutual_information_backward_cuda
(
torch
::
Tensor
px
,
torch
::
Tensor
py
,
std
::
optional
<
torch
::
Tensor
>
boundary_info
,
torch
::
Tensor
p
,
torch
::
Tensor
ans_grad
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"mutual_information_cuda"
,
&
mutual_information_cuda
,
"
Learned nonlinearity
forward function (CUDA)"
);
m
.
def
(
"mutual_information_backward_cuda"
,
&
mutual_information_backward_cuda
,
"
Learned nonlinearity
backward function (CUDA)"
);
m
.
def
(
"mutual_information_cuda"
,
&
mutual_information_cuda
,
"
Mutual information
forward function (CUDA)"
);
m
.
def
(
"mutual_information_backward_cuda"
,
&
mutual_information_backward_cuda
,
"
Mutual information
backward function (CUDA)"
);
}
torch_mutual_information/mutual_information_cuda_kernel.cu
View file @
77eed83f
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment