Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
02a36166
"tests/unittest/vscode:/vscode.git/clone" did not exist on "843d13829b204586fbab59f1cf4055a290d56dcc"
Commit
02a36166
authored
Jul 29, 2021
by
Daniel Povey
Browse files
Fix more bugs, add some debug statement.
parent
9ac065f0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
100 additions
and
80 deletions
+100
-80
torch_mutual_information/mutual_information_cuda_kernel.cu
torch_mutual_information/mutual_information_cuda_kernel.cu
+96
-77
torch_mutual_information/mutual_information_test.py
torch_mutual_information/mutual_information_test.py
+4
-3
No files found.
torch_mutual_information/mutual_information_cuda_kernel.cu
View file @
02a36166
...
@@ -168,11 +168,14 @@ void mutual_information_kernel(
...
@@ -168,11 +168,14 @@ void mutual_information_kernel(
b
=
batch_block_iter
%
B
;
// b is the index into the batch
b
=
batch_block_iter
%
B
;
// b is the index into the batch
// Note: `block` can be no greater than `iter` because num_blocks_this_iter
// Note: `block` can be no greater than `iter` because num_blocks_this_iter
// <= iter + 1, so iter - block >= 0.
// <= iter + 1, i.e. iter >= num_blocks_this_iter - 1; and
// block < num_blocks_this_iter, so iter - block >= 0.
int
s_block_begin
=
block
*
BLOCK_SIZE
,
int
s_block_begin
=
block
*
BLOCK_SIZE
,
t_block_begin
=
(
iter
-
block
)
*
BLOCK_SIZE
;
t_block_begin
=
(
iter
-
block
)
*
BLOCK_SIZE
;
bool
is_origin_block
=
(
s_block_begin
*
t_block_begin
==
0
);
bool
is_origin_block
=
(
s_block_begin
*
t_block_begin
==
0
);
assert
(
b
<
B
&&
b
>=
0
&&
t_block_begin
>=
0
);
// TODO: remove.
if
(
boundary
.
size
(
0
)
!=
0
&&
threadIdx
.
x
<
4
)
if
(
boundary
.
size
(
0
)
!=
0
&&
threadIdx
.
x
<
4
)
boundary_buf
[
threadIdx
.
x
]
=
boundary
[
b
][
threadIdx
.
x
];
boundary_buf
[
threadIdx
.
x
]
=
boundary
[
b
][
threadIdx
.
x
];
__syncthreads
();
__syncthreads
();
...
@@ -233,10 +236,13 @@ void mutual_information_kernel(
...
@@ -233,10 +236,13 @@ void mutual_information_kernel(
if
(
static_cast
<
unsigned
int
>
(
s
)
<=
static_cast
<
unsigned
int
>
(
s_end
)
&&
if
(
static_cast
<
unsigned
int
>
(
s
)
<=
static_cast
<
unsigned
int
>
(
s_end
)
&&
static_cast
<
unsigned
int
>
(
t
)
<=
static_cast
<
unsigned
int
>
(
t_end
))
static_cast
<
unsigned
int
>
(
t
)
<=
static_cast
<
unsigned
int
>
(
t_end
))
this_p
=
p
[
b
][
s
][
t
];
this_p
=
p
[
b
][
s
][
t
];
p_buf
[
threadIdx
.
x
][
0
]
=
this_p
;
p_buf
[
s_in_p_buf
][
t_in_p_buf
]
=
this_p
;
}
}
}
else
{
// Another warp handles the other leg
}
else
{
if
(
int
(
threadIdx
.
x
)
-
64
<=
BLOCK_SIZE
)
{
// Another warp handles the other leg. Checking as unsigned
// tests that threadIdx.x - 64 is both >= 0 and <= BLOCK_SIZE
if
(
static_cast
<
unsigned
int
>
(
int
(
threadIdx
.
x
)
-
64
)
<=
static_cast
<
unsigned
int
>
(
BLOCK_SIZE
))
{
int
s_in_p_buf
=
0
,
int
s_in_p_buf
=
0
,
t_in_p_buf
=
threadIdx
.
x
-
64
,
t_in_p_buf
=
threadIdx
.
x
-
64
,
s
=
s_in_p_buf
+
s_block_begin
-
1
,
s
=
s_in_p_buf
+
s_block_begin
-
1
,
...
@@ -247,7 +253,7 @@ void mutual_information_kernel(
...
@@ -247,7 +253,7 @@ void mutual_information_kernel(
if
(
static_cast
<
unsigned
int
>
(
s
)
<=
static_cast
<
unsigned
int
>
(
s_end
)
&&
if
(
static_cast
<
unsigned
int
>
(
s
)
<=
static_cast
<
unsigned
int
>
(
s_end
)
&&
static_cast
<
unsigned
int
>
(
t
)
<=
static_cast
<
unsigned
int
>
(
t_end
))
static_cast
<
unsigned
int
>
(
t
)
<=
static_cast
<
unsigned
int
>
(
t_end
))
this_p
=
p
[
b
][
s
][
t
];
this_p
=
p
[
b
][
s
][
t
];
p_buf
[
threadIdx
.
x
][
0
]
=
this_p
;
p_buf
[
s_in_p_buf
][
t_in_p_buf
]
=
this_p
;
}
}
}
}
...
@@ -269,11 +275,15 @@ void mutual_information_kernel(
...
@@ -269,11 +275,15 @@ void mutual_information_kernel(
// p_buf[0][0] = 0.0; <-- for search purposes.
// p_buf[0][0] = 0.0; <-- for search purposes.
// We'll later write an infinity there if something goes wrong, as a
// We'll later write an infinity there if something goes wrong, as a
// 'panic' indicator.
// 'panic' indicator.
p_buf
[
threadIdx
.
x
][
0
]
=
(
threadIdx
.
x
==
0
?
0.0
:
int
s
=
threadIdx
.
x
;
exp
(
p_buf
[
threadIdx
.
x
][
0
]
-
normalizer
));
p_buf
[
s
][
0
]
=
(
s
==
0
?
0.0
:
}
else
if
(
int
(
threadIdx
.
x
)
-
64
<
BLOCK_SIZE
)
{
exp
(
p_buf
[
s
][
0
]
-
normalizer
));
}
else
if
(
static_cast
<
unsigned
int
>
(
int
(
threadIdx
.
x
)
-
64
)
<
static_cast
<
unsigned
int
>
(
BLOCK_SIZE
))
{
// if (threadidx.x - 64) >= 0 && (threadIdx.x - 64) < BLOCK_SIZE..
int
t
=
(
int
)
threadIdx
.
x
-
64
+
1
;
// 0 < t <= BLOCK_SIZE
// this happens in a different warp so can be in parallel to the code above.
// this happens in a different warp so can be in parallel to the code above.
p_buf
[
0
][
t
hreadIdx
.
x
+
1
]
=
exp
(
p_buf
[
0
][
t
hreadIdx
.
x
+
1
]
-
normalizer
);
p_buf
[
0
][
t
]
=
exp
(
p_buf
[
0
][
t
]
-
normalizer
);
}
}
...
@@ -298,50 +308,53 @@ void mutual_information_kernel(
...
@@ -298,50 +308,53 @@ void mutual_information_kernel(
p_buf_s1_t
=
p_buf
[
s
+
1
][
threadIdx
.
x
==
0
?
1
:
0
];
p_buf_s1_t
=
p_buf
[
s
+
1
][
threadIdx
.
x
==
0
?
1
:
0
];
}
}
for
(
int
i
=
1
;
i
<
block_S
+
block_T
-
1
;
++
i
)
{
int
s
=
threadIdx
.
x
;
// i is the inner iteration, which corresponds to the (s + t) indexes of
if
(
s
<
block_S
)
{
// the elements within the block that we write. So i == 0 writes
for
(
int
i
=
1
;
i
<
block_S
+
block_T
-
1
;
++
i
)
{
// positions (s, t) == (0, 0) (but we treated i == 0 as a special case
// i is the inner iteration, which corresponds to the (s + t) indexes of
// above); i == 1 writes (0, 1) and (1, 0); i == 2 writes (0, 2), (1, 1)
// the elements within the block that we write. So i == 0 writes
// and (2, 1); and so on. Note: not many threads participate in this
// positions (s, t) == (0, 0) (but we treated i == 0 as a special case
// part, only up to BLOCK_SIZE at most. Unfortunately we couldn't figure
// above); i == 1 writes (0, 1) and (1, 0); i == 2 writes (0, 2), (1, 1)
// out a very meaningful way for more threads to do work, that looked like
// and (2, 1); and so on. Note: not many threads participate in this
// it would really spead things up.
// part, only up to BLOCK_SIZE at most. Unfortunately we couldn't figure
// So this kernel does (2 * BLOCK_SIZE) iterations, which may seem a lot,
// out a very meaningful way for more threads to do work, that looked like
// but we do at least do the I/O in an efficient way and keep the
// it would really spead things up.
// inner loop simple and fast (e.g. no exp() or log()).
// So this kernel does (2 * BLOCK_SIZE) iterations, which may seem a lot,
int
s
=
threadIdx
.
x
,
// but we do at least do the I/O in an efficient way and keep the
t
=
i
-
s
;
// inner loop simple and fast (e.g. no exp() or log()).
int
s
=
threadIdx
.
x
,
if
(
static_cast
<
unsigned
int
>
(
t
)
<
static_cast
<
unsigned
int
>
(
block_T
))
{
t
=
i
-
s
;
// p_buf is indexed by s + 1 and t + 1 because it has an extra initial
// row and column for context from previous blocks. Taking into account
if
(
static_cast
<
unsigned
int
>
(
t
)
<
static_cast
<
unsigned
int
>
(
block_T
))
{
// the way these buffers relate to the tensors p, px and py, and
// p_buf is indexed by s + 1 and t + 1 because it has an extra initial
// ignoring `normalizer`, code below can be interpreted as follows,
// row and column for context from previous blocks. Taking into account
// writing sbb for s_block_begin and tbb for t_block_begin:
// the way these buffers relate to the tensors p, px and py, and
//
// ignoring `normalizer`, code below can be interpreted as follows,
// p[b][s+sbb][t+tbb] = LogAdd(p[b][s+sbb-1][t+tbb] + px[s+sbb-1][t+tbb],
// writing sbb for s_block_begin and tbb for t_block_begin:
// p[b][s+sbb][t+tbb-1] + py[s+sbb][t+tbb-1]
//
//
// p[b][s+sbb][t+tbb] = LogAdd(p[b][s+sbb-1][t+tbb] + px[s+sbb-1][t+tbb],
// where you can see that apart from the offsets of tbb and sbb, this is
// p[b][s+sbb][t+tbb-1] + py[s+sbb][t+tbb-1]
// the same as the recursion defined for p in
//
// mutual_information.py:mutual_information_recursion(); and (eq. 0) above.
// where you can see that apart from the offsets of tbb and sbb, this is
// the same as the recursion defined for p in
// mutual_information.py:mutual_information_recursion(); and (eq. 0) above.
#if 1
#if 1
p_buf
[
s
+
1
][
t
+
1
]
=
p_buf
[
s
][
t
+
1
]
*
px_buf
[
s
][
t
]
+
p_buf
[
s
+
1
][
t
]
*
py_buf
[
s
][
t
];
p_buf
[
s
+
1
][
t
+
1
]
=
p_buf
[
s
][
t
+
1
]
*
px_buf
[
s
][
t
]
+
p_buf
[
s
+
1
][
t
]
*
py_buf
[
s
][
t
];
#else
#else
// This is an optimization of the statement above (the other half of
// This is an optimization of the statement above (the other half of
// this #if/#else) where we keep p_buf[s + 1][t] in a register to avoid
// this #if/#else) where we keep p_buf[s + 1][t] in a register to avoid
// the need for a load from shared memory.
// the need for a load from shared memory.
p_buf_s1_t
=
p_buf
[
s
][
t
+
1
]
*
px_buf
[
s
][
t
]
+
p_buf_s1_t
*
py_buf
[
s
][
t
];
p_buf_s1_t
=
p_buf
[
s
][
t
+
1
]
*
px_buf
[
s
][
t
]
+
p_buf_s1_t
*
py_buf
[
s
][
t
];
// The next time this thread reads p_buf_s1_t, t will be one greater,
// The next time this thread reads p_buf_s1_t, t will be one greater,
// so p_buf_s1_t will contain p_buf[s + 1][t]. The first time this
// so p_buf_s1_t will contain p_buf[s + 1][t]. The first time this
// thread uses p_buf_s1_t is when t == 0, except for thread 0 where
// thread uses p_buf_s1_t is when t == 0, except for thread 0 where
// the 1st item accessed is for s == 0, t == 1.
// the 1st item accessed is for s == 0, t == 1.
p_buf
[
s
+
1
][
t
+
1
]
=
p_buf_s1_t
;
p_buf
[
s
+
1
][
t
+
1
]
=
p_buf_s1_t
;
#endif
#endif
// We don't need to do __syncthreads() in this loop because all the
// We don't need to do __syncthreads() in this loop because all the
// threads that are active are in the same warp. (However, in future,
// threads that are active are in the same warp. (However, in future,
// if NVidia changes some things, we might need to sync here).
// if NVidia changes some things, we might need to sync here).
}
}
}
__syncthreads
();
__syncthreads
();
}
}
...
@@ -382,22 +395,24 @@ void mutual_information_kernel(
...
@@ -382,22 +395,24 @@ void mutual_information_kernel(
// This time we won't use the buffers, we'll just load and save from main
// This time we won't use the buffers, we'll just load and save from main
// memory. This code should very rarely be reached; and anyway, caching
// memory. This code should very rarely be reached; and anyway, caching
// should help us quite a bit.
// should help us quite a bit.
for
(
int
i
=
0
;
i
<
block_S
+
block_T
-
1
;
++
i
)
{
int
s_in_block
=
threadIdx
.
x
;
int
s_in_block
=
threadIdx
.
x
,
if
(
s_in_block
<
block_S
)
{
t_in_block
=
i
-
s_in_block
;
for
(
int
i
=
0
;
i
<
block_S
+
block_T
-
1
;
++
i
)
{
if
(
s_in_block
<
block_S
&&
int
t_in_block
=
i
-
s_in_block
;
static_cast
<
unsigned
int
>
(
t_in_block
)
<
static_cast
<
unsigned
int
>
(
block_T
))
{
if
(
s_in_block
<
block_S
&&
int
s
=
s_in_block
+
s_block_begin
,
static_cast
<
unsigned
int
>
(
t_in_block
)
<
static_cast
<
unsigned
int
>
(
block_T
))
{
t
=
t_in_block
+
t_block_begin
;
int
s
=
s_in_block
+
s_block_begin
,
float
p_s1
=
(
s
==
0
?
-
INFINITY
:
p
[
b
][
s
-
1
][
t
]),
t
=
t_in_block
+
t_block_begin
;
this_px
=
(
s
==
0
?
-
INFINITY
:
px
[
b
][
s
-
1
][
t
]),
float
p_s1
=
(
s
==
0
?
-
INFINITY
:
p
[
b
][
s
-
1
][
t
]),
p_t1
=
(
t
==
0
?
-
INFINITY
:
p
[
b
][
s
][
t
-
1
]),
this_px
=
(
s
==
0
?
-
INFINITY
:
px
[
b
][
s
-
1
][
t
]),
this_py
=
(
t
==
0
?
-
INFINITY
:
py
[
b
][
s
][
t
-
1
]);
p_t1
=
(
t
==
0
?
-
INFINITY
:
p
[
b
][
s
][
t
-
1
]),
float
this_p
=
LogAdd
(
p_s1
+
this_px
,
this_py
=
(
t
==
0
?
-
INFINITY
:
py
[
b
][
s
][
t
-
1
]);
p_t1
+
this_py
);
float
this_p
=
LogAdd
(
p_s1
+
this_px
,
if
(
i
==
0
&&
is_origin_block
)
p_t1
+
this_py
);
this_p
=
0.0
;
if
(
i
==
0
&&
is_origin_block
)
p
[
b
][
s
][
t
]
=
this_p
;
this_p
=
0.0
;
p
[
b
][
s
][
t
]
=
this_p
;
}
}
}
}
}
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
...
@@ -700,17 +715,21 @@ void mutual_information_backward_kernel(
...
@@ -700,17 +715,21 @@ void mutual_information_backward_kernel(
--
first_iter
;
--
first_iter
;
}
}
for
(
int
i
=
first_iter
;
i
>=
0
;
--
i
)
{
{
int
s
=
i
,
int
s
=
threadIdx
.
x
;
t
=
i
-
threadIdx
.
x
;
if
(
s
<
block_S
)
{
if
(
t
>=
0
)
{
for
(
int
i
=
first_iter
;
i
>=
0
;
--
i
)
{
// The following statement is really operating on the gradients;
int
t
=
i
-
s
;
// it corresponds, with offsets of s_block_begin and t_block_begin
if
(
t
>=
0
)
{
// on the indexes, to (eq. 6) defined above, i.e.:
// The following statement is really operating on the gradients;
// p_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] +
// it corresponds, with offsets of s_block_begin and t_block_begin
// p_grad[b][s][t + 1] * yderiv[b][s][t]
// on the indexes, to (eq. 6) defined above, i.e.:
p_buf
[
s
][
t
]
=
(
p_buf
[
s
+
1
][
t
]
*
px_buf
[
s
][
t
]
+
// p_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] +
p_buf
[
s
][
t
+
1
]
*
py_buf
[
s
][
t
]);
// p_grad[b][s][t + 1] * yderiv[b][s][t]
p_buf
[
s
][
t
]
=
(
p_buf
[
s
+
1
][
t
]
*
px_buf
[
s
][
t
]
+
p_buf
[
s
][
t
+
1
]
*
py_buf
[
s
][
t
]);
}
}
}
}
}
}
...
...
torch_mutual_information/mutual_information_test.py
View file @
02a36166
...
@@ -12,13 +12,14 @@ def test_mutual_information_basic():
...
@@ -12,13 +12,14 @@ def test_mutual_information_basic():
for
device
in
[
torch
.
device
(
'cpu'
),
torch
.
device
(
'cuda:0'
)
]:
for
device
in
[
torch
.
device
(
'cpu'
),
torch
.
device
(
'cuda:0'
)
]:
print
(
"dtype = "
,
dtype
,
", device = "
,
device
)
print
(
"dtype = "
,
dtype
,
", device = "
,
device
)
B
=
2
B
=
2
S
=
4
S
=
17
T
=
5
T
=
17
boundary
=
torch
.
tensor
([
0
,
0
,
S
,
T
],
dtype
=
torch
.
int64
).
unsqueeze
(
0
).
expand
(
B
,
4
).
to
(
device
)
boundary
=
torch
.
tensor
([
0
,
0
,
S
,
T
],
dtype
=
torch
.
int64
).
unsqueeze
(
0
).
expand
(
B
,
4
).
to
(
device
)
px
=
torch
.
zeros
(
B
,
S
,
T
+
1
).
to
(
device
)
# log of an odds ratio
px
=
torch
.
zeros
(
B
,
S
,
T
+
1
).
to
(
device
)
# log of an odds ratio
py
=
torch
.
zeros
(
B
,
S
+
1
,
T
).
to
(
device
)
# log of an odds ratio
py
=
torch
.
zeros
(
B
,
S
+
1
,
T
).
to
(
device
)
# log of an odds ratio
m
=
mutual_information_recursion
(
px
,
py
,
boundary
)
m
=
mutual_information_recursion
(
px
,
py
,
None
)
#m = mutual_information_recursion(px, py, boundary)
print
(
"m = "
,
m
)
print
(
"m = "
,
m
)
print
(
"exp(m) = "
,
m
.
exp
())
print
(
"exp(m) = "
,
m
.
exp
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment