Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
b3c1340d
Commit
b3c1340d
authored
Jul 02, 2021
by
Daniel Povey
Browse files
Test backward; now seems to work.
parent
86e3a617
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
135 additions
and
33 deletions
+135
-33
torch_integrated_conv/integrated_conv_cpu.cpp
torch_integrated_conv/integrated_conv_cpu.cpp
+4
-4
torch_integrated_conv/integrated_conv_cuda_kernel.cu
torch_integrated_conv/integrated_conv_cuda_kernel.cu
+31
-21
torch_integrated_conv/integrated_conv_test.py
torch_integrated_conv/integrated_conv_test.py
+100
-8
No files found.
torch_integrated_conv/integrated_conv_cpu.cpp
View file @
b3c1340d
...
@@ -58,7 +58,7 @@ torch::Tensor integrated_conv_cpu(torch::Tensor input,
...
@@ -58,7 +58,7 @@ torch::Tensor integrated_conv_cpu(torch::Tensor input,
static_cast
<
unsigned
int
>
(
src_w
)
<
static_cast
<
unsigned
int
>
(
W
))
static_cast
<
unsigned
int
>
(
src_w
)
<
static_cast
<
unsigned
int
>
(
W
))
src
=
src_input_a
[
src_h
][
src_w
];
src
=
src_input_a
[
src_h
][
src_w
];
scalar_t
relu
=
src
+
dest
+
this_pos_add_a
[
kh
][
kw
];
scalar_t
relu
=
src
+
dest
+
this_pos_add_a
[
kh
][
kw
];
if
(
relu
>
0.0
)
if
(
relu
>
=
0.0
)
sum
+=
relu
*
this_pos_mul_a
[
kh
][
kw
];
sum
+=
relu
*
this_pos_mul_a
[
kh
][
kw
];
}
}
}
}
...
@@ -127,7 +127,7 @@ std::vector<torch::Tensor> integrated_conv_backward_cpu(torch::Tensor input,
...
@@ -127,7 +127,7 @@ std::vector<torch::Tensor> integrated_conv_backward_cpu(torch::Tensor input,
for
(
int
h
=
0
;
h
<
H
;
h
++
)
{
for
(
int
h
=
0
;
h
<
H
;
h
++
)
{
for
(
int
w
=
0
;
w
<
W
;
w
++
)
{
for
(
int
w
=
0
;
w
<
W
;
w
++
)
{
scalar_t
dest
=
input_a
[
n
][
c
+
C
][
h
][
w
],
scalar_t
dest
=
input_a
[
n
][
c
+
C
][
h
][
w
],
dest_grad
=
0.0
,
// to be multiplied by this_output
_grad
later..
dest_grad
=
0.0
,
// to be multiplied by this_
grad_
output later..
this_grad_output
=
grad_output_a
[
n
][
c
][
h
][
w
];
this_grad_output
=
grad_output_a
[
n
][
c
][
h
][
w
];
for
(
int
kh
=
0
;
kh
<
kH
;
kh
++
)
{
for
(
int
kh
=
0
;
kh
<
kH
;
kh
++
)
{
int
src_h
=
h
+
kh
-
kH
/
2
;
int
src_h
=
h
+
kh
-
kH
/
2
;
...
@@ -140,7 +140,7 @@ std::vector<torch::Tensor> integrated_conv_backward_cpu(torch::Tensor input,
...
@@ -140,7 +140,7 @@ std::vector<torch::Tensor> integrated_conv_backward_cpu(torch::Tensor input,
scalar_t
relu
=
src
+
dest
+
pos_add_a
[
c
][
kh
][
kw
];
scalar_t
relu
=
src
+
dest
+
pos_add_a
[
c
][
kh
][
kw
];
if
(
relu
>=
0.0
)
{
if
(
relu
>=
0.0
)
{
scalar_t
pos_mul_val
=
pos_mul_a
[
c
][
kh
][
kw
];
scalar_t
pos_mul_val
=
pos_mul_a
[
c
][
kh
][
kw
];
dest_grad
+=
pos_mul_val
;
// will later multiply by this_output
_grad
dest_grad
+=
pos_mul_val
;
// will later multiply by this_
grad_
output
grad_pos_add_a
[
c
][
kh
][
kw
]
+=
this_grad_output
*
pos_mul_val
;
grad_pos_add_a
[
c
][
kh
][
kw
]
+=
this_grad_output
*
pos_mul_val
;
grad_pos_mul_a
[
c
][
kh
][
kw
]
+=
this_grad_output
*
relu
;
grad_pos_mul_a
[
c
][
kh
][
kw
]
+=
this_grad_output
*
relu
;
if
(
static_cast
<
unsigned
int
>
(
src_h
)
<
static_cast
<
unsigned
int
>
(
H
)
&&
if
(
static_cast
<
unsigned
int
>
(
src_h
)
<
static_cast
<
unsigned
int
>
(
H
)
&&
...
@@ -149,7 +149,7 @@ std::vector<torch::Tensor> integrated_conv_backward_cpu(torch::Tensor input,
...
@@ -149,7 +149,7 @@ std::vector<torch::Tensor> integrated_conv_backward_cpu(torch::Tensor input,
}
}
}
}
}
}
grad_input_a
[
n
][
c
+
C
][
h
][
w
]
+
=
dest_grad
*
this_grad_output
;
grad_input_a
[
n
][
c
+
C
][
h
][
w
]
=
dest_grad
*
this_grad_output
;
}
}
}
}
}
}
...
...
torch_integrated_conv/integrated_conv_cuda_kernel.cu
View file @
b3c1340d
...
@@ -348,7 +348,7 @@ void integrated_conv_kernel_backward(
...
@@ -348,7 +348,7 @@ void integrated_conv_kernel_backward(
// where the 'h' and 'w' indexes are into the zero-padded input
// where the 'h' and 'w' indexes are into the zero-padded input
// image.
// image.
*
dest_img_buf
=
src_img_buf
+
ppatch_size
,
// version of input image that relates to destinatioon position
*
dest_img_buf
=
src_img_buf
+
ppatch_size
,
// version of input image that relates to destinatioon position
*
grad_output_buf
=
src
_img_buf
+
ppatch_size
,
// output gradient for padded patch, indexed [h*ppatchW + w]
*
grad_output_buf
=
dest
_img_buf
+
ppatch_size
,
// output gradient for padded patch, indexed [h*ppatchW + w]
*
grad_pos_add_buf
=
grad_output_buf
+
ppatch_size
,
// total grad for pos_add for this thread block, indexed [kh*kW + kw]
*
grad_pos_add_buf
=
grad_output_buf
+
ppatch_size
,
// total grad for pos_add for this thread block, indexed [kh*kW + kw]
*
grad_pos_mul_buf
=
grad_pos_add_buf
+
(
kH
*
kW
),
// total grad for pos_mul for this thread block, indexed [kh*kW + kw]
*
grad_pos_mul_buf
=
grad_pos_add_buf
+
(
kH
*
kW
),
// total grad for pos_mul for this thread block, indexed [kh*kW + kw]
*
reduce_buf
=
grad_pos_mul_buf
+
(
kH
*
kW
);
// buffer for reduction over threads, size == blockDim.x
*
reduce_buf
=
grad_pos_mul_buf
+
(
kH
*
kW
);
// buffer for reduction over threads, size == blockDim.x
...
@@ -360,13 +360,17 @@ void integrated_conv_kernel_backward(
...
@@ -360,13 +360,17 @@ void integrated_conv_kernel_backward(
// Load parts of the kernel parameters pos_add and pos_mul into shared memory,
// Load parts of the kernel parameters pos_add and pos_mul into shared memory,
// in pos_add_buf and pos_mul_buf; zero the corresponding gradient buffers.
// in pos_add_buf and pos_mul_buf; zero the corresponding gradient buffers.
// We know that blockDim.x >= kH * kW, see threads_per_kernel_pos.
// We know that blockDim.x >= kH * kW, see threads_per_kernel_pos.
if
(
threadIdx
.
x
<
kH
*
kW
)
{
int
i
=
threadIdx
.
x
;
for
(
int
i
=
threadIdx
.
x
%
(
blockDim
.
x
/
2
);
i
<
kH
*
kW
;
i
+=
(
blockDim
.
x
/
2
))
{
int
kh
=
i
/
kW
,
kw
=
i
%
kW
;
int
kh
=
i
/
kW
,
kw
=
i
%
kW
;
pos_add_buf
[
i
]
=
pos_add
[
c
][
kh
][
kw
];
if
(
threadIdx
.
x
<
blockDim
.
x
/
2
)
{
// First half of threads take care of pos_add..
pos_mul_buf
[
i
]
=
pos_mul
[
c
][
kh
][
kw
];
pos_add_buf
[
i
]
=
pos_add
[
c
][
kh
][
kw
];
grad_pos_add_buf
[
i
]
=
0.0
;
grad_pos_add_buf
[
i
]
=
0.0
;
grad_pos_mul_buf
[
i
]
=
0.0
;
}
else
{
// Second half take care of pos_mul... there is no warp divergence
// because we make sure blockDim.x is a multiple of 64.
pos_mul_buf
[
i
]
=
pos_mul
[
c
][
kh
][
kw
];
grad_pos_mul_buf
[
i
]
=
0.0
;
}
}
}
// n is the index within the batch of images. Loop to make sure we cover all
// n is the index within the batch of images. Loop to make sure we cover all
...
@@ -391,7 +395,8 @@ void integrated_conv_kernel_backward(
...
@@ -391,7 +395,8 @@ void integrated_conv_kernel_backward(
// Load the 'src' and 'dest' versions of the padded patch into
// Load the 'src' and 'dest' versions of the padded patch into
// shared-memory buffers, and also the output gradient.
// shared-memory buffers, and also the output gradient.
for
(
int
i
=
threadIdx
.
x
%
(
blockDim
.
x
/
2
);
i
<
ppatch_size
;
i
+=
(
blockDim
.
x
/
2
))
{
for
(
int
i
=
threadIdx
.
x
%
(
blockDim
.
x
/
2
);
i
<
ppatch_size
;
i
+=
(
blockDim
.
x
/
2
))
{
int
h_in_ppatch
=
i
/
ppatchW
,
int
h_in_ppatch
=
i
/
ppatchW
,
w_in_ppatch
=
i
%
ppatchW
;
w_in_ppatch
=
i
%
ppatchW
;
int
h
=
patch_h_offset
+
h_in_ppatch
-
(
kH
/
2
),
// kH / 2 is offset due to padding
int
h
=
patch_h_offset
+
h_in_ppatch
-
(
kH
/
2
),
// kH / 2 is offset due to padding
...
@@ -401,7 +406,7 @@ void integrated_conv_kernel_backward(
...
@@ -401,7 +406,7 @@ void integrated_conv_kernel_backward(
// load `input`
// load `input`
scalar_t
src_val
=
scalar_t
(
0
),
scalar_t
src_val
=
scalar_t
(
0
),
dest_val
=
scalar_t
(
0
);
dest_val
=
scalar_t
(
0
);
if
((
unsigned
int
)
h
<
(
unsigned
int
)
H
&&
// h >= 0 && h < H
.
if
((
unsigned
int
)
h
<
(
unsigned
int
)
H
&&
// h >= 0 && h < H
(
unsigned
int
)
w
<
(
unsigned
int
)
W
)
{
// w >= 0 && w < W
(
unsigned
int
)
w
<
(
unsigned
int
)
W
)
{
// w >= 0 && w < W
int
C
=
grad_output
.
size
(
1
);
int
C
=
grad_output
.
size
(
1
);
src_val
=
input
[
n
][
c
][
h
][
w
];
src_val
=
input
[
n
][
c
][
h
][
w
];
...
@@ -429,7 +434,7 @@ void integrated_conv_kernel_backward(
...
@@ -429,7 +434,7 @@ void integrated_conv_kernel_backward(
grad_input_dest_sum
=
0.0
;
// grad for channel c + C, for our pixel
grad_input_dest_sum
=
0.0
;
// grad for channel c + C, for our pixel
// of `input` (contribution of this thread)
// of `input` (contribution of this thread)
if
(
pos_in_patch
<
patch_size
)
{
if
(
pos_in_patch
<
patch_size
)
{
// This block computes `grad_input_sum`
.
// This block computes `grad_input_s
rc_sum` and `grad_input_dest_s
um`
// The num-threads for the backward kernel may not be an exact multiple
// The num-threads for the backward kernel may not be an exact multiple
// of patch_size, wo we need the if-guard.
// of patch_size, wo we need the if-guard.
...
@@ -456,7 +461,6 @@ void integrated_conv_kernel_backward(
...
@@ -456,7 +461,6 @@ void integrated_conv_kernel_backward(
// This is actually more like cross-correlation, as we don't have a
// This is actually more like cross-correlation, as we don't have a
// negative sign on the h and w indexes in the kernel.
// negative sign on the h and w indexes in the kernel.
int
src_h_in_ppatch
=
h_in_patch
+
h_in_kernel
,
int
src_h_in_ppatch
=
h_in_patch
+
h_in_kernel
,
src_w_in_ppatch
=
w_in_patch
+
w_in_kernel
;
src_w_in_ppatch
=
w_in_patch
+
w_in_kernel
;
int
src_pos_in_ppatch
=
src_h_in_ppatch
*
ppatchW
+
src_w_in_ppatch
;
int
src_pos_in_ppatch
=
src_h_in_ppatch
*
ppatchW
+
src_w_in_ppatch
;
...
@@ -469,9 +473,11 @@ void integrated_conv_kernel_backward(
...
@@ -469,9 +473,11 @@ void integrated_conv_kernel_backward(
scalar_t
this_grad_output
=
grad_output_buf
[
pos_in_ppatch
];
scalar_t
this_grad_output
=
grad_output_buf
[
pos_in_ppatch
];
grad_input_dest_sum
+=
this_grad_output
*
pos_mul_val
;
grad_input_dest_sum
+=
this_grad_output
*
pos_mul_val
;
}
}
// To compute a contribution to "this_input_src_grad", we need to consider the
// To compute a contribution to "this_input_src_grad", we need to
// contribution to the destination pixel that it would have contributed to
// consider the contribution to the destination pixel that it would
// with this same offset.
// have contributed to with this same offset.
// We have to flip the offsets: instead of "+ h_in_kernel",
// we use (kH - 1) - h_in_kernel,.
int
dest_h_in_ppatch
=
h_in_patch
+
(
kH
-
1
)
-
h_in_kernel
,
int
dest_h_in_ppatch
=
h_in_patch
+
(
kH
-
1
)
-
h_in_kernel
,
dest_w_in_ppatch
=
w_in_patch
+
(
kW
-
1
)
-
w_in_kernel
,
dest_w_in_ppatch
=
w_in_patch
+
(
kW
-
1
)
-
w_in_kernel
,
dest_pos_in_ppatch
=
dest_h_in_ppatch
*
ppatchW
+
dest_w_in_ppatch
;
dest_pos_in_ppatch
=
dest_h_in_ppatch
*
ppatchW
+
dest_w_in_ppatch
;
...
@@ -485,6 +491,7 @@ void integrated_conv_kernel_backward(
...
@@ -485,6 +491,7 @@ void integrated_conv_kernel_backward(
}
}
// Aggregate `grad_input_src_sum` over threads, if needed; and write the
// Aggregate `grad_input_src_sum` over threads, if needed; and write the
// result to `grad_input`.
// result to `grad_input`.
// h and w are un-padded indexes into the entire image.
int
h
=
patch_h_offset
+
pos_in_patch
/
patchW
,
int
h
=
patch_h_offset
+
pos_in_patch
/
patchW
,
w
=
patch_w_offset
+
pos_in_patch
%
patchW
;
w
=
patch_w_offset
+
pos_in_patch
%
patchW
;
...
@@ -514,7 +521,8 @@ void integrated_conv_kernel_backward(
...
@@ -514,7 +521,8 @@ void integrated_conv_kernel_backward(
kw
=
pos_in_kernel
%
kW
;
kw
=
pos_in_kernel
%
kW
;
// This group of (threads_per_kernel_pos) threads is responsible
// This group of (threads_per_kernel_pos) threads is responsible
// for position (kh, kw) in the kernel; we iterate over the patch.
// for position (kh, kw) in the kernel; we iterate over the patch
// (an un-padded patch of output).
scalar_t
pos_add_val
=
pos_add_buf
[
pos_in_kernel
],
scalar_t
pos_add_val
=
pos_add_buf
[
pos_in_kernel
],
pos_mul_val
=
pos_mul_buf
[
pos_in_kernel
];
pos_mul_val
=
pos_mul_buf
[
pos_in_kernel
];
...
@@ -524,15 +532,15 @@ void integrated_conv_kernel_backward(
...
@@ -524,15 +532,15 @@ void integrated_conv_kernel_backward(
// and pos_mul; we let `pos_in_patch` correspond to the *output*
// and pos_mul; we let `pos_in_patch` correspond to the *output*
// position, and work out the input position based on gthe kernel position.
// position, and work out the input position based on gthe kernel position.
int
h_in_patch
=
pos_in_patch
/
patch
H
,
int
h_in_patch
=
pos_in_patch
/
patch
W
,
w_in_patch
=
pos_in_patch
/
patchW
;
w_in_patch
=
pos_in_patch
%
patchW
;
// pos_in_ppatch is the position in the padded patch corresponding to
// pos_in_ppatch is the position in the padded patch corresponding to
// `pos_in_patch`.
// `pos_in_patch`.
int
pos_in_ppatch
=
(
h_in_patch
+
kH
/
2
)
*
ppatchW
+
(
w_in_patch
+
kW
/
2
);
int
pos_in_ppatch
=
(
h_in_patch
+
kH
/
2
)
*
ppatchW
+
(
w_in_patch
+
kW
/
2
);
scalar_t
dest_val
=
dest_img_buf
[
pos_in_ppatch
];
scalar_t
dest_val
=
dest_img_buf
[
pos_in_ppatch
];
int
offset
_pos_in_ppatch
=
(
h_in_patch
+
kh
)
*
ppatchW
+
(
w_in_patch
+
kw
);
int
src
_pos_in_ppatch
=
(
h_in_patch
+
kh
)
*
ppatchW
+
(
w_in_patch
+
kw
);
scalar_t
src_val
=
src_img_buf
[
offset
_pos_in_ppatch
];
scalar_t
src_val
=
src_img_buf
[
src
_pos_in_ppatch
];
scalar_t
relu
=
dest_val
+
src_val
+
pos_add_val
;
scalar_t
relu
=
dest_val
+
src_val
+
pos_add_val
;
if
(
relu
>=
0.0
)
{
if
(
relu
>=
0.0
)
{
...
@@ -546,13 +554,15 @@ void integrated_conv_kernel_backward(
...
@@ -546,13 +554,15 @@ void integrated_conv_kernel_backward(
this_grad_pos_mul
=
tiled_warp_reduce_sum
(
this_grad_pos_mul
=
tiled_warp_reduce_sum
(
threads_per_kernel_pos
,
reduce_buf
,
this_grad_pos_mul
);
threads_per_kernel_pos
,
reduce_buf
,
this_grad_pos_mul
);
if
(
threadIdx
.
x
%
threads_per_kernel_pos
==
0
)
{
if
(
threadIdx
.
x
%
threads_per_kernel_pos
==
0
)
{
grad_pos_add_buf
[
pos_in_kernel
]
=
this_grad_pos_add
;
grad_pos_add_buf
[
pos_in_kernel
]
+
=
this_grad_pos_add
;
grad_pos_mul_buf
[
pos_in_kernel
]
=
this_grad_pos_mul
;
grad_pos_mul_buf
[
pos_in_kernel
]
+
=
this_grad_pos_mul
;
}
}
}
}
}
}
}
}
__syncthreads
();
// make sure all threads have written to grad_pos_add_buf and
// grad_pos_mul_buf.
int
block
=
blockIdx
.
z
*
gridDim
.
y
+
blockIdx
.
y
;
int
block
=
blockIdx
.
z
*
gridDim
.
y
+
blockIdx
.
y
;
int
kernel_pos
=
threadIdx
.
x
;
int
kernel_pos
=
threadIdx
.
x
;
...
...
torch_integrated_conv/integrated_conv_test.py
View file @
b3c1340d
...
@@ -18,7 +18,7 @@ def test_integrated_conv_zeros():
...
@@ -18,7 +18,7 @@ def test_integrated_conv_zeros():
kH
=
5
kH
=
5
kW
=
5
kW
=
5
pos_add
=
torch
.
zeros
(
C
,
kH
,
kW
,
device
=
device
,
dtype
=
dtype
)
pos_add
=
torch
.
zeros
(
C
,
kH
,
kW
,
device
=
device
,
dtype
=
dtype
)
pos_mul
=
torch
.
zero
s
(
C
,
kH
,
kW
,
device
=
device
,
dtype
=
dtype
)
pos_mul
=
torch
.
one
s
(
C
,
kH
,
kW
,
device
=
device
,
dtype
=
dtype
)
input
.
requires_grad
=
True
input
.
requires_grad
=
True
pos_add
.
requires_grad
=
True
pos_add
.
requires_grad
=
True
pos_mul
.
requires_grad
=
True
pos_mul
.
requires_grad
=
True
...
@@ -45,20 +45,28 @@ def test_integrated_conv_compare():
...
@@ -45,20 +45,28 @@ def test_integrated_conv_compare():
print
(
"dtype="
,
dtype
)
print
(
"dtype="
,
dtype
)
input
=
torch
.
randn
(
N
,
2
*
C
,
H
,
W
,
dtype
=
dtype
)
input
=
torch
.
randn
(
N
,
2
*
C
,
H
,
W
,
dtype
=
dtype
)
device
=
torch
.
device
(
'cuda:0'
)
device
=
torch
.
device
(
'cuda:0'
)
input_cuda
=
input
.
to
(
device
)
input_cuda
=
input
.
to
(
device
)
.
detach
()
kH
=
5
kH
=
5
kW
=
5
kW
=
5
pos_add
=
torch
.
randn
(
C
,
kH
,
kW
,
dtype
=
dtype
)
pos_add
=
torch
.
randn
(
C
,
kH
,
kW
,
dtype
=
dtype
)
pos_mul
=
torch
.
randn
(
C
,
kH
,
kW
,
dtype
=
dtype
)
pos_mul
=
torch
.
randn
(
C
,
kH
,
kW
,
dtype
=
dtype
)
pos_add_cuda
=
pos_add
.
to
(
device
)
pos_add_cuda
=
pos_add
.
to
(
device
).
detach
()
pos_mul_cuda
=
pos_mul
.
to
(
device
)
pos_mul_cuda
=
pos_mul
.
to
(
device
).
detach
()
for
x
in
[
pos_add
,
pos_mul
,
pos_add_cuda
,
pos_mul_cuda
,
input
,
input_cuda
]:
x
.
requires_grad
=
True
output
=
integrated_conv
(
input
,
pos_add
,
pos_mul
)
output
=
integrated_conv
(
input
,
pos_add
,
pos_mul
)
output_cuda
=
integrated_conv
(
input_cuda
,
pos_add_cuda
,
pos_mul_cuda
)
output_cuda
=
integrated_conv
(
input_cuda
,
pos_add_cuda
,
pos_mul_cuda
)
print
(
"output = "
,
output
)
print
(
"output = "
,
output
)
print
(
"output_cuda = "
,
output_cuda
)
print
(
"output_cuda = "
,
output_cuda
)
output_grad
=
torch
.
randn
(
*
output
.
shape
,
dtype
=
dtype
)
output
.
backward
(
gradient
=
output_grad
)
output_cuda
.
backward
(
gradient
=
output_grad
.
to
(
device
))
diff
=
(
output
-
output_cuda
.
to
(
torch
.
device
(
'cpu'
))).
abs
().
sum
()
diff
=
(
output
-
output_cuda
.
to
(
torch
.
device
(
'cpu'
))).
abs
().
sum
()
abs
=
output
.
abs
().
sum
()
abs
=
output
.
abs
().
sum
()
print
(
"Diff = "
,
diff
,
", abs = "
,
abs
)
print
(
"Diff = "
,
diff
,
", abs = "
,
abs
)
...
@@ -66,6 +74,21 @@ def test_integrated_conv_compare():
...
@@ -66,6 +74,21 @@ def test_integrated_conv_compare():
atol
=
1.0e-05
)
atol
=
1.0e-05
)
for
a
,
b
,
name
in
[
(
pos_add
,
pos_add_cuda
,
'pos_add'
),
(
pos_mul
,
pos_mul_cuda
,
'pos_mul'
),
(
input
,
input_cuda
,
'input'
)
]:
grad
=
a
.
grad
cuda_grad
=
b
.
grad
.
to
(
torch
.
device
(
'cpu'
))
diff_abs
=
(
grad
-
cuda_grad
).
abs
().
sum
().
item
()
sum_abs
=
(
grad
+
cuda_grad
).
abs
().
sum
().
item
()
print
(
f
"Comparing grad of
{
name
}
: diff=
{
diff_abs
}
, sum=
{
sum_abs
}
"
)
if
diff_abs
>
1.0e-05
*
sum_abs
:
print
(
f
"Error: too much difference in grad of
{
name
}
."
)
print
(
"grad = "
,
grad
)
print
(
"cuda_grad = "
,
cuda_grad
)
def
test_integrated_conv_rand_compare
():
def
test_integrated_conv_rand_compare
():
for
_
in
range
(
30
):
for
_
in
range
(
30
):
N
=
random
.
randint
(
1
,
256
)
N
=
random
.
randint
(
1
,
256
)
...
@@ -108,11 +131,80 @@ def test_integrated_conv_rand_compare():
...
@@ -108,11 +131,80 @@ def test_integrated_conv_rand_compare():
output_cuda
=
integrated_conv
(
input_cuda
,
pos_add_cuda
,
pos_mul_cuda
)
output_cuda
=
integrated_conv
(
input_cuda
,
pos_add_cuda
,
pos_mul_cuda
)
diff
=
(
output
-
output_cuda
.
to
(
torch
.
device
(
'cpu'
))).
abs
().
sum
()
diff
=
(
output
-
output_cuda
.
to
(
torch
.
device
(
'cpu'
))).
abs
().
sum
()
abs
=
output
.
abs
().
sum
()
sum_
abs
=
output
.
abs
().
sum
()
print
(
"Diff = "
,
diff
,
", abs = "
,
abs
)
print
(
"Diff = "
,
diff
,
", abs = "
,
sum_
abs
)
if
not
torch
.
allclose
(
output
,
output_cuda
.
to
(
torch
.
device
(
'cpu'
)),
if
(
diff
/
sum_abs
).
item
()
>
0.001
:
atol
=
1.0e-05
):
print
(
"output = "
,
output
)
print
(
"output = "
,
output
)
print
(
"output_cuda = "
,
output_cuda
)
print
(
"output_cuda = "
,
output_cuda
)
assert
0
,
"outputs differ"
assert
0
,
"outputs differ"
def
test_integrated_conv_rand_grad
():
for
_
in
range
(
30
):
N
=
random
.
randint
(
1
,
256
)
C
=
random
.
randint
(
1
,
64
)
H
=
random
.
randint
(
1
,
128
)
W
=
random
.
randint
(
1
,
128
)
while
N
*
C
*
H
*
W
>
65535
:
if
N
>=
C
and
N
>=
H
and
N
>=
W
:
N
=
N
//
2
elif
C
>=
H
and
C
>=
W
:
C
=
C
//
2
elif
H
>=
W
:
H
=
H
//
2
else
:
W
=
W
//
2
for
device
in
[
torch
.
device
(
'cpu'
),
torch
.
device
(
'cuda:0'
)
]:
if
device
==
torch
.
device
(
'cuda:0'
)
and
not
torch
.
cuda
.
is_available
():
print
(
"Warning: torch not available, not testing this part."
)
continue
for
dtype
in
[
torch
.
float32
,
torch
.
float64
]:
print
(
"dtype="
,
dtype
,
", device="
,
device
)
input
=
torch
.
randn
(
N
,
2
*
C
,
H
,
W
,
dtype
=
dtype
,
device
=
device
)
kH
=
random
.
randint
(
1
,
10
)
kW
=
random
.
randint
(
1
,
10
)
if
kH
%
2
==
0
:
kH
+=
1
if
kW
%
2
==
0
:
kW
+=
1
pos_add
=
torch
.
randn
(
C
,
kH
,
kW
,
dtype
=
dtype
,
device
=
device
)
pos_mul
=
torch
.
randn
(
C
,
kH
,
kW
,
dtype
=
dtype
,
device
=
device
)
input
.
requires_grad
=
True
pos_add
.
requires_grad
=
True
pos_mul
.
requires_grad
=
True
output
=
integrated_conv
(
input
,
pos_add
,
pos_mul
)
output_grad
=
torch
.
randn
(
N
,
C
,
H
,
W
,
dtype
=
dtype
,
device
=
device
)
output
.
backward
(
gradient
=
output_grad
)
delta
=
1.0e-05
pos_delta
=
delta
*
torch
.
randn
(
C
,
kH
,
kW
,
dtype
=
dtype
,
device
=
device
)
pred_change
=
(
pos_delta
*
pos_add
.
grad
).
sum
().
to
(
'cpu'
).
item
()
change
=
(
output_grad
*
(
integrated_conv
(
input
,
pos_add
+
pos_delta
,
pos_mul
)
-
output
)).
sum
().
to
(
'cpu'
).
item
()
print
(
f
"For pos_add: pred_change=
{
pred_change
}
, change=
{
change
}
"
)
#assert abs(pred_change - change) < 1.0e-04
pred_change
=
(
pos_delta
*
pos_mul
.
grad
).
sum
().
to
(
'cpu'
).
item
()
change
=
(
output_grad
*
(
integrated_conv
(
input
,
pos_add
,
pos_mul
+
pos_delta
)
-
output
)).
sum
().
to
(
'cpu'
).
item
()
print
(
f
"For pos_mul: pred_change=
{
pred_change
}
, change=
{
change
}
"
)
#assert abs(pred_change - change) / abs(change) < 1.0e-04
input_delta
=
delta
*
torch
.
randn
(
N
,
2
*
C
,
H
,
W
,
dtype
=
dtype
,
device
=
device
)
pred_change
=
(
input_delta
*
input
.
grad
).
sum
().
to
(
'cpu'
).
item
()
change
=
(
output_grad
*
(
integrated_conv
(
input
+
input_delta
,
pos_add
,
pos_mul
)
-
output
)).
sum
().
to
(
'cpu'
).
item
()
print
(
f
"For input: pred_change=
{
pred_change
}
, change=
{
change
}
"
)
#assert abs(pred_change - change) / abs(change) < 1.0e-04
if
__name__
==
"__main__"
:
test_integrated_conv_rand_grad
()
test_integrated_conv_zeros
()
test_integrated_conv_compare
()
test_integrated_conv_rand_compare
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment