Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
53b31903
Commit
53b31903
authored
Jul 30, 2021
by
Daniel Povey
Browse files
Fix some bugs..
parent
17b18990
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
43 additions
and
106 deletions
+43
-106
torch_mutual_information/mutual_information.py
torch_mutual_information/mutual_information.py
+2
-2
torch_mutual_information/mutual_information_cpu.cpp
torch_mutual_information/mutual_information_cpu.cpp
+1
-1
torch_mutual_information/mutual_information_cuda_kernel.cu
torch_mutual_information/mutual_information_cuda_kernel.cu
+33
-29
torch_mutual_information/mutual_information_test.py
torch_mutual_information/mutual_information_test.py
+7
-74
No files found.
torch_mutual_information/mutual_information.py
View file @
53b31903
...
@@ -68,7 +68,7 @@ def _mutual_information_backward_dispatcher(px: torch.Tensor, py: torch.Tensor,
...
@@ -68,7 +68,7 @@ def _mutual_information_backward_dispatcher(px: torch.Tensor, py: torch.Tensor,
px
,
py
,
boundary
,
p
,
ans_grad_copy
,
overwrite_ans_grad
))
px
,
py
,
boundary
,
p
,
ans_grad_copy
,
overwrite_ans_grad
))
if
overwrite_ans_grad
:
if
overwrite_ans_grad
:
if
not
torch
.
allclose
(
ans_grad
,
ans_grad_copy
,
rtol
=
1.0e-02
):
if
not
torch
.
allclose
(
ans_grad
,
ans_grad_copy
,
rtol
=
1.0e-02
):
print
(
f
"Warning: possible excsssive roundoff in mutual information backward "
print
(
f
"Warning: possible exc
e
sssive roundoff in mutual information backward "
f
"recursion:
{
ans_grad
}
vs.
{
ans_grad_copy
}
"
);
f
"recursion:
{
ans_grad
}
vs.
{
ans_grad_copy
}
"
);
return
ans
return
ans
else
:
else
:
...
@@ -106,7 +106,7 @@ class MutualInformationRecursionFunction(torch.autograd.Function):
...
@@ -106,7 +106,7 @@ class MutualInformationRecursionFunction(torch.autograd.Function):
ans
=
_mutual_information_forward_dispatcher
(
px
,
py
,
boundary
,
p
)
ans
=
_mutual_information_forward_dispatcher
(
px
,
py
,
boundary
,
p
)
#
print(f"p = {p}, boundary = {boundary}")
print
(
f
"p =
{
p
}
, boundary =
{
boundary
}
"
)
if
px
.
requires_grad
or
py
.
requires_grad
:
if
px
.
requires_grad
or
py
.
requires_grad
:
ctx
.
save_for_backward
(
px
,
py
,
boundary
,
p
)
ctx
.
save_for_backward
(
px
,
py
,
boundary
,
p
)
...
...
torch_mutual_information/mutual_information_cpu.cpp
View file @
53b31903
...
@@ -97,8 +97,8 @@ torch::Tensor mutual_information_cpu(torch::Tensor px,
...
@@ -97,8 +97,8 @@ torch::Tensor mutual_information_cpu(torch::Tensor px,
t_end
=
boundary_a
[
b
][
3
];
t_end
=
boundary_a
[
b
][
3
];
}
else
{
}
else
{
s_begin
=
0
;
s_begin
=
0
;
s_end
=
S
;
t_begin
=
0
;
t_begin
=
0
;
s_end
=
S
;
t_end
=
T
;
t_end
=
T
;
}
}
p_a
[
b
][
s_begin
][
t_begin
]
=
0.0
;
p_a
[
b
][
s_begin
][
t_begin
]
=
0.0
;
...
...
torch_mutual_information/mutual_information_cuda_kernel.cu
View file @
53b31903
...
@@ -208,16 +208,16 @@ void mutual_information_kernel(
...
@@ -208,16 +208,16 @@ void mutual_information_kernel(
t_in_block
=
i
%
BLOCK_SIZE
,
t_in_block
=
i
%
BLOCK_SIZE
,
s
=
s_in_block
+
s_block_begin
,
s
=
s_in_block
+
s_block_begin
,
t
=
t_in_block
+
t_block_begin
;
t
=
t_in_block
+
t_block_begin
;
// comparing as unsigned int makes sure the index is nonnegative.
// comparing as unsigned int makes sure the index is nonnegative.
// Caution: if s_begin > 0 or t_begin > 0 we may end up loading some px and
// py values that are outside the proper boundaries that we need, but
// the corresponding p_buf values will end up being 0 so this won't matter.
scalar_t
this_px
=
0.0
;
scalar_t
this_px
=
0.0
;
if
(
static_cast
<
unsigned
int
>
(
s
-
1
)
<
static_cast
<
unsigned
int
>
(
s_end
)
&&
if
(
s
>
s_begin
&&
s
<=
s_end
&&
t
<=
t_end
)
t
<=
t_end
)
this_px
=
exp
(
px
[
b
][
s
-
1
][
t
]);
this_px
=
exp
(
px
[
b
][
s
-
1
][
t
]);
px_buf
[
s_in_block
][
t_in_block
]
=
this_px
;
px_buf
[
s_in_block
][
t_in_block
]
=
this_px
;
scalar_t
this_py
=
0.0
;
scalar_t
this_py
=
0.0
;
if
(
static_cast
<
unsigned
int
>
(
t
-
1
)
<
static_cast
<
unsigned
int
>
(
t_end
)
&&
if
(
t
>
t_begin
&&
t
<=
t_end
&&
s
<=
s_end
)
s
<=
s_end
)
this_py
=
exp
(
py
[
b
][
s
][
t
-
1
]);
this_py
=
exp
(
py
[
b
][
s
][
t
-
1
]);
py_buf
[
s_in_block
][
t_in_block
]
=
this_py
;
py_buf
[
s_in_block
][
t_in_block
]
=
this_py
;
}
}
...
@@ -234,32 +234,28 @@ void mutual_information_kernel(
...
@@ -234,32 +234,28 @@ void mutual_information_kernel(
s
=
s_in_p_buf
+
s_block_begin
-
1
,
s
=
s_in_p_buf
+
s_block_begin
-
1
,
t
=
t_in_p_buf
+
t_block_begin
-
1
;
t
=
t_in_p_buf
+
t_block_begin
-
1
;
scalar_t
this_p
=
-
INFINITY
;
scalar_t
this_p
=
-
INFINITY
;
if
(
s
tatic_cast
<
unsigned
int
>
(
s
)
<=
static_cast
<
unsigned
int
>
(
s_end
)
&&
if
(
s
>=
s_begin
&&
s
<=
s_end
&&
static_cast
<
unsigned
int
>
(
t
)
<=
static_cast
<
unsigned
int
>
(
t_end
)
)
{
t
>=
t_begin
&&
t
<=
t_end
)
this_p
=
p
[
b
][
s
][
t
];
this_p
=
p
[
b
][
s
][
t
];
/*printf("p[%d][%d][%d] = %f, threadIdx.x = %d, px = %f, py = %f\n", b, s, t, (float)this_p, (int)threadIdx.x,
/*printf("p[%d][%d][%d] = %f, threadIdx.x = %d, px = %f, py = %f\n", b, s, t, (float)this_p, (int)threadIdx.x,
(float)px_buf[s_in_p_buf][t_in_p_buf], (float)py_buf[s_in_p_buf][t_in_p_buf]); */
(float)px_buf[s_in_p_buf][t_in_p_buf], (float)py_buf[s_in_p_buf][t_in_p_buf]); */
}
p_buf
[
s_in_p_buf
][
t_in_p_buf
]
=
this_p
;
p_buf
[
s_in_p_buf
][
t_in_p_buf
]
=
this_p
;
}
else
{
}
else
if
(
static_cast
<
unsigned
int
>
(
int
(
threadIdx
.
x
)
-
64
)
<=
static_cast
<
unsigned
int
>
(
BLOCK_SIZE
))
{
// Another warp handles the other leg. Checking as unsigned
// Another warp handles the other leg. Checking as unsigned
// tests that threadIdx.x - 64 is both >= 0 and <= BLOCK_SIZE
// tests that threadIdx.x - 64 is both >= 0 and <= BLOCK_SIZE
if
(
static_cast
<
unsigned
int
>
(
int
(
threadIdx
.
x
)
-
64
)
<=
static_cast
<
unsigned
int
>
(
BLOCK_SIZE
))
{
int
s_in_p_buf
=
0
,
int
s_in_p_buf
=
0
,
t_in_p_buf
=
(
int
)
threadIdx
.
x
-
64
,
t_in_p_buf
=
(
int
)
threadIdx
.
x
-
64
,
s
=
s_in_p_buf
+
s_block_begin
-
1
,
s
=
s_in_p_buf
+
s_block_begin
-
1
,
t
=
t_in_p_buf
+
t_block_begin
-
1
;
t
=
t_in_p_buf
+
t_block_begin
-
1
;
scalar_t
this_p
=
-
INFINITY
;
scalar_t
this_p
=
-
INFINITY
;
if
(
s
tatic_cast
<
unsigned
int
>
(
s
)
<=
static_cast
<
unsigned
int
>
(
s_end
)
&&
if
(
s
>=
s_begin
&&
s
<=
s_end
&&
static_cast
<
unsigned
int
>
(
t
)
<=
static_cast
<
unsigned
int
>
(
t_end
)
)
{
t
>=
t_begin
&&
t
<=
t_end
)
this_p
=
p
[
b
][
s
][
t
];
this_p
=
p
[
b
][
s
][
t
];
/*printf("p[%d][%d][%d] = %f, threadIdx.x = %d, px = %f, py = %f\n", b, s, t, (float)this_p, (int)threadIdx.x,
/*printf("p[%d][%d][%d] = %f, threadIdx.x = %d, px = %f, py = %f\n", b, s, t, (float)this_p, (int)threadIdx.x,
(float)px_buf[s_in_p_buf][t_in_p_buf], (float)py_buf[s_in_p_buf][t_in_p_buf]);*/
(float)px_buf[s_in_p_buf][t_in_p_buf], (float)py_buf[s_in_p_buf][t_in_p_buf]);*/
}
p_buf
[
s_in_p_buf
][
t_in_p_buf
]
=
this_p
;
p_buf
[
s_in_p_buf
][
t_in_p_buf
]
=
this_p
;
}
}
}
__syncthreads
();
__syncthreads
();
...
@@ -421,10 +417,10 @@ void mutual_information_kernel(
...
@@ -421,10 +417,10 @@ void mutual_information_kernel(
static_cast
<
unsigned
int
>
(
block_T
))
{
static_cast
<
unsigned
int
>
(
block_T
))
{
int
s
=
s_in_block
+
s_block_begin
,
int
s
=
s_in_block
+
s_block_begin
,
t
=
t_in_block
+
t_block_begin
;
t
=
t_in_block
+
t_block_begin
;
float
p_s1
=
(
s
==
0
?
-
INFINITY
:
p
[
b
][
s
-
1
][
t
]),
float
p_s1
=
(
s
==
s_begin
?
-
INFINITY
:
p
[
b
][
s
-
1
][
t
]),
this_px
=
(
s
==
0
?
-
INFINITY
:
px
[
b
][
s
-
1
][
t
]),
this_px
=
(
s
==
s_begin
?
-
INFINITY
:
px
[
b
][
s
-
1
][
t
]),
p_t1
=
(
t
==
0
?
-
INFINITY
:
p
[
b
][
s
][
t
-
1
]),
p_t1
=
(
t
==
t_begin
?
-
INFINITY
:
p
[
b
][
s
][
t
-
1
]),
this_py
=
(
t
==
0
?
-
INFINITY
:
py
[
b
][
s
][
t
-
1
]);
this_py
=
(
t
==
t_begin
?
-
INFINITY
:
py
[
b
][
s
][
t
-
1
]);
float
this_p
=
LogAdd
(
p_s1
+
this_px
,
float
this_p
=
LogAdd
(
p_s1
+
this_px
,
p_t1
+
this_py
);
p_t1
+
this_py
);
if
(
i
==
0
&&
is_origin_block
)
if
(
i
==
0
&&
is_origin_block
)
...
@@ -433,6 +429,7 @@ void mutual_information_kernel(
...
@@ -433,6 +429,7 @@ void mutual_information_kernel(
}
}
}
}
}
}
__syncwarp
();
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
// Write `ans`, if this is the final (top-right) block in its sequence.
// Write `ans`, if this is the final (top-right) block in its sequence.
// This is only reached in the 'panic situation' where we had overflow.
// This is only reached in the 'panic situation' where we had overflow.
...
@@ -650,6 +647,7 @@ void mutual_information_backward_kernel(
...
@@ -650,6 +647,7 @@ void mutual_information_backward_kernel(
this_py
=
py
[
b
][
s
][
t
];
this_py
=
py
[
b
][
s
][
t
];
py_buf
[
s_in_block
][
t_in_block
]
=
this_py
;
py_buf
[
s_in_block
][
t_in_block
]
=
this_py
;
}
}
__syncthreads
();
// load p. We could use BLOCK_SIZE + 1 here, but we use + 8 to hopefully keep
// load p. We could use BLOCK_SIZE + 1 here, but we use + 8 to hopefully keep
// reads more aligned.
// reads more aligned.
...
@@ -669,6 +667,8 @@ void mutual_information_backward_kernel(
...
@@ -669,6 +667,8 @@ void mutual_information_backward_kernel(
p_buf
[
s_in_block
][
t_in_block
]
=
this_p
;
p_buf
[
s_in_block
][
t_in_block
]
=
this_p
;
}
}
__syncthreads
();
// Set xderiv and yderiv; see (eq. 4) and (eq. 5).
// Set xderiv and yderiv; see (eq. 4) and (eq. 5).
for
(
int
i
=
threadIdx
.
x
;
i
<
BLOCK_SIZE
*
BLOCK_SIZE
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
BLOCK_SIZE
*
BLOCK_SIZE
;
i
+=
blockDim
.
x
)
{
// We can apply this formula to the entire block even if we are processing
// We can apply this formula to the entire block even if we are processing
...
@@ -687,6 +687,8 @@ void mutual_information_backward_kernel(
...
@@ -687,6 +687,8 @@ void mutual_information_backward_kernel(
py_buf
[
s
][
t
]
=
exp
(
p_buf
[
s
][
t
]
+
py_buf
[
s
][
t
]
-
p_buf
[
s
][
t
+
1
]);
py_buf
[
s
][
t
]
=
exp
(
p_buf
[
s
][
t
]
+
py_buf
[
s
][
t
]
-
p_buf
[
s
][
t
+
1
]);
}
}
__syncthreads
();
// Load p_grad for the top and right elements in p_buf: i.e. for elements
// Load p_grad for the top and right elements in p_buf: i.e. for elements
// p_buf[s][t] where s == block_S (exclusive-or) t == block_T. We don't
// p_buf[s][t] where s == block_S (exclusive-or) t == block_T. We don't
// need to load the top-right corner [block_S][block_T]; that location will
// need to load the top-right corner [block_S][block_T]; that location will
...
@@ -714,6 +716,8 @@ void mutual_information_backward_kernel(
...
@@ -714,6 +716,8 @@ void mutual_information_backward_kernel(
s
<=
s_end
&&
t
<=
t_end
?
p_grad
[
b
][
s
][
t
]
:
0.0
);
s
<=
s_end
&&
t
<=
t_end
?
p_grad
[
b
][
s
][
t
]
:
0.0
);
}
}
__syncthreads
();
// The highest-numbered value in p_buf that we need (corresponding,
// The highest-numbered value in p_buf that we need (corresponding,
// of course, to p_grad), is:
// of course, to p_grad), is:
// p_buf[block_S - 1][block_T - 1],
// p_buf[block_S - 1][block_T - 1],
...
...
torch_mutual_information/mutual_information_test.py
View file @
53b31903
...
@@ -35,8 +35,11 @@ def test_mutual_information_basic():
...
@@ -35,8 +35,11 @@ def test_mutual_information_basic():
s_end
=
random
.
randint
(
s_begin
+
1
,
S
)
s_end
=
random
.
randint
(
s_begin
+
1
,
S
)
t_end
=
random
.
randint
(
t_begin
+
1
,
T
)
t_end
=
random
.
randint
(
t_begin
+
1
,
T
)
return
[
s_begin
,
t_begin
,
s_end
,
t_end
]
return
[
s_begin
,
t_begin
,
s_end
,
t_end
]
if
device
==
torch
.
device
(
'cpu'
):
boundary
=
torch
.
tensor
([
get_boundary_row
()
for
_
in
range
(
B
)
],
boundary
=
torch
.
tensor
([
get_boundary_row
()
for
_
in
range
(
B
)
],
dtype
=
torch
.
int64
,
device
=
device
)
dtype
=
torch
.
int64
,
device
=
device
)
else
:
boundary
=
boundary
.
to
(
device
)
else
:
else
:
# Use default boundary, but either specified directly or not.
# Use default boundary, but either specified directly or not.
if
random
.
random
()
<
0.5
:
if
random
.
random
()
<
0.5
:
...
@@ -84,78 +87,7 @@ def test_mutual_information_basic():
...
@@ -84,78 +87,7 @@ def test_mutual_information_basic():
if
not
torch
.
allclose
(
px_grads
[
0
],
px_grads
[
1
],
atol
=
1.0e-05
,
rtol
=
1.0e-04
):
if
not
torch
.
allclose
(
px_grads
[
0
],
px_grads
[
1
],
atol
=
1.0e-05
,
rtol
=
1.0e-04
):
print
(
f
"px_grads differed CPU vs CUDA:
{
px_grads
[
0
]
}
vs.
{
px_grads
[
1
]
}
"
)
print
(
f
"px_grads differed CPU vs CUDA:
{
px_grads
[
0
]
}
vs.
{
px_grads
[
1
]
}
"
)
assert
0
assert
0
if
not
torch
.
allclose
(
py_grads
[
0
],
py_grads
[
1
],
atol
=
1.0e-05
,
rtol
=
1.0e-04
):
if
not
torch
.
allclose
(
py_grads
[
0
],
py_grads
[
1
],
atol
=
1.0e-05
,
rtol
=
1.0e-03
):
print
(
f
"py_grads differed CPU vs CUDA:
{
py_grads
[
0
]
}
vs.
{
py_grads
[
1
]
}
"
)
assert
0
def
test_mutual_information_deriv
():
print
(
"Running test_mutual_information_basic()"
)
for
_iter
in
range
(
100
):
(
B
,
S
,
T
)
=
(
random
.
randint
(
1
,
10
),
random
.
randint
(
1
,
200
),
random
.
randint
(
1
,
200
))
random_px
=
(
random
.
random
()
<
0.1
)
random_py
=
(
random
.
random
()
<
0.1
)
big_px
=
(
random
.
random
()
<
0.1
)
big_py
=
(
random
.
random
()
<
0.1
)
print
(
f
"B, S, T =
{
B
}
,
{
S
}
,
{
T
}
, random_px=
{
random_px
}
, random_py=
{
random_py
}
, big_px=
{
big_px
}
, big_py=
{
big_py
}
"
)
for
dtype
in
[
torch
.
float32
,
torch
.
float64
]:
px_grads
=
[]
py_grads
=
[]
m_vals
=
[]
for
device
in
[
torch
.
device
(
'cpu'
),
torch
.
device
(
'cuda:0'
)
]:
print
(
"dtype = "
,
dtype
,
", device = "
,
device
)
B
=
2
S
=
14
T
=
14
boundary
=
torch
.
tensor
([
0
,
0
,
S
,
T
],
dtype
=
torch
.
int64
).
unsqueeze
(
0
).
expand
(
B
,
4
).
to
(
device
)
if
device
==
torch
.
device
(
'cpu'
):
if
random_px
:
px
=
torch
.
randn
(
B
,
S
,
T
+
1
,
dtype
=
dtype
).
to
(
device
)
# log of an odds ratio
else
:
px
=
torch
.
zeros
(
B
,
S
,
T
+
1
,
dtype
=
dtype
).
to
(
device
)
# log of an odds ratio
# px and py get exponentiated, and then multiplied together up to
# 32 times (BLOCK_SIZE in the CUDA code), so 15 is actually a big number that
# could lead to overflow.
if
big_px
:
px
+=
15.0
if
random_py
:
py
=
torch
.
randn
(
B
,
S
+
1
,
T
,
dtype
=
dtype
).
to
(
device
)
# log of an odds ratio
else
:
py
=
torch
.
zeros
(
B
,
S
+
1
,
T
,
dtype
=
dtype
).
to
(
device
)
# log of an odds ratio
if
big_py
:
py
+=
15.0
else
:
px
=
px
.
to
(
device
).
detach
()
py
=
py
.
to
(
device
).
detach
()
px
.
requires_grad
=
True
py
.
requires_grad
=
True
#m = mutual_information_recursion(px, py, None)
m
=
mutual_information_recursion
(
px
,
py
,
boundary
)
#print("m = ", m, ", size = ", m.shape)
#print("exp(m) = ", m.exp())
(
m
.
sum
()
*
3
).
backward
()
#print("px_grad = ", px.grad)
#print("py_grad = ", py.grad)
px_grads
.
append
(
px
.
grad
.
to
(
'cpu'
))
py_grads
.
append
(
py
.
grad
.
to
(
'cpu'
))
m_vals
.
append
(
m
.
to
(
'cpu'
))
if
not
torch
.
allclose
(
m_vals
[
0
],
m_vals
[
1
],
atol
=
1.0e-05
,
rtol
=
1.0e-04
):
print
(
f
"m_vals differed CPU vs CUDA:
{
m_vals
[
0
]
}
vs.
{
m_vals
[
1
]
}
"
)
assert
0
if
not
torch
.
allclose
(
px_grads
[
0
],
px_grads
[
1
],
atol
=
1.0e-05
,
rtol
=
1.0e-04
):
print
(
f
"px_grads differed CPU vs CUDA:
{
px_grads
[
0
]
}
vs.
{
px_grads
[
1
]
}
"
)
assert
0
if
not
torch
.
allclose
(
py_grads
[
0
],
py_grads
[
1
],
atol
=
1.0e-05
,
rtol
=
1.0e-04
):
print
(
f
"py_grads differed CPU vs CUDA:
{
py_grads
[
0
]
}
vs.
{
py_grads
[
1
]
}
"
)
print
(
f
"py_grads differed CPU vs CUDA:
{
py_grads
[
0
]
}
vs.
{
py_grads
[
1
]
}
"
)
assert
0
assert
0
...
@@ -164,5 +96,6 @@ def test_mutual_information_deriv():
...
@@ -164,5 +96,6 @@ def test_mutual_information_deriv():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
#torch.set_printoptions(edgeitems=30)
test_mutual_information_basic
()
test_mutual_information_basic
()
#test_mutual_information_deriv()
#test_mutual_information_deriv()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment