Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
ae1d18d6
Commit
ae1d18d6
authored
Jul 30, 2021
by
Daniel Povey
Browse files
More bug fices, just about working...
parent
53b31903
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
33 additions
and
34 deletions
+33
-34
MANIFEST.in
MANIFEST.in
+1
-1
torch_mutual_information/mutual_information.py
torch_mutual_information/mutual_information.py
+1
-1
torch_mutual_information/mutual_information_cuda_kernel.cu
torch_mutual_information/mutual_information_cuda_kernel.cu
+23
-24
torch_mutual_information/mutual_information_test.py
torch_mutual_information/mutual_information_test.py
+8
-8
No files found.
MANIFEST.in
View file @
ae1d18d6
...
...
@@ -2,6 +2,6 @@ include requirements.txt
include pyproject.toml
include LICENSE*
recursive-include torch_mutual_information *
p
recursive-include doc/img *
recursive-include doc/img *
recursive-include tests *
global-exclude *.pyc
\ No newline at end of file
torch_mutual_information/mutual_information.py
View file @
ae1d18d6
...
...
@@ -106,7 +106,7 @@ class MutualInformationRecursionFunction(torch.autograd.Function):
ans
=
_mutual_information_forward_dispatcher
(
px
,
py
,
boundary
,
p
)
print
(
f
"p =
{
p
}
, boundary =
{
boundary
}
"
)
#
print(f"p = {p}, boundary = {boundary}
, psum={p.sum()}
")
if
px
.
requires_grad
or
py
.
requires_grad
:
ctx
.
save_for_backward
(
px
,
py
,
boundary
,
p
)
...
...
torch_mutual_information/mutual_information_cuda_kernel.cu
View file @
ae1d18d6
...
...
@@ -172,7 +172,7 @@ void mutual_information_kernel(
// block < num_blocks_this_iter, so iter - block >= 0.
int
s_block_begin
=
block
*
BLOCK_SIZE
,
t_block_begin
=
(
iter
-
block
)
*
BLOCK_SIZE
;
bool
is_origin_block
=
(
s_block_begin
*
t_block_begin
==
0
);
bool
is_origin_block
=
(
s_block_begin
+
t_block_begin
==
0
);
__syncthreads
();
...
...
@@ -403,30 +403,29 @@ void mutual_information_kernel(
if
(
p_buf
[
0
][
0
]
!=
0.0
)
{
if
(
threadIdx
.
x
==
0
)
printf
(
"Panic flag set, value = %f
\n
"
,
(
float
)
p_buf
[
0
][
0
]);
// TEMP
printf
(
"Panic flag set, value = %f
\n
"
,
(
float
)
p_buf
[
0
][
0
]);
// TEMP
?
// The "panic" flag is set. We need to re-do the computation using log-add.
// This time we won't use the buffers, we'll just load and save from main
// memory. This code should very rarely be reached; and anyway, caching
// should help us quite a bit.
int
s_in_block
=
threadIdx
.
x
;
if
(
s_in_block
<
block_S
)
{
for
(
int
i
=
0
;
i
<
block_S
+
block_T
-
1
;
++
i
)
{
__syncwarp
();
int
t_in_block
=
i
-
s_in_block
;
if
(
static_cast
<
unsigned
int
>
(
t_in_block
)
<
static_cast
<
unsigned
int
>
(
block_T
))
{
int
s
=
s_in_block
+
s_block_begin
,
t
=
t_in_block
+
t_block_begin
;
float
p_s1
=
(
s
==
s_begin
?
-
INFINITY
:
p
[
b
][
s
-
1
][
t
]),
this_px
=
(
s
==
s_begin
?
-
INFINITY
:
px
[
b
][
s
-
1
][
t
]),
p_t1
=
(
t
==
t_begin
?
-
INFINITY
:
p
[
b
][
s
][
t
-
1
]),
this_py
=
(
t
==
t_begin
?
-
INFINITY
:
py
[
b
][
s
][
t
-
1
]);
float
this_p
=
LogAdd
(
p_s1
+
this_px
,
p_t1
+
this_py
);
if
(
i
==
0
&&
is_origin_block
)
this_p
=
0.0
;
p
[
b
][
s
][
t
]
=
this_p
;
}
for
(
int
i
=
0
;
i
<
block_S
+
block_T
-
1
;
++
i
)
{
__syncwarp
();
int
t_in_block
=
i
-
s_in_block
;
if
(
static_cast
<
unsigned
int
>
(
t_in_block
)
<
static_cast
<
unsigned
int
>
(
block_T
)
&&
s_in_block
<
block_S
)
{
int
s
=
s_in_block
+
s_block_begin
,
t
=
t_in_block
+
t_block_begin
;
float
p_s1
=
(
s
==
s_begin
?
-
INFINITY
:
p
[
b
][
s
-
1
][
t
]),
this_px
=
(
s
==
s_begin
?
-
INFINITY
:
px
[
b
][
s
-
1
][
t
]),
p_t1
=
(
t
==
t_begin
?
-
INFINITY
:
p
[
b
][
s
][
t
-
1
]),
this_py
=
(
t
==
t_begin
?
-
INFINITY
:
py
[
b
][
s
][
t
-
1
]);
float
this_p
=
LogAdd
(
p_s1
+
this_px
,
p_t1
+
this_py
);
if
(
i
==
0
&&
is_origin_block
)
this_p
=
0.0
;
p
[
b
][
s
][
t
]
=
this_p
;
}
}
__syncwarp
();
...
...
@@ -649,14 +648,13 @@ void mutual_information_backward_kernel(
}
__syncthreads
();
// load p. We could use BLOCK_SIZE + 1 here, but we use + 8 to hopefully keep
// reads more aligned.
// load p.
for
(
int
i
=
threadIdx
.
x
;
i
<
(
BLOCK_SIZE
+
1
)
*
(
BLOCK_SIZE
+
1
);
i
+=
blockDim
.
x
)
{
int
s_in_block
=
i
/
(
BLOCK_SIZE
+
1
),
t_in_block
=
i
%
(
BLOCK_SIZE
+
1
),
s
=
s_in_block
+
s_block_begin
,
t
=
t_in_block
+
t_block_begin
;
// Setting 0.0 for out-of-bounds elements, together with setting
// Setting 0.0 for out-of-bounds elements
of p
, together with setting
// -INFINITY for out-of-bounds elements of px_buf and py_buf, will
// ensure that we do the right thing in top and right edge cases,
// i.e. that no derivatives will be propagated from out-of-bounds points
...
...
@@ -742,7 +740,8 @@ void mutual_information_backward_kernel(
for
(
int
i
=
first_iter
;
i
>=
0
;
--
i
)
{
__syncwarp
();
int
t
=
i
-
s
;
if
(
t
>=
0
&&
s
<
block_S
)
{
if
(
s
<
block_S
&&
static_cast
<
unsigned
int
>
(
t
)
<
static_cast
<
unsigned
int
>
(
block_T
))
{
// The following statement is really operating on the gradients;
// it corresponds, with offsets of s_block_begin and t_block_begin
// on the indexes, to (eq. 6) defined above, i.e.:
...
...
torch_mutual_information/mutual_information_test.py
View file @
ae1d18d6
...
...
@@ -13,11 +13,11 @@ def test_mutual_information_basic():
(
B
,
S
,
T
)
=
(
random
.
randint
(
1
,
10
),
random
.
randint
(
1
,
200
),
random
.
randint
(
1
,
200
))
random_px
=
(
random
.
random
()
<
0.
1
)
random_py
=
(
random
.
random
()
<
0.
1
)
random_boundary
=
(
random
.
random
()
<
0.
1
)
big_px
=
(
random
.
random
()
<
0.
1
)
big_py
=
(
random
.
random
()
<
0.
1
)
random_px
=
(
random
.
random
()
<
0.
2
)
random_py
=
(
random
.
random
()
<
0.
2
)
random_boundary
=
(
random
.
random
()
<
0.
2
)
big_px
=
(
random
.
random
()
<
0.
2
)
big_py
=
(
random
.
random
()
<
0.
2
)
print
(
f
"B, S, T =
{
B
}
,
{
S
}
,
{
T
}
, random_px=
{
random_px
}
, random_py=
{
random_py
}
, big_px=
{
big_px
}
, big_py=
{
big_py
}
, random_boundary=
{
random_boundary
}
"
)
for
dtype
in
[
torch
.
float32
,
torch
.
float64
]:
...
...
@@ -81,13 +81,13 @@ def test_mutual_information_basic():
px_grads
.
append
(
px
.
grad
.
to
(
'cpu'
))
py_grads
.
append
(
py
.
grad
.
to
(
'cpu'
))
m_vals
.
append
(
m
.
to
(
'cpu'
))
if
not
torch
.
allclose
(
m_vals
[
0
],
m_vals
[
1
],
atol
=
1.0e-0
5
,
rtol
=
1.0e-0
4
):
if
not
torch
.
allclose
(
m_vals
[
0
],
m_vals
[
1
],
atol
=
1.0e-0
2
,
rtol
=
1.0e-0
2
):
print
(
f
"m_vals differed CPU vs CUDA:
{
m_vals
[
0
]
}
vs.
{
m_vals
[
1
]
}
"
)
assert
0
if
not
torch
.
allclose
(
px_grads
[
0
],
px_grads
[
1
],
atol
=
1.0e-0
5
,
rtol
=
1.0e-0
4
):
if
not
torch
.
allclose
(
px_grads
[
0
],
px_grads
[
1
],
atol
=
1.0e-0
2
,
rtol
=
1.0e-0
2
):
print
(
f
"px_grads differed CPU vs CUDA:
{
px_grads
[
0
]
}
vs.
{
px_grads
[
1
]
}
"
)
assert
0
if
not
torch
.
allclose
(
py_grads
[
0
],
py_grads
[
1
],
atol
=
1.0e-0
5
,
rtol
=
1.0e-0
3
):
if
not
torch
.
allclose
(
py_grads
[
0
],
py_grads
[
1
],
atol
=
1.0e-0
2
,
rtol
=
1.0e-0
2
):
print
(
f
"py_grads differed CPU vs CUDA:
{
py_grads
[
0
]
}
vs.
{
py_grads
[
1
]
}
"
)
assert
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment