Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
17b18990
Commit
17b18990
authored
Jul 30, 2021
by
Daniel Povey
Browse files
More debugging progress, still not fully debugged/tested.
parent
c39d5fe7
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
92 additions
and
267 deletions
+92
-267
torch_mutual_information/mutual_information.py
torch_mutual_information/mutual_information.py
+1
-1
torch_mutual_information/mutual_information_cuda_kernel.cu
torch_mutual_information/mutual_information_cuda_kernel.cu
+1
-1
torch_mutual_information/mutual_information_test.py
torch_mutual_information/mutual_information_test.py
+90
-265
No files found.
torch_mutual_information/mutual_information.py
View file @
17b18990
...
@@ -69,7 +69,7 @@ def _mutual_information_backward_dispatcher(px: torch.Tensor, py: torch.Tensor,
...
@@ -69,7 +69,7 @@ def _mutual_information_backward_dispatcher(px: torch.Tensor, py: torch.Tensor,
if
overwrite_ans_grad
:
if
overwrite_ans_grad
:
if
not
torch
.
allclose
(
ans_grad
,
ans_grad_copy
,
rtol
=
1.0e-02
):
if
not
torch
.
allclose
(
ans_grad
,
ans_grad_copy
,
rtol
=
1.0e-02
):
print
(
f
"Warning: possible excsssive roundoff in mutual information backward "
print
(
f
"Warning: possible excsssive roundoff in mutual information backward "
"recursion: {ans_grad} vs. {ans_grad_copy}"
);
f
"recursion:
{
ans_grad
}
vs.
{
ans_grad_copy
}
"
);
return
ans
return
ans
else
:
else
:
return
tuple
(
torch_mutual_information_cpu
.
mutual_information_backward_cpu
(
return
tuple
(
torch_mutual_information_cpu
.
mutual_information_backward_cpu
(
...
...
torch_mutual_information/mutual_information_cuda_kernel.cu
View file @
17b18990
...
@@ -384,7 +384,7 @@ void mutual_information_kernel(
...
@@ -384,7 +384,7 @@ void mutual_information_kernel(
p
[
b
][
s
][
t
]
=
normalizer
+
log
(
this_p
);
p
[
b
][
s
][
t
]
=
normalizer
+
log
(
this_p
);
// If this_p is infinity or NaN..
// If this_p is infinity or NaN..
if
(
this_p
-
this_p
!=
0
)
{
if
(
this_p
-
this_p
!=
0
)
{
printf
(
"[panic] threadIdx.x = %d, this_p = %f
\n
"
,
(
int
)
threadIdx
.
x
,
(
float
)
this_p
);
//
printf("[panic] threadIdx.x = %d, this_p = %f\n", (int)threadIdx.x, (float)this_p);
p_buf
[
0
][
0
]
=
1.0
;
// This is a "panic" flag.
p_buf
[
0
][
0
]
=
1.0
;
// This is a "panic" flag.
}
}
}
}
...
...
torch_mutual_information/mutual_information_test.py
View file @
17b18990
...
@@ -15,28 +15,52 @@ def test_mutual_information_basic():
...
@@ -15,28 +15,52 @@ def test_mutual_information_basic():
random
.
randint
(
1
,
200
))
random
.
randint
(
1
,
200
))
random_px
=
(
random
.
random
()
<
0.1
)
random_px
=
(
random
.
random
()
<
0.1
)
random_py
=
(
random
.
random
()
<
0.1
)
random_py
=
(
random
.
random
()
<
0.1
)
random_boundary
=
(
random
.
random
()
<
0.1
)
big_px
=
(
random
.
random
()
<
0.1
)
big_py
=
(
random
.
random
()
<
0.1
)
print
(
f
"B, S, T =
{
B
}
,
{
S
}
,
{
T
}
, random_px=
{
random_px
}
, random_py=
{
random_py
}
"
)
print
(
f
"B, S, T =
{
B
}
,
{
S
}
,
{
T
}
, random_px=
{
random_px
}
, random_py=
{
random_py
}
, big_px=
{
big_px
}
, big_py=
{
big_py
}
, random_boundary=
{
random_boundary
}
"
)
for
dtype
in
[
torch
.
float32
,
torch
.
float64
]:
for
dtype
in
[
torch
.
float32
,
torch
.
float64
]:
px_grads
=
[]
px_grads
=
[]
py_grads
=
[]
py_grads
=
[]
m_vals
=
[]
m_vals
=
[]
for
device
in
[
torch
.
device
(
'cpu'
),
torch
.
device
(
'cuda:0'
)
]:
for
device
in
[
torch
.
device
(
'cpu'
),
torch
.
device
(
'cuda:0'
)
]:
print
(
"dtype = "
,
dtype
,
", device = "
,
device
)
print
(
"dtype = "
,
dtype
,
", device = "
,
device
)
B
=
2
S
=
14
T
=
14
if
random_boundary
:
boundary
=
torch
.
tensor
([
0
,
0
,
S
,
T
],
dtype
=
torch
.
int64
).
unsqueeze
(
0
).
expand
(
B
,
4
).
to
(
device
)
def
get_boundary_row
():
s_begin
=
random
.
randint
(
0
,
S
-
1
)
t_begin
=
random
.
randint
(
0
,
T
-
1
)
s_end
=
random
.
randint
(
s_begin
+
1
,
S
)
t_end
=
random
.
randint
(
t_begin
+
1
,
T
)
return
[
s_begin
,
t_begin
,
s_end
,
t_end
]
boundary
=
torch
.
tensor
([
get_boundary_row
()
for
_
in
range
(
B
)
],
dtype
=
torch
.
int64
,
device
=
device
)
else
:
# Use default boundary, but either specified directly or not.
if
random
.
random
()
<
0.5
:
boundary
=
torch
.
tensor
([
0
,
0
,
S
,
T
],
dtype
=
torch
.
int64
).
unsqueeze
(
0
).
expand
(
B
,
4
).
to
(
device
)
else
:
boundary
=
None
if
device
==
torch
.
device
(
'cpu'
):
if
device
==
torch
.
device
(
'cpu'
):
if
random_px
:
if
random_px
:
px
=
torch
.
randn
(
B
,
S
,
T
+
1
,
dtype
=
dtype
).
to
(
device
)
# log of an odds ratio
px
=
torch
.
randn
(
B
,
S
,
T
+
1
,
dtype
=
dtype
).
to
(
device
)
# log of an odds ratio
else
:
else
:
px
=
torch
.
zeros
(
B
,
S
,
T
+
1
,
dtype
=
dtype
).
to
(
device
)
# log of an odds ratio
px
=
torch
.
zeros
(
B
,
S
,
T
+
1
,
dtype
=
dtype
).
to
(
device
)
# log of an odds ratio
# px and py get exponentiated, and then multiplied together up to
# 32 times (BLOCK_SIZE in the CUDA code), so 15 is actually a big number that
# could lead to overflow.
if
big_px
:
px
+=
15.0
if
random_py
:
if
random_py
:
py
=
torch
.
randn
(
B
,
S
+
1
,
T
,
dtype
=
dtype
).
to
(
device
)
# log of an odds ratio
py
=
torch
.
randn
(
B
,
S
+
1
,
T
,
dtype
=
dtype
).
to
(
device
)
# log of an odds ratio
else
:
else
:
py
=
torch
.
zeros
(
B
,
S
+
1
,
T
,
dtype
=
dtype
).
to
(
device
)
# log of an odds ratio
py
=
torch
.
zeros
(
B
,
S
+
1
,
T
,
dtype
=
dtype
).
to
(
device
)
# log of an odds ratio
if
big_py
:
py
+=
15.0
else
:
else
:
px
=
px
.
to
(
device
).
detach
()
px
=
px
.
to
(
device
).
detach
()
py
=
py
.
to
(
device
).
detach
()
py
=
py
.
to
(
device
).
detach
()
...
@@ -66,278 +90,79 @@ def test_mutual_information_basic():
...
@@ -66,278 +90,79 @@ def test_mutual_information_basic():
def
test_mutual_information_deriv
():
""" Tests derivatives in randomized way """
for
_
in
range
(
10
):
for
dtype
in
[
torch
.
float32
,
torch
.
float64
]:
B
=
random
.
randrange
(
1
,
10
)
C
=
random
.
randrange
(
1
,
10
)
T
=
random
.
randrange
(
1
,
20
)
x
=
torch
.
randn
(
B
,
C
,
T
,
dtype
=
dtype
)
K
=
2
**
random
.
randrange
(
0
,
4
)
N
=
K
*
2
params
=
torch
.
randn
(
C
,
N
+
1
,
dtype
=
dtype
)
x
.
requires_grad
=
True
params
.
requires_grad
=
True
print
(
f
"B,C,T,K =
{
B
}
,
{
C
}
,
{
T
}
,
{
K
}
"
)
y
=
mutual_information
(
x
,
params
,
dim
=
1
)
y_deriv
=
torch
.
randn_like
(
y
)
y
.
backward
(
gradient
=
y_deriv
)
if
torch
.
cuda
.
is_available
():
# test that the CUDA forward is the same as the CPU forward.
device
=
torch
.
device
(
'cuda:0'
)
x2
,
params2
=
x
.
to
(
device
).
detach
(),
params
.
to
(
device
).
detach
()
x2
.
requires_grad
=
True
params2
.
requires_grad
=
True
y2
=
mutual_information
(
x2
,
params2
,
dim
=
1
)
if
N
>=
4
and
N
<=
16
:
# Currently backprop requires these conditions
y2
.
backward
(
gradient
=
y_deriv
.
to
(
device
))
x2grad
,
params2grad
=
x2
.
grad
.
to
(
'cpu'
),
params2
.
grad
.
to
(
'cpu'
)
y2
=
y2
.
to
(
'cpu'
)
print
(
"Checking CUDA is same"
)
if
not
torch
.
allclose
(
y
,
y2
,
atol
=
1.0e-05
):
print
(
f
"Error: CPU output versus CUDA not the same:
{
y
}
vs.
{
y2
}
, diff =
{
y2
-
y
}
, max-diff =
{
(
y2
-
y
).
abs
().
max
()
}
"
)
assert
(
0
)
if
N
>=
4
and
N
<=
16
:
# Currently backprop requires these conditions
if
not
torch
.
allclose
(
x
.
grad
,
x2grad
,
atol
=
1.0e-05
):
print
(
f
"Error: CPU x.grad versus CUDA not the same:
{
x
.
grad
}
vs.
{
x2grad
}
, diff =
{
x2grad
-
x
.
grad
}
, max-diff =
{
(
x2grad
-
x
.
grad
).
abs
().
max
()
}
"
)
assert
(
0
)
if
not
torch
.
allclose
(
params
.
grad
,
params2grad
,
atol
=
1.0e-05
):
print
(
f
"Error: CPU params.grad versus CUDA not the same:
{
params
.
grad
}
vs.
{
params2grad
}
, "
f
"diff =
{
params2grad
-
params
.
grad
}
, max-diff =
{
(
params2grad
-
params
.
grad
).
abs
().
max
()
}
"
)
assert
(
0
)
delta
=
1.0e-04
delta_x
=
torch
.
randn_like
(
x
)
*
delta
pred_change
=
(
x
.
grad
*
delta_x
).
sum
()
y2
=
mutual_information
(
x
+
delta_x
,
params
,
dim
=
1
)
observed_change
=
(
y_deriv
*
(
y2
-
y
)).
sum
()
print
(
f
"for input: pred_change =
{
pred_change
}
, observed_change=
{
observed_change
}
"
)
if
not
torch
.
allclose
(
pred_change
,
observed_change
,
rtol
=
5.0e-02
,
atol
=
3.0e-05
):
print
(
f
"For changed input, output differs too much: params=
{
params
}
, input=
{
x
}
, mod_input=
{
x
+
delta_x
}
, y=
{
y
}
, y2=
{
y2
}
, diff=
{
y2
-
y
}
"
)
assert
0
delta_params
=
torch
.
randn_like
(
params
)
*
delta
pred_change
=
(
params
.
grad
*
delta_params
).
sum
()
observed_change
=
(
y_deriv
*
(
mutual_information
(
x
,
params
+
delta_params
,
dim
=
1
)
-
y
)).
sum
()
print
(
f
"for params: pred_change =
{
pred_change
}
, observed_change=
{
observed_change
}
"
)
assert
torch
.
allclose
(
pred_change
,
observed_change
,
rtol
=
1.0e-02
,
atol
=
1.0e-05
)
def
test_mutual_information_zeros
():
N
=
1
C
=
2
H
=
3
W
=
4
for
device
in
[
torch
.
device
(
'cpu'
),
torch
.
device
(
'cuda:0'
)
]:
if
device
==
torch
.
device
(
'cuda:0'
)
and
not
torch
.
cuda
.
is_available
():
print
(
"Warning: torch not available, not testing this part."
)
continue
for
dtype
in
[
torch
.
float32
,
torch
.
float64
]:
print
(
"device="
,
device
,
", dtype="
,
dtype
)
input
=
torch
.
zeros
(
N
,
2
*
C
,
H
,
W
,
device
=
device
,
dtype
=
dtype
)
kH
=
5
kW
=
5
pos_add
=
torch
.
zeros
(
C
,
kH
,
kW
,
device
=
device
,
dtype
=
dtype
)
pos_mul
=
torch
.
ones
(
C
,
kH
,
kW
,
device
=
device
,
dtype
=
dtype
)
input
.
requires_grad
=
True
pos_add
.
requires_grad
=
True
pos_mul
.
requires_grad
=
True
output_ref
=
torch
.
zeros
(
N
,
C
,
H
,
W
,
device
=
device
,
dtype
=
dtype
)
output
=
mutual_information
(
input
,
pos_add
,
pos_mul
)
assert
torch
.
allclose
(
output
,
output_ref
)
output
.
sum
().
backward
()
print
(
"input_grad="
,
input
.
grad
)
print
(
"pos_add_grad="
,
pos_add
.
grad
)
print
(
"pos_mul_grad="
,
pos_mul
.
grad
)
def
test_mutual_information_compare
():
N
=
1
C
=
2
H
=
3
W
=
4
if
not
torch
.
cuda
.
is_available
():
print
(
"Warning: torch not available, not testing this part."
)
return
for
dtype
in
[
torch
.
float32
,
torch
.
float64
]:
print
(
"dtype="
,
dtype
)
input
=
torch
.
randn
(
N
,
2
*
C
,
H
,
W
,
dtype
=
dtype
)
device
=
torch
.
device
(
'cuda:0'
)
input_cuda
=
input
.
to
(
device
).
detach
()
kH
=
5
kW
=
5
pos_add
=
torch
.
randn
(
C
,
kH
,
kW
,
dtype
=
dtype
)
pos_mul
=
torch
.
randn
(
C
,
kH
,
kW
,
dtype
=
dtype
)
pos_add_cuda
=
pos_add
.
to
(
device
).
detach
()
pos_mul_cuda
=
pos_mul
.
to
(
device
).
detach
()
for
x
in
[
pos_add
,
pos_mul
,
pos_add_cuda
,
pos_mul_cuda
,
input
,
input_cuda
]:
x
.
requires_grad
=
True
output
=
mutual_information
(
input
,
pos_add
,
pos_mul
)
output_cuda
=
mutual_information
(
input_cuda
,
pos_add_cuda
,
pos_mul_cuda
)
print
(
"output = "
,
output
)
print
(
"output_cuda = "
,
output_cuda
)
output_grad
=
torch
.
randn
(
*
output
.
shape
,
dtype
=
dtype
)
output
.
backward
(
gradient
=
output_grad
)
output_cuda
.
backward
(
gradient
=
output_grad
.
to
(
device
))
diff
=
(
output
-
output_cuda
.
to
(
torch
.
device
(
'cpu'
))).
abs
().
sum
()
abs
=
output
.
abs
().
sum
()
print
(
"Diff = "
,
diff
,
", abs = "
,
abs
)
assert
torch
.
allclose
(
output
,
output_cuda
.
to
(
torch
.
device
(
'cpu'
)),
atol
=
1.0e-05
)
for
a
,
b
,
name
in
[
(
pos_add
,
pos_add_cuda
,
'pos_add'
),
(
pos_mul
,
pos_mul_cuda
,
'pos_mul'
),
(
input
,
input_cuda
,
'input'
)
]:
grad
=
a
.
grad
cuda_grad
=
b
.
grad
.
to
(
torch
.
device
(
'cpu'
))
diff_abs
=
(
grad
-
cuda_grad
).
abs
().
sum
().
item
()
sum_abs
=
(
grad
+
cuda_grad
).
abs
().
sum
().
item
()
print
(
f
"Comparing grad of
{
name
}
: diff=
{
diff_abs
}
, sum=
{
sum_abs
}
"
)
if
diff_abs
>
1.0e-05
*
sum_abs
:
print
(
f
"Error: too much difference in grad of
{
name
}
."
)
print
(
"grad = "
,
grad
)
print
(
"cuda_grad = "
,
cuda_grad
)
def
test_mutual_information_deriv
():
print
(
"Running test_mutual_information_basic()"
)
for
_iter
in
range
(
100
):
(
B
,
S
,
T
)
=
(
random
.
randint
(
1
,
10
),
random
.
randint
(
1
,
200
),
random
.
randint
(
1
,
200
))
random_px
=
(
random
.
random
()
<
0.1
)
random_py
=
(
random
.
random
()
<
0.1
)
big_px
=
(
random
.
random
()
<
0.1
)
big_py
=
(
random
.
random
()
<
0.1
)
def
test_mutual_information_rand_compare
():
print
(
f
"B, S, T =
{
B
}
,
{
S
}
,
{
T
}
, random_px=
{
random_px
}
, random_py=
{
random_py
}
, big_px=
{
big_px
}
, big_py=
{
big_py
}
"
)
for
_
in
range
(
30
):
N
=
random
.
randint
(
1
,
256
)
C
=
random
.
randint
(
1
,
64
)
H
=
random
.
randint
(
1
,
128
)
W
=
random
.
randint
(
1
,
128
)
while
N
*
C
*
H
*
W
>
65535
:
if
N
>=
C
and
N
>=
H
and
N
>=
W
:
N
=
N
//
2
elif
C
>=
H
and
C
>=
W
:
C
=
C
//
2
elif
H
>=
W
:
H
=
H
//
2
else
:
W
=
W
//
2
if
not
torch
.
cuda
.
is_available
():
print
(
"Warning: torch not available, not testing this part."
)
return
for
dtype
in
[
torch
.
float32
,
torch
.
float64
]:
for
dtype
in
[
torch
.
float32
,
torch
.
float64
]:
print
(
"dtype="
,
dtype
)
px_grads
=
[]
input
=
torch
.
randn
(
N
,
2
*
C
,
H
,
W
,
dtype
=
dtype
)
py_grads
=
[]
device
=
torch
.
device
(
'cuda:0'
)
m_vals
=
[]
input_cuda
=
input
.
to
(
device
)
for
device
in
[
torch
.
device
(
'cpu'
),
torch
.
device
(
'cuda:0'
)
]:
print
(
"dtype = "
,
dtype
,
", device = "
,
device
)
kH
=
random
.
randint
(
1
,
10
)
B
=
2
kW
=
random
.
randint
(
1
,
10
)
S
=
14
if
kH
%
2
==
0
:
T
=
14
kH
+=
1
boundary
=
torch
.
tensor
([
0
,
0
,
S
,
T
],
dtype
=
torch
.
int64
).
unsqueeze
(
0
).
expand
(
B
,
4
).
to
(
device
)
if
kW
%
2
==
0
:
kW
+=
1
pos_add
=
torch
.
randn
(
C
,
kH
,
kW
,
dtype
=
dtype
)
pos_mul
=
torch
.
randn
(
C
,
kH
,
kW
,
dtype
=
dtype
)
pos_add_cuda
=
pos_add
.
to
(
device
)
pos_mul_cuda
=
pos_mul
.
to
(
device
)
output
=
mutual_information
(
input
,
pos_add
,
pos_mul
)
output_cuda
=
mutual_information
(
input_cuda
,
pos_add_cuda
,
pos_mul_cuda
)
diff
=
(
output
-
output_cuda
.
to
(
torch
.
device
(
'cpu'
))).
abs
().
sum
()
sum_abs
=
output
.
abs
().
sum
()
print
(
"Diff = "
,
diff
,
", abs = "
,
sum_abs
)
if
(
diff
/
sum_abs
).
item
()
>
0.001
:
print
(
"output = "
,
output
)
print
(
"output_cuda = "
,
output_cuda
)
assert
0
,
"outputs differ"
def
test_mutual_information_rand_grad
():
for
_
in
range
(
30
):
N
=
random
.
randint
(
1
,
256
)
C
=
random
.
randint
(
1
,
64
)
H
=
random
.
randint
(
1
,
128
)
W
=
random
.
randint
(
1
,
128
)
while
N
*
C
*
H
*
W
>
65535
:
if
N
>=
C
and
N
>=
H
and
N
>=
W
:
N
=
N
//
2
elif
C
>=
H
and
C
>=
W
:
C
=
C
//
2
elif
H
>=
W
:
H
=
H
//
2
else
:
W
=
W
//
2
for
device
in
[
torch
.
device
(
'cpu'
),
torch
.
device
(
'cuda:0'
)
]:
if
device
==
torch
.
device
(
'cuda:0'
)
and
not
torch
.
cuda
.
is_available
():
print
(
"Warning: torch not available, not testing this part."
)
continue
for
dtype
in
[
torch
.
float32
,
torch
.
float64
]:
print
(
"dtype="
,
dtype
,
", device="
,
device
)
input
=
torch
.
randn
(
N
,
2
*
C
,
H
,
W
,
dtype
=
dtype
,
device
=
device
)
if
device
==
torch
.
device
(
'cpu'
):
if
random_px
:
px
=
torch
.
randn
(
B
,
S
,
T
+
1
,
dtype
=
dtype
).
to
(
device
)
# log of an odds ratio
else
:
px
=
torch
.
zeros
(
B
,
S
,
T
+
1
,
dtype
=
dtype
).
to
(
device
)
# log of an odds ratio
# px and py get exponentiated, and then multiplied together up to
# 32 times (BLOCK_SIZE in the CUDA code), so 15 is actually a big number that
# could lead to overflow.
if
big_px
:
px
+=
15.0
if
random_py
:
py
=
torch
.
randn
(
B
,
S
+
1
,
T
,
dtype
=
dtype
).
to
(
device
)
# log of an odds ratio
else
:
py
=
torch
.
zeros
(
B
,
S
+
1
,
T
,
dtype
=
dtype
).
to
(
device
)
# log of an odds ratio
if
big_py
:
py
+=
15.0
kH
=
random
.
randint
(
1
,
10
)
else
:
kW
=
random
.
randint
(
1
,
10
)
px
=
px
.
to
(
device
).
detach
()
if
kH
%
2
==
0
:
py
=
py
.
to
(
device
).
detach
()
kH
+=
1
px
.
requires_grad
=
True
if
kW
%
2
==
0
:
py
.
requires_grad
=
True
kW
+=
1
pos_add
=
torch
.
randn
(
C
,
kH
,
kW
,
dtype
=
dtype
,
device
=
device
)
pos_mul
=
torch
.
randn
(
C
,
kH
,
kW
,
dtype
=
dtype
,
device
=
device
)
input
.
requires_grad
=
True
pos_add
.
requires_grad
=
True
pos_mul
.
requires_grad
=
True
output
=
mutual_information
(
input
,
pos_add
,
pos_mul
)
#m
= mutual_information
_recursion(px, py, None
)
output_grad
=
torch
.
randn
(
N
,
C
,
H
,
W
,
dtype
=
dtype
,
device
=
device
)
m
=
mutual_information_recursion
(
px
,
py
,
boundary
)
output
.
backward
(
gradient
=
output_grad
)
#print("m = ", m, ", size = ", m.shape)
#print("exp(m) = ", m.exp())
(
m
.
sum
()
*
3
).
backward
()
#print("px_grad = ", px.grad)
#print("py_grad = ", py.grad)
px_grads
.
append
(
px
.
grad
.
to
(
'cpu'
))
py_grads
.
append
(
py
.
grad
.
to
(
'cpu'
))
m_vals
.
append
(
m
.
to
(
'cpu'
))
if
not
torch
.
allclose
(
m_vals
[
0
],
m_vals
[
1
],
atol
=
1.0e-05
,
rtol
=
1.0e-04
):
print
(
f
"m_vals differed CPU vs CUDA:
{
m_vals
[
0
]
}
vs.
{
m_vals
[
1
]
}
"
)
assert
0
if
not
torch
.
allclose
(
px_grads
[
0
],
px_grads
[
1
],
atol
=
1.0e-05
,
rtol
=
1.0e-04
):
print
(
f
"px_grads differed CPU vs CUDA:
{
px_grads
[
0
]
}
vs.
{
px_grads
[
1
]
}
"
)
assert
0
if
not
torch
.
allclose
(
py_grads
[
0
],
py_grads
[
1
],
atol
=
1.0e-05
,
rtol
=
1.0e-04
):
print
(
f
"py_grads differed CPU vs CUDA:
{
py_grads
[
0
]
}
vs.
{
py_grads
[
1
]
}
"
)
assert
0
delta
=
1.0e-05
pos_delta
=
delta
*
torch
.
randn
(
C
,
kH
,
kW
,
dtype
=
dtype
,
device
=
device
)
pred_change
=
(
pos_delta
*
pos_add
.
grad
).
sum
().
to
(
'cpu'
).
item
()
change
=
(
output_grad
*
(
mutual_information
(
input
,
pos_add
+
pos_delta
,
pos_mul
)
-
output
)).
sum
().
to
(
'cpu'
).
item
()
print
(
f
"For pos_add: pred_change=
{
pred_change
}
, change=
{
change
}
"
)
#assert abs(pred_change - change) < 1.0e-04
pred_change
=
(
pos_delta
*
pos_mul
.
grad
).
sum
().
to
(
'cpu'
).
item
()
change
=
(
output_grad
*
(
mutual_information
(
input
,
pos_add
,
pos_mul
+
pos_delta
)
-
output
)).
sum
().
to
(
'cpu'
).
item
()
print
(
f
"For pos_mul: pred_change=
{
pred_change
}
, change=
{
change
}
"
)
#assert abs(pred_change - change) / abs(change) < 1.0e-04
input_delta
=
delta
*
torch
.
randn
(
N
,
2
*
C
,
H
,
W
,
dtype
=
dtype
,
device
=
device
)
pred_change
=
(
input_delta
*
input
.
grad
).
sum
().
to
(
'cpu'
).
item
()
change
=
(
output_grad
*
(
mutual_information
(
input
+
input_delta
,
pos_add
,
pos_mul
)
-
output
)).
sum
().
to
(
'cpu'
).
item
()
print
(
f
"For input: pred_change=
{
pred_change
}
, change=
{
change
}
"
)
#assert abs(pred_change - change) / abs(change) < 1.0e-04
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_mutual_information_basic
()
test_mutual_information_basic
()
test_mutual_information_deriv
()
#test_mutual_information_deriv()
if
False
:
test_mutual_information_rand_grad
()
test_mutual_information_zeros
()
test_mutual_information_compare
()
test_mutual_information_rand_compare
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment