Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
e95d7864
Commit
e95d7864
authored
Jul 26, 2021
by
Daniel Povey
Browse files
Drafts..
parent
621d5fbb
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
220 additions
and
41 deletions
+220
-41
torch_mutual_information/mutual_information.py
torch_mutual_information/mutual_information.py
+14
-6
torch_mutual_information/mutual_information_cpu.cpp
torch_mutual_information/mutual_information_cpu.cpp
+72
-23
torch_mutual_information/mutual_information_cuda_kernel.cu
torch_mutual_information/mutual_information_cuda_kernel.cu
+134
-12
No files found.
torch_mutual_information/mutual_information.py
View file @
e95d7864
...
...
@@ -84,16 +84,19 @@ class MutualInformationRecursionFunction(torch.autograd.Function):
# has t_begin > 0 or s_begin > 0, i.e. we really access q as
# q[b, s-s_begin + t-t_begin, t-t_begin];
# note, rows of `boundaries` are [s_begin, t_begin, s_end, t_end].
# We don't need q if we are not going to do backprop
q
=
(
torch
.
empty
(
B
,
S
+
T
,
device
=
px
.
device
,
dtype
=
px
.
dtype
)
if
px
.
requires_grad
or
py
.
requires_grad
else
None
)
if
px
.
requires_grad
or
py
.
requires_grad
:
q
=
torch
.
empty
(
B
,
S
,
T
,
device
=
px
.
device
,
dtype
=
px
.
dtype
)
else
:
# We don't need to store q if we are not going to do backprop, but we
# do pass in a temporary with one real row, expanded to have "fake" rows,
# which happens to be convenient for the CPU implementation.
q
=
torch
.
empty
({
1
,
1
,
T
},
device
=
px
.
device
,
dtype
=
px
.
dtype
).
expand
(
B
,
S
+
T
,
T
)
ans
=
_mutual_information_forward_dispatcher
(
px
,
py
,
boundaries
,
q
)
if
px
.
requires_grad
or
py
.
requires_grad
:
ctx
.
save_for_backward
(
px
,
py
,
boundaries
,
w
)
ctx
.
save_for_backward
(
px
,
py
,
boundaries
,
q
)
@
staticmethod
def
backward
(
ctx
,
ans_grad
:
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
None
]:
...
...
@@ -109,7 +112,7 @@ def mutual_information_recursion(input, px, py, boundaries=None):
monotonic alignment between pairs of sequences is desired. The definitions of
the arguments are definitions that would be used when computing this type of
mutual information, but you can also view them as arbitrary quantities and just
look at
the formula computed by this function.
make use of
the formula computed by this function.
Args:
px: A torch.Tensor of some floating point type, with shape [B][S][T],
...
...
@@ -131,6 +134,11 @@ def mutual_information_recursion(input, px, py, boundaries=None):
log(N exp f(x_s, y_{t-1}) / sum_t' exp f(x_s, y_t'))
where N is the number of terms that the sum over t' included, which
might include some or all of the other sequences as well as this one.
Note: we don't require px and py to be contiguous, but the
code assumes for optimization purposes that the T axis has
stride 1.
py: A torch.Tensor of the same dtype as px, with shape [B][S][T],
representing
py[b][s][t] = log [ p(y_t | x_{0..s-1}, y_{0..t-1}) / p(y_t) ]
...
...
torch_mutual_information/mutual_information_cpu.cpp
View file @
e95d7864
#include <math.h> // for log1p, log1pf
#include <torch/extension.h>
// returns log(exp(x) + exp(y)).
inline
double
LogAdd
(
double
x
,
double
y
)
{
double
diff
;
if
(
x
<
y
)
{
diff
=
x
-
y
;
x
=
y
;
}
else
{
diff
=
y
-
x
;
}
// diff is negative. x is now the larger one.
if
(
diff
>=
kMinLogDiffDouble
)
{
double
res
;
res
=
x
+
log1p
(
exp
(
diff
));
return
res
;
}
return
x
;
// return the larger one.
}
// returns log(exp(x) + exp(y)).
inline
float
LogAdd
(
float
x
,
float
y
)
{
float
diff
;
if
(
x
<
y
)
{
diff
=
x
-
y
;
x
=
y
;
}
else
{
diff
=
y
-
x
;
}
// diff is negative. x is now the larger one.
if
(
diff
>=
kMinLogDiffFloat
)
{
float
res
;
res
=
x
+
log1pf
(
expf
(
diff
));
return
res
;
}
return
x
;
// return the larger one.
}
// forward of mutual_information. See """... """ comment of `mutual_information` in
// mutual_information.py for documentation of the behavior of this function.
torch
::
Tensor
mutual_information_cpu
(
torch
::
Tensor
input
,
torch
::
Tensor
params
)
{
TORCH_CHECK
(
input
.
dim
()
==
3
,
"input must be 3-dimensional"
);
TORCH_CHECK
(
params
.
dim
()
==
2
,
"params must be 2-dimensional."
);
TORCH_CHECK
(
params
.
size
(
1
)
>=
3
&&
((
params
.
size
(
1
)
-
1
)
&
(
params
.
size
(
1
)
-
2
))
==
0
,
"params.size(1) has invalid value, must be a power of 2 plus 1."
);
TORCH_CHECK
(
params
.
size
(
0
)
==
input
.
size
(
1
),
"params vs input channels mismatch"
);
torch
::
Tensor
mutual_information_cpu
(
torch
::
Tensor
px
,
torch
::
Tensor
py
,
std
::
optional
<
torch
::
Tensor
>
optional_boundary
,
torch
::
Tensor
q
)
{
TORCH_CHECK
(
input
.
device
().
is_cpu
(),
"Input must be a CPU tensor"
);
TORCH_CHECK
(
params
.
device
().
is_cpu
(),
"Params must be a CPU tensor"
);
TORCH_CHECK
(
px
.
dim
()
==
3
,
"px must be 3-dimensional"
);
TORCH_CHECK
(
py
.
dim
()
==
3
,
"params must be 3-dimensional."
);
TORCH_CHECK
(
q
.
dim
()
==
3
,
"params must be 3-dimensional."
);
const
int
B
=
input
.
size
(
0
),
C
=
input
.
size
(
1
),
T
=
input
.
size
(
2
),
N
=
params
.
size
(
1
)
-
1
,
K
=
N
/
2
;
auto
scalar_t
=
px
.
scalar_type
();
auto
opts
=
torch
::
TensorOptions
().
dtype
(
scalar_t
).
device
(
px
.
device
());
auto
scalar_t
=
input
.
scalar_type
();
auto
opts
=
torch
::
TensorOptions
().
dtype
(
scalar_t
).
device
(
input
.
device
());
torch
::
Tensor
y_vals
=
torch
::
empty
({
C
,
N
},
opts
),
output
=
torch
::
empty
({
B
,
C
,
T
},
opts
);
const
int
B
=
px
.
size
(
0
),
S
=
px
.
size
(
1
),
T
=
px
.
size
(
2
);
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"mutual_information_cpu_loop"
,
([
&
]
{
auto
params_a
=
params
.
accessor
<
scalar_t
,
2
>
(),
y_vals_a
=
y_vals
.
accessor
<
scalar_t
,
2
>
();
TORCH_CHECK
(
q
.
size
(
0
)
==
B
&&
q
.
size
(
1
)
==
S
+
T
&&
q
.
size
(
2
)
==
T
);
auto
long_opts
=
torch
::
TensorOptiona
().
dtype
(
torch
::
kInt64
);
bool
has_boundary
=
(
bool
)
optional_boundary
;
if
(
!
has_boundary
)
optional_boundary
=
torch
::
empty
({},
long_opts
);
AT_DISPATCH_FLOATING_TYPES
(
px
.
scalar_type
(),
"mutual_information_cpu_loop"
,
([
&
]
{
auto
px_a
=
px
.
accessor
<
scalar_t
,
3
>
(),
py_a
=
py
.
accessor
<
scalar_t
,
3
>
();
for
(
int
c
=
0
;
c
<
C
;
c
++
)
{
scalar_t
sum_negative
=
0.0
,
sum_positive
=
0.0
,
...
...
torch_mutual_information/mutual_information_cuda_kernel.cu
View file @
e95d7864
...
...
@@ -3,7 +3,6 @@
#include <cooperative_groups.h>
#define THREADS_PER_BLOCK 256
...
...
@@ -43,9 +42,11 @@ __forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile,
/*
Forward of mutual_information. Each thread group handles a single channel (channel
c = blockIdx.x); the gridDim is (C, nb, 1) where 1 <= nb <= B (nb relates to the
image within the batch).
Forward of mutual_information. Each thread block handles blocks of (x, y) shape
equal to (BLOCK_S_SIZE, BLOCK_T_SIZE), e.g. (4, 64). Thread blocks loop over such
blocks, but they might typically loop only once. We sequentially launch groups of
threads in such a way that thread-blocks within a group do not depend on each other.
Template args:
scalar_t: the floating-point type, e.g. float, double, maybe half.
...
...
@@ -88,17 +89,138 @@ __forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile,
*/
extern
__shared__
int
extern_buf
[];
template
<
typename
scalar_t
>
template
<
typename
scalar_t
,
int
BLOCK_S_SIZE
,
// e.g. BLOCK_S_SIZE == 4; power of 2
int
BLOCK_T_SIZE
>
// e.g. BLOCK_T_SIZE == 64; power of 2.
// BLOCK_T_SIZE * 4 must equal num_threads; and must be >= 128, so BLOCK_T_SIZE >= 32 is required.
// (Note: this 4 is unrelated to BLOCK_S_SIZE but can be viewed as 1<<2,
// where 2 is the loop unrolling factor).
__global__
void
mutual_information_kernel
(
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
input
,
// B, C, T, i.e. batch, channels, time
torch
::
PackedTensorAccessor32
<
scalar_t
,
2
>
params
,
// C, N + 1
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
output
,
int
images_per_thread_block
)
{
// B, C, T
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
px
,
// B, S, T, i.e. batch, x_seq_length, y_seq_length
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
py
,
// B, S, T, as above
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
p
,
// B, S, T, as above. This is an output.
torch
::
PackedTensorAccessor32
<
scalar_t
,
2
>
boundary
,
// B, 4; or 0, 0 if boundaries are the defaults (0, 0, S, T)
int
iter
)
{
// This kernel is sequentially called with 'iter' = 0, 1, 2 and so on, up to:
// (S+BLOCK_S_SIZE-1)/BLOCK_S_SIZE + (T+BLOCK_T_SIZE-1)/BLOCK_T_SIZE - 1
// so that each group depends on the previous group...
const
int
block_dimx
=
BLOCK_T_SIZE
*
4
;
// known at compile time.
assert
(
blockDim
.
x
==
block_dimx
);
const
int
B
=
px
.
size
(
0
),
S
=
px
.
size
(
1
),
T
=
py
.
size
(
2
);
// num_s_blocks and num_t_blocks are the number of blocks we need to cover the
// array of size (S, T) with blocks of this size, in the s and t directions
// respectively.
const
int
num_s_blocks
=
(
S
+
BLOCK_S_SIZE
-
1
)
/
BLOCK_S_SIZE
,
num_t_blocks
=
(
T
+
BLOCK_T_SIZE
-
1
)
/
BLOCK_T_SIZE
;
// num_blocks_this_iter is an upper bound on the number of blocks that might
// be active on this iteration. We go from the bottom left of the image
// so that on iter == 0 we process only one block with block-index (0, 0)
// then on iter == 1 we process block-indexes (1, 0) and (0, 1); and then on iter==2
// we process (2, 0), (1, 1) and (0, 2); and so on. We also will never have more
// than `num_s_blocks` blocks (We'll never have more than num_t_blocks either, but
// the numbering we use corresponds to s and not t, so if we hit the num_t_blocks limit,
// the lowest-numbered blocks on s would just not be active and we'll 'continue' below).
int
num_blocks_this_iter
=
min
(
iter
+
1
,
num_s_blocks
);
__shared__
scalar_t
px_buf
[
BLOCK_S_SIZE
][
BLOCK_T_SIZE
],
py_buf
[
BLOCK_S_SIZE
][
BLOCK_T_SIZE
],
p_buf
[
BLOCK_S_SIZE
+
1
][
BLOCK_T_SIZE
+
1
];
// 1st row/col of p_buf
// correspond to the previous
// blocks, or an edge case.
__shared__
boundary_buf
[
4
];
// batch_block_iter iterates over both batch elements (index b), and block
// indexes
for
(
int
batch_block_iter
=
blockIdx
.
x
;
batch_block_iter
<
B
*
num_blocks_this_iter
;
batch_block_iter
+=
gridDim
.
x
)
{
int
b
=
batch_block_iter
%
B
,
block
=
batch_block_iter
/
B
;
int
s_block_begin
=
block
*
BLOCK_S_SIZE
,
t_block_begin
=
(
iter
-
block
)
*
BLOCK_T_SIZE
;
bool
is_origin_block
=
(
s_block_begin
*
t_block_begin
==
0
);
int
s_end
,
t_end
;
// s_end and t_end are the end points (last-plus-one) of the entire sequence.
if
(
boundary
.
size
(
0
)
==
0
)
{
s_end
=
S
;
t_end
=
T
;
}
else
{
if
(
threadDim
.
x
<
4
)
boundary_buf
[
threadDim
.
x
]
=
boundary
[
b
][
threadDim
.
x
];
__syncthreads
();
int
s_begin
=
boundary_buf
[
0
],
t_begin
=
boundary_buf
[
1
];
s_end
=
boundary_buf
[
2
];
t_end
=
boundary_buf
[
3
];
s_block_begin
+=
s_begin
;
t_block_begin
+=
t_begin
;
}
// block_S and block_T are the actual sizes of this block, up to
// (BLOCK_S_SIZE, BLOCK_T_SIZE) but possibly truncated if we
// are towards the end of the sequence.
int
block_S
=
min
(
BLOCK_T_SIZE
,
s_end
-
s_block_begin
),
block_T
=
min
(
BLOCK_S_SIZE
,
t_end
-
t_block_begin
);
if
(
block_S
<=
0
||
block_T
<=
0
)
continue
;
// Load px_buf and py_buf. We exponentiate; the assumption is that they
// won't overflow or underflow! If they overflow we'll detect it later!
for
(
int
i
=
threadDim
.
x
;
i
<
BLOCK_S_SIZE
*
BLOCK_T_SIZE
;
i
+=
block_dimx
)
{
int
t
=
i
%
BLOCK_T_SIZE
,
s
=
i
/
BLOCK_T_SIZE
;
if
(
s
<
block_S
&&
t
<
block_T
)
{
px_buf
[
s
][
t
]
=
exp
(
px
[
b
][
s
+
s_block_begin
][
t
+
t_block_begin
]);
py_buf
[
s
][
t
]
=
exp
(
py
[
b
][
s
+
s_block_begin
][
t
+
t_block_begin
]);
}
else
{
// Not necessary? We'll see
px_buf
[
s
][
t
]
=
0.0
;
py_buf
[
s
][
t
]
=
0.0
;
}
}
// Load the 1st row and column of p_buf (except element[0][0] is not needed).
if
(
threadIdx
.
x
<
64
)
{
// 64 == warp size...
if
(
threadIdx
.
x
<=
BLOCK_S_SIZE
)
{
// this s and t are offsets relative to the block start
int
s
=
threadIdx
.
x
-
1
,
t
=
-
1
;
if
(
static_cast
<
unsigned
int
>
(
s
+
s_block_begin
)
<
static_cast
<
unsigned
int
>
(
block_S
)
&&
static_cast
<
unsigned
int
>
(
t
+
t_block_begin
)
<
static_cast
<
unsigned
int
>
(
block_T
))
p_buf
[
threadIdx
.
x
][
0
]
=
p
[
s
+
s_block_begin
][
s
+
t_block_begin
];
else
p_buf
[
threadIdx
.
x
][
0
]
=
-
infinity
;
}
}
else
{
if
(
threadIdx
.
x
-
64
<=
BLOCK_T_SIZE
)
{
int
i
=
threadIdx
.
x
-
64
,
t
=
i
-
1
,
s
=
-
1
;
if
(
static_cast
<
unsigned
int
>
(
s
+
s_block_begin
)
<
static_cast
<
unsigned
int
>
(
block_S
)
&&
static_cast
<
unsigned
int
>
(
t
+
t_block_begin
)
<
static_cast
<
unsigned
int
>
(
block_T
))
p_buf
[
0
][
i
]
=
p
[
s
+
s_block_begin
][
s
+
t_block_begin
];
else
{
p_buf
[
0
][
i
]
=
(
is_origin_block
&&
i
==
1
?
1.0
/
-
infinity
;
}
}
}
const
int
B
=
input
.
size
(
0
),
C
=
input
.
size
(
1
),
T
=
input
.
size
(
2
),
N
=
params
.
size
(
1
)
-
1
,
K
=
N
/
2
;
// Note: N and K are powers of 2, with K >= 1.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment