Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
c39d5fe7
Commit
c39d5fe7
authored
Jul 30, 2021
by
Daniel Povey
Browse files
More (randomized) testing; bug fixes.
parent
9ebcf9d5
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
60 additions
and
35 deletions
+60
-35
torch_mutual_information/mutual_information.py
torch_mutual_information/mutual_information.py
+1
-1
torch_mutual_information/mutual_information_cpu.cpp
torch_mutual_information/mutual_information_cpu.cpp
+3
-3
torch_mutual_information/mutual_information_cuda_kernel.cu
torch_mutual_information/mutual_information_cuda_kernel.cu
+1
-1
torch_mutual_information/mutual_information_test.py
torch_mutual_information/mutual_information_test.py
+55
-30
No files found.
torch_mutual_information/mutual_information.py
View file @
c39d5fe7
...
@@ -106,7 +106,7 @@ class MutualInformationRecursionFunction(torch.autograd.Function):
...
@@ -106,7 +106,7 @@ class MutualInformationRecursionFunction(torch.autograd.Function):
ans
=
_mutual_information_forward_dispatcher
(
px
,
py
,
boundary
,
p
)
ans
=
_mutual_information_forward_dispatcher
(
px
,
py
,
boundary
,
p
)
print
(
f
"p =
{
p
}
, boundary =
{
boundary
}
"
)
#
print(f"p = {p}, boundary = {boundary}")
if
px
.
requires_grad
or
py
.
requires_grad
:
if
px
.
requires_grad
or
py
.
requires_grad
:
ctx
.
save_for_backward
(
px
,
py
,
boundary
,
p
)
ctx
.
save_for_backward
(
px
,
py
,
boundary
,
p
)
...
...
torch_mutual_information/mutual_information_cpu.cpp
View file @
c39d5fe7
...
@@ -231,16 +231,16 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
...
@@ -231,16 +231,16 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
// .. but we can use this for a check, that the grad at the beginning
// .. but we can use this for a check, that the grad at the beginning
// of the sequence is equal to the grad at the end of the sequence.
// of the sequence is equal to the grad at the end of the sequence.
if
(
ans_grad_a
[
b
]
!=
0.0
)
{
if
(
ans_grad_a
[
b
]
!=
0.0
)
{
float
grad_ratio
=
p_a
[
b
][
s_begin
][
t_begin
]
/
ans_grad_a
[
b
];
float
grad_ratio
=
p_
grad_
a
[
b
][
s_begin
][
t_begin
]
/
ans_grad_a
[
b
];
if
(
fabs
(
grad_ratio
-
1.0
)
>
0.01
)
{
if
(
fabs
(
grad_ratio
-
1.0
)
>
0.01
)
{
printf
(
"Warning: mutual_information backprop: expected these numbers to be the same: %f vs. %f
\n
"
,
printf
(
"Warning: mutual_information backprop: expected these numbers to be the same: %f vs. %f
\n
"
,
(
float
)
p_a
[
b
][
s_begin
][
t_begin
],
(
float
)
ans_grad_a
[
b
]);
(
float
)
p_
grad_
a
[
b
][
s_begin
][
t_begin
],
(
float
)
ans_grad_a
[
b
]);
}
}
}
}
}
}
}));
}));
std
::
cout
<<
"p_grad = "
<<
p_grad
;
//
std::cout << "p_grad = " << p_grad;
return
std
::
vector
<
torch
::
Tensor
>
({
px_grad
,
py_grad
});
return
std
::
vector
<
torch
::
Tensor
>
({
px_grad
,
py_grad
});
}
}
...
...
torch_mutual_information/mutual_information_cuda_kernel.cu
View file @
c39d5fe7
...
@@ -915,6 +915,6 @@ mutual_information_backward_cuda(torch::Tensor px,
...
@@ -915,6 +915,6 @@ mutual_information_backward_cuda(torch::Tensor px,
overwrite_ans_grad
);
overwrite_ans_grad
);
}
}
}));
}));
std
::
cout
<<
"p_grad = "
<<
p_grad
;
//
std::cout << "p_grad = " << p_grad;
return
std
::
vector
<
torch
::
Tensor
>
({
px_grad
,
py_grad
});
return
std
::
vector
<
torch
::
Tensor
>
({
px_grad
,
py_grad
});
}
}
torch_mutual_information/mutual_information_test.py
View file @
c39d5fe7
...
@@ -8,36 +8,61 @@ from torch_mutual_information import mutual_information_recursion
...
@@ -8,36 +8,61 @@ from torch_mutual_information import mutual_information_recursion
def
test_mutual_information_basic
():
def
test_mutual_information_basic
():
print
(
"Running test_mutual_information_basic()"
)
print
(
"Running test_mutual_information_basic()"
)
for
dtype
in
[
torch
.
float32
,
torch
.
float64
]:
px_grads
=
[]
for
_iter
in
range
(
100
):
py_grads
=
[]
(
B
,
S
,
T
)
=
(
random
.
randint
(
1
,
10
),
for
device
in
[
torch
.
device
(
'cpu'
),
torch
.
device
(
'cuda:0'
)
]:
random
.
randint
(
1
,
200
),
print
(
"dtype = "
,
dtype
,
", device = "
,
device
)
random
.
randint
(
1
,
200
))
B
=
2
random_px
=
(
random
.
random
()
<
0.1
)
S
=
14
random_py
=
(
random
.
random
()
<
0.1
)
T
=
14
boundary
=
torch
.
tensor
([
0
,
0
,
S
,
T
],
dtype
=
torch
.
int64
).
unsqueeze
(
0
).
expand
(
B
,
4
).
to
(
device
)
print
(
f
"B, S, T =
{
B
}
,
{
S
}
,
{
T
}
, random_px=
{
random_px
}
, random_py=
{
random_py
}
"
)
px
=
torch
.
zeros
(
B
,
S
,
T
+
1
,
dtype
=
dtype
).
to
(
device
)
# log of an odds ratio
for
dtype
in
[
torch
.
float32
,
torch
.
float64
]:
py
=
torch
.
zeros
(
B
,
S
+
1
,
T
,
dtype
=
dtype
).
to
(
device
)
# log of an odds ratio
px_grads
=
[]
px
.
requires_grad
=
True
py_grads
=
[]
py
.
requires_grad
=
True
m_vals
=
[]
for
device
in
[
torch
.
device
(
'cpu'
),
torch
.
device
(
'cuda:0'
)
]:
#m = mutual_information_recursion(px, py, None)
print
(
"dtype = "
,
dtype
,
", device = "
,
device
)
m
=
mutual_information_recursion
(
px
,
py
,
boundary
)
B
=
2
S
=
14
print
(
"m = "
,
m
,
", size = "
,
m
.
shape
)
T
=
14
print
(
"exp(m) = "
,
m
.
exp
())
boundary
=
torch
.
tensor
([
0
,
0
,
S
,
T
],
dtype
=
torch
.
int64
).
unsqueeze
(
0
).
expand
(
B
,
4
).
to
(
device
)
(
m
.
sum
()
*
3
).
backward
()
print
(
"px_grad = "
,
px
.
grad
)
if
device
==
torch
.
device
(
'cpu'
):
print
(
"py_grad = "
,
py
.
grad
)
if
random_px
:
px_grads
.
append
(
px
.
grad
.
to
(
'cpu'
))
px
=
torch
.
randn
(
B
,
S
,
T
+
1
,
dtype
=
dtype
).
to
(
device
)
# log of an odds ratio
py_grads
.
append
(
py
.
grad
.
to
(
'cpu'
))
else
:
if
not
torch
.
allclose
(
px_grads
[
0
],
px_grads
[
1
]):
px
=
torch
.
zeros
(
B
,
S
,
T
+
1
,
dtype
=
dtype
).
to
(
device
)
# log of an odds ratio
print
(
f
"px_grads differed CPU vs CUDA:
{
px_grads
[
0
]
}
vs.
{
px_grads
[
1
]
}
"
)
if
random_py
:
assert
0
py
=
torch
.
randn
(
B
,
S
+
1
,
T
,
dtype
=
dtype
).
to
(
device
)
# log of an odds ratio
if
not
torch
.
allclose
(
py_grads
[
0
],
py_grads
[
1
]):
else
:
print
(
f
"py_grads differed CPU vs CUDA:
{
py_grads
[
0
]
}
vs.
{
py_grads
[
1
]
}
"
)
py
=
torch
.
zeros
(
B
,
S
+
1
,
T
,
dtype
=
dtype
).
to
(
device
)
# log of an odds ratio
assert
0
else
:
px
=
px
.
to
(
device
).
detach
()
py
=
py
.
to
(
device
).
detach
()
px
.
requires_grad
=
True
py
.
requires_grad
=
True
#m = mutual_information_recursion(px, py, None)
m
=
mutual_information_recursion
(
px
,
py
,
boundary
)
#print("m = ", m, ", size = ", m.shape)
#print("exp(m) = ", m.exp())
(
m
.
sum
()
*
3
).
backward
()
#print("px_grad = ", px.grad)
#print("py_grad = ", py.grad)
px_grads
.
append
(
px
.
grad
.
to
(
'cpu'
))
py_grads
.
append
(
py
.
grad
.
to
(
'cpu'
))
m_vals
.
append
(
m
.
to
(
'cpu'
))
if
not
torch
.
allclose
(
m_vals
[
0
],
m_vals
[
1
],
atol
=
1.0e-05
,
rtol
=
1.0e-04
):
print
(
f
"m_vals differed CPU vs CUDA:
{
m_vals
[
0
]
}
vs.
{
m_vals
[
1
]
}
"
)
assert
0
if
not
torch
.
allclose
(
px_grads
[
0
],
px_grads
[
1
],
atol
=
1.0e-05
,
rtol
=
1.0e-04
):
print
(
f
"px_grads differed CPU vs CUDA:
{
px_grads
[
0
]
}
vs.
{
px_grads
[
1
]
}
"
)
assert
0
if
not
torch
.
allclose
(
py_grads
[
0
],
py_grads
[
1
],
atol
=
1.0e-05
,
rtol
=
1.0e-04
):
print
(
f
"py_grads differed CPU vs CUDA:
{
py_grads
[
0
]
}
vs.
{
py_grads
[
1
]
}
"
)
assert
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment