Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
85c97136
Commit
85c97136
authored
Jul 30, 2021
by
Daniel Povey
Browse files
Fix various bugs...
parent
02a36166
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
101 additions
and
84 deletions
+101
-84
torch_mutual_information/mutual_information_cuda_kernel.cu
torch_mutual_information/mutual_information_cuda_kernel.cu
+96
-79
torch_mutual_information/mutual_information_test.py
torch_mutual_information/mutual_information_test.py
+5
-5
No files found.
torch_mutual_information/mutual_information_cuda_kernel.cu
View file @
85c97136
...
...
@@ -174,11 +174,13 @@ void mutual_information_kernel(
t_block_begin
=
(
iter
-
block
)
*
BLOCK_SIZE
;
bool
is_origin_block
=
(
s_block_begin
*
t_block_begin
==
0
);
assert
(
b
<
B
&&
b
>=
0
&&
t_block_begin
>=
0
);
// TODO: remove.
__syncthreads
();
if
(
boundary
.
size
(
0
)
!=
0
&&
threadIdx
.
x
<
4
)
boundary_buf
[
threadIdx
.
x
]
=
boundary
[
b
][
threadIdx
.
x
];
__syncthreads
();
int
s_begin
=
boundary_buf
[
0
],
t_begin
=
boundary_buf
[
1
],
s_end
=
boundary_buf
[
2
],
...
...
@@ -225,34 +227,36 @@ void mutual_information_kernel(
// needed). This is the context from previously computed blocks of the
// image. Remember: p_buf[s][t] will correspond to exp(p[s + s_block_begin -
// 1][t + t_block_begin - 1] - normalizer.
if
(
threadIdx
.
x
<
64
)
{
// 64 == warp size. First half of threads...
if
(
threadIdx
.
x
<=
BLOCK_SIZE
)
{
//
s_in_p_buf
are simply the indexes into p_buf
int
s
_in_p_buf
=
threadIdx
.
x
,
t
_in_p_buf
=
0
,
s
=
s
_in_p_buf
+
s
_block_begin
-
1
,
t
=
t_in_p_buf
+
t_block_begin
-
1
;
scalar_t
this_p
=
-
INFINITY
;
if
(
static_cast
<
unsigned
int
>
(
s
)
<=
static_cast
<
unsigned
int
>
(
s
_end
)
&&
static_cast
<
unsigned
int
>
(
t
)
<=
static_cast
<
unsigned
int
>
(
t_end
))
this_p
=
p
[
b
][
s
][
t
];
p
_buf
[
s_in_p_buf
][
t_in_p_buf
]
=
this_p
;
if
(
threadIdx
.
x
<
=
BLOCK_SIZE
)
{
// s_in_p_buf are simply the indexes into p_buf
int
s_in_p_buf
=
threadIdx
.
x
,
t
_in_p_buf
=
0
,
s
=
s
_in_p_buf
+
s_block_begin
-
1
,
t
=
t
_in_p_buf
+
t
_block_begin
-
1
;
scalar_t
this_p
=
-
INFINITY
;
if
(
static_cast
<
unsigned
int
>
(
s
)
<=
static_cast
<
unsigned
int
>
(
s_end
)
&&
static_cast
<
unsigned
int
>
(
t
)
<=
static_cast
<
unsigned
int
>
(
t
_end
)
)
{
this_p
=
p
[
b
][
s
][
t
];
/*printf("p[%d][%d][%d] = %f, threadIdx.x = %d, px = %f, py = %f\n", b, s, t, (float)this_p, (int)threadIdx.x,
(float)px
_buf[s_in_p_buf][t_in_p_buf]
, (float)py_buf[s_in_p_buf][t_in_p_buf]); */
}
p_buf
[
s_in_p_buf
][
t_in_p_buf
]
=
this_p
;
}
else
{
// Another warp handles the other leg. Checking as unsigned
// tests that threadIdx.x - 64 is both >= 0 and <= BLOCK_SIZE
if
(
static_cast
<
unsigned
int
>
(
int
(
threadIdx
.
x
)
-
64
)
<=
static_cast
<
unsigned
int
>
(
BLOCK_SIZE
))
{
int
s_in_p_buf
=
0
,
t_in_p_buf
=
threadIdx
.
x
-
64
,
t_in_p_buf
=
(
int
)
threadIdx
.
x
-
64
,
s
=
s_in_p_buf
+
s_block_begin
-
1
,
t
=
t_in_p_buf
+
t_block_begin
-
1
;
// The if-statement below just guards against out-of-range memory
// accesses, it does not guarantee that we really need these values.
scalar_t
this_p
=
-
INFINITY
;
if
(
static_cast
<
unsigned
int
>
(
s
)
<=
static_cast
<
unsigned
int
>
(
s_end
)
&&
static_cast
<
unsigned
int
>
(
t
)
<=
static_cast
<
unsigned
int
>
(
t_end
))
static_cast
<
unsigned
int
>
(
t
)
<=
static_cast
<
unsigned
int
>
(
t_end
))
{
this_p
=
p
[
b
][
s
][
t
];
/*printf("p[%d][%d][%d] = %f, threadIdx.x = %d, px = %f, py = %f\n", b, s, t, (float)this_p, (int)threadIdx.x,
(float)px_buf[s_in_p_buf][t_in_p_buf], (float)py_buf[s_in_p_buf][t_in_p_buf]);*/
}
p_buf
[
s_in_p_buf
][
t_in_p_buf
]
=
this_p
;
}
}
...
...
@@ -264,7 +268,9 @@ void mutual_information_kernel(
// zero, and then exponentiate. We'll do everything in non-log space, for
// speed, and later take a log before we write out the data.
scalar_t
normalizer
=
(
is_origin_block
?
0.0
:
max
(
px_buf
[
0
][
1
],
px_buf
[
1
][
0
]));
max
(
p_buf
[
0
][
1
],
p_buf
[
1
][
0
]));
__syncthreads
();
// Normalize and exponentiate the edge elements of p_buf, i.e. the elements
// where at one index is 0. The [0][0] element is special; we write 0.0,
...
...
@@ -286,8 +292,12 @@ void mutual_information_kernel(
p_buf
[
0
][
t
]
=
exp
(
p_buf
[
0
][
t
]
-
normalizer
);
}
__syncthreads
();
// from here to the next __syncthreads(), only the 1st warp should be active
// so we shouldn't need to synchronize. (implicit within-warp
// synchronization).
if
(
threadIdx
.
x
==
0
&&
is_origin_block
)
{
if
(
threadIdx
.
x
==
0
)
{
// This if-statement is an optimization and modification of the loop below
// for the value i == 0, i.e. inner-iteration == 0. The modification is
// to set p_buf to 1.0 = exp(0.0) if this is the "origin block",
...
...
@@ -309,55 +319,58 @@ void mutual_information_kernel(
}
int
s
=
threadIdx
.
x
;
if
(
s
<
block_S
)
{
for
(
int
i
=
1
;
i
<
block_S
+
block_T
-
1
;
++
i
)
{
// i is the inner iteration, which corresponds to the (s + t) indexes of
// the elements within the block that we write. So i == 0 writes
// positions (s, t) == (0, 0) (but we treated i == 0 as a special case
// above); i == 1 writes (0, 1) and (1, 0); i == 2 writes (0, 2), (1, 1)
// and (2, 1); and so on. Note: not many threads participate in this
// part, only up to BLOCK_SIZE at most. Unfortunately we couldn't figure
// out a very meaningful way for more threads to do work, that looked like
// it would really spead things up.
// So this kernel does (2 * BLOCK_SIZE) iterations, which may seem a lot,
// but we do at least do the I/O in an efficient way and keep the
// inner loop simple and fast (e.g. no exp() or log()).
int
s
=
threadIdx
.
x
,
t
=
i
-
s
;
if
(
static_cast
<
unsigned
int
>
(
t
)
<
static_cast
<
unsigned
int
>
(
block_T
))
{
// p_buf is indexed by s + 1 and t + 1 because it has an extra initial
// row and column for context from previous blocks. Taking into account
// the way these buffers relate to the tensors p, px and py, and
// ignoring `normalizer`, code below can be interpreted as follows,
// writing sbb for s_block_begin and tbb for t_block_begin:
//
// p[b][s+sbb][t+tbb] = LogAdd(p[b][s+sbb-1][t+tbb] + px[s+sbb-1][t+tbb],
// p[b][s+sbb][t+tbb-1] + py[s+sbb][t+tbb-1]
//
// where you can see that apart from the offsets of tbb and sbb, this is
// the same as the recursion defined for p in
// mutual_information.py:mutual_information_recursion(); and (eq. 0) above.
for
(
int
i
=
1
;
i
<
block_S
+
block_T
-
1
;
++
i
)
{
__syncwarp
();
// i is the inner iteration, which corresponds to the (s + t) indexes of
// the elements within the block that we write. So i == 0 writes
// positions (s, t) == (0, 0) (but we treated i == 0 as a special case
// above); i == 1 writes (0, 1) and (1, 0); i == 2 writes (0, 2), (1, 1)
// and (2, 1); and so on. Note: not many threads participate in this
// part, only up to BLOCK_SIZE at most. Unfortunately we couldn't figure
// out a very meaningful way for more threads to do work, that looked like
// it would really spead things up.
// So this kernel does (2 * BLOCK_SIZE) iterations, which may seem a lot,
// but we do at least do the I/O in an efficient way and keep the
// inner loop simple and fast (e.g. no exp() or log()).
int
t
=
i
-
s
;
if
(
s
<
block_S
&&
static_cast
<
unsigned
int
>
(
t
)
<
static_cast
<
unsigned
int
>
(
block_T
))
{
// p_buf is indexed by s + 1 and t + 1 because it has an extra initial
// row and column for context from previous blocks. Taking into account
// the way these buffers relate to the tensors p, px and py, and
// ignoring `normalizer`, code below can be interpreted as follows,
// writing sbb for s_block_begin and tbb for t_block_begin:
//
// p[b][s+sbb][t+tbb] = LogAdd(p[b][s+sbb-1][t+tbb] + px[s+sbb-1][t+tbb],
// p[b][s+sbb][t+tbb-1] + py[s+sbb][t+tbb-1]
//
// where you can see that apart from the offsets of tbb and sbb, this is
// the same as the recursion defined for p in
// mutual_information.py:mutual_information_recursion(); and (eq. 0) above.
#if 1
p_buf
[
s
+
1
][
t
+
1
]
=
p_buf
[
s
][
t
+
1
]
*
px_buf
[
s
][
t
]
+
p_buf
[
s
+
1
][
t
]
*
py_buf
[
s
][
t
];
p_buf
[
s
+
1
][
t
+
1
]
=
p_buf
[
s
][
t
+
1
]
*
px_buf
[
s
][
t
]
+
p_buf
[
s
+
1
][
t
]
*
py_buf
[
s
][
t
];
/*printf("threadIdx.x = %d, i = %d, s = %d, t = %d, p_buf[s+1][t+1] = %f, p_buf[s][t+1] = %f, "
"px_buf[s][t] = %f, p_buf[s + 1][t] = %f, py_buf[s][t] = %f\n",
(int)threadIdx.x, i, s, t, (float)p_buf[s+1][t+1], (float)p_buf[s][t+1],
(float)px_buf[s][t], (float)p_buf[s+1][t], (float)py_buf[s][t]);*/
#else
// This is an optimization of the statement above (the other half of
// this #if/#else) where we keep p_buf[s + 1][t] in a register to avoid
// the need for a load from shared memory.
p_buf_s1_t
=
p_buf
[
s
][
t
+
1
]
*
px_buf
[
s
][
t
]
+
p_buf_s1_t
*
py_buf
[
s
][
t
];
// The next time this thread reads p_buf_s1_t, t will be one greater,
// so p_buf_s1_t will contain p_buf[s + 1][t]. The first time this
// thread uses p_buf_s1_t is when t == 0, except for thread 0 where
// the 1st item accessed is for s == 0, t == 1.
p_buf
[
s
+
1
][
t
+
1
]
=
p_buf_s1_t
;
// This is an optimization of the statement above (the other half of
// this #if/#else) where we keep p_buf[s + 1][t] in a register to avoid
// the need for a load from shared memory.
p_buf_s1_t
=
p_buf
[
s
][
t
+
1
]
*
px_buf
[
s
][
t
]
+
p_buf_s1_t
*
py_buf
[
s
][
t
];
// The next time this thread reads p_buf_s1_t, t will be one greater,
// so p_buf_s1_t will contain p_buf[s + 1][t]. The first time this
// thread uses p_buf_s1_t is when t == 0, except for thread 0 where
// the 1st item accessed is for s == 0, t == 1.
p_buf
[
s
+
1
][
t
+
1
]
=
p_buf_s1_t
;
#endif
// We don't need to do __syncthreads() in this loop because all the
// threads that are active are in the same warp. (However, in future,
// if NVidia changes some things, we might need to sync here).
}
// We don't need to do __syncthreads() in this loop because all the
// threads that are active are in the same warp. (However, in future,
// if NVidia changes some things, we might need to sync here).
}
__syncthreads
();
}
__syncthreads
();
// Write out the data to p; check that nothing has gone out of numerical
// range, and write 'panic' flag if it has.
...
...
@@ -369,9 +382,11 @@ void mutual_information_kernel(
if
(
s_in_block
<
block_S
&&
t_in_block
<
block_T
)
{
float
this_p
=
p_buf
[
s_in_block
+
1
][
t_in_block
+
1
];
p
[
b
][
s
][
t
]
=
normalizer
+
log
(
this_p
);
// If this_p is infinity, NaN or zero...
if
(
this_p
-
this_p
!=
0
||
this_p
==
0
)
// If this_p is infinity or NaN..
if
(
this_p
-
this_p
!=
0
)
{
printf
(
"[panic] threadIdx.x = %d, this_p = %f
\n
"
,
(
int
)
threadIdx
.
x
,
(
float
)
this_p
);
p_buf
[
0
][
0
]
=
1.0
;
// This is a "panic" flag.
}
}
}
...
...
@@ -391,6 +406,8 @@ void mutual_information_kernel(
}
if
(
p_buf
[
0
][
0
]
!=
0.0
)
{
if
(
threadIdx
.
x
==
0
)
printf
(
"Panic flag set, value = %f
\n
"
,
(
float
)
p_buf
[
0
][
0
]);
// TEMP
// The "panic" flag is set. We need to re-do the computation using log-add.
// This time we won't use the buffers, we'll just load and save from main
// memory. This code should very rarely be reached; and anyway, caching
...
...
@@ -398,9 +415,10 @@ void mutual_information_kernel(
int
s_in_block
=
threadIdx
.
x
;
if
(
s_in_block
<
block_S
)
{
for
(
int
i
=
0
;
i
<
block_S
+
block_T
-
1
;
++
i
)
{
__syncwarp
();
int
t_in_block
=
i
-
s_in_block
;
if
(
s
_in_block
<
block_S
&&
static_cast
<
unsigned
int
>
(
t_in_block
)
<
static_cast
<
unsigned
int
>
(
block_T
))
{
if
(
s
tatic_cast
<
unsigned
int
>
(
t_in_block
)
<
static_cast
<
unsigned
int
>
(
block_T
))
{
int
s
=
s_in_block
+
s_block_begin
,
t
=
t_in_block
+
t_block_begin
;
float
p_s1
=
(
s
==
0
?
-
INFINITY
:
p
[
b
][
s
-
1
][
t
]),
...
...
@@ -717,18 +735,17 @@ void mutual_information_backward_kernel(
{
int
s
=
threadIdx
.
x
;
if
(
s
<
block_S
)
{
for
(
int
i
=
first_iter
;
i
>=
0
;
--
i
)
{
int
t
=
i
-
s
;
if
(
t
>=
0
)
{
// The following statement is really operating on the gradients;
// it corresponds, with offsets of s_block_begin and t_block_begin
// on the indexes, to (eq. 6) defined above, i.e.:
// p_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] +
// p_grad[b][s][t + 1] * yderiv[b][s][t]
p_buf
[
s
][
t
]
=
(
p_buf
[
s
+
1
][
t
]
*
px_buf
[
s
][
t
]
+
p_buf
[
s
][
t
+
1
]
*
py_buf
[
s
][
t
]);
}
for
(
int
i
=
first_iter
;
i
>=
0
;
--
i
)
{
__syncwarp
();
int
t
=
i
-
s
;
if
(
t
>=
0
&&
s
<
block_S
)
{
// The following statement is really operating on the gradients;
// it corresponds, with offsets of s_block_begin and t_block_begin
// on the indexes, to (eq. 6) defined above, i.e.:
// p_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] +
// p_grad[b][s][t + 1] * yderiv[b][s][t]
p_buf
[
s
][
t
]
=
(
p_buf
[
s
+
1
][
t
]
*
px_buf
[
s
][
t
]
+
p_buf
[
s
][
t
+
1
]
*
py_buf
[
s
][
t
]);
}
}
}
...
...
torch_mutual_information/mutual_information_test.py
View file @
85c97136
...
...
@@ -12,16 +12,16 @@ def test_mutual_information_basic():
for
device
in
[
torch
.
device
(
'cpu'
),
torch
.
device
(
'cuda:0'
)
]:
print
(
"dtype = "
,
dtype
,
", device = "
,
device
)
B
=
2
S
=
17
T
=
17
S
=
33
T
=
33
boundary
=
torch
.
tensor
([
0
,
0
,
S
,
T
],
dtype
=
torch
.
int64
).
unsqueeze
(
0
).
expand
(
B
,
4
).
to
(
device
)
px
=
torch
.
zeros
(
B
,
S
,
T
+
1
).
to
(
device
)
# log of an odds ratio
py
=
torch
.
zeros
(
B
,
S
+
1
,
T
).
to
(
device
)
# log of an odds ratio
px
=
torch
.
zeros
(
B
,
S
,
T
+
1
,
dtype
=
dtype
).
to
(
device
)
# log of an odds ratio
py
=
torch
.
zeros
(
B
,
S
+
1
,
T
,
dtype
=
dtype
).
to
(
device
)
# log of an odds ratio
m
=
mutual_information_recursion
(
px
,
py
,
None
)
#m = mutual_information_recursion(px, py, boundary)
print
(
"m = "
,
m
)
print
(
"m = "
,
m
,
", size = "
,
m
.
shape
)
print
(
"exp(m) = "
,
m
.
exp
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment