Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
5fc62fa6
Commit
5fc62fa6
authored
Jun 30, 2021
by
Daniel Povey
Browse files
Get forward tests to work
parent
10381dab
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
26 additions
and
13 deletions
+26
-13
torch_integrated_conv/integrated_conv_cpu.cpp
torch_integrated_conv/integrated_conv_cpu.cpp
+1
-1
torch_integrated_conv/integrated_conv_cuda_kernel.cu
torch_integrated_conv/integrated_conv_cuda_kernel.cu
+3
-2
torch_integrated_conv/integrated_conv_test.py
torch_integrated_conv/integrated_conv_test.py
+22
-10
No files found.
torch_integrated_conv/integrated_conv_cpu.cpp
View file @
5fc62fa6
...
...
@@ -37,7 +37,7 @@ torch::Tensor integrated_conv_cpu(torch::Tensor input,
auto
input_a
=
input
.
accessor
<
scalar_t
,
4
>
(),
output_a
=
output
.
accessor
<
scalar_t
,
4
>
();
auto
pos_add_a
=
pos_add
.
accessor
<
scalar_t
,
3
>
(),
pos_mul_a
=
pos_
add
.
accessor
<
scalar_t
,
3
>
();
pos_mul_a
=
pos_
mul
.
accessor
<
scalar_t
,
3
>
();
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
c
=
0
;
c
<
C
;
c
++
)
{
...
...
torch_integrated_conv/integrated_conv_cuda_kernel.cu
View file @
5fc62fa6
...
...
@@ -294,9 +294,10 @@ torch::Tensor integrated_conv_cuda(torch::Tensor input,
<<
"; patchH,patchW="
<<
patchH
<<
","
<<
patchW
<<
", num_blocks_patch="
<<
num_blocks_patch
<<
", num_blocks_batch="
<<
num_blocks_batch
<<
std
::
endl
<<
num_blocks_batch
<<
", threads_per_opixel="
<<
threads_per_opixel
<<
", threads_per_block="
<<
threads_per_block
;
<<
", threads_per_block="
<<
threads_per_block
<<
std
::
endl
;
dim3
gridDim
(
C
,
num_blocks_patch
,
num_blocks_batch
);
// blockDim is scalar, just threads_per_block.
...
...
torch_integrated_conv/integrated_conv_test.py
View file @
5fc62fa6
...
...
@@ -35,14 +35,15 @@ def test_integrated_conv_compare():
return
for
dtype
in
[
torch
.
float32
,
torch
.
float64
]:
print
(
"dtype="
,
dtype
)
input
=
torch
.
ones
(
N
,
2
*
C
,
H
,
W
,
dtype
=
dtype
)
input
=
torch
.
randn
(
N
,
2
*
C
,
H
,
W
,
dtype
=
dtype
)
device
=
torch
.
device
(
'cuda:0'
)
input_cuda
=
input
.
to
(
device
)
kH
=
5
kW
=
5
pos_add
=
torch
.
ones
(
C
,
kH
,
kW
,
dtype
=
dtype
)
pos_mul
=
torch
.
ones
(
C
,
kH
,
kW
,
dtype
=
dtype
)
pos_add
=
torch
.
randn
(
C
,
kH
,
kW
,
dtype
=
dtype
)
pos_mul
=
torch
.
randn
(
C
,
kH
,
kW
,
dtype
=
dtype
)
pos_add_cuda
=
pos_add
.
to
(
device
)
pos_mul_cuda
=
pos_mul
.
to
(
device
)
...
...
@@ -50,7 +51,11 @@ def test_integrated_conv_compare():
output_cuda
=
integrated_conv
(
input_cuda
,
pos_add_cuda
,
pos_mul_cuda
)
print
(
"output = "
,
output
)
print
(
"output_cuda = "
,
output_cuda
)
assert
torch
.
allclose
(
output
,
output_cuda
.
to
(
torch
.
device
(
'cpu'
)))
diff
=
(
output
-
output_cuda
.
to
(
torch
.
device
(
'cpu'
))).
abs
().
sum
()
abs
=
output
.
abs
().
sum
()
print
(
"Diff = "
,
diff
,
", abs = "
,
abs
)
assert
torch
.
allclose
(
output
,
output_cuda
.
to
(
torch
.
device
(
'cpu'
)),
atol
=
1.0e-05
)
def
test_integrated_conv_rand_compare
():
...
...
@@ -76,7 +81,7 @@ def test_integrated_conv_rand_compare():
return
for
dtype
in
[
torch
.
float32
,
torch
.
float64
]:
print
(
"dtype="
,
dtype
)
input
=
torch
.
ones
(
N
,
2
*
C
,
H
,
W
,
dtype
=
dtype
)
input
=
torch
.
randn
(
N
,
2
*
C
,
H
,
W
,
dtype
=
dtype
)
device
=
torch
.
device
(
'cuda:0'
)
input_cuda
=
input
.
to
(
device
)
...
...
@@ -86,13 +91,20 @@ def test_integrated_conv_rand_compare():
kH
+=
1
if
kW
%
2
==
0
:
kW
+=
1
pos_add
=
torch
.
ones
(
C
,
kH
,
kW
,
dtype
=
dtype
)
pos_mul
=
torch
.
ones
(
C
,
kH
,
kW
,
dtype
=
dtype
)
pos_add
=
torch
.
randn
(
C
,
kH
,
kW
,
dtype
=
dtype
)
pos_mul
=
torch
.
randn
(
C
,
kH
,
kW
,
dtype
=
dtype
)
pos_add_cuda
=
pos_add
.
to
(
device
)
pos_mul_cuda
=
pos_mul
.
to
(
device
)
output
=
integrated_conv
(
input
,
pos_add
,
pos_mul
)
output_cuda
=
integrated_conv
(
input_cuda
,
pos_add_cuda
,
pos_mul_cuda
)
diff
=
(
output
-
output_cuda
.
to
(
torch
.
device
(
'cpu'
))).
abs
().
sum
()
abs
=
output
.
abs
().
sum
()
print
(
"Diff = "
,
diff
,
", abs = "
,
abs
)
if
not
torch
.
allclose
(
output
,
output_cuda
.
to
(
torch
.
device
(
'cpu'
)),
atol
=
1.0e-05
):
print
(
"output = "
,
output
)
print
(
"output_cuda = "
,
output_cuda
)
assert
torch
.
allclose
(
output
,
output
_cuda
.
to
(
torch
.
device
(
'cpu'
)))
assert
0
,
"
output
s differ"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment