Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
3c1ec347
Commit
3c1ec347
authored
Jul 29, 2021
by
Daniel Povey
Browse files
Get it to a stage where it looks like it might compile
parent
8ed6deff
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
277 additions
and
331 deletions
+277
-331
torch_mutual_information/mutual_information.py
torch_mutual_information/mutual_information.py
+33
-27
torch_mutual_information/mutual_information_cpu.cpp
torch_mutual_information/mutual_information_cpu.cpp
+16
-113
torch_mutual_information/mutual_information_cuda_kernel.cu
torch_mutual_information/mutual_information_cuda_kernel.cu
+228
-191
No files found.
torch_mutual_information/mutual_information.py
View file @
3c1ec347
...
...
@@ -44,66 +44,72 @@ except ImportError:
def
_mutual_information_forward_dispatcher
(
px
:
torch
.
Tensor
,
py
:
torch
.
Tensor
,
boundaries
:
torch
.
Tensor
,
q
:
torch
.
Tensor
)
->
torch
.
Tensor
:
boundaries
:
torch
.
Tensor
,
p
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
input
.
is_cuda
:
if
torch_mutual_information_cuda
is
None
:
raise
EnvironmentError
(
f
'Failed to load native CUDA module'
)
return
torch_mutual_information_cuda
.
mutual_information_cuda
(
px
,
py
,
boundaries
,
q
)
px
,
py
,
boundaries
,
p
)
else
:
return
torch_mutual_information_cpu
.
mutual_information_cpu
(
px
,
py
,
boundaries
,
q
)
px
,
py
,
boundaries
,
p
)
def
_mutual_information_backward_dispatcher
(
px
:
torch
.
Tensor
,
py
:
torch
.
Tensor
,
boundaries
:
torch
.
Tensor
,
q
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
boundaries
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
ans_grad
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
px
.
is_cuda
:
if
torch_mutual_information_cuda
is
None
:
raise
EnvironmentError
(
f
'Failed to load native CUDA module'
)
return
tuple
(
torch_mutual_information_cuda
.
mutual_information_backward_cuda
(
px
,
py
,
boundaries
,
q
))
overwrite_ans_grad
=
True
if
overwrite_ans_grad
:
ans_grad_copy
=
ans_grad
.
clone
()
ans
=
tuple
(
torch_mutual_information_cuda
.
mutual_information_backward_cuda
(
px
,
py
,
boundaries
,
p
,
ans_grad_copy
,
overwrite_ans_grad
))
if
overwrite_ans_grad
:
if
not
torch
.
allclose
(
ans_grad
,
ans_grad_copy
,
rtol
=
1.0e-02
):
print
(
f
"Warning: possible excsssive roundoff in mutual information backward "
"recursion: {ans_grad} vs. {ans_grad_copy}"
);
return
ans
else
:
return
tuple
(
torch_mutual_information_cpu
.
mutual_information_backward_cpu
(
px
,
py
,
boundaries
,
q
))
px
,
py
,
boundaries
,
p
,
ans_grad
))
class
MutualInformationRecursionFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
px
:
torch
.
Tensor
,
py
:
torch
.
Tensor
,
boundaries
:
torch
.
Tensor
)
->
torch
.
Tensor
:
(
B
,
S
,
T
)
=
px
.
shape
def
forward
(
ctx
,
px
:
torch
.
Tensor
,
py
:
torch
.
Tensor
,
boundaries
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
(
B
,
S
,
T1
)
=
px
.
shape
T
=
T1
-
1
;
assert
py
.
shape
==
(
B
,
S
+
1
,
T
)
if
boundaries
is
not
None
:
assert
boundaries
.
shape
==
(
B
,
4
)
# p is a tensor of shape (B, S + 1, T + 1) were p[s][t] is
related to
# p is a tensor of shape (B, S + 1, T + 1) were p[s][t] is
the
# the mutual information of the pair of subsequences of x and y that are of
# length s and t respectively. p[0][0] will be 0.0 and p[S][T] is
# the mutual information of the entire pair of sequences, i.e. of lengths
# S and T respectively.
# q is a rearrangement of a tensor p which is of shape (B,S,T),
# using p[b,s,t] == q[b,s+t,t]. The reason for working with this
# representation is that each row of q depends only on the previous row,
# so we can access the rows sequenctially and this leads to
# better memory access patterns. We are assuming that most likely
# T < S, which means that q should not require much more memory than p.
#
# Actually we access q beginning from 0 indexes even if `boundaries`
# has t_begin > 0 or s_begin > 0, i.e. we really access q as
# q[b, s-s_begin + t-t_begin, t-t_begin];
# note, rows of `boundaries` are [s_begin, t_begin, s_end, t_end].
# It is computed as follows (in C++ and CUDA):
# p[b,0,0] = 0.0
# p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
# p[b,s,t-1] + py[b,s,t-1])
# if s > 0 or t > 0,
# treating values with any -1 index as -infinity.
# .. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
p
=
torch
.
empty
(
B
,
S
+
1
,
T
+
1
,
device
=
px
.
device
,
dtype
=
px
.
dtype
)
ans
=
_mutual_information_forward_dispatcher
(
px
,
py
,
boundaries
,
p
)
if
px
.
requires_grad
or
py
.
requires_grad
:
ctx
.
save_for_backward
(
px
,
py
,
boundaries
,
q
)
ctx
.
save_for_backward
(
px
,
py
,
boundaries
,
p
)
@
staticmethod
def
backward
(
ctx
,
ans_grad
:
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
None
]:
(
px
,
py
,
boundaries
,
q
)
=
ctx
.
saved_tensors
(
px_grad
,
py_grad
)
=
_mutual_information_backward_dispatcher
(
px
,
py
,
boundaries
,
q
)
(
px
,
py
,
boundaries
,
p
)
=
ctx
.
saved_tensors
(
px_grad
,
py_grad
)
=
_mutual_information_backward_dispatcher
(
px
,
py
,
boundaries
,
p
,
ans_grad
)
return
(
px_grad
,
py_grad
,
None
)
...
...
torch_mutual_information/mutual_information_cpu.cpp
View file @
3c1ec347
...
...
@@ -151,11 +151,16 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
TORCH_CHECK
(
py
.
size
(
0
)
==
B
&&
py
.
size
(
1
)
==
S
+
1
&&
py
.
size
(
2
)
==
T
);
TORCH_CHECK
(
p
.
size
(
0
)
==
B
&&
p
.
size
(
1
)
==
S
+
1
&&
p
.
size
(
2
)
==
T
+
1
);
torch
::
Tensor
p_grad
=
torch
::
zeros
({
B
,
S
+
1
,
T
+
1
},
opts
);
bool
has_boundary
=
(
bool
)
optional_boundary
;
torch
::
Tensor
p_grad
=
torch
::
zeros
({
B
,
S
+
1
,
T
+
1
},
opts
),
px_grad
=
(
has_boundary
?
torch
::
zeros
({
B
,
S
,
T
+
1
},
opts
)
:
torch
::
empty
({
B
,
S
,
T
+
1
},
opts
)),
py_grad
=
(
has_boundary
?
torch
::
zeros
({
B
,
S
+
1
,
T
},
opts
)
:
torch
::
empty
({
B
,
S
+
1
,
T
},
opts
));
auto
long_opts
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
px
.
device
());
bool
has_boundary
=
(
bool
)
optional_boundary
;
if
(
!
has_boundary
)
optional_boundary
=
torch
::
empty
({
0
,
0
},
long_opts
);
...
...
@@ -166,7 +171,9 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
auto
px_a
=
px
.
packed_accessor32
<
scalar_t
,
3
>
(),
py_a
=
py
.
packed_accessor32
<
scalar_t
,
3
>
(),
p_a
=
p
.
packed_accessor32
<
scalar_t
,
3
>
(),
p_grad_a
=
p
.
packed_accessor32
<
scalar_t
,
3
>
();
p_grad_a
=
p_grad
.
packed_accessor32
<
scalar_t
,
3
>
(),
px_grad_a
=
px_grad
.
packed_accessor32
<
scalar_t
,
3
>
(),
py_grad_a
=
py_grad
.
packed_accessor32
<
scalar_t
,
3
>
();
auto
ans_grad_a
=
ans_grad
.
packed_accessor32
<
scalar_t
,
1
>
();
...
...
@@ -196,19 +203,17 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
// .. which obtains p_a[b][s][t - 1] from a register.
scalar_t
term1
=
p_a
[
b
][
s
-
1
][
t
]
+
px_a
[
b
][
s
-
1
][
t
],
// term2 = p_a[b][s][t - 1] + py_a[b][s][t - 1], <-- not
// actually needed..
total
=
p_a
[
b
][
s
][
t
],
term1_deriv
=
exp
(
term1
-
total
),
term2_deriv
=
1.0
-
term1_deriv
,
grad
=
p_grad_a
[
b
][
s
][
t
],
term1_grad
=
term1_deriv
*
grad
,
term2_grad
=
term2_deriv
*
grad
;
// We can assign to px_grad_a here rather than add, because we
// know it's currently zero.
TORCH_CHECK
(
px_grad_a
[
b
][
s
-
1
][
t
]
==
0
);
px_grad_a
[
b
][
s
-
1
][
t
]
=
term1_grad
;
TORCH_CHECK
(
p_grad_a
[
b
][
s
-
1
][
t
]
==
0.0
);
// likewise..
p_grad_a
[
b
][
s
-
1
][
t
]
=
term1_grad
py_grad_a
[
b
][
s
][
t
-
1
]
+=
term2_grad
;
p_grad_a
[
b
][
s
-
1
][
t
]
=
term1_grad
;
py_grad_a
[
b
][
s
][
t
-
1
]
=
term2_grad
;
p_grad_a
[
b
][
s
][
t
-
1
]
+=
term2_grad
;
}
}
...
...
@@ -239,111 +244,9 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
}
}
}));
return
ans
;
}
TORCH_CHECK
(
input
.
dim
()
==
3
,
"input must be 3-dimensional"
);
TORCH_CHECK
(
params
.
dim
()
==
2
,
"params must be 2-dimensional."
);
TORCH_CHECK
(
params
.
size
(
1
)
>=
3
&&
((
params
.
size
(
1
)
-
1
)
&
(
params
.
size
(
1
)
-
2
))
==
0
,
"params.size(1) has invalid value, must be a power of 2 plus 1."
);
TORCH_CHECK
(
params
.
size
(
0
)
==
input
.
size
(
1
),
"params vs input channels mismatch"
);
TORCH_CHECK
(
input
.
sizes
()
==
output_grad
.
sizes
(),
"Output-grad vs. input sizes mismatch."
);
TORCH_CHECK
(
input
.
device
().
is_cpu
(),
"Input must be a CPU tensor"
);
TORCH_CHECK
(
params
.
device
().
is_cpu
(),
"Params must be a CPU tensor"
);
TORCH_CHECK
(
output_grad
.
device
().
is_cpu
(),
"Output-grad must be a CPU tensor"
);
const
int
B
=
input
.
size
(
0
),
C
=
input
.
size
(
1
),
T
=
input
.
size
(
2
),
N
=
params
.
size
(
1
)
-
1
,
K
=
N
/
2
;
auto
scalar_t
=
input
.
scalar_type
();
auto
opts
=
torch
::
TensorOptions
().
dtype
(
scalar_t
).
device
(
input
.
device
());
torch
::
Tensor
y_vals
=
torch
::
empty
({
C
,
N
},
opts
),
y_vals_grad
=
torch
::
zeros
({
C
,
N
},
opts
),
params_grad
=
torch
::
zeros
({
C
,
N
+
1
},
opts
),
input_grad
=
torch
::
zeros
({
B
,
C
,
T
},
opts
);
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"mutual_information_backward_cpu_loop"
,
([
&
]
{
auto
params_a
=
params
.
accessor
<
scalar_t
,
2
>
(),
params_grad_a
=
params_grad
.
accessor
<
scalar_t
,
2
>
(),
y_vals_a
=
y_vals
.
accessor
<
scalar_t
,
2
>
(),
y_vals_grad_a
=
y_vals_grad
.
accessor
<
scalar_t
,
2
>
();
for
(
int
c
=
0
;
c
<
C
;
c
++
)
{
scalar_t
sum_negative
=
0.0
,
sum_positive
=
0.0
,
scale
=
exp
(
params_a
[
c
][
0
]);
for
(
int
i
=
0
;
i
<
K
;
i
++
)
{
scalar_t
pos_scaled_param
=
params_a
[
c
][
1
+
K
+
i
]
*
scale
,
neg_scaled_param
=
params_a
[
c
][
K
-
i
]
*
scale
;
y_vals_a
[
c
][
K
+
i
]
=
sum_positive
-
pos_scaled_param
*
i
;
sum_positive
+=
pos_scaled_param
;
sum_negative
-=
neg_scaled_param
;
y_vals_a
[
c
][
K
-
i
-
1
]
=
sum_negative
+
neg_scaled_param
*
(
i
+
1
);
}
}
auto
input_a
=
input
.
accessor
<
scalar_t
,
3
>
(),
output_grad_a
=
output_grad
.
accessor
<
scalar_t
,
3
>
(),
input_grad_a
=
input_grad
.
accessor
<
scalar_t
,
3
>
();
for
(
int
b
=
0
;
b
<
B
;
b
++
)
{
for
(
int
c
=
0
;
c
<
C
;
c
++
)
{
scalar_t
inv_scale
=
exp
(
-
params_a
[
c
][
0
]);
for
(
int
t
=
0
;
t
<
T
;
t
++
)
{
scalar_t
input
=
input_a
[
b
][
c
][
t
],
x
=
input
*
inv_scale
+
K
,
output_grad
=
output_grad_a
[
b
][
c
][
t
];
if
(
x
<
0
)
x
=
0
;
else
if
(
x
>=
N
)
x
=
N
-
1
;
// C++ rounds toward zero.
int
n
=
(
int
)
x
;
// OK, at this point, 0 <= n < 2*K.
// backprop for:
// output_a[b][c][t] = input * params_a[c][n + 1] + y_vals_a[c][n];
params_grad_a
[
c
][
n
+
1
]
+=
output_grad
*
input
;
y_vals_grad_a
[
c
][
n
]
+=
output_grad
;
input_grad_a
[
b
][
c
][
t
]
=
output_grad
*
params_a
[
c
][
n
+
1
];
}
}
}
// Now do the backprop for the loop above where we set y_vals_a.
for
(
int
c
=
0
;
c
<
C
;
c
++
)
{
scalar_t
scale
=
exp
(
params_a
[
c
][
0
]),
scale_grad
=
0.0
,
sum_negative_grad
=
0.0
,
sum_positive_grad
=
0.0
;
for
(
int
i
=
K
-
1
;
i
>=
0
;
i
--
)
{
// Backprop for: y_vals_a[c][K - i - 1] = sum_negative + neg_scaled_param * (i + 1):
scalar_t
y_grad_neg
=
y_vals_grad_a
[
c
][
K
-
i
-
1
];
sum_negative_grad
+=
y_grad_neg
;
scalar_t
neg_scaled_param_grad
=
y_grad_neg
*
(
i
+
1
);
// Backprop for: sum_negative -= neg_scaled_param;
neg_scaled_param_grad
-=
sum_negative_grad
;
// Backprop for: sum_positive += pos_scaled_param;
scalar_t
pos_scaled_param_grad
=
sum_positive_grad
;
// Backprop for: y_vals_a[c][K + i] = sum_positive - pos_scaled_param * i;
scalar_t
y_grad_pos
=
y_vals_grad_a
[
c
][
K
+
i
];
pos_scaled_param_grad
-=
i
*
y_grad_pos
;
sum_positive_grad
+=
y_grad_pos
;
// Backprop for: pos_scaled_param = params_a[c][1 + K + i] * scale,
// and: neg_scaled_param = params_a[c][K - i] * scale;
params_grad_a
[
c
][
1
+
K
+
i
]
+=
pos_scaled_param_grad
*
scale
;
params_grad_a
[
c
][
K
-
i
]
+=
neg_scaled_param_grad
*
scale
;
scale_grad
+=
(
pos_scaled_param_grad
*
params_a
[
c
][
1
+
K
+
i
]
+
neg_scaled_param_grad
*
params_a
[
c
][
K
-
i
]);
}
// Backprop for: scale = exp(params_a[c][0]),
params_grad_a
[
c
][
0
]
+=
scale
*
scale_grad
;
}
}));
return
std
::
vector
<
torch
::
Tensor
>
({
input_grad
,
params_grad
});
std
::
cout
<<
"p_grad = "
<<
p_grad
;
return
std
::
vector
<
torch
::
Tensor
>
({
px_grad
,
py_grad
});
}
...
...
torch_mutual_information/mutual_information_cuda_kernel.cu
View file @
3c1ec347
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment