Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
1ad556dc
Commit
1ad556dc
authored
Jul 29, 2021
by
Daniel Povey
Browse files
Fix some bugs..
parent
52ae49ee
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
40 additions
and
81 deletions
+40
-81
torch_mutual_information/__init__.py
torch_mutual_information/__init__.py
+1
-1
torch_mutual_information/mutual_information.py
torch_mutual_information/mutual_information.py
+28
-21
torch_mutual_information/mutual_information_cuda_kernel.cu
torch_mutual_information/mutual_information_cuda_kernel.cu
+3
-3
torch_mutual_information/mutual_information_test.py
torch_mutual_information/mutual_information_test.py
+8
-56
No files found.
torch_mutual_information/__init__.py
View file @
1ad556dc
from
.mutual_information
import
mutual_information
from
.mutual_information
import
mutual_information
_recursion
torch_mutual_information/mutual_information.py
View file @
1ad556dc
import
os
import
os
import
torch
import
torch
from
typing
import
Tuple
from
torch
import
Tensor
from
typing
import
Tuple
,
Optional
from
torch.utils.cpp_extension
import
load
from
torch.utils.cpp_extension
import
load
VERBOSE
=
False
VERBOSE
=
False
...
@@ -44,18 +45,18 @@ except ImportError:
...
@@ -44,18 +45,18 @@ except ImportError:
def
_mutual_information_forward_dispatcher
(
px
:
torch
.
Tensor
,
py
:
torch
.
Tensor
,
def
_mutual_information_forward_dispatcher
(
px
:
torch
.
Tensor
,
py
:
torch
.
Tensor
,
boundar
ies
:
torch
.
Tensor
,
p
:
torch
.
Tensor
)
->
torch
.
Tensor
:
boundar
y
:
torch
.
Tensor
,
p
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
input
.
is_cuda
:
if
px
.
is_cuda
:
if
torch_mutual_information_cuda
is
None
:
if
torch_mutual_information_cuda
is
None
:
raise
EnvironmentError
(
f
'Failed to load native CUDA module'
)
raise
EnvironmentError
(
f
'Failed to load native CUDA module'
)
return
torch_mutual_information_cuda
.
mutual_information_cuda
(
return
torch_mutual_information_cuda
.
mutual_information_cuda
(
px
,
py
,
boundar
ies
,
p
)
px
,
py
,
boundar
y
,
p
)
else
:
else
:
return
torch_mutual_information_cpu
.
mutual_information_cpu
(
return
torch_mutual_information_cpu
.
mutual_information_cpu
(
px
,
py
,
boundar
ies
,
p
)
px
,
py
,
boundar
y
,
p
)
def
_mutual_information_backward_dispatcher
(
px
:
torch
.
Tensor
,
py
:
torch
.
Tensor
,
def
_mutual_information_backward_dispatcher
(
px
:
torch
.
Tensor
,
py
:
torch
.
Tensor
,
boundar
ies
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
boundar
y
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
ans_grad
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
ans_grad
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
px
.
is_cuda
:
if
px
.
is_cuda
:
if
torch_mutual_information_cuda
is
None
:
if
torch_mutual_information_cuda
is
None
:
...
@@ -64,7 +65,7 @@ def _mutual_information_backward_dispatcher(px: torch.Tensor, py: torch.Tensor,
...
@@ -64,7 +65,7 @@ def _mutual_information_backward_dispatcher(px: torch.Tensor, py: torch.Tensor,
if
overwrite_ans_grad
:
if
overwrite_ans_grad
:
ans_grad_copy
=
ans_grad
.
clone
()
ans_grad_copy
=
ans_grad
.
clone
()
ans
=
tuple
(
torch_mutual_information_cuda
.
mutual_information_backward_cuda
(
ans
=
tuple
(
torch_mutual_information_cuda
.
mutual_information_backward_cuda
(
px
,
py
,
boundar
ies
,
p
,
ans_grad_copy
,
overwrite_ans_grad
))
px
,
py
,
boundar
y
,
p
,
ans_grad_copy
,
overwrite_ans_grad
))
if
overwrite_ans_grad
:
if
overwrite_ans_grad
:
if
not
torch
.
allclose
(
ans_grad
,
ans_grad_copy
,
rtol
=
1.0e-02
):
if
not
torch
.
allclose
(
ans_grad
,
ans_grad_copy
,
rtol
=
1.0e-02
):
print
(
f
"Warning: possible excsssive roundoff in mutual information backward "
print
(
f
"Warning: possible excsssive roundoff in mutual information backward "
...
@@ -72,18 +73,20 @@ def _mutual_information_backward_dispatcher(px: torch.Tensor, py: torch.Tensor,
...
@@ -72,18 +73,20 @@ def _mutual_information_backward_dispatcher(px: torch.Tensor, py: torch.Tensor,
return
ans
return
ans
else
:
else
:
return
tuple
(
torch_mutual_information_cpu
.
mutual_information_backward_cpu
(
return
tuple
(
torch_mutual_information_cpu
.
mutual_information_backward_cpu
(
px
,
py
,
boundar
ies
,
p
,
ans_grad
))
px
,
py
,
boundar
y
,
p
,
ans_grad
))
class
MutualInformationRecursionFunction
(
torch
.
autograd
.
Function
):
class
MutualInformationRecursionFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
px
:
torch
.
Tensor
,
py
:
torch
.
Tensor
,
boundar
ies
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
def
forward
(
ctx
,
px
:
torch
.
Tensor
,
py
:
torch
.
Tensor
,
boundar
y
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
(
B
,
S
,
T1
)
=
px
.
shape
(
B
,
S
,
T1
)
=
px
.
shape
T
=
T1
-
1
;
T
=
T1
-
1
;
assert
py
.
shape
==
(
B
,
S
+
1
,
T
)
assert
py
.
shape
==
(
B
,
S
+
1
,
T
)
if
boundaries
is
not
None
:
if
boundary
is
not
None
:
assert
boundaries
.
shape
==
(
B
,
4
)
assert
boundary
.
shape
==
(
B
,
4
)
else
:
boundary
=
torch
.
zeros
(
0
,
0
,
dtype
=
torch
.
int64
,
device
=
px
.
device
)
# p is a tensor of shape (B, S + 1, T + 1) were p[s][t] is the
# p is a tensor of shape (B, S + 1, T + 1) were p[s][t] is the
...
@@ -101,20 +104,23 @@ class MutualInformationRecursionFunction(torch.autograd.Function):
...
@@ -101,20 +104,23 @@ class MutualInformationRecursionFunction(torch.autograd.Function):
p
=
torch
.
empty
(
B
,
S
+
1
,
T
+
1
,
device
=
px
.
device
,
dtype
=
px
.
dtype
)
p
=
torch
.
empty
(
B
,
S
+
1
,
T
+
1
,
device
=
px
.
device
,
dtype
=
px
.
dtype
)
ans
=
_mutual_information_forward_dispatcher
(
px
,
py
,
boundaries
,
p
)
ans
=
_mutual_information_forward_dispatcher
(
px
,
py
,
boundary
,
p
)
print
(
f
"p =
{
p
}
, boundary =
{
boundary
}
"
)
if
px
.
requires_grad
or
py
.
requires_grad
:
if
px
.
requires_grad
or
py
.
requires_grad
:
ctx
.
save_for_backward
(
px
,
py
,
boundaries
,
p
)
ctx
.
save_for_backward
(
px
,
py
,
boundary
,
p
)
return
ans
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
ans_grad
:
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
None
]:
def
backward
(
ctx
,
ans_grad
:
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
None
]:
(
px
,
py
,
boundar
ies
,
p
)
=
ctx
.
saved_tensors
(
px
,
py
,
boundar
y
,
p
)
=
ctx
.
saved_tensors
(
px_grad
,
py_grad
)
=
_mutual_information_backward_dispatcher
(
px
,
py
,
boundar
ies
,
p
,
ans_grad
)
(
px_grad
,
py_grad
)
=
_mutual_information_backward_dispatcher
(
px
,
py
,
boundar
y
,
p
,
ans_grad
)
return
(
px_grad
,
py_grad
,
None
)
return
(
px_grad
,
py_grad
,
None
)
def
mutual_information_recursion
(
input
,
px
,
py
,
boundar
ies
=
None
):
def
mutual_information_recursion
(
px
,
py
,
boundar
y
=
None
):
"""A recursion that is useful in computing mutual information between two sequences of
"""A recursion that is useful in computing mutual information between two sequences of
real vectors, but may be useful more generally in sequence-to-sequence tasks where
real vectors, but may be useful more generally in sequence-to-sequence tasks where
monotonic alignment between pairs of sequences is desired. The definitions of
monotonic alignment between pairs of sequences is desired. The definitions of
...
@@ -154,7 +160,7 @@ def mutual_information_recursion(input, px, py, boundaries=None):
...
@@ -154,7 +160,7 @@ def mutual_information_recursion(input, px, py, boundaries=None):
is that for optimization purposes we assume the last axis (the t axis)
is that for optimization purposes we assume the last axis (the t axis)
has stride of 1; this is true if px and py are contiguous.
has stride of 1; this is true if px and py are contiguous.
boundar
ies
: If supplied, a torch.LongTensor of shape [B][4], where each row contains
boundar
y
: If supplied, a torch.LongTensor of shape [B][4], where each row contains
[s_begin, t_begin, s_end, t_end]. If not supplied, the values
[s_begin, t_begin, s_end, t_end]. If not supplied, the values
[0, 0, S, T] will be assumed. These are the beginning and
[0, 0, S, T] will be assumed. These are the beginning and
one-past-the-last positions in the x and y sequences
one-past-the-last positions in the x and y sequences
...
@@ -182,8 +188,9 @@ def mutual_information_recursion(input, px, py, boundaries=None):
...
@@ -182,8 +188,9 @@ def mutual_information_recursion(input, px, py, boundaries=None):
assert
py
.
shape
==
(
B
,
S
+
1
,
T
)
assert
py
.
shape
==
(
B
,
S
+
1
,
T
)
assert
px
.
dtype
==
py
.
dtype
assert
px
.
dtype
==
py
.
dtype
(
B
,
S
,
T
)
=
px
.
shape
(
B
,
S
,
T
)
=
px
.
shape
if
boundaries
is
not
None
:
if
boundary
is
not
None
:
assert
boundaries
.
dtype
==
torch
.
LongTensor
assert
boundary
.
dtype
==
torch
.
LongTensor
assert
boundaries
.
shape
==
(
B
,
4
)
assert
boundary
.
shape
==
(
B
,
4
)
return
MutualInformationRecursion
.
apply
(
px
,
py
,
boundar
ies
)
return
MutualInformationRecursion
Function
.
apply
(
px
,
py
,
boundar
y
)
torch_mutual_information/mutual_information_cuda_kernel.cu
View file @
1ad556dc
...
@@ -519,7 +519,7 @@ void mutual_information_backward_kernel(
...
@@ -519,7 +519,7 @@ void mutual_information_backward_kernel(
// comments. We'll focus, in the comments, on differences from the forward
// comments. We'll focus, in the comments, on differences from the forward
// pass.
// pass.
const
int
num_s_blocks
=
S
/
BLOCK_SIZE
+
1
,
const
int
num_s_blocks
=
S
/
BLOCK_SIZE
+
1
,
num_t_blocks
=
T
/
BLOCK_SIZE
+
1
,
//
num_t_blocks = T / BLOCK_SIZE + 1,
num_blocks_this_iter
=
min
(
iter
+
1
,
num_s_blocks
);
num_blocks_this_iter
=
min
(
iter
+
1
,
num_s_blocks
);
...
@@ -668,7 +668,7 @@ void mutual_information_backward_kernel(
...
@@ -668,7 +668,7 @@ void mutual_information_backward_kernel(
s
=
s_in_block
+
s_block_begin
,
s
=
s_in_block
+
s_block_begin
,
t
=
t_in_block
+
t_block_begin
;
t
=
t_in_block
+
t_block_begin
;
p_buf
[
s_in_block
][
t_in_block
]
=
(
p_buf
[
s_in_block
][
t_in_block
]
=
(
s
<=
s_end
&&
t
<=
t_end
?
p_grad
[
s
][
t
]
:
0.0
);
s
<=
s_end
&&
t
<=
t_end
?
p_grad
[
b
][
s
][
t
]
:
0.0
);
}
else
if
(
static_cast
<
unsigned
int
>
((
int
)
threadIdx
.
x
-
64
)
<
}
else
if
(
static_cast
<
unsigned
int
>
((
int
)
threadIdx
.
x
-
64
)
<
static_cast
<
unsigned
int
>
(
block_T
))
{
static_cast
<
unsigned
int
>
(
block_T
))
{
// casting to unsigned before the comparison tests for both negative and
// casting to unsigned before the comparison tests for both negative and
...
@@ -678,7 +678,7 @@ void mutual_information_backward_kernel(
...
@@ -678,7 +678,7 @@ void mutual_information_backward_kernel(
s
=
s_in_block
+
s_block_begin
,
s
=
s_in_block
+
s_block_begin
,
t
=
t_in_block
+
t_block_begin
;
t
=
t_in_block
+
t_block_begin
;
p_buf
[
s_in_block
][
t_in_block
]
=
(
p_buf
[
s_in_block
][
t_in_block
]
=
(
s
<=
s_end
&&
t
<=
t_end
?
p_grad
[
s
][
t
]
:
0.0
);
s
<=
s_end
&&
t
<=
t_end
?
p_grad
[
b
][
s
][
t
]
:
0.0
);
}
}
// The highest-numbered value in p_buf that we need (corresponding,
// The highest-numbered value in p_buf that we need (corresponding,
...
...
torch_mutual_information/mutual_information_test.py
View file @
1ad556dc
...
@@ -3,71 +3,23 @@
...
@@ -3,71 +3,23 @@
import
random
import
random
import
torch
import
torch
from
torch_mutual_information
import
mutual_information
from
torch_mutual_information
import
mutual_information
_recursion
def
test_mutual_information_basic
():
def
test_mutual_information_basic
():
print
(
"Running test_mutual_information_basic()"
)
for
dtype
in
[
torch
.
float32
,
torch
.
float64
]:
for
dtype
in
[
torch
.
float32
,
torch
.
float64
]:
B
=
2
B
=
2
C
=
4
S
=
4
T
=
10
T
=
5
x
=
-
2.0
+
0.4
*
torch
.
arange
(
10
,
dtype
=
dtype
)
px
=
torch
.
zeros
(
B
,
S
,
T
+
1
)
# log of an odds ratio
x
=
x
.
reshape
(
1
,
1
,
10
).
repeat
(
B
,
C
,
1
)
py
=
torch
.
zeros
(
B
,
S
+
1
,
T
)
# log of an odds ratio
K
=
4
N
=
K
*
2
params
=
torch
.
arange
(
N
+
1
,
dtype
=
dtype
).
unsqueeze
(
0
)
+
torch
.
arange
(
C
,
dtype
=
dtype
).
unsqueeze
(
1
)
-
3
x
.
requires_grad
=
True
params
.
requires_grad
=
True
print
(
"x = "
,
x
)
print
(
"params = "
,
params
)
print
(
"x.shape = "
,
x
.
shape
)
y
=
mutual_information
(
x
,
params
,
dim
=
1
)
if
True
:
# Check
x2
=
x
.
reshape
(
B
,
C
,
5
,
2
)
assert
torch
.
allclose
(
mutual_information
(
x
,
params
,
dim
=
1
),
mutual_information
(
x2
,
params
,
dim
=
1
).
reshape
(
x
.
shape
))
x2
=
x
.
reshape
(
B
,
1
,
C
,
10
)
assert
torch
.
allclose
(
mutual_information
(
x
,
params
,
dim
=
1
),
mutual_information
(
x2
,
params
,
dim
=
2
).
reshape
(
x
.
shape
))
print
(
"y = "
,
y
)
y
.
sum
().
backward
()
if
torch
.
cuda
.
is_available
():
# test that the CUDA forward is the same as the CPU forward.
device
=
torch
.
device
(
'cuda:0'
)
x2
=
x
.
to
(
device
).
detach
()
x2
.
requires_grad
=
True
params2
=
params
.
to
(
device
).
detach
()
params2
.
requires_grad
=
True
y2
=
mutual_information
(
x2
,
params2
,
dim
=
1
).
to
(
torch
.
device
(
'cpu'
))
print
(
"Checking CUDA is same"
)
if
not
torch
.
allclose
(
y
,
y2
,
atol
=
1.0e-06
):
print
(
f
"Error: CPU versus CUDA not the same:
{
y
}
vs.
{
y2
}
, diff =
{
y2
-
y
}
"
)
assert
(
0
);
y2
.
sum
().
backward
()
if
not
torch
.
allclose
(
x
.
grad
,
x2
.
grad
.
to
(
'cpu'
),
atol
=
1.0e-06
):
print
(
f
"Error: CPU x-grad versus CUDA grad not the same:
{
x
.
grad
}
vs.
{
x2
.
grad
}
, diff =
{
x2
.
grad
.
to
(
'cpu'
)
-
x
.
grad
}
"
)
assert
(
0
);
if
not
torch
.
allclose
(
params
.
grad
,
params2
.
grad
.
to
(
'cpu'
),
atol
=
1.0e-06
):
print
(
f
"Error: CPU params-grad versus CUDA grad not the same:
{
params
.
grad
}
vs.
{
params2
.
grad
}
, diff =
{
params2
.
grad
.
to
(
'cpu'
)
-
params
.
grad
}
"
)
assert
(
0
);
m
=
mutual_information_recursion
(
px
,
py
)
print
(
"x.grad = "
,
x
.
grad
)
print
(
"m = "
,
m
)
print
(
"params.grad = "
,
params
.
grad
)
# Just eyeballing the above tgo make sure it looks reasonable.
def
test_mutual_information_deriv
():
def
test_mutual_information_deriv
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment