Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
02a36166
Commit
02a36166
authored
Jul 29, 2021
by
Daniel Povey
Browse files
Fix more bugs, add some debug statement.
parent
9ac065f0
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
100 additions
and
80 deletions
+100
-80
torch_mutual_information/mutual_information_cuda_kernel.cu
torch_mutual_information/mutual_information_cuda_kernel.cu
+96
-77
torch_mutual_information/mutual_information_test.py
torch_mutual_information/mutual_information_test.py
+4
-3
No files found.
torch_mutual_information/mutual_information_cuda_kernel.cu
View file @
02a36166
...
...
@@ -168,11 +168,14 @@ void mutual_information_kernel(
b
=
batch_block_iter
%
B
;
// b is the index into the batch
// Note: `block` can be no greater than `iter` because num_blocks_this_iter
// <= iter + 1, so iter - block >= 0.
// <= iter + 1, i.e. iter >= num_blocks_this_iter - 1; and
// block < num_blocks_this_iter, so iter - block >= 0.
int
s_block_begin
=
block
*
BLOCK_SIZE
,
t_block_begin
=
(
iter
-
block
)
*
BLOCK_SIZE
;
bool
is_origin_block
=
(
s_block_begin
*
t_block_begin
==
0
);
assert
(
b
<
B
&&
b
>=
0
&&
t_block_begin
>=
0
);
// TODO: remove.
if
(
boundary
.
size
(
0
)
!=
0
&&
threadIdx
.
x
<
4
)
boundary_buf
[
threadIdx
.
x
]
=
boundary
[
b
][
threadIdx
.
x
];
__syncthreads
();
...
...
@@ -233,10 +236,13 @@ void mutual_information_kernel(
if
(
static_cast
<
unsigned
int
>
(
s
)
<=
static_cast
<
unsigned
int
>
(
s_end
)
&&
static_cast
<
unsigned
int
>
(
t
)
<=
static_cast
<
unsigned
int
>
(
t_end
))
this_p
=
p
[
b
][
s
][
t
];
p_buf
[
threadIdx
.
x
][
0
]
=
this_p
;
p_buf
[
s_in_p_buf
][
t_in_p_buf
]
=
this_p
;
}
}
else
{
// Another warp handles the other leg
if
(
int
(
threadIdx
.
x
)
-
64
<=
BLOCK_SIZE
)
{
}
else
{
// Another warp handles the other leg. Checking as unsigned
// tests that threadIdx.x - 64 is both >= 0 and <= BLOCK_SIZE
if
(
static_cast
<
unsigned
int
>
(
int
(
threadIdx
.
x
)
-
64
)
<=
static_cast
<
unsigned
int
>
(
BLOCK_SIZE
))
{
int
s_in_p_buf
=
0
,
t_in_p_buf
=
threadIdx
.
x
-
64
,
s
=
s_in_p_buf
+
s_block_begin
-
1
,
...
...
@@ -247,7 +253,7 @@ void mutual_information_kernel(
if
(
static_cast
<
unsigned
int
>
(
s
)
<=
static_cast
<
unsigned
int
>
(
s_end
)
&&
static_cast
<
unsigned
int
>
(
t
)
<=
static_cast
<
unsigned
int
>
(
t_end
))
this_p
=
p
[
b
][
s
][
t
];
p_buf
[
threadIdx
.
x
][
0
]
=
this_p
;
p_buf
[
s_in_p_buf
][
t_in_p_buf
]
=
this_p
;
}
}
...
...
@@ -269,11 +275,15 @@ void mutual_information_kernel(
// p_buf[0][0] = 0.0; <-- for search purposes.
// We'll later write an infinity there if something goes wrong, as a
// 'panic' indicator.
p_buf
[
threadIdx
.
x
][
0
]
=
(
threadIdx
.
x
==
0
?
0.0
:
exp
(
p_buf
[
threadIdx
.
x
][
0
]
-
normalizer
));
}
else
if
(
int
(
threadIdx
.
x
)
-
64
<
BLOCK_SIZE
)
{
int
s
=
threadIdx
.
x
;
p_buf
[
s
][
0
]
=
(
s
==
0
?
0.0
:
exp
(
p_buf
[
s
][
0
]
-
normalizer
));
}
else
if
(
static_cast
<
unsigned
int
>
(
int
(
threadIdx
.
x
)
-
64
)
<
static_cast
<
unsigned
int
>
(
BLOCK_SIZE
))
{
// if (threadidx.x - 64) >= 0 && (threadIdx.x - 64) < BLOCK_SIZE..
int
t
=
(
int
)
threadIdx
.
x
-
64
+
1
;
// 0 < t <= BLOCK_SIZE
// this happens in a different warp so can be in parallel to the code above.
p_buf
[
0
][
t
hreadIdx
.
x
+
1
]
=
exp
(
p_buf
[
0
][
t
hreadIdx
.
x
+
1
]
-
normalizer
);
p_buf
[
0
][
t
]
=
exp
(
p_buf
[
0
][
t
]
-
normalizer
);
}
...
...
@@ -298,6 +308,8 @@ void mutual_information_kernel(
p_buf_s1_t
=
p_buf
[
s
+
1
][
threadIdx
.
x
==
0
?
1
:
0
];
}
int
s
=
threadIdx
.
x
;
if
(
s
<
block_S
)
{
for
(
int
i
=
1
;
i
<
block_S
+
block_T
-
1
;
++
i
)
{
// i is the inner iteration, which corresponds to the (s + t) indexes of
// the elements within the block that we write. So i == 0 writes
...
...
@@ -343,6 +355,7 @@ void mutual_information_kernel(
// threads that are active are in the same warp. (However, in future,
// if NVidia changes some things, we might need to sync here).
}
}
__syncthreads
();
}
...
...
@@ -382,9 +395,10 @@ void mutual_information_kernel(
// This time we won't use the buffers, we'll just load and save from main
// memory. This code should very rarely be reached; and anyway, caching
// should help us quite a bit.
int
s_in_block
=
threadIdx
.
x
;
if
(
s_in_block
<
block_S
)
{
for
(
int
i
=
0
;
i
<
block_S
+
block_T
-
1
;
++
i
)
{
int
s_in_block
=
threadIdx
.
x
,
t_in_block
=
i
-
s_in_block
;
int
t_in_block
=
i
-
s_in_block
;
if
(
s_in_block
<
block_S
&&
static_cast
<
unsigned
int
>
(
t_in_block
)
<
static_cast
<
unsigned
int
>
(
block_T
))
{
int
s
=
s_in_block
+
s_block_begin
,
...
...
@@ -400,6 +414,7 @@ void mutual_information_kernel(
p
[
b
][
s
][
t
]
=
this_p
;
}
}
}
if
(
threadIdx
.
x
==
0
)
{
// Write `ans`, if this is the final (top-right) block in its sequence.
// This is only reached in the 'panic situation' where we had overflow.
...
...
@@ -700,9 +715,11 @@ void mutual_information_backward_kernel(
--
first_iter
;
}
{
int
s
=
threadIdx
.
x
;
if
(
s
<
block_S
)
{
for
(
int
i
=
first_iter
;
i
>=
0
;
--
i
)
{
int
s
=
i
,
t
=
i
-
threadIdx
.
x
;
int
t
=
i
-
s
;
if
(
t
>=
0
)
{
// The following statement is really operating on the gradients;
// it corresponds, with offsets of s_block_begin and t_block_begin
...
...
@@ -713,6 +730,8 @@ void mutual_information_backward_kernel(
p_buf
[
s
][
t
+
1
]
*
py_buf
[
s
][
t
]);
}
}
}
}
// Write out p_grad, px_grad and py_grad.
for
(
int
i
=
threadIdx
.
x
;
i
<
BLOCK_SIZE
*
BLOCK_SIZE
;
i
+=
blockDim
.
x
)
{
...
...
torch_mutual_information/mutual_information_test.py
View file @
02a36166
...
...
@@ -12,13 +12,14 @@ def test_mutual_information_basic():
for
device
in
[
torch
.
device
(
'cpu'
),
torch
.
device
(
'cuda:0'
)
]:
print
(
"dtype = "
,
dtype
,
", device = "
,
device
)
B
=
2
S
=
4
T
=
5
S
=
17
T
=
17
boundary
=
torch
.
tensor
([
0
,
0
,
S
,
T
],
dtype
=
torch
.
int64
).
unsqueeze
(
0
).
expand
(
B
,
4
).
to
(
device
)
px
=
torch
.
zeros
(
B
,
S
,
T
+
1
).
to
(
device
)
# log of an odds ratio
py
=
torch
.
zeros
(
B
,
S
+
1
,
T
).
to
(
device
)
# log of an odds ratio
m
=
mutual_information_recursion
(
px
,
py
,
boundary
)
m
=
mutual_information_recursion
(
px
,
py
,
None
)
#m = mutual_information_recursion(px, py, boundary)
print
(
"m = "
,
m
)
print
(
"exp(m) = "
,
m
.
exp
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment