Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
b172f0bc
Commit
b172f0bc
authored
Jul 31, 2021
by
Daniel Povey
Browse files
Finished debugging and testing, I believe.
parent
ae1d18d6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
118 additions
and
5 deletions
+118
-5
torch_mutual_information/mutual_information_cuda_kernel.cu
torch_mutual_information/mutual_information_cuda_kernel.cu
+8
-4
torch_mutual_information/mutual_information_test.py
torch_mutual_information/mutual_information_test.py
+110
-1
No files found.
torch_mutual_information/mutual_information_cuda_kernel.cu
View file @
b172f0bc
...
...
@@ -376,7 +376,7 @@ void mutual_information_kernel(
s
=
s_in_block
+
s_block_begin
,
t
=
t_in_block
+
t_block_begin
;
if
(
s_in_block
<
block_S
&&
t_in_block
<
block_T
)
{
floa
t
this_p
=
p_buf
[
s_in_block
+
1
][
t_in_block
+
1
];
scalar_
t
this_p
=
p_buf
[
s_in_block
+
1
][
t_in_block
+
1
];
p
[
b
][
s
][
t
]
=
normalizer
+
log
(
this_p
);
// If this_p is infinity or NaN..
if
(
this_p
-
this_p
!=
0
)
{
...
...
@@ -402,8 +402,12 @@ void mutual_information_kernel(
}
if
(
p_buf
[
0
][
0
]
!=
0.0
)
{
/*
// FOR DEBUGGING PANIC MODE:
if (threadIdx.x == 0)
printf
(
"Panic flag set, value = %f
\n
"
,
(
float
)
p_buf
[
0
][
0
]);
// TEMP?
printf("Panic flag set, value = %f\n", (float)p_buf[0][0]);
*/
// The "panic" flag is set. We need to re-do the computation using log-add.
// This time we won't use the buffers, we'll just load and save from main
// memory. This code should very rarely be reached; and anyway, caching
...
...
@@ -417,11 +421,11 @@ void mutual_information_kernel(
s_in_block
<
block_S
)
{
int
s
=
s_in_block
+
s_block_begin
,
t
=
t_in_block
+
t_block_begin
;
floa
t
p_s1
=
(
s
==
s_begin
?
-
INFINITY
:
p
[
b
][
s
-
1
][
t
]),
scalar_
t
p_s1
=
(
s
==
s_begin
?
-
INFINITY
:
p
[
b
][
s
-
1
][
t
]),
this_px
=
(
s
==
s_begin
?
-
INFINITY
:
px
[
b
][
s
-
1
][
t
]),
p_t1
=
(
t
==
t_begin
?
-
INFINITY
:
p
[
b
][
s
][
t
-
1
]),
this_py
=
(
t
==
t_begin
?
-
INFINITY
:
py
[
b
][
s
][
t
-
1
]);
floa
t
this_p
=
LogAdd
(
p_s1
+
this_px
,
scalar_
t
this_p
=
LogAdd
(
p_s1
+
this_px
,
p_t1
+
this_py
);
if
(
i
==
0
&&
is_origin_block
)
this_p
=
0.0
;
...
...
torch_mutual_information/mutual_information_test.py
View file @
b172f0bc
...
...
@@ -92,10 +92,119 @@ def test_mutual_information_basic():
assert
0
def
test_mutual_information_deriv
():
print
(
"Running test_mutual_information_deriv()"
)
for
_iter
in
range
(
100
):
(
B
,
S
,
T
)
=
(
random
.
randint
(
1
,
10
),
random
.
randint
(
1
,
200
),
random
.
randint
(
1
,
200
))
random_px
=
(
random
.
random
()
<
0.2
)
random_py
=
(
random
.
random
()
<
0.2
)
random_boundary
=
(
random
.
random
()
<
0.2
)
big_px
=
(
random
.
random
()
<
0.2
)
big_py
=
(
random
.
random
()
<
0.2
)
print
(
f
"B, S, T =
{
B
}
,
{
S
}
,
{
T
}
, random_px=
{
random_px
}
, random_py=
{
random_py
}
, big_px=
{
big_px
}
, big_py=
{
big_py
}
, random_boundary=
{
random_boundary
}
"
)
for
dtype
in
[
torch
.
float32
,
torch
.
float64
]:
#px_grads = []
#py_grads = []
#m_vals = []
for
device
in
[
torch
.
device
(
'cpu'
),
torch
.
device
(
'cuda:0'
)
]:
print
(
"dtype = "
,
dtype
,
", device = "
,
device
)
if
random_boundary
:
def
get_boundary_row
():
s_begin
=
random
.
randint
(
0
,
S
-
1
)
t_begin
=
random
.
randint
(
0
,
T
-
1
)
s_end
=
random
.
randint
(
s_begin
+
1
,
S
)
t_end
=
random
.
randint
(
t_begin
+
1
,
T
)
return
[
s_begin
,
t_begin
,
s_end
,
t_end
]
if
device
==
torch
.
device
(
'cpu'
):
boundary
=
torch
.
tensor
([
get_boundary_row
()
for
_
in
range
(
B
)
],
dtype
=
torch
.
int64
,
device
=
device
)
else
:
boundary
=
boundary
.
to
(
device
)
else
:
# Use default boundary, but either specified directly or not.
if
random
.
random
()
<
0.5
:
boundary
=
torch
.
tensor
([
0
,
0
,
S
,
T
],
dtype
=
torch
.
int64
).
unsqueeze
(
0
).
expand
(
B
,
4
).
to
(
device
)
else
:
boundary
=
None
if
device
==
torch
.
device
(
'cpu'
):
if
random_px
:
px
=
torch
.
randn
(
B
,
S
,
T
+
1
,
dtype
=
dtype
).
to
(
device
)
# log of an odds ratio
else
:
px
=
torch
.
zeros
(
B
,
S
,
T
+
1
,
dtype
=
dtype
).
to
(
device
)
# log of an odds ratio
# px and py get exponentiated, and then multiplied together up to
# 32 times (BLOCK_SIZE in the CUDA code), so 15 is actually a big number that
# could lead to overflow.
if
big_px
:
px
+=
15.0
if
random_py
:
py
=
torch
.
randn
(
B
,
S
+
1
,
T
,
dtype
=
dtype
).
to
(
device
)
# log of an odds ratio
else
:
py
=
torch
.
zeros
(
B
,
S
+
1
,
T
,
dtype
=
dtype
).
to
(
device
)
# log of an odds ratio
if
big_py
:
py
+=
15.0
else
:
px
=
px
.
to
(
device
).
detach
()
py
=
py
.
to
(
device
).
detach
()
px
.
requires_grad
=
True
py
.
requires_grad
=
True
m
=
mutual_information_recursion
(
px
,
py
,
boundary
)
#print("m = ", m)
#print("exp(m) = ", m.exp())
#print("px_grad = ", px.grad)
#print("py_grad = ", py.grad)
#px_grads.append(px.grad.to('cpu'))
#py_grads.append(py.grad.to('cpu'))
#m_vals.append(m.to('cpu'))
m_grad
=
torch
.
randn
(
B
,
dtype
=
dtype
,
device
=
device
)
m
.
backward
(
gradient
=
m_grad
)
delta
=
1.0e-04
delta_px
=
delta
*
torch
.
randn_like
(
px
)
m2
=
mutual_information_recursion
(
px
+
delta_px
,
py
,
boundary
)
delta_m
=
m2
-
m
observed_delta
=
(
delta_m
*
m_grad
).
sum
().
to
(
'cpu'
)
predicted_delta
=
(
delta_px
*
px
.
grad
).
sum
().
to
(
'cpu'
)
print
(
f
"For px: observed,predicted objf changes are:
{
observed_delta
}
,
{
predicted_delta
}
, absolute objf was
{
(
m
*
m_grad
).
sum
()
}
"
)
atol
=
1.0e-02
if
dtype
==
torch
.
float32
else
1.0e-04
rtol
=
1.0e-02
if
dtype
==
torch
.
float32
else
1.0e-04
if
not
torch
.
allclose
(
observed_delta
,
predicted_delta
,
atol
=
atol
,
rtol
=
rtol
):
print
(
f
"Error: observed and predicted delta too different."
)
assert
0
delta_py
=
delta
*
torch
.
randn_like
(
py
)
m2
=
mutual_information_recursion
(
px
,
py
+
delta_py
,
boundary
)
delta_m
=
m2
-
m
observed_delta
=
(
delta_m
*
m_grad
).
sum
().
to
(
'cpu'
)
predicted_delta
=
(
delta_py
*
py
.
grad
).
sum
().
to
(
'cpu'
)
print
(
f
"For py: observed,predicted objf changes are:
{
observed_delta
}
,
{
predicted_delta
}
, absolute objf was
{
(
m
*
m_grad
).
sum
()
}
"
)
# if not torch.allclose(m_vals[0], m_vals[1], atol=1.0e-02, rtol=1.0e-02):
# print(f"m_vals differed CPU vs CUDA: {m_vals[0]} vs. {m_vals[1]}")
# assert 0
# if not torch.allclose(px_grads[0], px_grads[1], atol=1.0e-02, rtol=1.0e-02):
# print(f"px_grads differed CPU vs CUDA: {px_grads[0]} vs. {px_grads[1]}")
# assert 0
# if not torch.allclose(py_grads[0], py_grads[1], atol=1.0e-02, rtol=1.0e-02):
# print(f"py_grads differed CPU vs CUDA: {py_grads[0]} vs. {py_grads[1]}")
# assert 0
if
__name__
==
"__main__"
:
#torch.set_printoptions(edgeitems=30)
test_mutual_information_basic
()
#
test_mutual_information_deriv()
test_mutual_information_deriv
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment