Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
1ad556dc
Commit
1ad556dc
authored
Jul 29, 2021
by
Daniel Povey
Browse files
Fix some bugs..
parent
52ae49ee
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
40 additions
and
81 deletions
+40
-81
torch_mutual_information/__init__.py
torch_mutual_information/__init__.py
+1
-1
torch_mutual_information/mutual_information.py
torch_mutual_information/mutual_information.py
+28
-21
torch_mutual_information/mutual_information_cuda_kernel.cu
torch_mutual_information/mutual_information_cuda_kernel.cu
+3
-3
torch_mutual_information/mutual_information_test.py
torch_mutual_information/mutual_information_test.py
+8
-56
No files found.
torch_mutual_information/__init__.py
View file @
1ad556dc
from
.mutual_information
import
mutual_information
from
.mutual_information
import
mutual_information
_recursion
torch_mutual_information/mutual_information.py
View file @
1ad556dc
import
os
import
torch
from
typing
import
Tuple
from
torch
import
Tensor
from
typing
import
Tuple
,
Optional
from
torch.utils.cpp_extension
import
load
VERBOSE
=
False
...
...
@@ -44,18 +45,18 @@ except ImportError:
def
_mutual_information_forward_dispatcher
(
px
:
torch
.
Tensor
,
py
:
torch
.
Tensor
,
boundar
ies
:
torch
.
Tensor
,
p
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
input
.
is_cuda
:
boundar
y
:
torch
.
Tensor
,
p
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
px
.
is_cuda
:
if
torch_mutual_information_cuda
is
None
:
raise
EnvironmentError
(
f
'Failed to load native CUDA module'
)
return
torch_mutual_information_cuda
.
mutual_information_cuda
(
px
,
py
,
boundar
ies
,
p
)
px
,
py
,
boundar
y
,
p
)
else
:
return
torch_mutual_information_cpu
.
mutual_information_cpu
(
px
,
py
,
boundar
ies
,
p
)
px
,
py
,
boundar
y
,
p
)
def
_mutual_information_backward_dispatcher
(
px
:
torch
.
Tensor
,
py
:
torch
.
Tensor
,
boundar
ies
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
boundar
y
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
ans_grad
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
px
.
is_cuda
:
if
torch_mutual_information_cuda
is
None
:
...
...
@@ -64,7 +65,7 @@ def _mutual_information_backward_dispatcher(px: torch.Tensor, py: torch.Tensor,
if
overwrite_ans_grad
:
ans_grad_copy
=
ans_grad
.
clone
()
ans
=
tuple
(
torch_mutual_information_cuda
.
mutual_information_backward_cuda
(
px
,
py
,
boundar
ies
,
p
,
ans_grad_copy
,
overwrite_ans_grad
))
px
,
py
,
boundar
y
,
p
,
ans_grad_copy
,
overwrite_ans_grad
))
if
overwrite_ans_grad
:
if
not
torch
.
allclose
(
ans_grad
,
ans_grad_copy
,
rtol
=
1.0e-02
):
print
(
f
"Warning: possible excsssive roundoff in mutual information backward "
...
...
@@ -72,18 +73,20 @@ def _mutual_information_backward_dispatcher(px: torch.Tensor, py: torch.Tensor,
return
ans
else
:
return
tuple
(
torch_mutual_information_cpu
.
mutual_information_backward_cpu
(
px
,
py
,
boundar
ies
,
p
,
ans_grad
))
px
,
py
,
boundar
y
,
p
,
ans_grad
))
class
MutualInformationRecursionFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
px
:
torch
.
Tensor
,
py
:
torch
.
Tensor
,
boundar
ies
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
def
forward
(
ctx
,
px
:
torch
.
Tensor
,
py
:
torch
.
Tensor
,
boundar
y
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
(
B
,
S
,
T1
)
=
px
.
shape
T
=
T1
-
1
;
assert
py
.
shape
==
(
B
,
S
+
1
,
T
)
if
boundaries
is
not
None
:
assert
boundaries
.
shape
==
(
B
,
4
)
if
boundary
is
not
None
:
assert
boundary
.
shape
==
(
B
,
4
)
else
:
boundary
=
torch
.
zeros
(
0
,
0
,
dtype
=
torch
.
int64
,
device
=
px
.
device
)
# p is a tensor of shape (B, S + 1, T + 1) were p[s][t] is the
...
...
@@ -101,20 +104,23 @@ class MutualInformationRecursionFunction(torch.autograd.Function):
p
=
torch
.
empty
(
B
,
S
+
1
,
T
+
1
,
device
=
px
.
device
,
dtype
=
px
.
dtype
)
ans
=
_mutual_information_forward_dispatcher
(
px
,
py
,
boundaries
,
p
)
ans
=
_mutual_information_forward_dispatcher
(
px
,
py
,
boundary
,
p
)
print
(
f
"p =
{
p
}
, boundary =
{
boundary
}
"
)
if
px
.
requires_grad
or
py
.
requires_grad
:
ctx
.
save_for_backward
(
px
,
py
,
boundaries
,
p
)
ctx
.
save_for_backward
(
px
,
py
,
boundary
,
p
)
return
ans
@
staticmethod
def
backward
(
ctx
,
ans_grad
:
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
None
]:
(
px
,
py
,
boundar
ies
,
p
)
=
ctx
.
saved_tensors
(
px_grad
,
py_grad
)
=
_mutual_information_backward_dispatcher
(
px
,
py
,
boundar
ies
,
p
,
ans_grad
)
(
px
,
py
,
boundar
y
,
p
)
=
ctx
.
saved_tensors
(
px_grad
,
py_grad
)
=
_mutual_information_backward_dispatcher
(
px
,
py
,
boundar
y
,
p
,
ans_grad
)
return
(
px_grad
,
py_grad
,
None
)
def
mutual_information_recursion
(
input
,
px
,
py
,
boundar
ies
=
None
):
def
mutual_information_recursion
(
px
,
py
,
boundar
y
=
None
):
"""A recursion that is useful in computing mutual information between two sequences of
real vectors, but may be useful more generally in sequence-to-sequence tasks where
monotonic alignment between pairs of sequences is desired. The definitions of
...
...
@@ -154,7 +160,7 @@ def mutual_information_recursion(input, px, py, boundaries=None):
is that for optimization purposes we assume the last axis (the t axis)
has stride of 1; this is true if px and py are contiguous.
boundar
ies
: If supplied, a torch.LongTensor of shape [B][4], where each row contains
boundar
y
: If supplied, a torch.LongTensor of shape [B][4], where each row contains
[s_begin, t_begin, s_end, t_end]. If not supplied, the values
[0, 0, S, T] will be assumed. These are the beginning and
one-past-the-last positions in the x and y sequences
...
...
@@ -182,8 +188,9 @@ def mutual_information_recursion(input, px, py, boundaries=None):
assert
py
.
shape
==
(
B
,
S
+
1
,
T
)
assert
px
.
dtype
==
py
.
dtype
(
B
,
S
,
T
)
=
px
.
shape
if
boundaries
is
not
None
:
assert
boundaries
.
dtype
==
torch
.
LongTensor
assert
boundaries
.
shape
==
(
B
,
4
)
if
boundary
is
not
None
:
assert
boundary
.
dtype
==
torch
.
LongTensor
assert
boundary
.
shape
==
(
B
,
4
)
return
MutualInformationRecursion
.
apply
(
px
,
py
,
boundar
ies
)
return
MutualInformationRecursion
Function
.
apply
(
px
,
py
,
boundar
y
)
torch_mutual_information/mutual_information_cuda_kernel.cu
View file @
1ad556dc
...
...
@@ -519,7 +519,7 @@ void mutual_information_backward_kernel(
// comments. We'll focus, in the comments, on differences from the forward
// pass.
const
int
num_s_blocks
=
S
/
BLOCK_SIZE
+
1
,
num_t_blocks
=
T
/
BLOCK_SIZE
+
1
,
//
num_t_blocks = T / BLOCK_SIZE + 1,
num_blocks_this_iter
=
min
(
iter
+
1
,
num_s_blocks
);
...
...
@@ -668,7 +668,7 @@ void mutual_information_backward_kernel(
s
=
s_in_block
+
s_block_begin
,
t
=
t_in_block
+
t_block_begin
;
p_buf
[
s_in_block
][
t_in_block
]
=
(
s
<=
s_end
&&
t
<=
t_end
?
p_grad
[
s
][
t
]
:
0.0
);
s
<=
s_end
&&
t
<=
t_end
?
p_grad
[
b
][
s
][
t
]
:
0.0
);
}
else
if
(
static_cast
<
unsigned
int
>
((
int
)
threadIdx
.
x
-
64
)
<
static_cast
<
unsigned
int
>
(
block_T
))
{
// casting to unsigned before the comparison tests for both negative and
...
...
@@ -678,7 +678,7 @@ void mutual_information_backward_kernel(
s
=
s_in_block
+
s_block_begin
,
t
=
t_in_block
+
t_block_begin
;
p_buf
[
s_in_block
][
t_in_block
]
=
(
s
<=
s_end
&&
t
<=
t_end
?
p_grad
[
s
][
t
]
:
0.0
);
s
<=
s_end
&&
t
<=
t_end
?
p_grad
[
b
][
s
][
t
]
:
0.0
);
}
// The highest-numbered value in p_buf that we need (corresponding,
...
...
torch_mutual_information/mutual_information_test.py
View file @
1ad556dc
...
...
@@ -3,72 +3,24 @@
import
random
import
torch
from
torch_mutual_information
import
mutual_information
from
torch_mutual_information
import
mutual_information
_recursion
def
test_mutual_information_basic
():
print
(
"Running test_mutual_information_basic()"
)
for
dtype
in
[
torch
.
float32
,
torch
.
float64
]:
B
=
2
C
=
4
T
=
10
x
=
-
2.0
+
0.4
*
torch
.
arange
(
10
,
dtype
=
dtype
)
x
=
x
.
reshape
(
1
,
1
,
10
).
repeat
(
B
,
C
,
1
)
S
=
4
T
=
5
p
x
=
torch
.
zeros
(
B
,
S
,
T
+
1
)
# log of an odds ratio
py
=
torch
.
zeros
(
B
,
S
+
1
,
T
)
# log of an odds ratio
K
=
4
N
=
K
*
2
params
=
torch
.
arange
(
N
+
1
,
dtype
=
dtype
).
unsqueeze
(
0
)
+
torch
.
arange
(
C
,
dtype
=
dtype
).
unsqueeze
(
1
)
-
3
x
.
requires_grad
=
True
params
.
requires_grad
=
True
print
(
"x = "
,
x
)
print
(
"params = "
,
params
)
print
(
"x.shape = "
,
x
.
shape
)
m
=
mutual_information_recursion
(
px
,
py
)
y
=
mutual_information
(
x
,
params
,
dim
=
1
)
print
(
"m = "
,
m
)
if
True
:
# Check
x2
=
x
.
reshape
(
B
,
C
,
5
,
2
)
assert
torch
.
allclose
(
mutual_information
(
x
,
params
,
dim
=
1
),
mutual_information
(
x2
,
params
,
dim
=
1
).
reshape
(
x
.
shape
))
x2
=
x
.
reshape
(
B
,
1
,
C
,
10
)
assert
torch
.
allclose
(
mutual_information
(
x
,
params
,
dim
=
1
),
mutual_information
(
x2
,
params
,
dim
=
2
).
reshape
(
x
.
shape
))
print
(
"y = "
,
y
)
y
.
sum
().
backward
()
if
torch
.
cuda
.
is_available
():
# test that the CUDA forward is the same as the CPU forward.
device
=
torch
.
device
(
'cuda:0'
)
x2
=
x
.
to
(
device
).
detach
()
x2
.
requires_grad
=
True
params2
=
params
.
to
(
device
).
detach
()
params2
.
requires_grad
=
True
y2
=
mutual_information
(
x2
,
params2
,
dim
=
1
).
to
(
torch
.
device
(
'cpu'
))
print
(
"Checking CUDA is same"
)
if
not
torch
.
allclose
(
y
,
y2
,
atol
=
1.0e-06
):
print
(
f
"Error: CPU versus CUDA not the same:
{
y
}
vs.
{
y2
}
, diff =
{
y2
-
y
}
"
)
assert
(
0
);
y2
.
sum
().
backward
()
if
not
torch
.
allclose
(
x
.
grad
,
x2
.
grad
.
to
(
'cpu'
),
atol
=
1.0e-06
):
print
(
f
"Error: CPU x-grad versus CUDA grad not the same:
{
x
.
grad
}
vs.
{
x2
.
grad
}
, diff =
{
x2
.
grad
.
to
(
'cpu'
)
-
x
.
grad
}
"
)
assert
(
0
);
if
not
torch
.
allclose
(
params
.
grad
,
params2
.
grad
.
to
(
'cpu'
),
atol
=
1.0e-06
):
print
(
f
"Error: CPU params-grad versus CUDA grad not the same:
{
params
.
grad
}
vs.
{
params2
.
grad
}
, diff =
{
params2
.
grad
.
to
(
'cpu'
)
-
params
.
grad
}
"
)
assert
(
0
);
print
(
"x.grad = "
,
x
.
grad
)
print
(
"params.grad = "
,
params
.
grad
)
# Just eyeballing the above tgo make sure it looks reasonable.
def
test_mutual_information_deriv
():
""" Tests derivatives in randomized way """
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment