Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
9ebcf9d5
Commit
9ebcf9d5
authored
Jul 30, 2021
by
Daniel Povey
Browse files
Fix more bugs..
parent
85c97136
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
30 additions
and
13 deletions
+30
-13
torch_mutual_information/mutual_information_cpu.cpp
torch_mutual_information/mutual_information_cpu.cpp
+8
-8
torch_mutual_information/mutual_information_cuda_kernel.cu
torch_mutual_information/mutual_information_cuda_kernel.cu
+3
-1
torch_mutual_information/mutual_information_test.py
torch_mutual_information/mutual_information_test.py
+19
-4
No files found.
torch_mutual_information/mutual_information_cpu.cpp
View file @
9ebcf9d5
...
@@ -192,11 +192,11 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
...
@@ -192,11 +192,11 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
for
(
int
s
=
s_end
;
s
>
s_begin
;
--
s
)
{
for
(
int
s
=
s_end
;
s
>
s_begin
;
--
s
)
{
for
(
int
t
=
t_end
;
t
>
t_begin
;
--
t
)
{
for
(
int
t
=
t_end
;
t
>
t_begin
;
--
t
)
{
// The s,t indexes correspond to
// The statement we are backpropagating here is:
// The statement we are backpropagating here is:
// p_a[b][s][t] = LogAdd(p_a[b][s - 1][t] + px_a[b][s - 1][t],
// p_a[b][s][t] = LogAdd(p_a[b][s - 1][t] + px_a[b][s - 1][t],
// p_a[b][s][t - 1] + py_a[b][s][t - 1]);
// p_a[b][s][t - 1] + py_a[b][s][t - 1]);
// .. which obtains p_a[b][s][t - 1] from a register.
// .. which obtains p_a[b][s][t - 1] from a register.
scalar_t
term1
=
p_a
[
b
][
s
-
1
][
t
]
+
px_a
[
b
][
s
-
1
][
t
],
scalar_t
term1
=
p_a
[
b
][
s
-
1
][
t
]
+
px_a
[
b
][
s
-
1
][
t
],
// term2 = p_a[b][s][t - 1] + py_a[b][s][t - 1], <-- not
// term2 = p_a[b][s][t - 1] + py_a[b][s][t - 1], <-- not
// actually needed..
// actually needed..
...
@@ -212,19 +212,19 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
...
@@ -212,19 +212,19 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
p_grad_a
[
b
][
s
][
t
-
1
]
+=
term2_grad
;
p_grad_a
[
b
][
s
][
t
-
1
]
+=
term2_grad
;
}
}
}
}
for
(
int
t
=
t_end
;
t
>
=
t_begin
;
--
t
)
{
for
(
int
t
=
t_end
;
t
>
t_begin
;
--
t
)
{
// Backprop for:
// Backprop for:
// p_a[b][s_begin][t] = p_a[b][s_begin][t - 1] + py_a[b][s_begin][t - 1];
// p_a[b][s_begin][t] = p_a[b][s_begin][t - 1] + py_a[b][s_begin][t - 1];
scalar_t
this_p_grad
=
p_grad_a
[
b
][
s_begin
][
t
];
scalar_t
this_p_grad
=
p_grad_a
[
b
][
s_begin
][
t
];
p_grad_a
[
b
][
s_begin
][
t
-
1
]
+=
this_p_grad
;
p_grad_a
[
b
][
s_begin
][
t
-
1
]
+=
this_p_grad
;
py_grad_a
[
b
][
s_begin
][
t
-
1
]
+
=
this_p_grad
;
py_grad_a
[
b
][
s_begin
][
t
-
1
]
=
this_p_grad
;
}
}
for
(
int
s
=
s_end
;
s
>
=
s_begin
;
--
s
)
{
for
(
int
s
=
s_end
;
s
>
s_begin
;
--
s
)
{
// Backprop for:
// Backprop for:
// p_a[b][s][t_begin] = p_a[b][s - 1][t_begin] + px_a[b][s - 1][t_begin];
// p_a[b][s][t_begin] = p_a[b][s - 1][t_begin] + px_a[b][s - 1][t_begin];
scalar_t
this_p_grad
=
p_grad_a
[
b
][
s
][
s
_begin
];
scalar_t
this_p_grad
=
p_grad_a
[
b
][
s
][
t
_begin
];
p_a
[
b
][
s
-
1
][
t_begin
]
+=
this_p_grad
;
p_
grad_
a
[
b
][
s
-
1
][
t_begin
]
+=
this_p_grad
;
px_a
[
b
][
s
-
1
][
t_begin
]
+
=
this_p_grad
;
px_
grad_
a
[
b
][
s
-
1
][
t_begin
]
=
this_p_grad
;
}
}
// There is no backprop for:
// There is no backprop for:
// p_a[b][s_begin][t_begin] = 0.0;
// p_a[b][s_begin][t_begin] = 0.0;
...
@@ -232,7 +232,7 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
...
@@ -232,7 +232,7 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
// of the sequence is equal to the grad at the end of the sequence.
// of the sequence is equal to the grad at the end of the sequence.
if
(
ans_grad_a
[
b
]
!=
0.0
)
{
if
(
ans_grad_a
[
b
]
!=
0.0
)
{
float
grad_ratio
=
p_a
[
b
][
s_begin
][
t_begin
]
/
ans_grad_a
[
b
];
float
grad_ratio
=
p_a
[
b
][
s_begin
][
t_begin
]
/
ans_grad_a
[
b
];
if
(
grad_ratio
-
1.0
>
0.01
)
{
if
(
fabs
(
grad_ratio
-
1.0
)
>
0.01
)
{
printf
(
"Warning: mutual_information backprop: expected these numbers to be the same: %f vs. %f
\n
"
,
printf
(
"Warning: mutual_information backprop: expected these numbers to be the same: %f vs. %f
\n
"
,
(
float
)
p_a
[
b
][
s_begin
][
t_begin
],
(
float
)
ans_grad_a
[
b
]);
(
float
)
p_a
[
b
][
s_begin
][
t_begin
],
(
float
)
ans_grad_a
[
b
]);
}
}
...
...
torch_mutual_information/mutual_information_cuda_kernel.cu
View file @
9ebcf9d5
...
@@ -750,6 +750,8 @@ void mutual_information_backward_kernel(
...
@@ -750,6 +750,8 @@ void mutual_information_backward_kernel(
}
}
}
}
__syncthreads
();
// Write out p_grad, px_grad and py_grad.
// Write out p_grad, px_grad and py_grad.
for
(
int
i
=
threadIdx
.
x
;
i
<
BLOCK_SIZE
*
BLOCK_SIZE
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
BLOCK_SIZE
*
BLOCK_SIZE
;
i
+=
blockDim
.
x
)
{
int
s_in_block
=
i
/
BLOCK_SIZE
,
int
s_in_block
=
i
/
BLOCK_SIZE
,
...
@@ -881,7 +883,7 @@ mutual_information_backward_cuda(torch::Tensor px,
...
@@ -881,7 +883,7 @@ mutual_information_backward_cuda(torch::Tensor px,
torch
::
Tensor
p_grad
=
torch
::
empty
({
B
,
S
+
1
,
T
+
1
},
opts
),
torch
::
Tensor
p_grad
=
torch
::
empty
({
B
,
S
+
1
,
T
+
1
},
opts
),
px_grad
=
(
has_boundary
?
torch
::
zeros
({
B
,
S
,
T
+
1
},
opts
)
:
px_grad
=
(
has_boundary
?
torch
::
zeros
({
B
,
S
,
T
+
1
},
opts
)
:
torch
::
empty
({
B
,
S
,
T
+
1
},
opts
)),
torch
::
empty
({
B
,
S
,
T
+
1
},
opts
)),
py_grad
=
(
has_boundary
?
torch
::
zeros
({
B
,
S
,
T
+
1
},
opts
)
:
py_grad
=
(
has_boundary
?
torch
::
zeros
({
B
,
S
+
1
,
T
},
opts
)
:
torch
::
empty
({
B
,
S
+
1
,
T
},
opts
));
torch
::
empty
({
B
,
S
+
1
,
T
},
opts
));
// num_threads and num_blocks and BLOCK_SIZE can be tuned.
// num_threads and num_blocks and BLOCK_SIZE can be tuned.
...
...
torch_mutual_information/mutual_information_test.py
View file @
9ebcf9d5
...
@@ -9,20 +9,35 @@ from torch_mutual_information import mutual_information_recursion
...
@@ -9,20 +9,35 @@ from torch_mutual_information import mutual_information_recursion
def
test_mutual_information_basic
():
def
test_mutual_information_basic
():
print
(
"Running test_mutual_information_basic()"
)
print
(
"Running test_mutual_information_basic()"
)
for
dtype
in
[
torch
.
float32
,
torch
.
float64
]:
for
dtype
in
[
torch
.
float32
,
torch
.
float64
]:
px_grads
=
[]
py_grads
=
[]
for
device
in
[
torch
.
device
(
'cpu'
),
torch
.
device
(
'cuda:0'
)
]:
for
device
in
[
torch
.
device
(
'cpu'
),
torch
.
device
(
'cuda:0'
)
]:
print
(
"dtype = "
,
dtype
,
", device = "
,
device
)
print
(
"dtype = "
,
dtype
,
", device = "
,
device
)
B
=
2
B
=
2
S
=
33
S
=
14
T
=
33
T
=
14
boundary
=
torch
.
tensor
([
0
,
0
,
S
,
T
],
dtype
=
torch
.
int64
).
unsqueeze
(
0
).
expand
(
B
,
4
).
to
(
device
)
boundary
=
torch
.
tensor
([
0
,
0
,
S
,
T
],
dtype
=
torch
.
int64
).
unsqueeze
(
0
).
expand
(
B
,
4
).
to
(
device
)
px
=
torch
.
zeros
(
B
,
S
,
T
+
1
,
dtype
=
dtype
).
to
(
device
)
# log of an odds ratio
px
=
torch
.
zeros
(
B
,
S
,
T
+
1
,
dtype
=
dtype
).
to
(
device
)
# log of an odds ratio
py
=
torch
.
zeros
(
B
,
S
+
1
,
T
,
dtype
=
dtype
).
to
(
device
)
# log of an odds ratio
py
=
torch
.
zeros
(
B
,
S
+
1
,
T
,
dtype
=
dtype
).
to
(
device
)
# log of an odds ratio
px
.
requires_grad
=
True
py
.
requires_grad
=
True
m
=
mutual_information_recursion
(
px
,
py
,
None
)
#
m = mutual_information_recursion(px, py, None)
#
m = mutual_information_recursion(px, py, boundary)
m
=
mutual_information_recursion
(
px
,
py
,
boundary
)
print
(
"m = "
,
m
,
", size = "
,
m
.
shape
)
print
(
"m = "
,
m
,
", size = "
,
m
.
shape
)
print
(
"exp(m) = "
,
m
.
exp
())
print
(
"exp(m) = "
,
m
.
exp
())
(
m
.
sum
()
*
3
).
backward
()
print
(
"px_grad = "
,
px
.
grad
)
print
(
"py_grad = "
,
py
.
grad
)
px_grads
.
append
(
px
.
grad
.
to
(
'cpu'
))
py_grads
.
append
(
py
.
grad
.
to
(
'cpu'
))
if
not
torch
.
allclose
(
px_grads
[
0
],
px_grads
[
1
]):
print
(
f
"px_grads differed CPU vs CUDA:
{
px_grads
[
0
]
}
vs.
{
px_grads
[
1
]
}
"
)
assert
0
if
not
torch
.
allclose
(
py_grads
[
0
],
py_grads
[
1
]):
print
(
f
"py_grads differed CPU vs CUDA:
{
py_grads
[
0
]
}
vs.
{
py_grads
[
1
]
}
"
)
assert
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment